In [14]:
"""
Generate ProtT5 (XL-UniRef50) sequence embeddings for PIDs that appear under
CAFA_Style <split>/mf/*, reading sequences from
CAFA_Style/sequences/<PID>/seq.txt and writing the embeddings to
CAFA_Style/sequences/<PID>/prot_t5_emb.pt.

Output per PID:
    sequences/<PID>/prot_t5_emb.pt   # torch.Tensor, shape (1, L+2, DIM)
                                     # CLS/EOS slots are zero-filled

Notes
-----
- We *only* embed PIDs that exist under <split>/mf (splits: train/val/test).
- Tokenization follows the official ProtT5 recipe: space-separated residues.
- Ambiguous residues U,Z,O,B are mapped to 'X' (per official ProtT5 preprocessing).
- If L <= WINDOW_EFF, run a single forward pass (fast path).
- If L > WINDOW_EFF, use sliding windows with overlap and **context-trim**:
  in each window keep only the non-overlap core (no weighting). Only indices
  0 and L+1 are zero.
- Existing files are skipped silently; progress bar shows overall progress.
- Preserves original embedding logic; only I/O and PID enumeration changed.
"""
from __future__ import annotations

from pathlib import Path
from typing import Iterable, Set, List, Tuple
import re
import math

import torch
from tqdm.auto import tqdm
from transformers import AutoTokenizer, T5EncoderModel

# ─────────────────────────────── config ─────────────────────────────── #
BASE          = Path("/teamspace/studios/this_studio/CAFA_Style")
ONTOLOGY      = "mf"                       # only embed PIDs that appear under this ontology
SPLITS        = ["train", "val", "test"]
SEQS_DIR      = BASE / "sequences"         # source of seq.txt and target for embeddings
OUT_FILENAME  = "prot_t5_emb.pt"           # (1, L+2, DIM) with zero CLS/EOS rows
MODEL_NAME    = "Rostlab/prot_t5_xl_uniref50"

DEVICE        = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DTYPE         = torch.float32
DIM           = 1024                       # ProtT5 embedding dimension

# Sliding-window hyperparams for ProtT5 (T5 has EOS but no BOS; we keep your (1,L+2,*) interface)
MAX_LEN_TOKENS = 512     # effective model token limit (incl. EOS)
N_SPECIALS_WIN = 1        # EOS only inside each window
WINDOW_EFF     = MAX_LEN_TOKENS - N_SPECIALS_WIN  # residues per window
OVERLAP        = 128      # residues of overlap between adjacent windows (64–128 is a good default)
# ────────────────────────────────────────────────────────────────────── #


def gather_pids_from_mf(base: Path, splits: Iterable[str]) -> Set[str]:
    """Union of subfolder names under <base>/<split>/mf/* for all splits.
    Missing split folders are ignored with a warning.
    """
    pids: Set[str] = set()
    for split in splits:
        root = base / split / ONTOLOGY
        if not root.exists():
            tqdm.write(f"[WARN] Missing folder: {root} — skipping this split")
            continue
        for item in root.iterdir():
            if item.is_dir():
                pids.add(item.name)
    return pids


def read_sequence(seqs_dir: Path, pid: str) -> str:
    seq_path = seqs_dir / pid / "seq.txt"
    if not seq_path.exists():
        return ""
    seq_parts = []
    with seq_path.open("r", encoding="utf-8") as fh:
        for line in fh:
            if line.startswith(">"):
                continue
            seq_parts.append(line.strip())
    # Build raw sequence (uppercase), then map ambiguous residues per official ProtT5 recipe
    seq = ("".join(seq_parts)).upper()
    seq = re.sub(r"[UZOB]", "X", seq)
    if not seq:
        return ""
    return seq


@torch.inference_mode()
def _embed_window(model: T5EncoderModel, tok: AutoTokenizer, subseq: str) -> torch.Tensor:
    """Embed a *subsequence* with ProtT5, returning (Li, DIM) residue embeddings.

    - Space-separate residues.
    - add_special_tokens=True to append EOS.
    - Drop the final EOS row to keep only residues.
    """
    spaced = " ".join(subseq)
    enc = tok(spaced, return_tensors="pt", add_special_tokens=True)
    enc = {k: v.to(DEVICE) for k, v in enc.items()}

    out = model(**enc).last_hidden_state[0]  # [T, DIM], where T = Li + 1 (EOS)
    # Drop EOS (last row)
    kept = out[:-1]  # [Li, DIM]
    if kept.size(0) != len(subseq):
        raise RuntimeError(
            f"[ProtT5] Token/AA mismatch in window: tokens_no_eos={kept.size(0)} vs Li={len(subseq)}"
        )
    return kept.to(DTYPE)


def _plan_windows(L: int) -> List[Tuple[int, int]]:
    """Return inclusive-exclusive windows [(s,e), ...] covering [0, L) with overlap."""
    if L <= WINDOW_EFF:
        return [(0, L)]
    stride = max(1, WINDOW_EFF - OVERLAP)
    windows: List[Tuple[int, int]] = []
    s = 0
    while s < L:
        e = min(s + WINDOW_EFF, L)
        windows.append((s, e))
        if e == L:
            break
        s += stride
    return windows


@torch.inference_mode()
def embed_prott5_stitched(model: T5EncoderModel, tok: AutoTokenizer, seq: str) -> torch.Tensor:
    """Return ProtT5 embeddings as (1, L+2, DIM) with zero CLS/EOS rows."""
    L = len(seq)

    # Fast path: one shot, identical to your original logic.
    if L <= WINDOW_EFF:
        kept = _embed_window(model, tok, seq)  # (L, DIM)
        emb = torch.zeros(1, L + 2, kept.size(-1), dtype=DTYPE, device="cpu")
        emb[0, 1 : L + 1] = kept.detach().cpu()
        return emb

    # Sliding-window path
    windows = _plan_windows(L)

    # Output buffer (CPU to control memory)
    E = torch.empty(L, DIM, dtype=DTYPE, device="cpu")

    for (s, e) in windows:
        subseq = seq[s:e]
        Li = e - s
        if Li > WINDOW_EFF:
            raise RuntimeError(f"Internal window larger than WINDOW_EFF: Li={Li} > {WINDOW_EFF}")

        R = _embed_window(model, tok, subseq).detach().cpu()  # (Li, DIM)

        # Context-trim stitching: keep only the non-overlap core of each window.
        # First window keeps its full left edge; last window keeps its full right edge.
        left_trim  = 0 if s == 0 else OVERLAP // 2
        right_trim = 0 if e == L else (OVERLAP - OVERLAP // 2)

        gs = s + left_trim           # global start to write
        ge = e - right_trim          # global end   to write (exclusive)
        ls = left_trim               # local  start to read
        le = Li - right_trim         # local  end   to read (exclusive)

        if ge <= gs:
            raise RuntimeError(f"Empty keep range for window {(s,e)} with trims {left_trim},{right_trim}")
        if (le - ls) != (ge - gs):
            raise RuntimeError("Keep lengths disagree (local vs global) during trim-stitch.")

        E[gs:ge] = R[ls:le]

    # Pack to (1, L+2, DIM) with zeros at [0] and [L+1]
    emb = torch.zeros(1, L + 2, DIM, dtype=DTYPE, device="cpu")
    emb[0, 1 : L + 1] = E
    return emb


def main() -> None:
    if not BASE.exists():
        raise FileNotFoundError(f"CAFA_Style base not found: {BASE}")
    SEQS_DIR.mkdir(parents=True, exist_ok=True)

    pids = sorted(gather_pids_from_mf(BASE, SPLITS))
    if not pids:
        print("No PIDs found under <split>/mf — nothing to do.")
        return

    # Load model & tokenizer once
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, do_lower_case=False, use_fast=False)
    model = T5EncoderModel.from_pretrained(MODEL_NAME).to(DEVICE).eval()

    made, skipped_exist, skipped_empty = 0, 0, 0
    for pid in tqdm(pids, desc="ProtT5 embedding"):
        out_path = SEQS_DIR / pid / OUT_FILENAME
        if out_path.exists():
            skipped_exist += 1
            continue  # skip silently

        seq = read_sequence(SEQS_DIR, pid)
        if not seq:
            skipped_empty += 1
            continue

        emb = embed_prott5_stitched(model, tokenizer, seq)  # (1, L+2, DIM); zeros only at [0] and [L+1]
        out_path.parent.mkdir(parents=True, exist_ok=True)
        torch.save(emb, out_path)
        made += 1

    print(
        f"Done. created={made} | existing_skipped={skipped_exist} | missing_seq_skipped={skipped_empty}"
    )


