In [1]:
"""
Generate ProtT5 (XL-UniRef50) sequence embeddings for PIDs that appear under
PFP_Testing/data/PDBCH <split>/*, reading sequences from
<split>/<PID>/sequence.txt and writing the embeddings to
<split>/<PID>/prot_t5_emb.pt.

Output per PID:
    <split>/<PID>/prot_t5_emb.pt   # torch.Tensor, shape (1, L+2, DIM)
                                   # CLS/EOS slots are zero-filled

Notes
-----
- We *only* embed PIDs that exist under <split> (splits: train_pdbch/val_pdbch/test_pdbch).
- Tokenization follows the official ProtT5 recipe: space-separated residues.
- Ambiguous residues U,Z,O,B are mapped to 'X' (per official ProtT5 preprocessing).
- If L <= WINDOW_EFF, run a single forward pass (fast path).
- If L > WINDOW_EFF, use sliding windows with overlap and **context-trim**:
  in each window keep only the non-overlap core (no weighting). Only indices
  0 and L+1 are zero.
- Existing files are skipped silently; progress bar shows overall progress.
- Preserves original embedding logic; only I/O and PID enumeration changed.
"""
from __future__ import annotations

from pathlib import Path
from typing import Iterable, Set, List, Tuple
import re
import math

import torch
from tqdm.auto import tqdm
from transformers import AutoTokenizer, T5EncoderModel

# ─────────────────────────────── config ─────────────────────────────── #
BASE          = Path("/teamspace/studios/this_studio/PFP_Testing/data/PDBCH")  # CHANGED
SPLITS        = ["train_pdbch", "val_pdbch", "test_pdbch"]                     # CHANGED
OUT_FILENAME  = "prot_t5_emb.pt"           # (1, L+2, DIM) with zero CLS/EOS rows
MODEL_NAME    = "Rostlab/prot_t5_xl_uniref50"

DEVICE        = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DTYPE         = torch.float16
DIM           = 1024                       # ProtT5 embedding dimension

# Sliding-window hyperparams for ProtT5 (T5 has EOS but no BOS; we keep your (1,L+2,*) interface)
MAX_LEN_TOKENS = 512     # effective model token limit (incl. EOS)
N_SPECIALS_WIN = 1        # EOS only inside each window
WINDOW_EFF     = MAX_LEN_TOKENS - N_SPECIALS_WIN  # residues per window
OVERLAP        = 128      # residues of overlap between adjacent windows (64–128 is a good default)
# ────────────────────────────────────────────────────────────────────── #


def iter_pid_dirs(base: Path, splits: Iterable[str]) -> Iterable[Path]:        # CHANGED
    """Yield each PID directory under <base>/<split>/* for all splits."""
    for split in splits:
        split_dir = base / split
        if not split_dir.exists():
            tqdm.write(f"[WARN] Missing folder: {split_dir} — skipping this split")
            continue
        for item in split_dir.iterdir():
            if item.is_dir() and not item.name.startswith("."):
                yield item


def read_sequence(pid_dir: Path) -> str:                                       # CHANGED
    seq_path = pid_dir / "sequence.txt"                                        # CHANGED
    if not seq_path.exists():
        return ""
    seq_parts = []
    with seq_path.open("r", encoding="utf-8") as fh:
        for line in fh:
            if line.startswith(">"):
                continue
            seq_parts.append(line.strip())
    # Build raw sequence (uppercase), then map ambiguous residues per official ProtT5 recipe
    seq = ("".join(seq_parts)).upper()
    seq = re.sub(r"[UZOB]", "X", seq)
    if not seq:
        return ""
    return seq


@torch.inference_mode()
def _embed_window(model: T5EncoderModel, tok: AutoTokenizer, subseq: str) -> torch.Tensor:
    """Embed a *subsequence* with ProtT5, returning (Li, DIM) residue embeddings.

    - Space-separate residues.
    - add_special_tokens=True to append EOS.
    - Drop the final EOS row to keep only residues.
    """
    spaced = " ".join(subseq)
    enc = tok(spaced, return_tensors="pt", add_special_tokens=True)
    enc = {k: v.to(DEVICE) for k, v in enc.items()}

    out = model(**enc).last_hidden_state[0]  # [T, DIM], where T = Li + 1 (EOS)
    # Drop EOS (last row)
    kept = out[:-1]  # [Li, DIM]
    if kept.size(0) != len(subseq):
        raise RuntimeError(
            f"[ProtT5] Token/AA mismatch in window: tokens_no_eos={kept.size(0)} vs Li={len(subseq)}"
        )
    return kept.to(DTYPE)


def _plan_windows(L: int) -> List[Tuple[int, int]]:
    """Return inclusive-exclusive windows [(s,e), ...] covering [0, L) with overlap."""
    if L <= WINDOW_EFF:
        return [(0, L)]
    stride = max(1, WINDOW_EFF - OVERLAP)
    windows: List[Tuple[int, int]] = []
    s = 0
    while s < L:
        e = min(s + WINDOW_EFF, L)
        windows.append((s, e))
        if e == L:
            break
        s += stride
    return windows


