# PottsMPNN sanity checks

This notebook provides quick smoke tests for the Potts/structure losses and optional ESM + FAPE components.

In [1]:
import torch
from training.struct_potts_losses import (
    msa_similarity_loss,
    msa_similarity_loss_esm,
    structure_consistency_loss,
    msa_similarity_loss_esmc,
    structure_fape_loss,
    potts_consistency_loss,
    expand_etab_dense,
)

torch.manual_seed(0)
B, L, V, M = 2, 8, 22, 4
log_probs = torch.log_softmax(torch.randn(B, L, V), dim=-1)
msa_tokens = torch.randint(0, V, (B, M, L))
msa_mask = torch.ones(B, M, L)
seq_mask = torch.ones(B, L)

baseline_loss = msa_similarity_loss(log_probs, msa_tokens, msa_mask, seq_mask, margin=0.1)
print(f"Baseline MSA loss: {baseline_loss.item():.4f}")

positions = torch.randn(B, L, 4, 3)
X = positions + 0.1 * torch.randn_like(positions)
mask = torch.ones(B, L)
ca_loss = structure_consistency_loss(positions, X, mask)
print(f"CA loss: {ca_loss.item():.4f}")

# Standalone Potts consistency sanity check.
K, H = 3, 21 * 21
etab_geom = torch.randn(B, L, K, H)
e_idx = torch.randint(0, L, (B, L, K))
etab_seq_dense_match = expand_etab_dense(etab_geom, e_idx)
loss_potts_match = potts_consistency_loss(etab_geom, e_idx, etab_seq_dense_match, mask)
print(f"Potts loss (matched targets): {loss_potts_match.item():.6f}")
assert torch.isclose(loss_potts_match, torch.tensor(0.0), atol=1e-6)

etab_seq_dense_noisy = etab_seq_dense_match + 0.1 * torch.randn_like(etab_seq_dense_match)
loss_potts_noisy = potts_consistency_loss(etab_geom, e_idx, etab_seq_dense_noisy, mask)
print(f"Potts loss (noisy targets): {loss_potts_noisy.item():.6f}")
assert loss_potts_noisy > loss_potts_match


Baseline MSA loss: 0.2342
CA loss: 0.0302


In [2]:
import numpy as np
from training.boltz2_features import build_boltz2_item_feats, collate_boltz2_feats
from training.training_struct_potts import validate_boltz2_alignment

dummy_item = {
    'seq': 'ACD',
    'chain_order': ['A'],
    'chain_lengths': [3],
    'seq_chain_A': 'ACD',
    'atom14_xyz': np.zeros((3, 14, 3), dtype=np.float32),
    'atom14_mask': np.ones((3, 14), dtype=np.float32),
}
dummy_feats = build_boltz2_item_feats(dummy_item)
batched_feats = collate_boltz2_feats([dummy_feats])
validate_boltz2_alignment(batched_feats, torch.ones(1, 3))
assert batched_feats['msa'].shape[-1] == 3
print('Boltz2 feature alignment check passed.')


length:  3
seq:  ACD
Boltz2 feature alignment check passed.


In [4]:
# Optional: ESM-C-based MSA similarity loss
try:
    import esm
    from training.training_struct_potts import build_esm_token_map
    from training.struct_potts_losses import msa_similarity_loss_esmc

    from esm.models.esmc import ESMC
    model = ESMC.from_pretrained("esmc_300m")
    model.eval().cpu()
    tokenizer = model.tokenizer
    token_map = build_esm_token_map(tokenizer, 'ACDEFGHIKLMNPQRSTVWYX-')

    log_probs_esmc = log_probs.detach().clone().requires_grad_(True)
    esm_loss = msa_similarity_loss_esmc(
        log_probs_esmc,
        msa_tokens,
        msa_mask,
        seq_mask,
        model,
        token_map,
        margin=0.1,
    )
    esm_loss.backward()
    grad_norm = log_probs_esmc.grad.norm().item()
    print(f"ESM-C MSA loss: {esm_loss.item():.4f}, grad norm: {grad_norm:.4f}")
except Exception as exc:
    print(f"ESM-C not available: {exc}")


Fetching 4 files: 100%|██████████| 4/4 [00:00<00:00, 57456.22it/s]


ESM-C MSA loss: 0.1013, grad norm: 0.0031


In [5]:
# Optional: FAPE loss (requires OpenFold)
try:
    frames = torch.randn(1, B, L, 4, 4)
    backbone_4x4 = torch.randn(B, L, 4, 4)
    fape_loss = structure_fape_loss(frames, backbone_4x4, mask)
    print(f"FAPE loss: {fape_loss.item():.4f}")
except Exception as exc:
    print(f"OpenFold not available: {exc}")

FAPE loss: 0.4707