if __name__ == "__main__":
    main()


ProtT5 embedding:   0%|          | 0/32546 [00:00<?, ?it/s]

Done. created=12708 | existing_skipped=19838 | missing_seq_skipped=0


In [1]:
"""
Generate ESM-C (600M) sequence embeddings for PIDs that appear under the
CAFA_Style <split>/mf/* directories, *reading sequences from*
CAFA_Style/sequences/<PID>/seq.txt and *writing*
CAFA_Style/sequences/<PID>/esmc_emb.pt.

Output per PID:
    sequences/<PID>/esmc_emb.pt     # torch.Tensor, shape (1, L+2, 1152)

Logic
-----
- If L <= WINDOW_EFF, run a single forward pass (fast path). This preserves
  the original ESM-C behavior, including native [CLS]/[EOS] embeddings.
- If L > WINDOW_EFF, use sliding windows with overlap and **context-trim**:
  for each window, embed, drop CLS/EOS to get per-residue rows, then keep
  only the non-overlap core (no weighting). Stitch into a full [L, 1152]
  buffer. We then pack to (1, L+2, 1152) with zeros at positions [0] and
  [L+1] (global CLS/EOS are not available under windowing).

Notes
-----
- We *only* embed PIDs that exist under <split>/mf (splits: train/val/test).
- Existing files are skipped.
- Uses GPU if available.
"""
from __future__ import annotations

import re
from pathlib import Path
from typing import Iterable, Set, List, Tuple

import torch
from tqdm.auto import tqdm

# ─────────────────────────────── config ─────────────────────────────── #
BASE          = Path("/teamspace/studios/this_studio/CAFA_Style")
ONTOLOGY      = "mf"                       # only embed PIDs that appear under this ontology
SPLITS        = ["train", "val", "test"]
SEQS_DIR      = BASE / "sequences"        # source of seq.txt and target for esmc_emb.pt
OUT_FILENAME  = "esmc_emb.pt"             # what downstream expects
MODEL_NAME    = "esmc_600m"               # Fair-ESM registry tag

DEVICE        = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DTYPE         = torch.float32
DIM           = 1152

# Sliding-window hyperparameters (mirroring your ProtT5 setup)
# ESM-C has both CLS and EOS per window → 2 specials inside each window.
MAX_LEN_TOKENS = 2048     # effective model token limit (incl. CLS+EOS)
N_SPECIALS_WIN = 2        # CLS + EOS
WINDOW_EFF     = MAX_LEN_TOKENS - N_SPECIALS_WIN   # residues per window
OVERLAP        = 256      # residues overlapped between adjacent windows
# ────────────────────────────────────────────────────────────────────── #


def gather_pids_from_mf(base: Path, splits: Iterable[str]) -> Set[str]:
    """Union of subfolder names under <base>/<split>/mf/* for all splits."""
    pids: Set[str] = set()
    for split in splits:
        root = base / split / ONTOLOGY
        if not root.exists():
            tqdm.write(f"[WARN] Missing folder: {root} — skipping this split")
            continue
        for item in root.iterdir():
            if item.is_dir():
                pids.add(item.name)
    return pids


def read_sequence(seqs_dir: Path, pid: str) -> str:
    """Read FASTA-like or plain sequence, strip headers/whitespace, uppercase."""
    seq_path = seqs_dir / pid / "seq.txt"
    if not seq_path.exists():
        return ""
    seq_parts = []
    with seq_path.open("r", encoding="utf-8") as fh:
        for line in fh:
            if line.startswith(">"):
                continue
            seq_parts.append(line.strip())
    seq = ("".join(seq_parts)).upper()
    return seq


def _plan_windows(L: int) -> List[Tuple[int, int]]:
    """Return inclusive-exclusive windows [(s,e), ...] covering [0, L) with overlap."""
    if L <= WINDOW_EFF:
        return [(0, L)]
    stride = max(1, WINDOW_EFF - OVERLAP)
    windows: List[Tuple[int, int]] = []
    s = 0
    while s < L:
        e = min(s + WINDOW_EFF, L)
        windows.append((s, e))
        if e == L:
            break
        s += stride
    return windows


@torch.inference_mode()
def _embed_window_esmc(model, subseq: str) -> torch.Tensor:
    """
    Embed a *subsequence* with ESM-C and return per-residue embeddings [Li, DIM].

    We call model.encode + model.logits(..., return_embeddings=True), which returns
    (1, Li+2, DIM) including CLS/EOS. We drop [CLS]/[EOS] here and only return
    residue rows.
    """
    from esm.sdk.api import ESMProtein, LogitsConfig

    prot   = ESMProtein(sequence=subseq)
    hidden = model.encode(prot)  # device-aware; model already on DEVICE
    out    = model.logits(hidden, LogitsConfig(sequence=True, return_embeddings=True))
    # out.embeddings: (1, Li+2, DIM)
    R = out.embeddings[0, 1:-1].to(DTYPE)  # (Li, DIM), drop CLS/EOS
    if R.size(0) != len(subseq):
        raise RuntimeError(
            f"[ESM-C] Token/AA mismatch in window: residues={len(subseq)} vs got={R.size(0)} rows"
        )
    return R


@torch.inference_mode()
def _embed_full_esmc(model, seq: str) -> torch.Tensor:
    """
    Full-sequence ESM-C embedding → (1, L+2, DIM), keeping native CLS/EOS.
    This matches your original single-pass behavior.
    """
    from esm.sdk.api import ESMProtein, LogitsConfig

    prot   = ESMProtein(sequence=seq)
    hidden = model.encode(prot)
    out    = model.logits(hidden, LogitsConfig(sequence=True, return_embeddings=True))
    return out.embeddings.to(DTYPE).cpu()  # (1, L+2, DIM)


