In [1]:
# === ESM-MSA introspection for LoRA targeting ===
import sys, os, platform, re, textwrap, json
from pathlib import Path

import torch
import torch.nn as nn

# If you have esm installed as `esm` (facebookresearch/esm)
import esm

OUT_PATH = Path("msa_introspection.txt")

def header(title):
    line = "=" * len(title)
    return f"\n{title}\n{line}\n"

def env_report():
    try:
        import importlib.metadata as im
        esm_ver = im.version("esm")
    except Exception:
        esm_ver = "unknown"
    return {
        "python": sys.version.replace("\n", " "),
        "platform": platform.platform(),
        "torch": torch.__version__,
        "esm": esm_ver,
        "cuda_available": torch.cuda.is_available(),
        "cuda_device_count": torch.cuda.device_count(),
    }

def print_tree(model, max_lines=400):
    """Render a readable tree of named_modules (truncated)."""
    lines = []
    for name, module in model.named_modules():
        mod = type(module).__name__
        lines.append(f"{name or '<root>'} :: {mod}")
    truncated = lines[:max_lines]
    body = "\n".join(truncated)
    if len(lines) > max_lines:
        body += f"\n... ({len(lines)-max_lines} more lines truncated)"
    return body

def find_proj_modules(model):
    """
    Find attention projection linears named q_proj/k_proj/v_proj/out_proj
    and group them by their parent block (e.g., row_attn / col_attn).
    """
    hits = []
    for name, module in model.named_modules():
        for proj in ("q_proj", "k_proj", "v_proj", "out_proj"):
            if hasattr(module, proj) and isinstance(getattr(module, proj), nn.Linear):
                lin = getattr(module, proj)
                hits.append({
                    "parent": name,
                    "proj_name": proj,
                    "shape": tuple(lin.weight.shape),
                    "bias": getattr(lin, "bias", None) is not None,
                    "type": type(lin).__name__,
                })
    return hits

def group_row_col(hits):
    """
    Try to separate row/col attention based on parent name substrings.
    This is robust to common ESM naming like 'layers.0.row_attn' / 'layers.0.col_attn'.
    """
    row, col, other = [], [], []
    for h in hits:
        parent = h["parent"].lower()
        if "row" in parent:
            row.append(h)
        elif "col" in parent or "column" in parent:
            col.append(h)
        else:
            other.append(h)
    return row, col, other

def list_all_linears(model):
    items = []
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            items.append({
                "name": name,
                "shape": tuple(module.weight.shape),
                "bias": module.bias is not None
            })
    return items

def dump_named_parameters(model, pattern=None, max_lines=200):
    lines = []
    for name, p in model.named_parameters():
        if (pattern is None) or re.search(pattern, name):
            lines.append(f"{name} :: {tuple(p.shape)} :: requires_grad={p.requires_grad}")
    lines = lines[:max_lines] + (["... (truncated)"] if len(lines) > max_lines else [])
    return "\n".join(lines)

# ---------- Load model ----------
msa_model, msa_alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S()
msa_model.eval().requires_grad_(False)

# ---------- Collect info ----------
env = env_report()
tree = print_tree(msa_model, max_lines=800)  # bump if you want more
hits = find_proj_modules(msa_model)
row_hits, col_hits, other_hits = group_row_col(hits)
linears = list_all_linears(msa_model)
params_glimpse = dump_named_parameters(msa_model, pattern=r"(q_proj|k_proj|v_proj|out_proj)", max_lines=300)

# ---------- Build a human-readable report ----------
report = []
report.append(header("Environment"))
report.append(json.dumps(env, indent=2))

report.append(header("Model Class"))
report.append(repr(type(msa_model)))

report.append(header("Named Modules (tree; truncated)"))
report.append(tree)

report.append(header("Attention Projection Linears (q/k/v/out) — ALL HITS"))
for h in hits:
    report.append(f"{h['parent']}.{h['proj_name']} :: {h['shape']} :: bias={h['bias']}")

report.append(header("Row Attention Projections"))
for h in row_hits:
    report.append(f"{h['parent']}.{h['proj_name']} :: {h['shape']} :: bias={h['bias']}")

report.append(header("Column Attention Projections"))
for h in col_hits:
    report.append(f"{h['parent']}.{h['proj_name']} :: {h['shape']} :: bias={h['bias']}")

if other_hits:
    report.append(header("Other Attention-like Projections (neither row nor col by name)"))
    for h in other_hits:
        report.append(f"{h['parent']}.{h['proj_name']} :: {h['shape']} :: bias={h['bias']}")

report.append(header("All nn.Linear layers (for optional LoRA targets; truncated to 500)"))
for item in linears[:500]:
    report.append(f"{item['name']} :: {item['shape']} :: bias={item['bias']}")
if len(linears) > 500:
    report.append(f"... ({len(linears)-500} more linears truncated)")

report.append(header("Named parameters glimpse (q/k/v/out only; truncated)"))
report.append(params_glimpse)

text = "\n".join(report)

# ---------- Save & show ----------
OUT_PATH.write_text(text)
print(text[:4000])
print(f"\n\n=== Saved full report to: {OUT_PATH.resolve()} ===")



Environment

{
  "python": "3.10.18 (main, Jun  5 2025, 13:14:17) [GCC 11.2.0]",
  "platform": "Linux-6.8.0-1033-gcp-x86_64-with-glibc2.31",
  "torch": "2.7.1+cu126",
  "esm": "unknown",
  "cuda_available": false,
  "cuda_device_count": 0
}

Model Class

<class 'esm.model.msa_transformer.MSATransformer'>

Named Modules (tree; truncated)

<root> :: MSATransformer
embed_tokens :: Embedding
dropout_module :: Dropout
layers :: ModuleList
layers.0 :: AxialTransformerLayer
layers.0.row_self_attention :: NormalizedResidualBlock
layers.0.row_self_attention.layer :: RowSelfAttention
layers.0.row_self_attention.layer.k_proj :: Linear
layers.0.row_self_attention.layer.v_proj :: Linear
layers.0.row_self_attention.layer.q_proj :: Linear
layers.0.row_self_attention.layer.out_proj :: Linear
layers.0.row_self_attention.layer.dropout_module :: Dropout
layers.0.row_self_attention.dropout_module :: Dropout
layers.0.row_self_attention.layer_norm :: LayerNorm
layers.0.column_self_attention :: NormalizedRe