In [1]:
# === ESM-MSA introspection for LoRA targeting ===
import sys, os, platform, re, textwrap, json
from pathlib import Path

import torch
import torch.nn as nn

# If you have esm installed as `esm` (facebookresearch/esm)
import esm

OUT_PATH = Path("msa_introspection.txt")

def header(title):
    line = "=" * len(title)
    return f"\n{title}\n{line}\n"

def env_report():
    try:
        import importlib.metadata as im
        esm_ver = im.version("esm")
    except Exception:
        esm_ver = "unknown"
    return {
        "python": sys.version.replace("\n", " "),
        "platform": platform.platform(),
        "torch": torch.__version__,
        "esm": esm_ver,
        "cuda_available": torch.cuda.is_available(),
        "cuda_device_count": torch.cuda.device_count(),
    }

def print_tree(model, max_lines=400):
    """Render a readable tree of named_modules (truncated)."""
    lines = []
    for name, module in model.named_modules():
        mod = type(module).__name__
        lines.append(f"{name or '<root>'} :: {mod}")
    truncated = lines[:max_lines]
    body = "\n".join(truncated)
    if len(lines) > max_lines:
        body += f"\n... ({len(lines)-max_lines} more lines truncated)"
    return body

def find_proj_modules(model):
    """
    Find attention projection linears named q_proj/k_proj/v_proj/out_proj
    and group them by their parent block (e.g., row_attn / col_attn).
    """
    hits = []
    for name, module in model.named_modules():
        for proj in ("q_proj", "k_proj", "v_proj", "out_proj"):
            if hasattr(module, proj) and isinstance(getattr(module, proj), nn.Linear):
                lin = getattr(module, proj)
                hits.append({
                    "parent": name,
                    "proj_name": proj,
                    "shape": tuple(lin.weight.shape),
                    "bias": getattr(lin, "bias", None) is not None,
                    "type": type(lin).__name__,
                })
    return hits

def group_row_col(hits):
    """
    Try to separate row/col attention based on parent name substrings.
    This is robust to common ESM naming like 'layers.0.row_attn' / 'layers.0.col_attn'.
    """
    row, col, other = [], [], []
    for h in hits:
        parent = h["parent"].lower()
        if "row" in parent:
            row.append(h)
        elif "col" in parent or "column" in parent:
            col.append(h)
        else:
            other.append(h)
    return row, col, other

def list_all_linears(model):
    items = []
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            items.append({
                "name": name,
                "shape": tuple(module.weight.shape),
                "bias": module.bias is not None
            })
    return items

def dump_named_parameters(model, pattern=None, max_lines=200):
    lines = []
    for name, p in model.named_parameters():
        if (pattern is None) or re.search(pattern, name):
            lines.append(f"{name} :: {tuple(p.shape)} :: requires_grad={p.requires_grad}")
    lines = lines[:max_lines] + (["... (truncated)"] if len(lines) > max_lines else [])
    return "\n".join(lines)

# ---------- Load model ----------
msa_model, msa_alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S()
msa_model.eval().requires_grad_(False)

# ---------- Collect info ----------
env = env_report()
tree = print_tree(msa_model, max_lines=800)  # bump if you want more
hits = find_proj_modules(msa_model)
row_hits, col_hits, other_hits = group_row_col(hits)
linears = list_all_linears(msa_model)
params_glimpse = dump_named_parameters(msa_model, pattern=r"(q_proj|k_proj|v_proj|out_proj)", max_lines=300)

# ---------- Build a human-readable report ----------
report = []
report.append(header("Environment"))
report.append(json.dumps(env, indent=2))

report.append(header("Model Class"))
report.append(repr(type(msa_model)))

report.append(header("Named Modules (tree; truncated)"))
report.append(tree)

report.append(header("Attention Projection Linears (q/k/v/out) — ALL HITS"))
for h in hits:
    report.append(f"{h['parent']}.{h['proj_name']} :: {h['shape']} :: bias={h['bias']}")

report.append(header("Row Attention Projections"))
for h in row_hits:
    report.append(f"{h['parent']}.{h['proj_name']} :: {h['shape']} :: bias={h['bias']}")