In [None]:
# Load a single batch, run the model, and evaluate all major losses
# This path now computes etab_seq_dense the same way as training (Boltz2 trunk + SequencePottsHead).
# try:
from training.utils import (
    worker_init_fn,
    loader_pdb,
    build_training_clusters,
    PDB_dataset,
    StructureDataset,
    StructureLoader,
    get_pdbs,
)
from training.model_utils_struct import ProteinMPNN, featurize
from training.boltz2_adapter import Boltz2TrunkAdapter, SequencePottsHead
from training.training_struct_potts import (
    load_boltz2_checkpoint,
    get_boltz2_feats,
    validate_boltz2_alignment,
)
import os
import torch

data_path = '/mnt/shared/fosterb/ProteinMPNN/data/pdb_2021aug02'  # update for your local data
boltz2_checkpoint = '/mnt/shared/fosterb/boltz2/boltz2.ckpt'  # update for your local checkpoint
boltz2_recycles = 1
params = {
    'LIST': f'{data_path}/list.csv',
    'VAL': f'{data_path}/valid_clusters.txt',
    'TEST': f'{data_path}/test_clusters.txt',
    'DIR': f'{data_path}',
    'DATCUT': '2030-Jan-01',
    'RESCUT': 3.5,
    'HOMO': 0.70,
}

if not os.path.exists(boltz2_checkpoint):
    raise FileNotFoundError(
        f'Boltz2 checkpoint not found at {boltz2_checkpoint}. '
        'Set boltz2_checkpoint to your local checkpoint to run this integration test.'
    )

train, _, _ = build_training_clusters(params, debug=True)
train_set = PDB_dataset(list(train.keys()), loader_pdb, train, params)
train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=1, shuffle=True, worker_init_fn=worker_init_fn
)

pdb_dict_train = get_pdbs(train_loader, repeat=1, max_length=1000, num_units=1)
dataset = StructureDataset(pdb_dict_train, truncate=1, max_length=1000)
loader = StructureLoader(dataset, batch_size=1)
batch = next(iter(loader))

# Use the same Boltz2 features consumed in training.
boltz2_feats = get_boltz2_feats(batch)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
boltz2_feats = {
    k: (v.to(device) if torch.is_tensor(v) else v)
    for k, v in boltz2_feats.items()
}

model = ProteinMPNN(
    node_features=128,
    edge_features=128,
    hidden_dim=128,
    num_encoder_layers=3,
    num_decoder_layers=3,
    k_neighbors=48,
    dropout=0.1,
    augment_eps=0.2,
    use_potts=True,
    struct_predict=True,
    struct_use_decoder_one_hot=True,
).to(device)
model.eval()

boltz2_model = load_boltz2_checkpoint(boltz2_checkpoint, device)
boltz2_model = boltz2_model.to(device)
boltz2_model.eval()
boltz2_trunk = Boltz2TrunkAdapter.from_boltz2_model(boltz2_model).to(device)
seq_potts_head = SequencePottsHead(
    pair_dim=boltz2_model.hparams.token_z,
    potts_dim=400,
).to(device)
seq_potts_head.eval()

# Build OpenFold backbones so FAPE can be tested on the same batch.
(
    X,
    S,
    _,
    mask,
    lengths,
    chain_M,
    residue_idx,
    mask_self,
    chain_encoding_all,
    _,
    backbone_4x4,
    _,
) = featurize(
    batch,
    device,
    augment_type='atomic',
    augment_eps=0.2,
    replicate=1,
    epoch=0,
    openfold_backbone=True,
)

X = X.to(device)
S = S.to(device)
mask = mask.to(device)
chain_M = chain_M.to(device)
residue_idx = residue_idx.to(device)
mask_self = mask_self.to(device)
chain_encoding_all = chain_encoding_all.to(device)
backbone_4x4 = backbone_4x4.to(device)

validate_boltz2_alignment(boltz2_feats, mask)

with torch.no_grad():
    trunk_out = boltz2_trunk(boltz2_feats, boltz2_recycles)
    etab_seq_dense = seq_potts_head(trunk_out.z_trunk)

    log_probs, etab_geom, e_idx, frames, positions, logits = model(
        X,
        S,
        mask,
        chain_M,
        residue_idx,
        chain_encoding_all,
        return_logits=True,
    )

print(f'positions shape: {None if positions is None else tuple(positions.shape)}')
print('The first dimension in positions is trajectory/recycle step; the loss uses the final step.')

loss_msa = msa_similarity_loss(
    log_probs,
    boltz2_feats['msa'],
    boltz2_feats['msa_mask'],
    mask,
)

# Main-path Potts check: use the same Boltz2-derived dense Potts target as training.
loss_potts = potts_consistency_loss(etab_geom, e_idx, etab_seq_dense, mask)
loss_struct_ca = structure_consistency_loss(positions, X, mask)
loss_struct_fape = structure_fape_loss(frames, backbone_4x4, mask)
print(
    f'Losses -> MSA: {loss_msa.item():.4f}, '
    f'Potts: {loss_potts.item():.4f}, '
    f'Struct (CA): {loss_struct_ca.item():.4f}, '
    f'Struct (FAPE): {loss_struct_fape.item():.4f}'
)
assert torch.isfinite(loss_potts)
# except Exception as exc:
#     print(f'Full model smoke test skipped: {exc}')