@torch.inference_mode()
def embed_prott5_stitched(model: T5EncoderModel, tok: AutoTokenizer, seq: str) -> torch.Tensor:
    """Return ProtT5 embeddings as (1, L+2, DIM) with zero CLS/EOS rows."""
    L = len(seq)

    # Fast path: one shot, identical to your original logic.
    if L <= WINDOW_EFF:
        kept = _embed_window(model, tok, seq)  # (L, DIM)
        emb = torch.zeros(1, L + 2, kept.size(-1), dtype=DTYPE, device="cpu")
        emb[0, 1 : L + 1] = kept.detach().cpu()
        return emb

    # Sliding-window path
    windows = _plan_windows(L)

    # Output buffer (CPU to control memory)
    E = torch.empty(L, DIM, dtype=DTYPE, device="cpu")

    for (s, e) in windows:
        subseq = seq[s:e]
        Li = e - s
        if Li > WINDOW_EFF:
            raise RuntimeError(f"Internal window larger than WINDOW_EFF: Li={Li} > {WINDOW_EFF}")

        R = _embed_window(model, tok, subseq).detach().cpu()  # (Li, DIM)

        # Context-trim stitching: keep only the non-overlap core of each window.
        # First window keeps its full left edge; last window keeps its full right edge.
        left_trim  = 0 if s == 0 else OVERLAP // 2
        right_trim = 0 if e == L else (OVERLAP - OVERLAP // 2)

        gs = s + left_trim           # global start to write
        ge = e - right_trim          # global end   to write (exclusive)
        ls = left_trim               # local  start to read
        le = Li - right_trim         # local  end   to read (exclusive)

        if ge <= gs:
            raise RuntimeError(f"Empty keep range for window {(s,e)} with trims {left_trim},{right_trim}")
        if (le - ls) != (ge - gs):
            raise RuntimeError("Keep lengths disagree (local vs global) during trim-stitch.")

        E[gs:ge] = R[ls:le]

    # Pack to (1, L+2, DIM) with zeros at [0] and [L+1]
    emb = torch.zeros(1, L + 2, DIM, dtype=DTYPE, device="cpu")
    emb[0, 1 : L + 1] = E
    return emb


def main() -> None:
    if not BASE.exists():                                                       # CHANGED
        raise FileNotFoundError(f"Base dataset not found: {BASE}")              # CHANGED

    # Load model & tokenizer once
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, do_lower_case=False, use_fast=False)
    model = T5EncoderModel.from_pretrained(MODEL_NAME).to(DEVICE).eval()

    made, skipped_exist, skipped_empty = 0, 0, 0                                # CHANGED
    for pid_dir in tqdm(list(iter_pid_dirs(BASE, SPLITS)), desc="ProtT5 embedding"):  # CHANGED
        out_path = pid_dir / OUT_FILENAME                                       # CHANGED
        if out_path.exists():
            skipped_exist += 1
            continue  # skip silently
        seq = read_sequence(pid_dir)                                            # CHANGED
        if not seq:
            skipped_empty += 1
            continue

        emb = embed_prott5_stitched(model, tokenizer, seq)  # (1, L+2, DIM); zeros only at [0] and [L+1]
        out_path.parent.mkdir(parents=True, exist_ok=True)                      # unchanged
        torch.save(emb, out_path)
        made += 1

    print(
        f"Done. created={made} | existing_skipped={skipped_exist} | missing_seq_skipped={skipped_empty}"
    )


if __name__ == "__main__":
    main()