@torch.inference_mode()
def embed_esmc_stitched(model, seq: str) -> torch.Tensor:
    """
    Return ESM-C embeddings as (1, L+2, 1152).

    - Short sequences (L <= WINDOW_EFF): single pass; preserve native CLS/EOS.
    - Long sequences: sliding windows + context-trim stitching for residues;
      pack to (1, L+2, DIM) with zeros at [0] and [L+1].
    """
    L = len(seq)
    if L == 0:
        raise ValueError("Empty sequence.")

    # Fast path (identical to original logic)
    if L <= WINDOW_EFF:
        return _embed_full_esmc(model, seq)

    # Sliding-window path
    windows = _plan_windows(L)
    E = torch.empty(L, DIM, dtype=DTYPE, device="cpu")

    for (s, e) in windows:
        subseq = seq[s:e]
        Li = e - s
        if Li > WINDOW_EFF:
            raise RuntimeError(f"Internal window larger than WINDOW_EFF: Li={Li} > {WINDOW_EFF}")

        R = _embed_window_esmc(model, subseq).detach().cpu()  # (Li, DIM)

        # Context-trim stitching (same as in your ProtT5 script).
        left_trim  = 0 if s == 0 else OVERLAP // 2
        right_trim = 0 if e == L else (OVERLAP - OVERLAP // 2)

        gs = s + left_trim           # global start
        ge = e - right_trim          # global end (exclusive)
        ls = left_trim               # local start
        le = Li - right_trim         # local end (exclusive)

        if ge <= gs:
            raise RuntimeError(f"Empty keep range for window {(s, e)}")
        if (le - ls) != (ge - gs):
            raise RuntimeError("Keep lengths disagree during trim-stitch.")

        E[gs:ge] = R[ls:le]

    # Pack to (1, L+2, DIM); under sliding we zero CLS/EOS to keep interface stable.
    emb = torch.zeros(1, L + 2, DIM, dtype=DTYPE, device="cpu")
    emb[0, 1 : L + 1] = E
    return emb


def main() -> None:
    if not BASE.exists():
        raise FileNotFoundError(f"CAFA_Style base not found: {BASE}")
    SEQS_DIR.mkdir(parents=True, exist_ok=True)

    pids = sorted(gather_pids_from_mf(BASE, SPLITS))
    if not pids:
        print("No PIDs found under <split>/mf — nothing to do.")
        return

    print(f"Found {len(pids)} unique PIDs under mf across {SPLITS}\n")
    print(f"⇢ Loading {MODEL_NAME} onto {DEVICE} …")
    from esm.models.esmc import ESMC

    model = ESMC.from_pretrained(MODEL_NAME).to(DEVICE).eval()

    made, skipped_exist, skipped_empty = 0, 0, 0
    for pid in tqdm(pids, desc="Embedding (ESM-C, sliding-aware)"):
        out_path = SEQS_DIR / pid / OUT_FILENAME
        if out_path.exists():
            skipped_exist += 1
            continue

        # Read sequence (accept FASTA-like; strip headers/whitespace)
        seq_path = SEQS_DIR / pid / "seq.txt"
        if not seq_path.exists():
            tqdm.write(f"[WARN] {pid}: missing sequences/seq.txt — skipping")
            skipped_empty += 1
            continue

        raw = seq_path.read_text(encoding="utf-8")
        seq = "".join(line.strip() for line in raw.splitlines() if not line.startswith(">")).upper()
        if not seq:
            tqdm.write(f"[WARN] {pid}: empty sequence — skipping")
            skipped_empty += 1
            continue

        # Embed (fast path or sliding-stitch)
        emb = embed_esmc_stitched(model, seq)  # (1, L+2, 1152)
        out_path.parent.mkdir(parents=True, exist_ok=True)
        torch.save(emb, out_path)
        made += 1
        tqdm.write(f"✓ {pid}: L={len(seq):4d} → {OUT_FILENAME}")

    print(
        f"\nDone. created={made} | existing_skipped={skipped_exist} | missing/empty_seq_skipped={skipped_empty}"
    )


if __name__ == "__main__":
    main()


Found 32546 unique PIDs under mf across ['train', 'val', 'test']

⇢ Loading esmc_600m onto cuda …


Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

README.md: 0.00B [00:00, ?B/s]

.gitattributes: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/2.00 [00:00<?, ?B/s]

esmc_600m_2024_12_v0.pth:   0%|          | 0.00/2.30G [00:00<?, ?B/s]

  state_dict = torch.load(


Embedding (ESM-C, sliding-aware):   0%|          | 0/32546 [00:00<?, ?it/s]

✓ 128UP_DROME: L= 368 → esmc_emb.pt
✓ 14310_ARATH: L= 254 → esmc_emb.pt
✓ 14311_ARATH: L= 252 → esmc_emb.pt
✓ 14333_ARATH: L= 255 → esmc_emb.pt
✓ 14335_ARATH: L= 268 → esmc_emb.pt
✓ 1433B_HUMAN: L= 246 → esmc_emb.pt
✓ 1433B_MOUSE: L= 246 → esmc_emb.pt
✓ 1433B_RAT: L= 246 → esmc_emb.pt
✓ 1433E_DROME: L= 262 → esmc_emb.pt
✓ 1433E_HUMAN: L= 255 → esmc_emb.pt
✓ 1433E_MOUSE: L= 255 → esmc_emb.pt
✓ 1433E_RAT: L= 255 → esmc_emb.pt
✓ 1433F_HUMAN: L= 246 → esmc_emb.pt
✓ 1433F_MOUSE: L= 246 → esmc_emb.pt
✓ 1433G_HUMAN: L= 247 → esmc_emb.pt
✓ 1433G_MOUSE: L= 247 → esmc_emb.pt
✓ 1433G_RAT: L= 247 → esmc_emb.pt
✓ 1433S_HUMAN: L= 248 → esmc_emb.pt
✓ 1433S_MOUSE: L= 248 → esmc_emb.pt
✓ 1433T_HUMAN: L= 245 → esmc_emb.pt
✓ 1433T_MOUSE: L= 245 → esmc_emb.pt
✓ 1433T_RAT: L= 245 → esmc_emb.pt
✓ 1433Z_HUMAN: L= 245 → esmc_emb.pt
✓ 1433Z_MOUSE: L= 245 → esmc_emb.pt
✓ 1433Z_RAT: L= 245 → esmc_emb.pt
✓ 1433_DICDI: L= 252 → esmc_emb.pt
✓ 1A111_ARATH: L= 460 → esmc_emb.pt
✓ 1A12_ARATH: L= 496 → esmc_emb.pt
✓ 1A

In [1]:
"""
Generate Ankh3-XL encoder embeddings for PIDs that appear under the
CAFA_Style <split>/mf/* directories, *reading sequences from*
CAFA_Style/sequences/<PID>/seq.txt and *writing*
CAFA_Style/sequences/<PID>/ankh_emb_xl.pt.

Output per PID:
    sequences/<PID>/ankh_emb_xl.pt     # torch.Tensor, (1, L+2, d_model)

Logic
-----
- Positions [0] and [L+1] are zeros; positions [1..L] are per-residue embeddings.
- Uses the S2S prefix token and includes EOS in tokenization.
- If L <= WINDOW_EFF (512), run a single forward pass (fast path).
- If L > WINDOW_EFF, use sliding windows with **OVERLAP=128** and
  **context-trim** stitching (keep only the non-overlap core per window; no weighting).
  Only indices 0 and L+1 are zero.

Notes
-----
- We *only* embed PIDs that exist under <split>/mf (splits: train/val/test).
- Skips existing outputs and missing/empty sequences.
- Uses GPU if available.
"""
from __future__ import annotations

from pathlib import Path
from typing import Iterable, Set, List, Tuple

import torch
from tqdm.auto import tqdm
from transformers import T5Tokenizer, T5EncoderModel

# ─────────────────────────────── config ─────────────────────────────── #
BASE          = Path("/teamspace/studios/this_studio/CAFA_Style")
ONTOLOGY      = "mf"                       # only embed PIDs that appear under this ontology
SPLITS        = ["train", "val", "test"]
SEQS_DIR      = BASE / "sequences"        # source of seq.txt and target for outputs
OUT_FILENAME  = "ankh_emb_xl.pt"
CKPT          = "ElnaggarLab/ankh3-xl"
PREFIX        = "[S2S]"                   # Use S2S prefix for better embedding quality

DTYPE         = torch.float16
DEVICE        = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hardcoded model dimension for Ankh3-XL
D_MODEL       = 2560

# Sliding-window hyperparameters (requested)
WINDOW_EFF    = 512    # residues per window (context window)
OVERLAP       = 128    # residues of overlap between adjacent windows
# ────────────────────────────────────────────────────────────────────── #


def gather_pids_from_mf(base: Path, splits: Iterable[str]) -> Set[str]:
    """Union of subfolder names under <base>/<split>/mf/* for all splits."""
    pids: Set[str] = set()
    for split in splits:
        root = base / split / ONTOLOGY
        if not root.exists():
            tqdm.write(f"[WARN] Missing folder: {root} — skipping this split")
            continue
        for item in root.iterdir():
            if item.is_dir():
                pids.add(item.name)
    return pids


def read_sequence(seqs_dir: Path, pid: str) -> str:
    """Read FASTA-like or plain sequence, strip headers/whitespace, uppercase."""
    seq_path = seqs_dir / pid / "seq.txt"
    if not seq_path.exists():
        return ""
    parts = []
    with seq_path.open("r", encoding="utf-8") as fh:
        for line in fh:
            if line.startswith(">"):
                continue
            parts.append(line.strip())
    return ("".join(parts)).upper()


def _plan_windows(L: int) -> List[Tuple[int, int]]:
    """Return inclusive-exclusive windows [(s,e), ...] covering [0, L) with overlap."""
    if L <= WINDOW_EFF:
        return [(0, L)]
    stride = max(1, WINDOW_EFF - OVERLAP)
    windows: List[Tuple[int, int]] = []
    s = 0
    while s < L:
        e = min(s + WINDOW_EFF, L)
        windows.append((s, e))
        if e == L:
            break
        s += stride
    return windows


@torch.inference_mode()
def _embed_window_ankh(model: T5EncoderModel, tok: T5Tokenizer, subseq: str) -> torch.Tensor:
    """
    Embed a *subsequence* with Ankh3-XL and return per-residue embeddings [Li, d_model].

    We tokenize with the S2S prefix and add_special_tokens=True (includes EOS).
    The first token corresponds to the prefix; we drop the prefix and EOS,
    keeping exactly Li residue rows.
    """
    enc = tok(PREFIX + subseq, add_special_tokens=True, return_tensors="pt")
    enc = {k: v.to(model.device) for k, v in enc.items()}

    hidden = model(**enc).last_hidden_state.squeeze(0)  # [T, d_model]
    Li = len(subseq)

    # Positions: 0 -> S2S prefix; 1..Li -> residues; Li+1 -> EOS (when present)
    per_res = hidden[1 : 1 + Li]                         # [Li, d_model]
    if per_res.size(0) != Li:
        raise RuntimeError(
            f"[Ankh3-XL] Token/AA mismatch in window: expected {Li} got {per_res.size(0)}"
        )
    return per_res.to(DTYPE)


@torch.inference_mode()
def _embed_full_ankh(model: T5EncoderModel, tok: T5Tokenizer, seq: str) -> torch.Tensor:
    """
    Full-sequence Ankh3-XL embedding → (1, L+2, d_model) with zeros at [0] and [L+1],
    residues in [1..L]. (We still use S2S prefix + EOS at encode time.)
    """
    L = len(seq)
    per_res = _embed_window_ankh(model, tok, seq)        # [L, d_model]
    d = per_res.size(-1)
    out = torch.zeros(1, L + 2, d, dtype=DTYPE, device="cpu")
    out[0, 1 : L + 1] = per_res.detach().cpu()
    return out


@torch.inference_mode()
def embed_ankh_stitched(model: T5EncoderModel, tok: T5Tokenizer, seq: str) -> torch.Tensor:
    """
    Return Ankh3-XL embeddings as (1, L+2, d_model), zeros at [0] and [L+1].
    - Short sequences (L <= 512): single pass (preserves your original interface).
    - Long sequences: sliding windows + context-trim stitching (no weighting).
    """
    L = len(seq)
    if L == 0:
        raise ValueError("Empty sequence.")

    if L <= WINDOW_EFF:
        return _embed_full_ankh(model, tok, seq)

    # Sliding-window path
    windows = _plan_windows(L)
    E = torch.empty(L, D_MODEL, dtype=DTYPE, device="cpu")  # residue buffer

    for (s, e) in windows:
        subseq = seq[s:e]
        Li = e - s
        if Li > WINDOW_EFF:
            raise RuntimeError(f"Internal window larger than WINDOW_EFF: Li={Li} > {WINDOW_EFF}")

        R = _embed_window_ankh(model, tok, subseq).detach().cpu()  # [Li, d_model]

        # Context-trim: first window keeps its left edge; last keeps its right edge.
        left_trim  = 0 if s == 0 else OVERLAP // 2             # 0 or 64
        right_trim = 0 if e == L else (OVERLAP - OVERLAP // 2) # 0 or 64

        gs = s + left_trim
        ge = e - right_trim
        ls = left_trim
        le = Li - right_trim

        if ge <= gs:
            raise RuntimeError(f"Empty keep range for window {(s, e)}")
        if (le - ls) != (ge - gs):
            raise RuntimeError("Keep lengths disagree during trim-stitch.")

        E[gs:ge] = R[ls:le]

    # Pack (1, L+2, d_model) with zeros at [0] and [L+1]
    out = torch.zeros(1, L + 2, D_MODEL, dtype=DTYPE, device="cpu")
    out[0, 1 : L + 1] = E
    return out


def main() -> None:
    if not BASE.exists():
        raise FileNotFoundError(f"CAFA_Style base not found: {BASE}")
    SEQS_DIR.mkdir(parents=True, exist_ok=True)

    # Load model/tokenizer once
    print(f"⇢ Loading {CKPT} encoder on {DEVICE} …")
    tok   = T5Tokenizer.from_pretrained(CKPT)
    model = T5EncoderModel.from_pretrained(CKPT).to(DEVICE).eval()

    pids = sorted(gather_pids_from_mf(BASE, SPLITS))
    if not pids:
        print("No PIDs found under <split>/mf — nothing to do.")
        return

    print(f"Found {len(pids)} unique PIDs under mf across {SPLITS}\n")

    made, skipped_exist, skipped_empty = 0, 0, 0
    for pid in tqdm(pids, desc="Embedding (Ankh3-XL, sliding-aware)"):
        out_path = SEQS_DIR / pid / OUT_FILENAME
        if out_path.exists():
            skipped_exist += 1
            continue

        seq = read_sequence(SEQS_DIR, pid)
        if not seq:
            tqdm.write(f"[WARN] {pid}: missing/empty seq.txt in sequences — skipping")
            skipped_empty += 1
            continue

        emb = embed_ankh_stitched(model, tok, seq)  # (1, L+2, d_model)
        out_path.parent.mkdir(parents=True, exist_ok=True)
        torch.save(emb, out_path)
        made += 1
        tqdm.write(f"✓ {pid}: L={len(seq):4d} → {OUT_FILENAME}")

    print(
        f"\nDone. created={made} | existing_skipped={skipped_exist} | missing_seq_skipped={skipped_empty}"
    )


if __name__ == "__main__":
    main()


⇢ Loading ElnaggarLab/ankh3-xl encoder on cuda …


tokenizer_config.json: 0.00B [00:00, ?B/s]

spiece.model:   0%|          | 0.00/238k [00:00<?, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


config.json:   0%|          | 0.00/778 [00:00<?, ?B/s]

pytorch_model.bin.index.json: 0.00B [00:00, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

pytorch_model-00001-of-00003.bin:   0%|          | 0.00/7.99G [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

pytorch_model-00002-of-00003.bin:   0%|          | 0.00/7.98G [00:00<?, ?B/s]

pytorch_model-00003-of-00003.bin:   0%|          | 0.00/6.95G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Found 32546 unique PIDs under mf across ['train', 'val', 'test']



Embedding (Ankh3-XL, sliding-aware):   0%|          | 0/32546 [00:00<?, ?it/s]

✓ PGM3_YEAST: L= 622 → ankh_emb_xl.pt
✓ PGM48_ARATH: L= 344 → ankh_emb_xl.pt
✓ PGMB_ECOLI: L= 219 → ankh_emb_xl.pt
✓ PGMC1_ARATH: L= 583 → ankh_emb_xl.pt
✓ PGMP_ARATH: L= 623 → ankh_emb_xl.pt
✓ PGM_DROME: L= 560 → ankh_emb_xl.pt
✓ PGM_ECOLI: L= 546 → ankh_emb_xl.pt
✓ PGPA_ECOLI: L= 172 → ankh_emb_xl.pt
✓ PGPB_ECOLI: L= 254 → ankh_emb_xl.pt
✓ PGPC_ECOLI: L= 211 → ankh_emb_xl.pt
✓ PGPLB_DROME: L= 232 → ankh_emb_xl.pt
✓ PGPLC_DROME: L= 520 → ankh_emb_xl.pt
✓ PGPLF_DROME: L= 369 → ankh_emb_xl.pt
✓ PGPP1_ARATH: L= 348 → ankh_emb_xl.pt
✓ PGPS1_ARATH: L= 296 → ankh_emb_xl.pt
✓ PGPS1_SCHPO: L= 502 → ankh_emb_xl.pt
✓ PGPS1_YEAST: L= 521 → ankh_emb_xl.pt
✓ PGPS2_ARATH: L= 233 → ankh_emb_xl.pt
✓ PGPSA_DROME: L= 203 → ankh_emb_xl.pt
✓ PGPSD_DROME: L= 186 → ankh_emb_xl.pt
✓ PGP_HUMAN: L= 321 → ankh_emb_xl.pt
✓ PGP_MOUSE: L= 321 → ankh_emb_xl.pt
✓ PGP_RAT: L= 321 → ankh_emb_xl.pt
✓ PGR5_ARATH: L= 133 → ankh_emb_xl.pt
✓ PGRC1_HUMAN: L= 195 → ankh_emb_xl.pt
✓ PGRC1_MOUSE: L= 195 → ankh_emb_xl.pt
✓ PGR

In [6]:
"""
Generate ProteinGLM per-residue embeddings for PIDs that appear under the
CAFA_Style <split>/mf/* directories, *reading sequences from*
CAFA_Style/sequences/<PID>/seq.txt and *writing*
CAFA_Style/sequences/<PID>/pglm_emb.pt.

This preserves your original per-residue embedding layout and only changes how we
handle long sequences: if L > WINDOW_EFF, we use sliding windows with overlap
and **context-trim** stitching (no weighting). For long sequences, indices 1..L
contain per-residue embeddings; indices 0 and L+1 are zero.

Layout expected:
    /teamspace/studios/this_studio/CAFA_Style/
      ├─ train/mf/<PID>/
      ├─ val/mf/<PID>/
      ├─ test/mf/<PID>/
      └─ sequences/<PID>/seq.txt   # created by your flat builder

Output per PID:
    sequences/<PID>/pglm_emb.pt  # torch.Tensor, shape (1, L+2, 2560)

Run:
    python embed_pglm_cafa_style.py

Notes:
- Skips PIDs whose embeddings already exist.
- Skips PIDs with missing/empty sequences, with a warning.
- Uses GPU if available.
"""
from __future__ import annotations

from pathlib import Path
from typing import Iterable, Set, List, Tuple

import torch
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModelForMaskedLM

# ─────────────────────────────── config ─────────────────────────────── #
BASE          = Path("/teamspace/studios/this_studio/CAFA_Style")
ONTOLOGY      = "mf"                       # only embed PIDs that appear under this ontology
SPLITS        = ["train", "val", "test"]
SEQS_DIR      = BASE / "sequences"        # source of seq.txt and target for outputs

MODEL_ID      = "biomap-research/proteinglm-3b-mlm"   # HF repo id
OUT_FILENAME  = "pglm_emb.pt"                         # per-protein output
DTYPE_SAVE    = torch.float32
DEVICE        = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ProteinGLM specifics (as requested)
DIM              = 2560
MAX_LEN_TOKENS   = 2048      # model token limit incl. specials
N_SPECIALS_WIN   = 1         # assume only <eos> is added; set to 2 if BOS is present
WINDOW_EFF       = MAX_LEN_TOKENS - N_SPECIALS_WIN  # residues per window
OVERLAP          = 256        # residues overlap between windows
# ────────────────────────────────────────────────────────────────────── #


def gather_pids_from_mf(base: Path, splits: Iterable[str]) -> Set[str]:
    """Union of subfolder names under <base>/<split>/mf/* for all splits."""
    pids: Set[str] = set()
    for split in splits:
        root = base / split / ONTOLOGY
        if not root.exists():
            tqdm.write(f"[WARN] Missing folder: {root} — skipping this split")
            continue
        for item in root.iterdir():
            if item.is_dir():
                pids.add(item.name)
    return pids


def read_sequence(seqs_dir: Path, pid: str) -> str:
    """Read FASTA-like or plain sequence, strip headers/whitespace, uppercase."""
    seq_path = seqs_dir / pid / "seq.txt"
    if not seq_path.exists():
        return ""
    seq_parts = []
    with seq_path.open("r", encoding="utf-8") as fh:
        for line in fh:
            if line.startswith(">"):
                continue
            seq_parts.append(line.strip())
    return ("".join(seq_parts)).upper().replace(" ", "")


def _plan_windows(L: int) -> List[Tuple[int, int]]:
    """Return inclusive-exclusive windows [(s,e), ...] covering [0, L) with overlap."""
    if L <= WINDOW_EFF:
        return [(0, L)]
    stride = max(1, WINDOW_EFF - OVERLAP)
    windows: List[Tuple[int, int]] = []
    s = 0
    while s < L:
        e = min(s + WINDOW_EFF, L)
        windows.append((s, e))
        if e == L:
            break
        s += stride
    return windows


@torch.inference_mode()
def _embed_window_pglm(model: AutoModelForMaskedLM, tok: AutoTokenizer, subseq: str) -> torch.Tensor:
    """
    Embed a *subsequence* with ProteinGLM, returning per-residue embeddings [Li, DIM].
    Uses exact logic from the original embedding function.
    """
    # Tokenize with special tokens (adds <eos>)
    out = tok(subseq, add_special_tokens=True, return_tensors='pt')
    inputs = {
        "input_ids": out["input_ids"].to(model.device),
        "attention_mask": out["attention_mask"].to(model.device),
    }

    # Model forward per Hugging Face usage; drop trailing <eos> token embeddings
    out_m = model(**inputs, output_hidden_states=True, return_last_hidden_state=True)
    token_emb = out_m.hidden_states[:-1, 0]  # per docs: remove <eos>, take batch=0 → [Li, d]

    if token_emb.size(0) != len(subseq):
        raise RuntimeError(
            f"[ProteinGLM] Token/AA mismatch in window: tokens_no_specials={token_emb.size(0)} vs Li={len(subseq)}"
        )
    return token_emb.to(DTYPE_SAVE)


@torch.inference_mode()
def _embed_full_pglm(model: AutoModelForMaskedLM, tok: AutoTokenizer, seq: str) -> torch.Tensor:
    """
    Full-sequence per-residue embeddings packed to (1, L+2, DIM) with zeros at [0] and [L+1].
    Uses exact logic from the original embedding function.
    """
    L = len(seq)

    # Tokenize with special tokens (adds <eos>)
    out = tok(seq, add_special_tokens=True, return_tensors='pt')
    inputs = {
        "input_ids": out["input_ids"].to(model.device),
        "attention_mask": out["attention_mask"].to(model.device),
    }

    # Model forward per Hugging Face usage; drop trailing <eos> token embeddings
    out_m = model(**inputs, output_hidden_states=True, return_last_hidden_state=True)
    token_emb = out_m.hidden_states[:-1, 0]  # per docs: remove <eos>, take batch=0 → [L, d]

    if token_emb.size(0) != L:
        raise RuntimeError(f"[ProteinGLM] Full pass token/AA mismatch: got {token_emb.size(0)} vs L={L}")

    # Pack to (1, L+2, d); zero at positions 0 and L+1; move to CPU float32 for saving
    token_emb = token_emb.to(DTYPE_SAVE).cpu()  # [L, d]
    d = token_emb.size(-1)
    emb = torch.zeros(1, L + 2, d, dtype=DTYPE_SAVE, device="cpu")
    emb[0, 1:L+1] = token_emb
    return emb


@torch.inference_mode()
def embed_pglm_stitched(model: AutoModelForMaskedLM, tok: AutoTokenizer, seq: str) -> torch.Tensor:
    """
    Return ProteinGLM embeddings as (1, L+2, DIM).

    - Short sequences (L <= WINDOW_EFF): single pass; zeros at [0] and [L+1].
    - Long sequences: sliding windows + context-trim stitching for residues;
      pack to (1, L+2, DIM) with zeros at [0] and [L+1].
    """
    L = len(seq)
    if L == 0:
        raise ValueError("Empty sequence.")

    # Fast path
    if L <= WINDOW_EFF:
        return _embed_full_pglm(model, tok, seq)

    # Sliding-window path
    windows = _plan_windows(L)
    E = torch.empty(L, DIM, dtype=DTYPE_SAVE, device="cpu")

    for (s, e) in windows:
        subseq = seq[s:e]
        Li = e - s
        if Li > WINDOW_EFF:
            raise RuntimeError(f"Internal window larger than WINDOW_EFF: Li={Li} > {WINDOW_EFF}")

        R = _embed_window_pglm(model, tok, subseq).detach().cpu()  # (Li, DIM)

        # Context-trim stitching (identical pattern to your ProtT5 script)
        left_trim  = 0 if s == 0 else OVERLAP // 2
        right_trim = 0 if e == L else (OVERLAP - OVERLAP // 2)

        gs = s + left_trim           # global start (inclusive)
        ge = e - right_trim          # global end   (exclusive)
        ls = left_trim               # local  start
        le = Li - right_trim         # local  end

        if ge <= gs:
            raise RuntimeError(f"Empty keep range for window {(s, e)}")
        if (le - ls) != (ge - gs):
            raise RuntimeError("Keep lengths disagree during trim-stitch.")

        E[gs:ge] = R[ls:le]

    # Pack to (1, L+2, DIM) with zeros at [0] and [L+1]
    emb = torch.zeros(1, L + 2, DIM, dtype=DTYPE_SAVE, device="cpu")
    emb[0, 1:L+1] = E
    return emb


# ─────────────────────────────────── main ───────────────────────────── #

def main() -> None:
    if not BASE.exists():
        raise FileNotFoundError(f"CAFA_Style base not found: {BASE}")
    SEQS_DIR.mkdir(parents=True, exist_ok=True)

    pids = sorted(gather_pids_from_mf(BASE, SPLITS))
    if not pids:
        print("No PIDs found under <split>/mf — nothing to do.")
        return

    print(f"Found {len(pids)} unique PIDs under mf across {SPLITS}\n")
    print(f"⇢ Loading {MODEL_ID} onto {DEVICE} …")

    tok = AutoTokenizer.from_pretrained(MODEL_ID)
    model = AutoModelForMaskedLM.from_pretrained(MODEL_ID).to(DEVICE).eval()

    made, skipped_exist, skipped_empty = 0, 0, 0
    for pid in tqdm(pids, desc="Embedding (ProteinGLM, sliding-aware)"):
        out_path = SEQS_DIR / pid / OUT_FILENAME
        if out_path.exists():
            skipped_exist += 1
            continue

        seq = read_sequence(SEQS_DIR, pid)
        if not seq:
            tqdm.write(f"[WARN] {pid}: missing/empty seq.txt in sequences — skipping")
            skipped_empty += 1
            continue

        emb = embed_pglm_stitched(model, tok, seq)  # (1, L+2, 2560)
        out_path.parent.mkdir(parents=True, exist_ok=True)
        torch.save(emb, out_path)
        made += 1
        tqdm.write(f"✓ {pid}: L={len(seq):4d} → {OUT_FILENAME}")

    print(
        f"\nDone. created={made} | existing_skipped={skipped_exist} | missing_seq_skipped={skipped_empty}"
    )


if __name__ == "__main__":
    main()


Found 32546 unique PIDs under mf across ['train', 'val', 'test']

⇢ Loading biomap-research/proteinglm-3b-mlm onto cuda …


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Embedding (ProteinGLM, sliding-aware):   0%|          | 0/32546 [00:00<?, ?it/s]

✓ 7LESS_DROME: L=2554 → pglm_emb.pt
✓ A1M_RAT: L=1500 → pglm_emb.pt
✓ A2MG_ECOLI: L=1653 → pglm_emb.pt
✓ A2MG_HUMAN: L=1474 → pglm_emb.pt
✓ A2MG_MOUSE: L=1474 → pglm_emb.pt
✓ A2MG_RAT: L=1472 → pglm_emb.pt
✓ A2ML1_HUMAN: L=1454 → pglm_emb.pt
✓ AB19B_ARATH: L=1252 → pglm_emb.pt
✓ AB1B_ARATH: L=1286 → pglm_emb.pt
✓ AB1C_ARATH: L=1622 → pglm_emb.pt
✓ AB21B_ARATH: L=1296 → pglm_emb.pt
✓ AB2C_ARATH: L=1623 → pglm_emb.pt
✓ AB30G_ARATH: L=1400 → pglm_emb.pt
✓ AB31G_ARATH: L=1426 → pglm_emb.pt
✓ AB36G_ARATH: L=1469 → pglm_emb.pt
✓ AB37G_ARATH: L=1450 → pglm_emb.pt
✓ AB3C_ARATH: L=1514 → pglm_emb.pt
✓ AB40G_ARATH: L=1423 → pglm_emb.pt
✓ AB4B_ARATH: L=1286 → pglm_emb.pt
✓ AB5C_ARATH: L=1514 → pglm_emb.pt
✓ ABC2_SCHPO: L=1478 → pglm_emb.pt
✓ ABC3_SCHPO: L=1465 → pglm_emb.pt
✓ ABCA1_HUMAN: L=2261 → pglm_emb.pt
✓ ABCA1_MOUSE: L=2261 → pglm_emb.pt
✓ ABCA2_HUMAN: L=2435 → pglm_emb.pt
✓ ABCA2_RAT: L=2434 → pglm_emb.pt
✓ ABCA4_HUMAN: L=2273 → pglm_emb.pt
✓ ABCA4_MOUSE: L=2310 → pglm_emb.pt
✓ ABCA7_HUMA

In [13]:
# JUPYTER CELL — Batched check + confirm + delete (no quarantine, no spam)
from __future__ import annotations
from pathlib import Path
import json
import torch
from tqdm.auto import tqdm

# ── config ─────────────────────────────────────────────────────────────
BASE      = Path("/teamspace/studios/this_studio/CAFA_Style")
SEQS_DIR  = BASE / "sequences"
STATE_F   = BASE / ".batch_state.json"
BATCH_SIZE = 10_000  # process this many PID folders per run

EMB_FILES = {
    "esmc":      "esmc_emb.pt",
    "prot_t5":   "prot_t5_emb.pt",
    "pglm":      "pglm_emb.pt",
    "ankh_xl":   "ankh_emb_xl.pt",
}
CAP = {  # sequences strictly greater than this must be re-embedded with sliding window
    "esmc":    20480,
    "prot_t5": 512,
    "pglm":    20480,
    "ankh_xl": 5120,
}
assert set(EMB_FILES) == set(CAP), "CAP/EMB_FILES key mismatch"

# ── helpers ────────────────────────────────────────────────────────────
def load_state() -> int:
    """Return starting index (offset) for this batch."""
    try:
        with STATE_F.open("r", encoding="utf-8") as fh:
            s = json.load(fh)
            return int(s.get("offset", 0))
    except Exception:
        return 0

def save_state(offset: int) -> None:
    STATE_F.write_text(json.dumps({"offset": int(offset)}), encoding="utf-8")

def read_seq_len(seq_path: Path) -> int:
    if not seq_path.exists(): return 0
    L = 0
    with seq_path.open("r", encoding="utf-8") as fh:
        for line in fh:
            if line.startswith(">"): continue
            L += len(line.strip().replace(" ", ""))
    return L

def pt_shape(p: Path):
    try:
        obj = torch.load(p, map_location="cpu")
        if isinstance(obj, torch.Tensor): return tuple(obj.shape)
        if isinstance(obj, dict):
            for k in ("emb", "embeddings", "tensor"):
                if k in obj and isinstance(obj[k], torch.Tensor):
                    return tuple(obj[k].shape)
    except Exception:
        return None
    return None

# ── gather & slice batch ───────────────────────────────────────────────
if not SEQS_DIR.exists():
    raise FileNotFoundError(f"Sequences folder not found: {SEQS_DIR}")

all_pid_dirs = [p for p in SEQS_DIR.iterdir() if p.is_dir() and not p.name.startswith(".")]
all_pid_dirs.sort(key=lambda p: p.name)
N = len(all_pid_dirs)

start = load_state()
end = min(start + BATCH_SIZE, N)
if start >= N:
    print(f"All done. Total PID dirs: {N}. Nothing left to process.")
    raise SystemExit

batch = all_pid_dirs[start:end]
print(f"Processing PIDs {start}..{end-1} of {N} (batch size {BATCH_SIZE})")

# ── scan (batch only) ──────────────────────────────────────────────────
missing = {k: 0 for k in EMB_FILES}
badshape = {k: 0 for k in EMB_FILES}
to_delete = {k: [] for k in EMB_FILES}

total_pids = 0
for pid_dir in tqdm(batch, desc="Analyzing proteins"):
    pid = pid_dir.name
    L = read_seq_len(pid_dir / "seq.txt")
    if L <= 0:
        continue
    total_pids += 1

    for mk, fname in EMB_FILES.items():
        f = pid_dir / fname
        if not f.exists():
            missing[mk] += 1
            continue

        shp = pt_shape(f)
        if shp is None or len(shp) != 3 or shp[0] != 1 or shp[1] != (L + 2):
            badshape[mk] += 1

        # deletion rule: seq length exceeds model cap ⇒ will re-make with sliding window
        if L > CAP[mk]:
            to_delete[mk].append(f)

# ── concise summary ────────────────────────────────────────────────────
def csum(d): return ", ".join(f"{k}:{v}" for k,v in d.items())

total_delete = sum(len(v) for v in to_delete.values())
print("\n=== BATCH SUMMARY ===")
print(f"PIDs scanned in batch: {total_pids}  (global {start}..{end-1} of {N})")
print("Missing      :", csum(missing))
print("Bad shapes   :", csum(badshape))
print("To DELETE    :", ", ".join(f"{k}:{len(v)}" for k,v in to_delete.items()), f"(total {total_delete})")

# ── confirmation & delete (batch only) ─────────────────────────────────
if total_delete > 0:
    ans = input("\nDelete ALL files above caps in THIS BATCH (no undo)? [y/N]: ").strip().lower()
    if ans == "y":
        print("Deleting files...")
        deleted = 0
        all_files = []
        for files in to_delete.values():
            all_files.extend(files)
        for f in tqdm(all_files, desc="Deleting files"):
            try:
                f.unlink()
                deleted += 1
            except Exception as e:
                print(f"Failed to delete {f}: {e}")
        print(f"Deleted {deleted} files.")
    else:
        print("No files deleted in this batch.")
else:
    print("\nNo files to delete in this batch.")

# ── advance to next batch ──────────────────────────────────────────────
save_state(end)
remaining = max(0, N - end)
print(f"\nNext start offset saved: {end}. Remaining PID dirs: {remaining}.")
print("Re-run this cell to process the next batch.")


Processing PIDs 20000..29999 of 59397 (batch size 10000)


Analyzing proteins:   0%|          | 0/10000 [00:00<?, ?it/s]

  obj = torch.load(p, map_location="cpu")



=== BATCH SUMMARY ===
PIDs scanned in batch: 10000  (global 20000..29999 of 59397)
Missing      : esmc:4324, prot_t5:6519, pglm:4324, ankh_xl:4324
Bad shapes   : esmc:0, prot_t5:0, pglm:0, ankh_xl:0
To DELETE    : esmc:0, prot_t5:0, pglm:0, ankh_xl:0 (total 0)

No files to delete in this batch.

Next start offset saved: 30000. Remaining PID dirs: 29397.
Re-run this cell to process the next batch.


In [1]:
from pathlib import Path
from tqdm.auto import tqdm

SEQS_DIR = Path("/teamspace/studios/this_studio/CAFA_Style/sequences")
TARGET_NAME = "esmc_emb.pt"

files = []
for pid_dir in SEQS_DIR.iterdir():
    if pid_dir.is_dir() and not pid_dir.name.startswith("."):
        f = pid_dir / TARGET_NAME
        if f.exists():
            files.append(f)

deleted = 0
for f in tqdm(files, desc="Deleting ankh embeddings"):
    try:
        f.unlink()
        deleted += 1
    except Exception as e:
        print(f"Failed to delete {f}: {e}")

print(f"Deleted {deleted} files." if deleted else "No ankh embeddings found.")

Deleting ankh embeddings:   0%|          | 0/32546 [00:00<?, ?it/s]

Deleted 32546 files.


In [None]:
from pathlib import Path
from tqdm.auto import tqdm
import torch

SEQS_DIR = Path("/teamspace/studios/this_studio/CAFA_Style/sequences")
TARGET = "ankh_emb_xl.pt"

paths = list(SEQS_DIR.rglob(TARGET))
converted = 0

for f in tqdm(paths, desc="Converting ankh_emb_xl.pt → FP16"):
    try:
        obj = torch.load(f, map_location="cpu")
        if isinstance(obj, torch.Tensor) and obj.dtype != torch.float16:
            torch.save(obj.half(), f)
            converted += 1
    except Exception as e:
        print(f"Skip {f}: {e}")

print(f"Converted {converted} file(s) to FP16.")


Converting ankh_emb_xl.pt → FP16:   0%|          | 0/20821 [00:00<?, ?it/s]

  obj = torch.load(f, map_location="cpu")


Skip /teamspace/studios/this_studio/CAFA_Style/sequences/GRRE1_HUMAN/ankh_emb_xl.pt: File /teamspace/studios/this_studio/CAFA_Style/sequences/GRRE1_HUMAN/ankh_emb_xl.pt cannot be opened.
Skip /teamspace/studios/this_studio/CAFA_Style/sequences/JAZF1_HUMAN/ankh_emb_xl.pt: File /teamspace/studios/this_studio/CAFA_Style/sequences/JAZF1_HUMAN/ankh_emb_xl.pt cannot be opened.
Skip /teamspace/studios/this_studio/CAFA_Style/sequences/KPCD2_HUMAN/ankh_emb_xl.pt: File /teamspace/studios/this_studio/CAFA_Style/sequences/KPCD2_HUMAN/ankh_emb_xl.pt cannot be opened.
Skip /teamspace/studios/this_studio/CAFA_Style/sequences/MMP8_HUMAN/ankh_emb_xl.pt: File /teamspace/studios/this_studio/CAFA_Style/sequences/MMP8_HUMAN/ankh_emb_xl.pt cannot be opened.
Skip /teamspace/studios/this_studio/CAFA_Style/sequences/NMS_RAT/ankh_emb_xl.pt: File /teamspace/studios/this_studio/CAFA_Style/sequences/NMS_RAT/ankh_emb_xl.pt cannot be opened.
Skip /teamspace/studios/this_studio/CAFA_Style/sequences/NMT1_ARATH/ankh_em

In [6]:
import torch

# Load the two embedding files
ankh_emb = torch.load('/teamspace/studios/this_studio/CAFA_Style/sequences/1A1L1_HUMAN/ankh_emb_xl.pt')
pglm_emb = torch.load('/teamspace/studios/this_studio/CAFA_Style/sequences/1A1L1_HUMAN/pglm_emb.pt')

# Compare the sizes
print(f"ANKH embedding shape: {ankh_emb.shape}")
print(f"PGLM embedding shape: {pglm_emb.shape}")
print(f"ANKH embedding size: {ankh_emb.numel()}")
print(f"PGLM embedding size: {pglm_emb.numel()}")

# Calculate memory sizes in gigabytes
ankh_size_gb = ankh_emb.numel() * ankh_emb.element_size() / (1024**3)
pglm_size_gb = pglm_emb.numel() * pglm_emb.element_size() / (1024**3)

print(f"ANKH embedding size: {ankh_size_gb:.4f} GB")
print(f"PGLM embedding size: {pglm_size_gb:.4f} GB")
print(f"Size difference: {abs(ankh_size_gb - pglm_size_gb):.4f} GB")

ANKH embedding shape: torch.Size([1, 503, 2560])
PGLM embedding shape: torch.Size([1, 503, 2560])
ANKH embedding size: 1287680
PGLM embedding size: 1287680
ANKH embedding size: 0.0048 GB
PGLM embedding size: 0.0024 GB
Size difference: 0.0024 GB


  ankh_emb = torch.load('/teamspace/studios/this_studio/CAFA_Style/sequences/1A1L1_HUMAN/ankh_emb_xl.pt')
  pglm_emb = torch.load('/teamspace/studios/this_studio/CAFA_Style/sequences/1A1L1_HUMAN/pglm_emb.pt')
