In [1]:
# === ESM-MSA introspection for LoRA targeting ===
import sys, os, platform, re, textwrap, json
from pathlib import Path

import torch
import torch.nn as nn

# If you have esm installed as `esm` (facebookresearch/esm)
import esm

OUT_PATH = Path("msa_introspection.txt")

def header(title):
    line = "=" * len(title)
    return f"\n{title}\n{line}\n"

def env_report():
    try:
        import importlib.metadata as im
        esm_ver = im.version("esm")
    except Exception:
        esm_ver = "unknown"
    return {
        "python": sys.version.replace("\n", " "),
        "platform": platform.platform(),
        "torch": torch.__version__,
        "esm": esm_ver,
        "cuda_available": torch.cuda.is_available(),
        "cuda_device_count": torch.cuda.device_count(),
    }

def print_tree(model, max_lines=400):
    """Render a readable tree of named_modules (truncated)."""
    lines = []
    for name, module in model.named_modules():
        mod = type(module).__name__
        lines.append(f"{name or '<root>'} :: {mod}")
    truncated = lines[:max_lines]
    body = "\n".join(truncated)
    if len(lines) > max_lines:
        body += f"\n... ({len(lines)-max_lines} more lines truncated)"
    return body

def find_proj_modules(model):
    """
    Find attention projection linears named q_proj/k_proj/v_proj/out_proj
    and group them by their parent block (e.g., row_attn / col_attn).
    """
    hits = []
    for name, module in model.named_modules():
        for proj in ("q_proj", "k_proj", "v_proj", "out_proj"):
            if hasattr(module, proj) and isinstance(getattr(module, proj), nn.Linear):
                lin = getattr(module, proj)
                hits.append({
                    "parent": name,
                    "proj_name": proj,
                    "shape": tuple(lin.weight.shape),
                    "bias": getattr(lin, "bias", None) is not None,
                    "type": type(lin).__name__,
                })
    return hits

def group_row_col(hits):
    """
    Try to separate row/col attention based on parent name substrings.
    This is robust to common ESM naming like 'layers.0.row_attn' / 'layers.0.col_attn'.
    """
    row, col, other = [], [], []
    for h in hits:
        parent = h["parent"].lower()
        if "row" in parent:
            row.append(h)
        elif "col" in parent or "column" in parent:
            col.append(h)
        else:
            other.append(h)
    return row, col, other

def list_all_linears(model):
    items = []
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            items.append({
                "name": name,
                "shape": tuple(module.weight.shape),
                "bias": module.bias is not None
            })
    return items

def dump_named_parameters(model, pattern=None, max_lines=200):
    lines = []
    for name, p in model.named_parameters():
        if (pattern is None) or re.search(pattern, name):
            lines.append(f"{name} :: {tuple(p.shape)} :: requires_grad={p.requires_grad}")
    lines = lines[:max_lines] + (["... (truncated)"] if len(lines) > max_lines else [])
    return "\n".join(lines)

# ---------- Load model ----------
msa_model, msa_alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S()
msa_model.eval().requires_grad_(False)

# ---------- Collect info ----------
env = env_report()
tree = print_tree(msa_model, max_lines=800)  # bump if you want more
hits = find_proj_modules(msa_model)
row_hits, col_hits, other_hits = group_row_col(hits)
linears = list_all_linears(msa_model)
params_glimpse = dump_named_parameters(msa_model, pattern=r"(q_proj|k_proj|v_proj|out_proj)", max_lines=300)

# ---------- Build a human-readable report ----------
report = []
report.append(header("Environment"))
report.append(json.dumps(env, indent=2))

report.append(header("Model Class"))
report.append(repr(type(msa_model)))

report.append(header("Named Modules (tree; truncated)"))
report.append(tree)

report.append(header("Attention Projection Linears (q/k/v/out) — ALL HITS"))
for h in hits:
    report.append(f"{h['parent']}.{h['proj_name']} :: {h['shape']} :: bias={h['bias']}")

report.append(header("Row Attention Projections"))
for h in row_hits:
    report.append(f"{h['parent']}.{h['proj_name']} :: {h['shape']} :: bias={h['bias']}")

report.append(header("Column Attention Projections"))
for h in col_hits:
    report.append(f"{h['parent']}.{h['proj_name']} :: {h['shape']} :: bias={h['bias']}")

if other_hits:
    report.append(header("Other Attention-like Projections (neither row nor col by name)"))
    for h in other_hits:
        report.append(f"{h['parent']}.{h['proj_name']} :: {h['shape']} :: bias={h['bias']}")

report.append(header("All nn.Linear layers (for optional LoRA targets; truncated to 500)"))
for item in linears[:500]:
    report.append(f"{item['name']} :: {item['shape']} :: bias={item['bias']}")
if len(linears) > 500:
    report.append(f"... ({len(linears)-500} more linears truncated)")

report.append(header("Named parameters glimpse (q/k/v/out only; truncated)"))
report.append(params_glimpse)

text = "\n".join(report)

# ---------- Save & show ----------
OUT_PATH.write_text(text)
print(text[:4000])
print(f"\n\n=== Saved full report to: {OUT_PATH.resolve()} ===")



Environment

