# PottsMPNN sanity checks

This notebook provides quick smoke tests for the Potts/structure losses and optional ESM + FAPE components.

In [None]:
import torch
from training.struct_potts_losses import (
    msa_similarity_loss,
    msa_similarity_loss_esm,
    structure_consistency_loss,
    structure_fape_loss,
)

torch.manual_seed(0)
B, L, V, M = 2, 8, 22, 4
log_probs = torch.log_softmax(torch.randn(B, L, V), dim=-1)
msa_tokens = torch.randint(0, V, (B, M, L))
msa_mask = torch.ones(B, M, L)
seq_mask = torch.ones(B, L)

baseline_loss = msa_similarity_loss(log_probs, msa_tokens, msa_mask, seq_mask, margin=0.1)
print(f"Baseline MSA loss: {baseline_loss.item():.4f}")

positions = torch.randn(B, L, 4, 3)
X = positions + 0.1 * torch.randn_like(positions)
mask = torch.ones(B, L)
ca_loss = structure_consistency_loss(positions, X, mask)
print(f"CA loss: {ca_loss.item():.4f}")

In [None]:
# Optional: ESM-based MSA similarity loss
try:
    import esm

    model, alphabet = esm.pretrained.esm2_t6_8M_UR50D()
    model.eval()
    token_map = torch.tensor([alphabet.get_idx(aa) for aa in 'ACDEFGHIKLMNPQRSTVWYX-'])
    esm_loss = msa_similarity_loss_esm(
        log_probs,
        msa_tokens,
        msa_mask,
        seq_mask,
        model,
        token_map,
        margin=0.1,
    )
    print(f"ESM MSA loss: {esm_loss.item():.4f}")
except Exception as exc:
    print(f"ESM not available: {exc}")

In [None]:
# Optional: FAPE loss (requires OpenFold)
try:
    frames = torch.randn(1, B, L, 4, 4)
    backbone_4x4 = torch.randn(B, L, 4, 4)
    fape_loss = structure_fape_loss(frames, backbone_4x4, mask)
    print(f"FAPE loss: {fape_loss.item():.4f}")
except Exception as exc:
    print(f"OpenFold not available: {exc}")