report.append(header("Column Attention Projections"))
for h in col_hits:
    report.append(f"{h['parent']}.{h['proj_name']} :: {h['shape']} :: bias={h['bias']}")

if other_hits:
    report.append(header("Other Attention-like Projections (neither row nor col by name)"))
    for h in other_hits:
        report.append(f"{h['parent']}.{h['proj_name']} :: {h['shape']} :: bias={h['bias']}")

report.append(header("All nn.Linear layers (for optional LoRA targets; truncated to 500)"))
for item in linears[:500]:
    report.append(f"{item['name']} :: {item['shape']} :: bias={item['bias']}")
if len(linears) > 500:
    report.append(f"... ({len(linears)-500} more linears truncated)")

report.append(header("Named parameters glimpse (q/k/v/out only; truncated)"))
report.append(params_glimpse)

text = "\n".join(report)

# ---------- Save & show ----------
OUT_PATH.write_text(text)
print(text[:4000])
print(f"\n\n=== Saved full report to: {OUT_PATH.resolve()} ===")



Environment

{
  "python": "3.10.18 (main, Jun  5 2025, 13:14:17) [GCC 11.2.0]",
  "platform": "Linux-6.8.0-1033-gcp-x86_64-with-glibc2.31",
  "torch": "2.7.1+cu126",
  "esm": "unknown",
  "cuda_available": false,
  "cuda_device_count": 0
}

Model Class

<class 'esm.model.msa_transformer.MSATransformer'>

Named Modules (tree; truncated)

<root> :: MSATransformer
embed_tokens :: Embedding
dropout_module :: Dropout
layers :: ModuleList
layers.0 :: AxialTransformerLayer
layers.0.row_self_attention :: NormalizedResidualBlock
layers.0.row_self_attention.layer :: RowSelfAttention
layers.0.row_self_attention.layer.k_proj :: Linear
layers.0.row_self_attention.layer.v_proj :: Linear
layers.0.row_self_attention.layer.q_proj :: Linear
layers.0.row_self_attention.layer.out_proj :: Linear
layers.0.row_self_attention.layer.dropout_module :: Dropout
layers.0.row_self_attention.dropout_module :: Dropout
layers.0.row_self_attention.layer_norm :: LayerNorm
layers.0.column_self_attention :: NormalizedRe

In [1]:
# --- ESM-C Introspection for LoRA Targeting ---
import torch, re, sys
from typing import List, Tuple
from esm.models.esmc import ESMC
from torch import nn

def is_linear(m): 
    return isinstance(m, nn.Linear)

def find_transformer_blocks(core: nn.Module) -> List[nn.Module]:
    """
    Try to locate the list of transformer blocks.
    Works if the model exposes a ModuleList of blocks or
    otherwise collects any module that has a 'self_attn' attr.
    """
    # Heuristic 1: a ModuleList whose first item has 'self_attn'
    for name, mod in core.named_modules():
        if isinstance(mod, nn.ModuleList) and len(mod) > 0:
            first = mod[0]
            if hasattr(first, "self_attn"):
                return list(mod)
    # Heuristic 2: collect any module that has 'self_attn'
    blocks = []
    for _, mod in core.named_modules():
        if hasattr(mod, "self_attn"):
            blocks.append(mod)
    return blocks

def preview_attn_linears(core: nn.Module) -> List[Tuple[str, nn.Linear]]:
    """
    Return (name, module) for attention-related linears commonly used in ESM/ESMC.
    """
    hits = []
    for name, mod in core.named_modules():
        if is_linear(mod):
            if any(key in name for key in (
                "self_attn.q_proj","self_attn.k_proj","self_attn.v_proj","self_attn.out_proj",
                "attn.q_proj","attn.k_proj","attn.v_proj","attn.out_proj"
            )):
                hits.append((name, mod))
    return hits

def detect_fused_inproj(core: nn.Module) -> List[str]:
    """
    Find modules that look like Fairseq-style fused projection attention
    (i.e., have 'in_proj_weight' or 'in_proj_bias' attributes).
    """
    fused = []
    for name, mod in core.named_modules():
        if hasattr(mod, "in_proj_weight") or hasattr(mod, "in_proj_bias"):
            fused.append(name)
    return fused