{
  "python": "3.10.18 (main, Jun  5 2025, 13:14:17) [GCC 11.2.0]",
  "platform": "Linux-6.8.0-1033-gcp-x86_64-with-glibc2.31",
  "torch": "2.7.1+cu126",
  "esm": "unknown",
  "cuda_available": false,
  "cuda_device_count": 0
}

Model Class

<class 'esm.model.msa_transformer.MSATransformer'>

Named Modules (tree; truncated)

<root> :: MSATransformer
embed_tokens :: Embedding
dropout_module :: Dropout
layers :: ModuleList
layers.0 :: AxialTransformerLayer
layers.0.row_self_attention :: NormalizedResidualBlock
layers.0.row_self_attention.layer :: RowSelfAttention
layers.0.row_self_attention.layer.k_proj :: Linear
layers.0.row_self_attention.layer.v_proj :: Linear
layers.0.row_self_attention.layer.q_proj :: Linear
layers.0.row_self_attention.layer.out_proj :: Linear
layers.0.row_self_attention.layer.dropout_module :: Dropout
layers.0.row_self_attention.dropout_module :: Dropout
layers.0.row_self_attention.layer_norm :: LayerNorm
layers.0.column_self_attention :: NormalizedRe

In [1]:
# --- ESM-C Introspection for LoRA Targeting ---
import torch, re, sys
from typing import List, Tuple
from esm.models.esmc import ESMC
from torch import nn

def is_linear(m): 
    return isinstance(m, nn.Linear)

def find_transformer_blocks(core: nn.Module) -> List[nn.Module]:
    """
    Try to locate the list of transformer blocks.
    Works if the model exposes a ModuleList of blocks or
    otherwise collects any module that has a 'self_attn' attr.
    """
    # Heuristic 1: a ModuleList whose first item has 'self_attn'
    for name, mod in core.named_modules():
        if isinstance(mod, nn.ModuleList) and len(mod) > 0:
            first = mod[0]
            if hasattr(first, "self_attn"):
                return list(mod)
    # Heuristic 2: collect any module that has 'self_attn'
    blocks = []
    for _, mod in core.named_modules():
        if hasattr(mod, "self_attn"):
            blocks.append(mod)
    return blocks

def preview_attn_linears(core: nn.Module) -> List[Tuple[str, nn.Linear]]:
    """
    Return (name, module) for attention-related linears commonly used in ESM/ESMC.
    """
    hits = []
    for name, mod in core.named_modules():
        if is_linear(mod):
            if any(key in name for key in (
                "self_attn.q_proj","self_attn.k_proj","self_attn.v_proj","self_attn.out_proj",
                "attn.q_proj","attn.k_proj","attn.v_proj","attn.out_proj"
            )):
                hits.append((name, mod))
    return hits

def detect_fused_inproj(core: nn.Module) -> List[str]:
    """
    Find modules that look like Fairseq-style fused projection attention
    (i.e., have 'in_proj_weight' or 'in_proj_bias' attributes).
    """
    fused = []
    for name, mod in core.named_modules():
        if hasattr(mod, "in_proj_weight") or hasattr(mod, "in_proj_bias"):
            fused.append(name)
    return fused

def list_layernorms(core: nn.Module) -> List[str]:
    return [name for name, m in core.named_modules() if isinstance(m, nn.LayerNorm)]

def fmt(s, char="="):
    return f"\n{char*8} {s} {char*8}"