tokenizer_config.json:   0%|          | 0.00/24.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/546 [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/238k [00:00<?, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


pytorch_model.bin:   0%|          | 0.00/11.3G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/11.3G [00:00<?, ?B/s]

ProtT5 embedding:   0%|          | 0/36629 [00:00<?, ?it/s]

Done. created=36629 | existing_skipped=0 | missing_seq_skipped=0


In [2]:
"""
Generate ESM-C (600M) sequence embeddings for PIDs that appear under the
PFP_Testing/data/PDBCH <split>/* directories, *reading sequences from*
<split>/<PID>/sequence.txt and *writing*
<split>/<PID>/esmc_emb.pt.

Output per PID:
    <split>/<PID>/esmc_emb.pt     # torch.Tensor, shape (1, L+2, 1152)

Logic
-----
- If L <= WINDOW_EFF, run a single forward pass (fast path). This preserves
  the original ESM-C behavior, including native [CLS]/[EOS] embeddings.
- If L > WINDOW_EFF, use sliding windows with overlap and **context-trim**:
  for each window, embed, drop CLS/EOS to get per-residue rows, then keep
  only the non-overlap core (no weighting). Stitch into a full [L, 1152]
  buffer. We then pack to (1, L+2, 1152) with zeros at positions [0] and
  [L+1] (global CLS/EOS are not available under windowing).

Notes
-----
- We *only* embed PIDs that exist under <split>/* (splits: train_pdbch/val_pdbch/test_pdbch).
- Existing files are skipped.
- Uses GPU if available.
"""
from __future__ import annotations

import re
from pathlib import Path
from typing import Iterable, Set, List, Tuple

import torch
from tqdm.auto import tqdm

# ─────────────────────────────── config ─────────────────────────────── #
BASE          = Path("/teamspace/studios/this_studio/PFP_Testing/data/PDBCH")   # CHANGED
SPLITS        = ["train_pdbch", "val_pdbch", "test_pdbch"]                      # CHANGED
OUT_FILENAME  = "esmc_emb.pt"             # what downstream expects
MODEL_NAME    = "esmc_600m"               # Fair-ESM registry tag

DEVICE        = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DTYPE         = torch.float16
DIM           = 1152

# Sliding-window hyperparameters (mirroring your ProtT5 setup)
# ESM-C has both CLS and EOS per window → 2 specials inside each window.
MAX_LEN_TOKENS = 2048     # effective model token limit (incl. CLS+EOS)
N_SPECIALS_WIN = 2        # CLS + EOS
WINDOW_EFF     = MAX_LEN_TOKENS - N_SPECIALS_WIN   # residues per window
OVERLAP        = 256      # residues overlapped between adjacent windows
# ────────────────────────────────────────────────────────────────────── #


def iter_pid_dirs(base: Path, splits: Iterable[str]) -> Iterable[Path]:          # CHANGED
    """Yield each PID directory under <base>/<split>/* for all splits."""        # CHANGED
    for split in splits:
        split_dir = base / split
        if not split_dir.exists():
            tqdm.write(f"[WARN] Missing folder: {split_dir} — skipping this split")
            continue
        for item in split_dir.iterdir():
            if item.is_dir() and not item.name.startswith("."):
                yield item


def read_sequence(pid_dir: Path) -> str:                                         # CHANGED
    """Read FASTA-like or plain sequence, strip headers/whitespace, uppercase."""
    seq_path = pid_dir / "sequence.txt"                                          # CHANGED
    if not seq_path.exists():
        return ""
    seq_parts = []
    with seq_path.open("r", encoding="utf-8") as fh:
        for line in fh:
            if line.startswith(">"):
                continue
            seq_parts.append(line.strip())
    seq = ("".join(seq_parts)).upper()
    return seq


def _plan_windows(L: int) -> List[Tuple[int, int]]:
    """Return inclusive-exclusive windows [(s,e), ...] covering [0, L) with overlap."""
    if L <= WINDOW_EFF:
        return [(0, L)]
    stride = max(1, WINDOW_EFF - OVERLAP)
    windows: List[Tuple[int, int]] = []
    s = 0
    while s < L:
        e = min(s + WINDOW_EFF, L)
        windows.append((s, e))
        if e == L:
            break
        s += stride
    return windows


@torch.inference_mode()
def _embed_window_esmc(model, subseq: str) -> torch.Tensor:
    """
    Embed a *subsequence* with ESM-C and return per-residue embeddings [Li, DIM].

    We call model.encode + model.logits(..., return_embeddings=True), which returns
    (1, Li+2, DIM) including CLS/EOS. We drop [CLS]/[EOS] here and only return
    residue rows.
    """
    from esm.sdk.api import ESMProtein, LogitsConfig

    prot   = ESMProtein(sequence=subseq)
    hidden = model.encode(prot)  # device-aware; model already on DEVICE
    out    = model.logits(hidden, LogitsConfig(sequence=True, return_embeddings=True))
    # out.embeddings: (1, Li+2, DIM)
    R = out.embeddings[0, 1:-1].to(DTYPE)  # (Li, DIM), drop CLS/EOS
    if R.size(0) != len(subseq):
        raise RuntimeError(
            f"[ESM-C] Token/AA mismatch in window: residues={len(subseq)} vs got={R.size(0)} rows"
        )
    return R


@torch.inference_mode()
def _embed_full_esmc(model, seq: str) -> torch.Tensor:
    """
    Full-sequence ESM-C embedding → (1, L+2, DIM), keeping native CLS/EOS.
    This matches your original single-pass behavior.
    """
    from esm.sdk.api import ESMProtein, LogitsConfig

    prot   = ESMProtein(sequence=seq)
    hidden = model.encode(prot)
    out    = model.logits(hidden, LogitsConfig(sequence=True, return_embeddings=True))
    return out.embeddings.to(DTYPE).cpu()  # (1, L+2, DIM)


@torch.inference_mode()
def embed_esmc_stitched(model, seq: str) -> torch.Tensor:
    """
    Return ESM-C embeddings as (1, L+2, 1152).

    - Short sequences (L <= WINDOW_EFF): single pass; preserve native CLS/EOS.
    - Long sequences: sliding windows + context-trim stitching for residues;
      pack to (1, L+2, DIM) with zeros at [0] and [L+1].
    """
    L = len(seq)
    if L == 0:
        raise ValueError("Empty sequence.")

    # Fast path (identical to original logic)
    if L <= WINDOW_EFF:
        return _embed_full_esmc(model, seq)

    # Sliding-window path
    windows = _plan_windows(L)
    E = torch.empty(L, DIM, dtype=DTYPE, device="cpu")

    for (s, e) in windows:
        subseq = seq[s:e]
        Li = e - s
        if Li > WINDOW_EFF:
            raise RuntimeError(f"Internal window larger than WINDOW_EFF: Li={Li} > {WINDOW_EFF}")

        R = _embed_window_esmc(model, subseq).detach().cpu()  # (Li, DIM)

        # Context-trim stitching (same as in your ProtT5 script).
        left_trim  = 0 if s == 0 else OVERLAP // 2
        right_trim = 0 if e == L else (OVERLAP - OVERLAP // 2)

        gs = s + left_trim           # global start
        ge = e - right_trim          # global end (exclusive)
        ls = left_trim               # local start
        le = Li - right_trim         # local end (exclusive)

        if ge <= gs:
            raise RuntimeError(f"Empty keep range for window {(s, e)}")
        if (le - ls) != (ge - gs):
            raise RuntimeError("Keep lengths disagree during trim-stitch.")

        E[gs:ge] = R[ls:le]

    # Pack to (1, L+2, DIM); under sliding we zero CLS/EOS to keep interface stable.
    emb = torch.zeros(1, L + 2, DIM, dtype=DTYPE, device="cpu")
    emb[0, 1 : L + 1] = E
    return emb


def main() -> None:
    if not BASE.exists():                                                       # CHANGED
        raise FileNotFoundError(f"Base dataset not found: {BASE}")              # CHANGED

    pid_dirs = list(iter_pid_dirs(BASE, SPLITS))                                # CHANGED
    if not pid_dirs:                                                           
        print("No PID folders found — nothing to do.")                          # CHANGED
        return

    print(f"Found {len(pid_dirs)} PID folders across {SPLITS}\n")               # CHANGED
    print(f"⇢ Loading {MODEL_NAME} onto {DEVICE} …")
    from esm.models.esmc import ESMC

    model = ESMC.from_pretrained(MODEL_NAME).to(DEVICE).eval()

    made, skipped_exist, skipped_empty = 0, 0, 0
    for pid_dir in tqdm(pid_dirs, desc="Embedding (ESM-C, sliding-aware)"):     # CHANGED
        out_path = pid_dir / OUT_FILENAME                                       # CHANGED
        if out_path.exists():
            skipped_exist += 1
            continue
        # Read sequence (accept FASTA-like; strip headers/whitespace)           # CHANGED
        seq = read_sequence(pid_dir)                                            # CHANGED
        if not seq:
            tqdm.write(f"[WARN] {pid_dir.name}: missing/empty sequence.txt — skipping")  # CHANGED
            skipped_empty += 1
            continue

        # Embed (fast path or sliding-stitch)
        emb = embed_esmc_stitched(model, seq)  # (1, L+2, 1152)
        out_path.parent.mkdir(parents=True, exist_ok=True)                      # unchanged
        torch.save(emb, out_path)
        made += 1
        tqdm.write(f"✓ {pid_dir.name}: L={len(seq):4d} → {OUT_FILENAME}")      # CHANGED

    print(
        f"\nDone. created={made} | existing_skipped={skipped_exist} | missing/empty_seq_skipped={skipped_empty}"
    )


if __name__ == "__main__":
    main()


Found 36629 PID folders across ['train_pdbch', 'val_pdbch', 'test_pdbch']

⇢ Loading esmc_600m onto cuda …


Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

README.md: 0.00B [00:00, ?B/s]

.gitattributes: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/2.00 [00:00<?, ?B/s]

esmc_600m_2024_12_v0.pth:   0%|          | 0.00/2.30G [00:00<?, ?B/s]

  state_dict = torch.load(


Embedding (ESM-C, sliding-aware):   0%|          | 0/36629 [00:00<?, ?it/s]

✓ 154L-A: L= 185 → esmc_emb.pt
✓ 155C-A: L= 135 → esmc_emb.pt
✓ 16PK-A: L= 415 → esmc_emb.pt
✓ 16VP-A: L= 366 → esmc_emb.pt
✓ 1914-A: L= 232 → esmc_emb.pt
✓ 19HC-A: L= 292 → esmc_emb.pt
✓ 1A05-A: L= 358 → esmc_emb.pt
✓ 1A0C-A: L= 438 → esmc_emb.pt
✓ 1A0D-A: L= 440 → esmc_emb.pt
✓ 1A0E-A: L= 443 → esmc_emb.pt
✓ 1A0H-A: L= 159 → esmc_emb.pt
✓ 1A0I-A: L= 348 → esmc_emb.pt
✓ 1A0J-A: L= 223 → esmc_emb.pt
✓ 1A0Q-L: L= 212 → esmc_emb.pt
✓ 1A0R-P: L= 245 → esmc_emb.pt
✓ 1A14-H: L= 120 → esmc_emb.pt
✓ 1A14-L: L= 104 → esmc_emb.pt
✓ 1A17-A: L= 166 → esmc_emb.pt
✓ 1A1S-A: L= 314 → esmc_emb.pt
✓ 1A1Z-A: L=  91 → esmc_emb.pt
✓ 1A25-A: L= 149 → esmc_emb.pt
✓ 1A2A-A: L= 122 → esmc_emb.pt
✓ 1A2O-A: L= 349 → esmc_emb.pt
✓ 1A2Z-A: L= 220 → esmc_emb.pt
✓ 1A3W-A: L= 500 → esmc_emb.pt
✓ 1A41-A: L= 234 → esmc_emb.pt
✓ 1A47-A: L= 683 → esmc_emb.pt
✓ 1A4B-A: L= 129 → esmc_emb.pt
✓ 1A57-A: L= 116 → esmc_emb.pt
✓ 1A59-A: L= 378 → esmc_emb.pt
✓ 1A5I-A: L= 265 → esmc_emb.pt
✓ 1A5Z-A: L= 319 → esmc_emb.pt
✓ 1A63-A

In [3]:
"""
Generate Ankh3-XL encoder embeddings for PIDs that appear under the
PDBCH <split>/* directories, *reading sequences from*
PDBCH/<split>/<PID>/sequence.txt and *writing*
PDBCH/<split>/<PID>/ankh_emb_xl.pt.

Output per PID:
    <split>/<PID>/ankh_emb_xl.pt     # torch.Tensor, (1, L+2, d_model)

Logic
-----
- Positions [0] and [L+1] are zeros; positions [1..L] are per-residue embeddings.
- Uses the S2S prefix token and includes EOS in tokenization.
- If L <= WINDOW_EFF (512), run a single forward pass (fast path).
- If L > WINDOW_EFF, use sliding windows with **OVERLAP=128** and
  **context-trim** stitching (keep only the non-overlap core per window; no weighting).
  Only indices 0 and L+1 are zero.

Notes
-----
- We *only* embed PIDs that exist under <split>/* (splits: train_pdbch/val_pdbch/test_pdbch).
- Skips existing outputs and missing/empty sequences.
- Uses GPU if available.
"""
from __future__ import annotations

from pathlib import Path
from typing import Iterable, List, Tuple

import torch
from tqdm.auto import tqdm
from transformers import T5Tokenizer, T5EncoderModel

# ─────────────────────────────── config ─────────────────────────────── #
BASE          = Path("/teamspace/studios/this_studio/PFP_Testing/data/PDBCH")  # CHANGED
SPLITS        = ["train_pdbch", "val_pdbch", "test_pdbch"]                     # CHANGED
OUT_FILENAME  = "ankh_emb_xl.pt"
CKPT          = "ElnaggarLab/ankh3-xl"
PREFIX        = "[S2S]"                   # Use S2S prefix for better embedding quality

DTYPE         = torch.float16
DEVICE        = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hardcoded model dimension for Ankh3-XL
D_MODEL       = 2560

# Sliding-window hyperparameters (requested)
MAX_LEN_TOKENS   = 512      # model token limit incl. specials
N_SPECIALS_WIN   = 2         # assume <bos> and <eos> are present
WINDOW_EFF       = MAX_LEN_TOKENS - N_SPECIALS_WIN  # residues per window
OVERLAP       = 128    # residues of overlap between adjacent windows
# ────────────────────────────────────────────────────────────────────── #


def iter_pid_dirs(base: Path, splits: Iterable[str]) -> Iterable[Path]:        # CHANGED
    """Yield each PID directory under <base>/<split>/* for all splits."""
    for split in splits:
        split_dir = base / split
        if not split_dir.exists():
            tqdm.write(f"[WARN] Missing folder: {split_dir} — skipping this split")
            continue
        for item in split_dir.iterdir():
            if item.is_dir() and not item.name.startswith("."):
                yield item


def read_sequence(pid_dir: Path) -> str:                                       # CHANGED
    """Read FASTA-like or plain sequence, strip headers/whitespace, uppercase."""
    seq_path = pid_dir / "sequence.txt"                                         # CHANGED
    if not seq_path.exists():
        return ""
    parts = []
    with seq_path.open("r", encoding="utf-8") as fh:
        for line in fh:
            if line.startswith(">"):
                continue
            parts.append(line.strip())
    return ("".join(parts)).upper()


def _plan_windows(L: int) -> List[Tuple[int, int]]:
    """Return inclusive-exclusive windows [(s,e), ...] covering [0, L) with overlap."""
    if L <= WINDOW_EFF:
        return [(0, L)]
    stride = max(1, WINDOW_EFF - OVERLAP)
    windows: List[Tuple[int, int]] = []
    s = 0
    while s < L:
        e = min(s + WINDOW_EFF, L)
        windows.append((s, e))
        if e == L:
            break
        s += stride
    return windows


@torch.inference_mode()
def _embed_window_ankh(model: T5EncoderModel, tok: T5Tokenizer, subseq: str) -> torch.Tensor:
    """
    Embed a *subsequence* with Ankh3-XL and return per-residue embeddings [Li, d_model].

    We tokenize with the S2S prefix and add_special_tokens=True (includes EOS).
    The first token corresponds to the prefix; we drop the prefix and EOS,
    keeping exactly Li residue rows.
    """
    enc = tok(PREFIX + subseq, add_special_tokens=True, return_tensors="pt")
    enc = {k: v.to(model.device) for k, v in enc.items()}

    hidden = model(**enc).last_hidden_state.squeeze(0)  # [T, d_model]
    Li = len(subseq)

    # Positions: 0 -> S2S prefix; 1..Li -> residues; Li+1 -> EOS (when present)
    per_res = hidden[1 : 1 + Li]                         # [Li, d_model]
    if per_res.size(0) != Li:
        raise RuntimeError(
            f"[Ankh3-XL] Token/AA mismatch in window: expected {Li} got {per_res.size(0)}"
        )
    return per_res.to(DTYPE)


@torch.inference_mode()
def _embed_full_ankh(model: T5EncoderModel, tok: T5Tokenizer, seq: str) -> torch.Tensor:
    """
    Full-sequence Ankh3-XL embedding → (1, L+2, d_model) with zeros at [0] and [L+1],
    residues in [1..L]. (We still use S2S prefix + EOS at encode time.)
    """
    L = len(seq)
    per_res = _embed_window_ankh(model, tok, seq)        # [L, d_model]
    d = per_res.size(-1)
    out = torch.zeros(1, L + 2, d, dtype=DTYPE, device="cpu")
    out[0, 1 : L + 1] = per_res.detach().cpu()
    return out


@torch.inference_mode()
def embed_ankh_stitched(model: T5EncoderModel, tok: T5Tokenizer, seq: str) -> torch.Tensor:
    """
    Return Ankh3-XL embeddings as (1, L+2, d_model), zeros at [0] and [L+1].
    - Short sequences (L <= 512): single pass (preserves your original interface).
    - Long sequences: sliding windows + context-trim stitching (no weighting).
    """
    L = len(seq)
    if L == 0:
        raise ValueError("Empty sequence.")

    if L <= WINDOW_EFF:
        return _embed_full_ankh(model, tok, seq)

    # Sliding-window path
    windows = _plan_windows(L)
    E = torch.empty(L, D_MODEL, dtype=DTYPE, device="cpu")  # residue buffer

    for (s, e) in windows:
        subseq = seq[s:e]
        Li = e - s
        if Li > WINDOW_EFF:
            raise RuntimeError(f"Internal window larger than WINDOW_EFF: Li={Li} > {WINDOW_EFF}")

        R = _embed_window_ankh(model, tok, subseq).detach().cpu()  # [Li, d_model]

        # Context-trim: first window keeps its left edge; last keeps its right edge.
        left_trim  = 0 if s == 0 else OVERLAP // 2             # 0 or 64
        right_trim = 0 if e == L else (OVERLAP - OVERLAP // 2) # 0 or 64

        gs = s + left_trim
        ge = e - right_trim
        ls = left_trim
        le = Li - right_trim

        if ge <= gs:
            raise RuntimeError(f"Empty keep range for window {(s, e)}")
        if (le - ls) != (ge - gs):
            raise RuntimeError("Keep lengths disagree during trim-stitch.")

        E[gs:ge] = R[ls:le]

    # Pack (1, L+2, d_model) with zeros at [0] and [L+1]
    out = torch.zeros(1, L + 2, D_MODEL, dtype=DTYPE, device="cpu")
    out[0, 1 : L + 1] = E
    return out


def main() -> None:
    if not BASE.exists():                                                       # CHANGED
        raise FileNotFoundError(f"Base dataset not found: {BASE}")              # CHANGED

    # Load model/tokenizer once
    print(f"⇢ Loading {CKPT} encoder on {DEVICE} …")
    tok   = T5Tokenizer.from_pretrained(CKPT)
    model = T5EncoderModel.from_pretrained(CKPT).to(DEVICE).eval()

    pid_dirs = list(iter_pid_dirs(BASE, SPLITS))                                # CHANGED
    if not pid_dirs:                                                             # CHANGED
        print("No PID folders found — nothing to do.")                           # CHANGED
        return

    print(f"Found {len(pid_dirs)} PID folders across {SPLITS}\n")               # CHANGED

    made, skipped_exist, skipped_empty = 0, 0, 0
    for pid_dir in tqdm(pid_dirs, desc="Embedding (Ankh3-XL, sliding-aware)"):  # CHANGED
        out_path = pid_dir / OUT_FILENAME                                       # CHANGED
        if out_path.exists():
            skipped_exist += 1
            continue

        seq = read_sequence(pid_dir)                                            # CHANGED
        if not seq:
            tqdm.write(f"[WARN] {pid_dir.name}: missing/empty sequence.txt — skipping")  # CHANGED
            skipped_empty += 1
            continue

        emb = embed_ankh_stitched(model, tok, seq)  # (1, L+2, d_model)
        out_path.parent.mkdir(parents=True, exist_ok=True)                      # unchanged
        torch.save(emb, out_path)
        made += 1
        tqdm.write(f"✓ {pid_dir.name}: L={len(seq):4d} → {OUT_FILENAME}")       # CHANGED

    print(
        f"\nDone. created={made} | existing_skipped={skipped_exist} | missing/empty_seq_skipped={skipped_empty}"
    )


if __name__ == "__main__":
    main()


⇢ Loading ElnaggarLab/ankh3-xl encoder on cuda …


tokenizer_config.json: 0.00B [00:00, ?B/s]

spiece.model:   0%|          | 0.00/238k [00:00<?, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/778 [00:00<?, ?B/s]

pytorch_model.bin.index.json: 0.00B [00:00, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

pytorch_model-00001-of-00003.bin:   0%|          | 0.00/7.99G [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

pytorch_model-00002-of-00003.bin:   0%|          | 0.00/7.98G [00:00<?, ?B/s]

pytorch_model-00003-of-00003.bin:   0%|          | 0.00/6.95G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Found 36629 PID folders across ['train_pdbch', 'val_pdbch', 'test_pdbch']



Embedding (Ankh3-XL, sliding-aware):   0%|          | 0/36629 [00:00<?, ?it/s]

✓ 154L-A: L= 185 → ankh_emb_xl.pt
✓ 155C-A: L= 135 → ankh_emb_xl.pt
✓ 16PK-A: L= 415 → ankh_emb_xl.pt
✓ 16VP-A: L= 366 → ankh_emb_xl.pt
✓ 1914-A: L= 232 → ankh_emb_xl.pt
✓ 19HC-A: L= 292 → ankh_emb_xl.pt
✓ 1A05-A: L= 358 → ankh_emb_xl.pt
✓ 1A0C-A: L= 438 → ankh_emb_xl.pt
✓ 1A0D-A: L= 440 → ankh_emb_xl.pt
✓ 1A0E-A: L= 443 → ankh_emb_xl.pt
✓ 1A0H-A: L= 159 → ankh_emb_xl.pt
✓ 1A0I-A: L= 348 → ankh_emb_xl.pt
✓ 1A0J-A: L= 223 → ankh_emb_xl.pt
✓ 1A0Q-L: L= 212 → ankh_emb_xl.pt
✓ 1A0R-P: L= 245 → ankh_emb_xl.pt
✓ 1A14-H: L= 120 → ankh_emb_xl.pt
✓ 1A14-L: L= 104 → ankh_emb_xl.pt
✓ 1A17-A: L= 166 → ankh_emb_xl.pt
✓ 1A1S-A: L= 314 → ankh_emb_xl.pt
✓ 1A1Z-A: L=  91 → ankh_emb_xl.pt
✓ 1A25-A: L= 149 → ankh_emb_xl.pt
✓ 1A2A-A: L= 122 → ankh_emb_xl.pt
✓ 1A2O-A: L= 349 → ankh_emb_xl.pt
✓ 1A2Z-A: L= 220 → ankh_emb_xl.pt
✓ 1A3W-A: L= 500 → ankh_emb_xl.pt
✓ 1A41-A: L= 234 → ankh_emb_xl.pt
✓ 1A47-A: L= 683 → ankh_emb_xl.pt
✓ 1A4B-A: L= 129 → ankh_emb_xl.pt
✓ 1A57-A: L= 116 → ankh_emb_xl.pt
✓ 1A59-A: L= 3

In [4]:
"""
Generate ProteinGLM per-residue embeddings for PIDs that appear under the
PFP_Testing/data/PDBCH <split>/* directories, *reading sequences from*
<split>/<PID>/sequence.txt and *writing*
<split>/<PID>/pglm_emb.pt.

This preserves your original per-residue embedding layout and only changes how we
handle long sequences: if L > WINDOW_EFF, we use sliding windows with overlap
and **context-trim** stitching (no weighting). For long sequences, indices 1..L
contain per-residue embeddings; indices 0 and L+1 are zero.

Layout expected:
    /teamspace/studios/this_studio/PFP_Testing/data/PDBCH/
      ├─ train_pdbch/<PID>/
      ├─ val_pdbch/<PID>/
      ├─ test_pdbch/<PID>/
      └─ <split>/<PID>/sequence.txt   # created by your flat builder

Output per PID:
    <split>/<PID>/pglm_emb.pt  # torch.Tensor, shape (1, L+2, 2560)

Run:
    python embed_pglm_cafa_style.py

Notes:
- Skips PIDs whose embeddings already exist.
- Skips PIDs with missing/empty sequences, with a warning.
- Uses GPU if available.
"""
from __future__ import annotations

from pathlib import Path
from typing import Iterable, Set, List, Tuple

import torch
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModelForMaskedLM

# ─────────────────────────────── config ─────────────────────────────── #
BASE          = Path("/teamspace/studios/this_studio/PFP_Testing/data/PDBCH")  # CHANGED
SPLITS        = ["train_pdbch", "val_pdbch", "test_pdbch"]                     # CHANGED

MODEL_ID      = "biomap-research/proteinglm-3b-mlm"   # HF repo id
OUT_FILENAME  = "pglm_emb.pt"                         # per-protein output
DTYPE_SAVE    = torch.float16
DEVICE        = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ProteinGLM specifics (as requested)
DIM              = 2560
MAX_LEN_TOKENS   = 2048      # model token limit incl. specials
N_SPECIALS_WIN   = 1         # assume only <eos> is added; set to 2 if BOS is present
WINDOW_EFF       = MAX_LEN_TOKENS - N_SPECIALS_WIN  # residues per window
OVERLAP          = 256        # residues overlap between windows
# ────────────────────────────────────────────────────────────────────── #


def iter_pid_dirs(base: Path, splits: Iterable[str]):                           # CHANGED
    """Yield each PID directory under <base>/<split>/* for all splits."""       # CHANGED
    for split in splits:                                                         # CHANGED
        split_dir = base / split                                                 # CHANGED
        if not split_dir.exists():                                               # CHANGED
            tqdm.write(f"[WARN] Missing folder: {split_dir} — skipping this split")  # CHANGED
            continue                                                             # CHANGED
        for item in split_dir.iterdir():                                         # CHANGED
            if item.is_dir() and not item.name.startswith("."):                  # CHANGED
                yield item                                                       # CHANGED


def read_sequence(pid_dir: Path) -> str:                                        # CHANGED
    """Read FASTA-like or plain sequence, strip headers/whitespace, uppercase."""# CHANGED
    seq_path = pid_dir / "sequence.txt"                                         # CHANGED
    if not seq_path.exists():
        return ""
    seq_parts = []
    with seq_path.open("r", encoding="utf-8") as fh:
        for line in fh:
            if line.startswith(">"):
                continue
            seq_parts.append(line.strip())
    return ("".join(seq_parts)).upper().replace(" ", "")


def _plan_windows(L: int) -> List[Tuple[int, int]]:
    """Return inclusive-exclusive windows [(s,e), ...] covering [0, L) with overlap."""
    if L <= WINDOW_EFF:
        return [(0, L)]
    stride = max(1, WINDOW_EFF - OVERLAP)
    windows: List[Tuple[int, int]] = []
    s = 0
    while s < L:
        e = min(s + WINDOW_EFF, L)
        windows.append((s, e))
        if e == L:
            break
        s += stride
    return windows


@torch.inference_mode()
def _embed_window_pglm(model: AutoModelForMaskedLM, tok: AutoTokenizer, subseq: str) -> torch.Tensor:
    """
    Embed a *subsequence* with ProteinGLM, returning per-residue embeddings [Li, DIM].
    Uses exact logic from the original embedding function.
    """
    # Tokenize with special tokens (adds <eos>)
    out = tok(subseq, add_special_tokens=True, return_tensors='pt')
    inputs = {
        "input_ids": out["input_ids"].to(model.device),
        "attention_mask": out["attention_mask"].to(model.device),
    }

    # Model forward per Hugging Face usage; drop trailing <eos> token embeddings
    out_m = model(**inputs, output_hidden_states=True, return_last_hidden_state=True)
    token_emb = out_m.hidden_states[:-1, 0]  # per docs: remove <eos>, take batch=0 → [Li, d]

    if token_emb.size(0) != len(subseq):
        raise RuntimeError(
            f"[ProteinGLM] Token/AA mismatch in window: tokens_no_specials={token_emb.size(0)} vs Li={len(subseq)}"
        )
    return token_emb.to(DTYPE_SAVE)


@torch.inference_mode()
def _embed_full_pglm(model: AutoModelForMaskedLM, tok: AutoTokenizer, seq: str) -> torch.Tensor:
    """
    Full-sequence per-residue embeddings packed to (1, L+2, DIM) with zeros at [0] and [L+1].
    Uses exact logic from the original embedding function.
    """
    L = len(seq)

    # Tokenize with special tokens (adds <eos>)
    out = tok(seq, add_special_tokens=True, return_tensors='pt')
    inputs = {
        "input_ids": out["input_ids"].to(model.device),
        "attention_mask": out["attention_mask"].to(model.device),
    }

    # Model forward per Hugging Face usage; drop trailing <eos> token embeddings
    out_m = model(**inputs, output_hidden_states=True, return_last_hidden_state=True)
    token_emb = out_m.hidden_states[:-1, 0]  # per docs: remove <eos>, take batch=0 → [L, d]

    if token_emb.size(0) != L:
        raise RuntimeError(f"[ProteinGLM] Full pass token/AA mismatch: got {token_emb.size(0)} vs L={L}")

    # Pack to (1, L+2, d); zero at positions 0 and L+1; move to CPU float32 for saving
    token_emb = token_emb.to(DTYPE_SAVE).cpu()  # [L, d]
    d = token_emb.size(-1)
    emb = torch.zeros(1, L + 2, d, dtype=DTYPE_SAVE, device="cpu")
    emb[0, 1:L+1] = token_emb
    return emb


@torch.inference_mode()
def embed_pglm_stitched(model: AutoModelForMaskedLM, tok: AutoTokenizer, seq: str) -> torch.Tensor:
    """
    Return ProteinGLM embeddings as (1, L+2, DIM).

    - Short sequences (L <= WINDOW_EFF): single pass; zeros at [0] and [L+1].
    - Long sequences: sliding windows + context-trim stitching for residues;
      pack to (1, L+2, DIM) with zeros at [0] and [L+1].
    """
    L = len(seq)
    if L == 0:
        raise ValueError("Empty sequence.")

    # Fast path
    if L <= WINDOW_EFF:
        return _embed_full_pglm(model, tok, seq)

    # Sliding-window path
    windows = _plan_windows(L)
    E = torch.empty(L, DIM, dtype=DTYPE_SAVE, device="cpu")

    for (s, e) in windows:
        subseq = seq[s:e]
        Li = e - s
        if Li > WINDOW_EFF:
            raise RuntimeError(f"Internal window larger than WINDOW_EFF: Li={Li} > {WINDOW_EFF}")

        R = _embed_window_pglm(model, tok, subseq).detach().cpu()  # (Li, DIM)

        # Context-trim stitching (identical pattern to your ProtT5 script)
        left_trim  = 0 if s == 0 else OVERLAP // 2
        right_trim = 0 if e == L else (OVERLAP - OVERLAP // 2)

        gs = s + left_trim           # global start (inclusive)
        ge = e - right_trim          # global end   (exclusive)
        ls = left_trim               # local  start
        le = Li - right_trim         # local  end

        if ge <= gs:
            raise RuntimeError(f"Empty keep range for window {(s, e)}")
        if (le - ls) != (ge - gs):
            raise RuntimeError("Keep lengths disagree during trim-stitch.")

        E[gs:ge] = R[ls:le]

    # Pack to (1, L+2, DIM) with zeros at [0] and [L+1]
    emb = torch.zeros(1, L + 2, DIM, dtype=DTYPE_SAVE, device="cpu")
    emb[0, 1:L+1] = E
    return emb


# ─────────────────────────────────── main ───────────────────────────── #

def main() -> None:
    if not BASE.exists():                                                       # CHANGED
        raise FileNotFoundError(f"Base dataset not found: {BASE}")              # CHANGED

    # Load model/tokenizer once
    print(f"⇢ Loading {MODEL_ID} onto {DEVICE} …")

    tok = AutoTokenizer.from_pretrained(MODEL_ID)
    model = AutoModelForMaskedLM.from_pretrained(MODEL_ID).to(DEVICE).eval()

    pid_dirs = list(iter_pid_dirs(BASE, SPLITS))                                # CHANGED
    if not pid_dirs:                                                             # CHANGED
        print("No PID folders found — nothing to do.")                           # CHANGED
        return                                                                   # CHANGED

    print(f"Found {len(pid_dirs)} PID folders across {SPLITS}\n")               # CHANGED

    made, skipped_exist, skipped_empty = 0, 0, 0                                # CHANGED
    for pid_dir in tqdm(pid_dirs, desc="Embedding (ProteinGLM, sliding-aware)"): # CHANGED
        out_path = pid_dir / OUT_FILENAME                                       # CHANGED
        if out_path.exists():
            skipped_exist += 1
            continue

        seq = read_sequence(pid_dir)                                            # CHANGED
        if not seq:
            tqdm.write(f"[WARN] {pid_dir.name}: missing/empty sequence.txt — skipping")  # CHANGED
            skipped_empty += 1
            continue

        emb = embed_pglm_stitched(model, tok, seq)  # (1, L+2, 2560)
        out_path.parent.mkdir(parents=True, exist_ok=True)                      # unchanged
        torch.save(emb, out_path)
        made += 1
        tqdm.write(f"✓ {pid_dir.name}: L={len(seq):4d} → {OUT_FILENAME}")       # CHANGED

    print(
        f"\nDone. created={made} | existing_skipped={skipped_exist} | missing/empty_seq_skipped={skipped_empty}"
    )


if __name__ == "__main__":
    main()


⇢ Loading biomap-research/proteinglm-3b-mlm onto cuda …


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenization_proteinglm.py: 0.00B [00:00, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/biomap-research/proteinglm-3b-mlm:
- tokenization_proteinglm.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


tokenizer.model:   0%|          | 0.00/112 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/260 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

configuration_proteinglm.py: 0.00B [00:00, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/biomap-research/proteinglm-3b-mlm:
- configuration_proteinglm.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


modeling_proteinglm.py: 0.00B [00:00, ?B/s]

quantization.py: 0.00B [00:00, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/biomap-research/proteinglm-3b-mlm:
- quantization.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downloaded from https://huggingface.co/biomap-research/proteinglm-3b-mlm:
- modeling_proteinglm.py
- quantization.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


[2025-09-16 03:21:45,983] [INFO] [real_accelerator.py:254:get_accelerator] Setting ds_accelerator to cuda (auto detect)


/home/zeus/miniconda3/envs/cloudspace/compiler_compat/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status
/home/zeus/miniconda3/envs/cloudspace/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::runtime_error::~runtime_error()@GLIBCXX_3.4'
/home/zeus/miniconda3/envs/cloudspace/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__gxx_personality_v0@CXXABI_1.3'
/home/zeus/miniconda3/envs/cloudspace/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream::tellp()@GLIBCXX_3.4'
/home/zeus/miniconda3/envs/cloudspace/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::chrono::_V2::steady_clock::now()@GLIBCXX_3.4.19'
/home/zeus/miniconda3/envs/cloudspace/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::_M_replace_aux(unsigned long, unsigned long, unsigned long, char)@GLIBCXX_3.4'


[2025-09-16 03:21:47,384] [INFO] [logging.py:107:log_dist] [Rank -1] [TorchCheckpointEngine] Initialized with serialization = False


Failed to load cpm_kernels:No module named 'cpm_kernels'


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.90G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/1.47G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Found 36629 PID folders across ['train_pdbch', 'val_pdbch', 'test_pdbch']



Embedding (ProteinGLM, sliding-aware):   0%|          | 0/36629 [00:00<?, ?it/s]

✓ 154L-A: L= 185 → pglm_emb.pt
✓ 155C-A: L= 135 → pglm_emb.pt
✓ 16PK-A: L= 415 → pglm_emb.pt
✓ 16VP-A: L= 366 → pglm_emb.pt
✓ 1914-A: L= 232 → pglm_emb.pt
✓ 19HC-A: L= 292 → pglm_emb.pt
✓ 1A05-A: L= 358 → pglm_emb.pt
✓ 1A0C-A: L= 438 → pglm_emb.pt
✓ 1A0D-A: L= 440 → pglm_emb.pt
✓ 1A0E-A: L= 443 → pglm_emb.pt
✓ 1A0H-A: L= 159 → pglm_emb.pt
✓ 1A0I-A: L= 348 → pglm_emb.pt
✓ 1A0J-A: L= 223 → pglm_emb.pt
✓ 1A0Q-L: L= 212 → pglm_emb.pt
✓ 1A0R-P: L= 245 → pglm_emb.pt
✓ 1A14-H: L= 120 → pglm_emb.pt
✓ 1A14-L: L= 104 → pglm_emb.pt
✓ 1A17-A: L= 166 → pglm_emb.pt
✓ 1A1S-A: L= 314 → pglm_emb.pt
✓ 1A1Z-A: L=  91 → pglm_emb.pt
✓ 1A25-A: L= 149 → pglm_emb.pt
✓ 1A2A-A: L= 122 → pglm_emb.pt
✓ 1A2O-A: L= 349 → pglm_emb.pt
✓ 1A2Z-A: L= 220 → pglm_emb.pt
✓ 1A3W-A: L= 500 → pglm_emb.pt
✓ 1A41-A: L= 234 → pglm_emb.pt
✓ 1A47-A: L= 683 → pglm_emb.pt
✓ 1A4B-A: L= 129 → pglm_emb.pt
✓ 1A57-A: L= 116 → pglm_emb.pt
✓ 1A59-A: L= 378 → pglm_emb.pt
✓ 1A5I-A: L= 265 → pglm_emb.pt
✓ 1A5Z-A: L= 319 → pglm_emb.pt
✓ 1A63-A

In [1]:
from pathlib import Path
from tqdm.auto import tqdm

SEQS_DIR = Path("/teamspace/studios/this_studio/CAFA_Style/sequences")
TARGET_NAME = "esmc_emb.pt"

files = []
for pid_dir in SEQS_DIR.iterdir():
    if pid_dir.is_dir() and not pid_dir.name.startswith("."):
        f = pid_dir / TARGET_NAME
        if f.exists():
            files.append(f)

deleted = 0
for f in tqdm(files, desc="Deleting ankh embeddings"):
    try:
        f.unlink()
        deleted += 1
    except Exception as e:
        print(f"Failed to delete {f}: {e}")

print(f"Deleted {deleted} files." if deleted else "No ankh embeddings found.")

Deleting ankh embeddings:   0%|          | 0/32546 [00:00<?, ?it/s]

Deleted 32546 files.


In [None]:
from pathlib import Path
from tqdm.auto import tqdm
import torch

SEQS_DIR = Path("/teamspace/studios/this_studio/CAFA_Style/sequences")
TARGET = "ankh_emb_xl.pt"

paths = list(SEQS_DIR.rglob(TARGET))
converted = 0

for f in tqdm(paths, desc="Converting ankh_emb_xl.pt → FP16"):
    try:
        obj = torch.load(f, map_location="cpu")
        if isinstance(obj, torch.Tensor) and obj.dtype != torch.float16:
            torch.save(obj.half(), f)
            converted += 1
    except Exception as e:
        print(f"Skip {f}: {e}")

print(f"Converted {converted} file(s) to FP16.")


Converting ankh_emb_xl.pt → FP16:   0%|          | 0/20821 [00:00<?, ?it/s]

  obj = torch.load(f, map_location="cpu")


Skip /teamspace/studios/this_studio/CAFA_Style/sequences/GRRE1_HUMAN/ankh_emb_xl.pt: File /teamspace/studios/this_studio/CAFA_Style/sequences/GRRE1_HUMAN/ankh_emb_xl.pt cannot be opened.
Skip /teamspace/studios/this_studio/CAFA_Style/sequences/JAZF1_HUMAN/ankh_emb_xl.pt: File /teamspace/studios/this_studio/CAFA_Style/sequences/JAZF1_HUMAN/ankh_emb_xl.pt cannot be opened.
Skip /teamspace/studios/this_studio/CAFA_Style/sequences/KPCD2_HUMAN/ankh_emb_xl.pt: File /teamspace/studios/this_studio/CAFA_Style/sequences/KPCD2_HUMAN/ankh_emb_xl.pt cannot be opened.
Skip /teamspace/studios/this_studio/CAFA_Style/sequences/MMP8_HUMAN/ankh_emb_xl.pt: File /teamspace/studios/this_studio/CAFA_Style/sequences/MMP8_HUMAN/ankh_emb_xl.pt cannot be opened.
Skip /teamspace/studios/this_studio/CAFA_Style/sequences/NMS_RAT/ankh_emb_xl.pt: File /teamspace/studios/this_studio/CAFA_Style/sequences/NMS_RAT/ankh_emb_xl.pt cannot be opened.
Skip /teamspace/studios/this_studio/CAFA_Style/sequences/NMT1_ARATH/ankh_em

In [6]:
import torch

# Load the two embedding files
ankh_emb = torch.load('/teamspace/studios/this_studio/CAFA_Style/sequences/1A1L1_HUMAN/ankh_emb_xl.pt')
pglm_emb = torch.load('/teamspace/studios/this_studio/CAFA_Style/sequences/1A1L1_HUMAN/pglm_emb.pt')

# Compare the sizes
print(f"ANKH embedding shape: {ankh_emb.shape}")
print(f"PGLM embedding shape: {pglm_emb.shape}")
print(f"ANKH embedding size: {ankh_emb.numel()}")
print(f"PGLM embedding size: {pglm_emb.numel()}")

# Calculate memory sizes in gigabytes
ankh_size_gb = ankh_emb.numel() * ankh_emb.element_size() / (1024**3)
pglm_size_gb = pglm_emb.numel() * pglm_emb.element_size() / (1024**3)

print(f"ANKH embedding size: {ankh_size_gb:.4f} GB")
print(f"PGLM embedding size: {pglm_size_gb:.4f} GB")
print(f"Size difference: {abs(ankh_size_gb - pglm_size_gb):.4f} GB")

ANKH embedding shape: torch.Size([1, 503, 2560])
PGLM embedding shape: torch.Size([1, 503, 2560])
ANKH embedding size: 1287680
PGLM embedding size: 1287680
ANKH embedding size: 0.0048 GB
PGLM embedding size: 0.0024 GB
Size difference: 0.0024 GB


  ankh_emb = torch.load('/teamspace/studios/this_studio/CAFA_Style/sequences/1A1L1_HUMAN/ankh_emb_xl.pt')
  pglm_emb = torch.load('/teamspace/studios/this_studio/CAFA_Style/sequences/1A1L1_HUMAN/pglm_emb.pt')