def list_layernorms(core: nn.Module) -> List[str]:
    return [name for name, m in core.named_modules() if isinstance(m, nn.LayerNorm)]

def fmt(s, char="="):
    return f"\n{char*8} {s} {char*8}"

def main():
    print(fmt("Loading ESM-C"))
    model = ESMC.from_pretrained("esmc_600m")
    core = getattr(model, "model", model)  # some SDKs wrap the raw torch module here

    # Basic model info
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total params: {total:,} | Trainable (now): {trainable:,}")
    print(f"Model class: {model.__class__.__name__}")
    print(f"Core class:  {core.__class__.__name__}")

    # Locate transformer blocks
    blocks = find_transformer_blocks(core)
    print(fmt("Transformer Blocks"))
    print(f"Num blocks found: {len(blocks)}")
    if len(blocks) > 0:
        # Print a quick sketch of the first, middle, and last block types
        idxs = sorted(set([0, len(blocks)//2, len(blocks)-1]))
        for i in idxs:
            b = blocks[i]
            print(f"- Block[{i}] type: {b.__class__.__name__}  (attrs: {', '.join(sorted(set(dir(b)) & {'self_attn','fc1','fc2','norm1','norm2'}))})")

    # Find attention linears
    print(fmt("Attention Linears (separate q/k/v/out)"))
    hits = preview_attn_linears(core)
    if hits:
        for name, mod in hits[:50]:  # print first 50 for sanity
            print(f"{name:60s}  Linear({mod.in_features} -> {mod.out_features})")
        if len(hits) > 50:
            print(f"... and {len(hits)-50} more")
    else:
        print("No explicit q_proj/k_proj/v_proj/out_proj linears found.")

    # Fused attention?
    print(fmt("Potential Fused In-Proj Attentions"))
    fused = detect_fused_inproj(core)
    if fused:
        for name in fused:
            print(f"{name}")
    else:
        print("No modules with in_proj_weight/in_proj_bias detected.")

    # LayerNorms
    print(fmt("LayerNorm Modules"))
    ln_names = list_layernorms(core)
    print(f"Found {len(ln_names)} LayerNorms.")
    for n in ln_names[:40]:
        print(n)
    if len(ln_names) > 40:
        print(f"... and {len(ln_names)-40} more")

    # Per-block quick look at attention fields
    print(fmt("Per-Block Attention Fields (quick check)"))
    for i, b in enumerate(blocks[:12]):  # first 12 blocks only to keep output manageable
        sa = getattr(b, "self_attn", None)
        fields = []
        if sa is not None:
            for attr in ("q_proj","k_proj","v_proj","out_proj","in_proj_weight","in_proj_bias"):
                if hasattr(sa, attr):
                    fields.append(attr)
        print(f"Block[{i}]: self_attn fields -> {fields if fields else 'N/A'}")

    print(fmt("Instructions"))
    print(
        "Paste this entire output back to me.\n"
        "- If you see names like '*.self_attn.k_proj' and '*.self_attn.v_proj', we can LoRA those directly.\n"
        "- If instead you see 'in_proj_weight', we'll use a fused-projection LoRA patch.\n"
        "- The block count tells us how many 'last_n' layers we can reasonably target."
    )

if __name__ == "__main__":
    main()





Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

Total params: 575,036,992 | Trainable (now): 575,036,992
Model class: ESMC
Core class:  ESMC

Num blocks found: 0

transformer.blocks.0.attn.out_proj                            Linear(1152 -> 1152)
transformer.blocks.1.attn.out_proj                            Linear(1152 -> 1152)
transformer.blocks.2.attn.out_proj                            Linear(1152 -> 1152)
transformer.blocks.3.attn.out_proj                            Linear(1152 -> 1152)
transformer.blocks.4.attn.out_proj                            Linear(1152 -> 1152)
transformer.blocks.5.attn.out_proj                            Linear(1152 -> 1152)
transformer.blocks.6.attn.out_proj                            Linear(1152 -> 1152)
transformer.blocks.7.attn.out_proj                            Linear(1152 -> 1152)
transformer.blocks.8.attn.out_proj                            Linear(1152 -> 1152)
transformer.blocks.9.attn.out_proj                            Linear(1152 -> 1152)
transformer.blocks.10.attn.out_proj                    