def main():
    print(fmt("Loading ESM-C"))
    model = ESMC.from_pretrained("esmc_600m")
    core = getattr(model, "model", model)  # some SDKs wrap the raw torch module here

    # Basic model info
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total params: {total:,} | Trainable (now): {trainable:,}")
    print(f"Model class: {model.__class__.__name__}")
    print(f"Core class:  {core.__class__.__name__}")

    # Locate transformer blocks
    blocks = find_transformer_blocks(core)
    print(fmt("Transformer Blocks"))
    print(f"Num blocks found: {len(blocks)}")
    if len(blocks) > 0:
        # Print a quick sketch of the first, middle, and last block types
        idxs = sorted(set([0, len(blocks)//2, len(blocks)-1]))
        for i in idxs:
            b = blocks[i]
            print(f"- Block[{i}] type: {b.__class__.__name__}  (attrs: {', '.join(sorted(set(dir(b)) & {'self_attn','fc1','fc2','norm1','norm2'}))})")

    # Find attention linears
    print(fmt("Attention Linears (separate q/k/v/out)"))
    hits = preview_attn_linears(core)
    if hits:
        for name, mod in hits[:50]:  # print first 50 for sanity
            print(f"{name:60s}  Linear({mod.in_features} -> {mod.out_features})")
        if len(hits) > 50:
            print(f"... and {len(hits)-50} more")
    else:
        print("No explicit q_proj/k_proj/v_proj/out_proj linears found.")

    # Fused attention?
    print(fmt("Potential Fused In-Proj Attentions"))
    fused = detect_fused_inproj(core)
    if fused:
        for name in fused:
            print(f"{name}")
    else:
        print("No modules with in_proj_weight/in_proj_bias detected.")

    # LayerNorms
    print(fmt("LayerNorm Modules"))
    ln_names = list_layernorms(core)
    print(f"Found {len(ln_names)} LayerNorms.")
    for n in ln_names[:40]:
        print(n)
    if len(ln_names) > 40:
        print(f"... and {len(ln_names)-40} more")

    # Per-block quick look at attention fields
    print(fmt("Per-Block Attention Fields (quick check)"))
    for i, b in enumerate(blocks[:12]):  # first 12 blocks only to keep output manageable
        sa = getattr(b, "self_attn", None)
        fields = []
        if sa is not None:
            for attr in ("q_proj","k_proj","v_proj","out_proj","in_proj_weight","in_proj_bias"):
                if hasattr(sa, attr):
                    fields.append(attr)
        print(f"Block[{i}]: self_attn fields -> {fields if fields else 'N/A'}")

    print(fmt("Instructions"))
    print(
        "Paste this entire output back to me.\n"
        "- If you see names like '*.self_attn.k_proj' and '*.self_attn.v_proj', we can LoRA those directly.\n"
        "- If instead you see 'in_proj_weight', we'll use a fused-projection LoRA patch.\n"
        "- The block count tells us how many 'last_n' layers we can reasonably target."
    )

if __name__ == "__main__":
    main()





Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

Total params: 575,036,992 | Trainable (now): 575,036,992
Model class: ESMC
Core class:  ESMC

Num blocks found: 0

transformer.blocks.0.attn.out_proj                            Linear(1152 -> 1152)
transformer.blocks.1.attn.out_proj                            Linear(1152 -> 1152)
transformer.blocks.2.attn.out_proj                            Linear(1152 -> 1152)
transformer.blocks.3.attn.out_proj                            Linear(1152 -> 1152)
transformer.blocks.4.attn.out_proj                            Linear(1152 -> 1152)
transformer.blocks.5.attn.out_proj                            Linear(1152 -> 1152)
transformer.blocks.6.attn.out_proj                            Linear(1152 -> 1152)
transformer.blocks.7.attn.out_proj                            Linear(1152 -> 1152)
transformer.blocks.8.attn.out_proj                            Linear(1152 -> 1152)
transformer.blocks.9.attn.out_proj                            Linear(1152 -> 1152)
transformer.blocks.10.attn.out_proj                    

In [5]:
# JUPYTER CELL — Generate ProteinGLM-1B-MLM per-residue embeddings and save per PID
from __future__ import annotations
from pathlib import Path
from typing import List
import torch
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModelForMaskedLM

# ─────────────────────────────── config ─────────────────────────────── #
ROOT         = Path("/teamspace/studios/this_studio/PFP_Testing/data/PDBCH")
SPLITS       = ["train_pdbch", "val_pdbch", "test_pdbch"]
DEVICE       = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_ID     = "biomap-research/proteinglm-1b-mlm"   # HF repo id
OUT_FILENAME = "pglm_emb_1b.pt"                      # per-protein output
DTYPE_SAVE   = torch.float32                         # saved dtype
# ─────────────────────────────────────────────────────────────────────── #

def collect_all_pid_dirs(root: Path) -> List[Path]:
    dirs: List[Path] = []
    for split in SPLITS:
        split_dir = root / split
        if not split_dir.exists():
            raise FileNotFoundError(f"Missing split folder: {split_dir}")
        for d in split_dir.iterdir():
            if d.is_dir():
                dirs.append(d)
    return sorted(dirs)

def collect_missing_pid_dirs(root: Path) -> List[Path]:
    """Return only PID dirs that do NOT already have OUT_FILENAME."""
    all_dirs = collect_all_pid_dirs(root)
    missing = []
    for d in all_dirs:
        out_path = d / OUT_FILENAME
        if not out_path.exists() or out_path.stat().st_size == 0:
            missing.append(d)
    return missing

def load_sequence(pid_dir: Path) -> str:
    seq_file = pid_dir / "sequence.txt"
    if not seq_file.exists():
        raise FileNotFoundError(f"{seq_file} is missing")
    seq = []
    with seq_file.open() as fh:
        for line in fh:
            if line.startswith(">"):
                continue
            seq.append(line.strip())
    seq = "".join(seq).upper().replace(" ", "")
    if not seq:
        raise ValueError(f"{pid_dir.name}: empty/invalid sequence")
    return seq

@torch.inference_mode()
def embed_pglm(model, tok, seq: str) -> torch.Tensor:
    """
    Build per-residue embeddings with ProteinGLM-1B-MLM per the HF model card.
    Returns (1, L+2, d): index 0 and L+1 are all-zero, indices 1..L are residues.
    """
    L = len(seq)

    # Tokenize with special tokens (adds <eos>)
    out = tok(seq, add_special_tokens=True, return_tensors='pt')
    inputs = {
        "input_ids": out["input_ids"].to(model.device),
        "attention_mask": out["attention_mask"].to(model.device),
    }

    # Model forward; drop trailing <eos> token embeddings
    out_m = model(**inputs, output_hidden_states=True, return_last_hidden_state=True)
    token_emb = out_m.hidden_states[-1][0, :-1]  # [L, d]

    # Pack to (1, L+2, d); zero at positions 0 and L+1; save as CPU float32
    token_emb = token_emb.to(DTYPE_SAVE).cpu()
    d = token_emb.size(-1)
    emb = torch.zeros(1, L + 2, d, dtype=DTYPE_SAVE, device="cpu")
    emb[0, 1:L+1] = token_emb
    return emb

def main() -> None:
    # Only work on proteins missing the .pt file
    pid_dirs = collect_missing_pid_dirs(ROOT)
    if len(pid_dirs) == 0:
        print("All proteins already have embeddings — nothing to do.")
        return

    print(f"Found {len(pid_dirs)} proteins without '{OUT_FILENAME}'.")
    print(f"Loading {MODEL_ID} on {DEVICE} …")
    tok = AutoTokenizer.from_pretrained(MODEL_ID)
    model = AutoModelForMaskedLM.from_pretrained(MODEL_ID).to(DEVICE).eval()

    for pid_dir in tqdm(pid_dirs, desc="Processing proteins"):
        out_path = pid_dir / OUT_FILENAME
        try:
            seq = load_sequence(pid_dir)
            emb = embed_pglm(model, tok, seq)
            torch.save(emb, out_path)
        except Exception as e:
            print(f"[ERROR] {pid_dir.name}: {e}")

if __name__ == "__main__":
    main()


Found 4734 proteins without 'pglm_emb_1b.pt'.
Loading biomap-research/proteinglm-1b-mlm on cuda …


Processing proteins:   0%|          | 0/4734 [00:00<?, ?it/s]

In [None]:
"""
Generate ESM-C (600 M) sequence embeddings for every protein chain that lives
in any of the three split folders

    data/PDBCH/{train_pdbch,val_pdbch,test_pdbch}/<PID>/

The script creates one file per protein:

    <PID>/esmc_emb.pt           # (1, L+2, 1152)  – CLS & EOS already present
"""

import os
from pathlib import Path
from typing import List

import torch
from tqdm.auto import tqdm

# ─────────────────────────────── config ─────────────────────────────── #
ROOT = Path("/teamspace/studios/this_studio/PFP_Testing/data/PDBCH")
SPLITS       = ["train_pdbch", "val_pdbch", "test_pdbch"]
DEVICE       = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_NAME   = "esmc_600m"                 # Fair-ESM registry tag
OUT_FILENAME = "esmc_emb.pt"               # what the datamodule expects
# ─────────────────────────────────────────────────────────────────────── #


def collect_protein_dirs(root: Path) -> List[Path]:
    """Return every <ROOT>/<split>/<PID>/ directory."""
    dirs = []
    for split in SPLITS:
        split_dir = root / split
        if not split_dir.exists():
            raise FileNotFoundError(f"Missing split folder: {split_dir}")
        for d in split_dir.iterdir():
            if d.is_dir():
                dirs.append(d)
    return sorted(dirs)


def load_sequence(pid_dir: Path) -> str:
    seq_file = pid_dir / "sequence.txt"
    if not seq_file.exists():
        raise FileNotFoundError(f"{seq_file} is missing")
    with seq_file.open() as fh:
        return "".join(line.strip() for line in fh if not line.startswith(">"))


@torch.inference_mode()
def embed_sequence(model, seq: str) -> torch.Tensor:
    """
    Run ESM-C and return a CPU tensor shaped (1, L+2, 1152).

    `model.encode` already prepends CLS & appends EOS.
    """
    from esm.sdk.api import ESMProtein, LogitsConfig

    prot   = ESMProtein(sequence=seq)
    hidden = model.encode(prot)                                    # (1, L+2, 1152) on DEVICE
    out    = model.logits(hidden, LogitsConfig(sequence=True, return_embeddings=True))
    return out.embeddings.cpu()                                    # move straight to CPU


def main() -> None:
    pid_dirs = collect_protein_dirs(ROOT)
    if not pid_dirs:
        print("No protein directories found – nothing to do.")
        return

    print(f"Found {len(pid_dirs)} protein chains across all splits\n")

    print(f"⇢ Loading {MODEL_NAME} onto {DEVICE} …")
    from esm.models.esmc import ESMC
    model = ESMC.from_pretrained(MODEL_NAME).to(DEVICE).eval()

    files_created = 0
    for pid_dir in tqdm(pid_dirs, desc="Embedding"):
        out_file = pid_dir / OUT_FILENAME
        if out_file.exists():
            tqdm.write(f"• {pid_dir.name}: {OUT_FILENAME} already exists – skipping")
            continue

        seq = load_sequence(pid_dir)
        emb = embed_sequence(model, seq)                 # (1, L+2, 1152)
        torch.save(emb, out_file)
        files_created += 1
        tqdm.write(f"✓ {pid_dir.name}: len={len(seq):4d}, saved → {OUT_FILENAME}")

    print(f"\nAll done! Created {files_created} new embedding files.")


if __name__ == "__main__":
    main()


Found 36629 protein chains across all splits

⇢ Loading esmc_300m onto cuda …


Embedding:   0%|          | 0/36629 [00:00<?, ?it/s]

• 11AS-A: esmc_300_emb.pt already exists – skipping
• 18GS-A: esmc_300_emb.pt already exists – skipping
• 1A0P-A: esmc_300_emb.pt already exists – skipping
• 1A22-A: esmc_300_emb.pt already exists – skipping
• 1A4E-A: esmc_300_emb.pt already exists – skipping
• 1A6F-A: esmc_300_emb.pt already exists – skipping
• 1A6J-A: esmc_300_emb.pt already exists – skipping
• 1A8Y-A: esmc_300_emb.pt already exists – skipping
• 1A9C-A: esmc_300_emb.pt already exists – skipping
• 1A9W-E: esmc_300_emb.pt already exists – skipping
• 1AD3-A: esmc_300_emb.pt already exists – skipping
• 1AE5-A: esmc_300_emb.pt already exists – skipping
• 1AGR-E: esmc_300_emb.pt already exists – skipping
• 1AHH-A: esmc_300_emb.pt already exists – skipping
• 1AHP-A: esmc_300_emb.pt already exists – skipping
• 1AI3-A: esmc_300_emb.pt already exists – skipping
• 1AII-A: esmc_300_emb.pt already exists – skipping
• 1AIN-A: esmc_300_emb.pt already exists – skipping
• 1AO5-A: esmc_300_emb.pt already exists – skipping
• 1AOS-A: es

KeyboardInterrupt: 

In [1]:
# JUPYTER CELL — Delete specific embedding files with a progress bar
import os
from pathlib import Path
from tqdm.auto import tqdm

BASE_PATH = Path("/teamspace/studios/this_studio/PFP_Testing/data/PDBCH")
SPLITS = ["test_pdbch", "train_pdbch", "val_pdbch"]
TARGET_FILES = ["prot_emb.pt", "esmc_emb.pt"]

def _fmt_bytes(n: int) -> str:
    for unit in ["B","KB","MB","GB","TB"]:
        if n < 1024:
            return f"{n:.2f} {unit}"
        n /= 1024
    return f"{n:.2f} PB"

def collect_targets(base: Path):
    files = []
    for split in SPLITS:
        split_dir = base / split
        if not split_dir.exists():
            continue
        for name in TARGET_FILES:
            files.extend(split_dir.rglob(name))
    # unique + only files
    files = [p for p in dict.fromkeys(files) if p.is_file()]
    return files

def delete_files(files):
    total_size = sum(p.stat().st_size for p in files)
    errors = 0
    for p in tqdm(files, desc="Deleting embeddings", unit="file"):
        try:
            p.unlink()
        except FileNotFoundError:
            # already gone
            pass
        except Exception:
            errors += 1
    deleted = len(files) - errors
    print(f"\nDone. Deleted: {deleted}/{len(files)} files. "
          f"Approx space freed: {_fmt_bytes(total_size)}. "
          f"Errors: {errors}")

if not BASE_PATH.exists():
    raise SystemExit(f"Base path not found: {BASE_PATH}")
print("collecting Targets")
targets = collect_targets(BASE_PATH)
print(f"Found {len(targets)} files to delete under {BASE_PATH}.")
delete_files(targets)


collecting Targets
Found 73258 files to delete under /teamspace/studios/this_studio/PFP_Testing/data/PDBCH.


Deleting embeddings:   0%|          | 0/73258 [00:00<?, ?file/s]


Done. Deleted: 73258/73258 files. Approx space freed: 84.46 GB. Errors: 0


In [2]:
# pip install -q "transformers>=4.38" torch --upgrade
import torch, torch.nn as nn

from transformers import AutoModelForMaskedLM
model = AutoModelForMaskedLM.from_pretrained(
    "Synthyra/ESMplusplus_large", trust_remote_code=True
)  # 600M (use _small for ~300M)
tokenizer = model.tokenizer
model.train()

# Tiny toy batch
seqs = ["MPEPTIDE", "KASDFGH", "GGGGGGGG"]
batch = tokenizer(seqs, padding=True, return_tensors="pt")

# Forward: get attached hidden states
out = model(**batch)                       # returns logits + last_hidden_state
hs  = out.last_hidden_state                # [B, L, 1152] for the large model

# Masked mean pool (ignore padding)
mask = batch["attention_mask"].unsqueeze(-1)   # [B, L, 1]
pooled = (hs * mask).sum(1) / mask.sum(1).clamp_min(1)  # [B, 1152]

# Tiny head + BCE loss
head = nn.Linear(pooled.size(-1), 4)
labels = torch.randint(0, 2, (pooled.size(0), 4)).float()
loss = nn.BCEWithLogitsLoss()(head(pooled), labels)
loss.backward()

# Sanity checks: did grads reach the backbone?
any_backbone_grad = any(
    p.grad is not None and p.grad.abs().sum() > 0
    for n, p in model.named_parameters()
    if "embed_tokens" in n or "layers" in n or "transformer" in n
)
print("Any grad reached ESM++ backbone? ", any_backbone_grad)


config.json:   0%|          | 0.00/771 [00:00<?, ?B/s]

modeling_esm_plusplus.py: 0.00B [00:00, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/Synthyra/ESMplusplus_large:
- modeling_esm_plusplus.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


model.safetensors:   0%|          | 0.00/2.30G [00:00<?, ?B/s]

Any grad reached ESM++ backbone?  True


In [2]:
"""
Generate ProtT5 (XL-UniRef50) sequence embeddings for every protein chain in:

    data/PDBCH/{train_pdbch,val_pdbch,test_pdbch}/<PID>/

Creates one file per protein:

    <PID>/prot_t5_emb.pt        # (1, L+2, 1024) – CLS/EOS slots are zero

Notes
-----
- Tokenization follows the official ProtT5 recipe: space-separated residues.
- Non-canonical residues (e.g., U, O, B, Z, J, X) are mapped to 'X'.
- We only log a progress bar; existing files are skipped silently.
"""

from __future__ import annotations
import os
from pathlib import Path
from typing import List

import torch
from tqdm.auto import tqdm
from transformers import AutoTokenizer, T5EncoderModel

# ─────────────────────────────── config ─────────────────────────────── #
ROOT         = Path("/teamspace/studios/this_studio/PFP_Testing/data/PDBCH")
SPLITS       = ["train_pdbch", "val_pdbch", "test_pdbch"]
DEVICE       = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_NAME   = "Rostlab/prot_t5_xl_uniref50"     # was: Rostlab/prot_t5_xl_uniref50
OUT_FILENAME = "prot_emb.pt"         # just to keep them separate
DTYPE        = torch.float32              # keep fp32 for stability
# ─────────────────────────────────────────────────────────────────────── #

AA_VOCAB = set("ACDEFGHIKLMNPQRSTVWYX")  # 20 AAs + 'X'

def collect_protein_dirs(root: Path) -> List[Path]:
    """Return every <ROOT>/<split>/<PID>/ directory."""
    dirs: List[Path] = []
    for split in SPLITS:
        split_dir = root / split
        if not split_dir.exists():
            raise FileNotFoundError(f"Missing split folder: {split_dir}")
        for d in split_dir.iterdir():
            if d.is_dir():
                dirs.append(d)
    return sorted(dirs)

def load_sequence(pid_dir: Path) -> str:
    seq_file = pid_dir / "sequence.txt"
    if not seq_file.exists():
        raise FileNotFoundError(f"{seq_file} is missing")
    seq = []
    with seq_file.open() as fh:
        for line in fh:
            if line.startswith(">"):
                continue
            seq.append(line.strip())
    seq = "".join(seq).upper().replace(" ", "")
    # Map non-canonical to X
    seq = "".join(ch if ch in AA_VOCAB else "X" for ch in seq)
    if not seq:
        raise ValueError(f"{pid_dir.name}: empty/invalid sequence")
    return seq

@torch.inference_mode()
def embed_prott5(model: T5EncoderModel, tok: AutoTokenizer, seq: str) -> torch.Tensor:
    spaced = " ".join(seq)
    enc = tok(spaced, return_tensors="pt", add_special_tokens=True)
    enc = {k: v.to(DEVICE) for k, v in enc.items()}

    # [T, 1024] — take the single sequence from the batch
    out = model(**enc).last_hidden_state[0].to(DTYPE)

    # Drop EOS (last non-pad token when add_special_tokens=True)
    kept = out[:-1]                      # [L, 1024]
    if kept.size(0) != len(seq):
        raise RuntimeError(
            f"Tokenizer/sequence length mismatch: got {kept.size(0)} tokens for L={len(seq)}"
        )

    # (1, L+2, 1024): zero at 0 and L+1, fill 1..L
    L = len(seq)
    emb = torch.zeros(1, L + 2, kept.size(-1), dtype=DTYPE, device="cpu")
    emb[0, 1:L+1] = kept.detach().cpu()
    return emb

def main() -> None:
    pid_dirs = collect_protein_dirs(ROOT)
    if not pid_dirs:
        print("No protein directories found – nothing to do.")
        return

    # Load once
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, do_lower_case=False, use_fast=False)
    model = T5EncoderModel.from_pretrained(MODEL_NAME).to(DEVICE).eval()

    files_created = 0
    for pid_dir in tqdm(pid_dirs, desc="ProtT5 embedding", leave=True):
        out_file = pid_dir / OUT_FILENAME
        if out_file.exists():
            continue  # skip silently to keep console clean

        seq = load_sequence(pid_dir)
        emb = embed_prott5(model, tokenizer, seq)  # (1, L+2, 1024)
        torch.save(emb, out_file)
        files_created += 1

    print(f"All done! Created {files_created} new ProtT5 embedding files.")

if __name__ == "__main__":
    main()

tokenizer_config.json:   0%|          | 0.00/24.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/546 [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/238k [00:00<?, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/11.3G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/11.3G [00:00<?, ?B/s]

ProtT5 embedding:   0%|          | 0/36629 [00:00<?, ?it/s]

All done! Created 36629 new ProtT5 embedding files.


In [1]:
# Introspect Synthyra ESM++ module/parameter names and check LoRA target substrings.
# It loads the model, prints a preview of names, and reports hits for:
# - YOUR custom targets
# - The "official-style" targets from their example
#
# You can tweak MODEL_ID and PREVIEW_LIMIT below.

import re
from typing import List, Iterable, Dict

import torch
from transformers import AutoModelForMaskedLM

# ────────────── config ──────────────
MODEL_ID = "Synthyra/ESMplusplus_large"   # or "Synthyra/ESMplusplus_small"
DEVICE   = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE    = torch.float32                  # use torch.bfloat16 on Ampere/Hopper if you want
PREVIEW_LIMIT = 200                       # how many module names to preview

# Your custom substrings (from your code)
CUSTOM_TARGETS = [
    "attn.layernorm_qkv.1",
    "attn.out_proj",
    "ffn.1",
    "ffn.3",
]

# The "official-style" substrings seen in their examples
OFFICIAL_STYLE_TARGETS = [
    "layernorm_qkv.1",
    "out_proj",
    "query", "key", "value",
    "dense",
]

# ────────────── helpers ──────────────
def print_section(title: str):
    bar = "=" * 88
    print(f"\n{bar}\n{title}\n{bar}")

def first_n(items: Iterable[str], n: int) -> List[str]:
    out = []
    for i, x in enumerate(items):
        if i >= n: break
        out.append(x)
    return out

def count_hits(names: Iterable[str], substrings: List[str]) -> Dict[str, int]:
    counts = {s: 0 for s in substrings}
    for n in names:
        for s in substrings:
            if s in n:
                counts[s] += 1
    return counts

def extract_matches(names: Iterable[str], substrings: List[str], limit: int = 200) -> List[str]:
    out = []
    for n in names:
        if any(s in n for s in substrings):
            out.append(n)
            if len(out) >= limit:
                break
    return out

def group_by_layer(names: Iterable[str]) -> Dict[str, List[str]]:
    """
    Try to group names by '.layers.<idx>.' if present.
    """
    buckets = {}
    pat = re.compile(r"\.layers\.(\d+)\.")
    for n in names:
        m = pat.search(n)
        key = f"layer_{m.group(1)}" if m else "no_layer_tag"
        buckets.setdefault(key, []).append(n)
    return buckets

# ────────────── load model ──────────────
print_section("Loading model")
print(f"MODEL_ID: {MODEL_ID}\nDEVICE: {DEVICE}\nDTYPE:  {DTYPE}")
model = AutoModelForMaskedLM.from_pretrained(
    MODEL_ID,
    trust_remote_code=True,
    torch_dtype=DTYPE,
)
model.to(DEVICE)

# ────────────── collect names ──────────────
module_names = [n for n, _ in model.named_modules()]
param_names  = [n for n, _ in model.named_parameters()]
state_keys   = list(model.state_dict().keys())

print_section("High-level stats")
total_params = sum(p.numel() for _, p in model.named_parameters())
print(f"Total modules:               {len(module_names):,}")
print(f"Total parameter tensors:     {len(param_names):,}")
print(f"Total parameter count:       {total_params:,}")

print_section(f"First {PREVIEW_LIMIT} MODULE names (preview)")
for n in first_n(module_names, PREVIEW_LIMIT):
    print(n)

print_section(f"First {PREVIEW_LIMIT} PARAMETER names (preview)")
for n in first_n(param_names, PREVIEW_LIMIT):
    print(n)

print_section(f"First {PREVIEW_LIMIT} STATE_DICT keys (preview)")
for n in first_n(state_keys, PREVIEW_LIMIT):
    print(n)

print_section("Layer-grouped sample (first 3 buckets)")
buckets = group_by_layer(module_names)
for i, key in enumerate(sorted(buckets.keys(), key=lambda x: (x=="no_layer_tag", x))):
    if i >= 3: break
    subset = first_n(buckets[key], 30)
    print(f"[{key}] ({len(buckets[key])} items)")
    for name in subset:
        print("  ", name)

# ────────────── substring hit counts ──────────────
print_section("Substring hits in NAMED MODULES")
custom_hits  = count_hits(module_names, CUSTOM_TARGETS)
official_hits = count_hits(module_names, OFFICIAL_STYLE_TARGETS)
print("Custom targets → hit counts:")
for k, v in custom_hits.items():
    print(f"  {k:<24} : {v}")
print("\nOfficial-style targets → hit counts:")
for k, v in official_hits.items():
    print(f"  {k:<24} : {v}")

print_section("Substring hits in STATE_DICT KEYS")
custom_state_hits   = count_hits(state_keys, CUSTOM_TARGETS)
official_state_hits = count_hits(state_keys, OFFICIAL_STYLE_TARGETS)
print("Custom targets → state_dict hit counts:")
for k, v in custom_state_hits.items():
    print(f"  {k:<24} : {v}")
print("\nOfficial-style targets → state_dict hit counts:")
for k, v in official_state_hits.items():
    print(f"  {k:<24} : {v}")

# ────────────── matched name lists (easy copy for target_modules) ──────────────
print_section("Matched MODULE names (CUSTOM targets)")
for n in extract_matches(module_names, CUSTOM_TARGETS, limit=300):
    print(n)
print("\n… (truncated if >300)")

print_section("Matched MODULE names (OFFICIAL-STYLE targets)")
for n in extract_matches(module_names, OFFICIAL_STYLE_TARGETS, limit=300):
    print(n)
print("\n… (truncated if >300)")

print_section("Matched STATE_DICT keys (CUSTOM targets)")
for n in extract_matches(state_keys, CUSTOM_TARGETS, limit=300):
    print(n)
print("\n… (truncated if >300)")

print_section("Matched STATE_DICT keys (OFFICIAL-STYLE targets)")
for n in extract_matches(state_keys, OFFICIAL_STYLE_TARGETS, limit=300):
    print(n)
print("\n… (truncated if >300)")

print_section("Tips if you see ZERO hits")
print("* Scan the previews above for the exact attention/FFN linear names (case-sensitive).")
print("* Replace target substrings with those exact fragments in your PEFT LoraConfig.")
print("* If memory is tight, try MODEL_ID='Synthyra/ESMplusplus_small' for introspection.")



Loading model
MODEL_ID: Synthyra/ESMplusplus_large
DEVICE: cpu
DTYPE:  torch.float32

High-level stats
Total modules:               551
Total parameter tensors:     368
Total parameter count:       575,036,992

First 200 MODULE names (preview)

embed
transformer
transformer.blocks
transformer.blocks.0
transformer.blocks.0.attn
transformer.blocks.0.attn.layernorm_qkv
transformer.blocks.0.attn.layernorm_qkv.0
transformer.blocks.0.attn.layernorm_qkv.1
transformer.blocks.0.attn.out_proj
transformer.blocks.0.attn.q_ln
transformer.blocks.0.attn.k_ln
transformer.blocks.0.attn.rotary
transformer.blocks.0.ffn
transformer.blocks.0.ffn.0
transformer.blocks.0.ffn.1
transformer.blocks.0.ffn.2
transformer.blocks.0.ffn.3
transformer.blocks.0.dropout
transformer.blocks.1
transformer.blocks.1.attn
transformer.blocks.1.attn.layernorm_qkv
transformer.blocks.1.attn.layernorm_qkv.0
transformer.blocks.1.attn.layernorm_qkv.1
transformer.blocks.1.attn.out_proj
transformer.blocks.1.attn.q_ln
transformer.block

In [None]:
# Verify ESMplusplus_* block structure across ALL blocks and confirm LoRA targets.
# This cell:
#  - Loads the model
#  - Checks each transformer.blocks.<i> submodule types
#  - Asserts the expected kinds (Linear / LayerNorm / None)
#  - Prints a summary and recommended target_modules for PEFT

import torch
import torch.nn as nn
from transformers import AutoModelForMaskedLM

MODEL_ID = "Synthyra/ESMplusplus_large"  # or "Synthyra/ESMplusplus_small"
DEVICE   = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE    = torch.float32

print(f"Loading {MODEL_ID} on {DEVICE} …")
model = AutoModelForMaskedLM.from_pretrained(MODEL_ID, trust_remote_code=True, torch_dtype=DTYPE)
model.to(DEVICE).eval()

def kind_of(mod: nn.Module):
    if mod is None:
        return "MISSING"
    if isinstance(mod, nn.Linear):
        return "Linear"
    if isinstance(mod, nn.LayerNorm):
        return "LayerNorm"
    # count direct params to detect paramless helpers
    nparams = sum(p.numel() for _, p in mod.named_parameters(recurse=False))
    return "None" if nparams == 0 else type(mod).__name__

# Discover how many blocks exist
try:
    n_blocks = len(model.transformer.blocks)
except Exception:
    # Fallback: infer from names
    n_blocks = max(
        int(n.split(".")[2]) for n, _ in model.named_modules()
        if n.startswith("transformer.blocks.") and n.split(".")[2].isdigit()
    ) + 1

print(f"Discovered {n_blocks} transformer blocks")

expect = {
    "attn.layernorm_qkv.0": "LayerNorm",
    "attn.layernorm_qkv.1": "Linear",
    "attn.out_proj":        "Linear",
    "attn.q_ln":            "LayerNorm",
    "attn.k_ln":            "LayerNorm",
    "attn.rotary":          "None",
    "ffn.0":                "LayerNorm",   # <-- IMPORTANT: LN in this model
    "ffn.1":                "Linear",
    "ffn.2":                "None",
    "ffn.3":                "Linear",
}

name_to_module = {n: m for n, m in model.named_modules()}
fails = []

for i in range(n_blocks):
    prefix = f"transformer.blocks.{i}"
    for sub, exp in expect.items():
        full = f"{prefix}.{sub}"
        mod = name_to_module.get(full)
        got = kind_of(mod)
        if got != exp:
            fails.append((full, exp, got))

if not fails:
    print("ALL BLOCKS MATCH EXPECTATIONS ✅")
else:
    print(f"Found {len(fails)} mismatches ❌")
    for full, exp, got in fails[:50]:
        print(f"  {full}: expected {exp}, got {got}")
    if len(fails) > 50:
        print("  … (truncated)")

# Recommended LoRA targets for this model:
TARGET_MODULES = ["attn.layernorm_qkv.1", "attn.out_proj", "ffn.1", "ffn.3"]
print("\nRecommended target_modules for LoraConfig:")
print(TARGET_MODULES)


Loading Synthyra/ESMplusplus_large on cpu…

=== Inspecting transformer.blocks.0 ===
transformer.blocks.0.attn.layernorm_qkv.0        → kind=LayerNorm   nparams=2304         shapes={'weight': (1152,), 'bias': (1152,)}   [OK]
transformer.blocks.0.attn.layernorm_qkv.1        → kind=Linear      nparams=3981312      shapes={'weight': (3456, 1152)}   [OK]
transformer.blocks.0.attn.out_proj               → kind=Linear      nparams=1327104      shapes={'weight': (1152, 1152)}   [OK]
transformer.blocks.0.attn.q_ln                   → kind=LayerNorm   nparams=1152         shapes={'weight': (1152,)}   [OK]
transformer.blocks.0.attn.k_ln                   → kind=LayerNorm   nparams=1152         shapes={'weight': (1152,)}   [OK]
transformer.blocks.0.attn.rotary                 → kind=None        nparams=0            shapes={}   [OK]
transformer.blocks.0.ffn.0                       → kind=LayerNorm   nparams=2304         shapes={'weight': (1152,), 'bias': (1152,)}   [FAIL: transformer.blocks.0.ffn.0