# PottsMPNN sanity checks

This notebook provides quick smoke tests for the Potts/structure losses and optional ESM + FAPE components.

In [None]:
import torch
from training.struct_potts_losses import (
    msa_similarity_loss,
    msa_similarity_loss_esm,
    structure_consistency_loss,
    structure_fape_loss,
)

torch.manual_seed(0)
B, L, V, M = 2, 8, 22, 4
log_probs = torch.log_softmax(torch.randn(B, L, V), dim=-1)
msa_tokens = torch.randint(0, V, (B, M, L))
msa_mask = torch.ones(B, M, L)
seq_mask = torch.ones(B, L)

baseline_loss = msa_similarity_loss(log_probs, msa_tokens, msa_mask, seq_mask, margin=0.1)
print(f"Baseline MSA loss: {baseline_loss.item():.4f}")

positions = torch.randn(B, L, 4, 3)
X = positions + 0.1 * torch.randn_like(positions)
mask = torch.ones(B, L)
ca_loss = structure_consistency_loss(positions, X, mask)
print(f"CA loss: {ca_loss.item():.4f}")

In [None]:
# Optional: ESM-based MSA similarity loss
try:
    import esm

    model, alphabet = esm.pretrained.esm2_t6_8M_UR50D()
    model.eval()
    token_map = torch.tensor([alphabet.get_idx(aa) for aa in 'ACDEFGHIKLMNPQRSTVWYX-'])
    esm_loss = msa_similarity_loss_esm(
        log_probs,
        msa_tokens,
        msa_mask,
        seq_mask,
        model,
        token_map,
        margin=0.1,
    )
    print(f"ESM MSA loss: {esm_loss.item():.4f}")
except Exception as exc:
    print(f"ESM not available: {exc}")

In [None]:
# Optional: FAPE loss (requires OpenFold)
try:
    frames = torch.randn(1, B, L, 4, 4)
    backbone_4x4 = torch.randn(B, L, 4, 4)
    fape_loss = structure_fape_loss(frames, backbone_4x4, mask)
    print(f"FAPE loss: {fape_loss.item():.4f}")
except Exception as exc:
    print(f"OpenFold not available: {exc}")

In [None]:
# Load a single batch, run the model, and evaluate losses
try:
    from training.utils import (
        worker_init_fn,
        loader_pdb,
        build_training_clusters,
        PDB_dataset,
        StructureDataset,
        StructureLoader,
    )
    from training.model_utils_struct import ProteinMPNN, featurize
    import torch

    data_path = 'my_path/pdb_2021aug02'  # update for your local data
    params = {
        'LIST': f'{data_path}/list.csv',
        'VAL': f'{data_path}/valid_clusters.txt',
        'TEST': f'{data_path}/test_clusters.txt',
        'DIR': f'{data_path}',
        'DATCUT': '2030-Jan-01',
        'RESCUT': 3.5,
        'HOMO': 0.70,
    }

    train, _, _ = build_training_clusters(params, debug=True)
    train_set = PDB_dataset(list(train.keys()), loader_pdb, train, params)
    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size=1, shuffle=True, worker_init_fn=worker_init_fn
    )

    pdb_dict = next(iter(train_loader))
    dataset = StructureDataset(pdb_dict, truncate=None, max_length=1000)
    loader = StructureLoader(dataset, batch_size=1)
    batch = next(iter(loader))

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    model = ProteinMPNN(
        node_features=128,
        edge_features=128,
        hidden_dim=128,
        num_encoder_layers=3,
        num_decoder_layers=3,
        k_neighbors=48,
        dropout=0.1,
        augment_eps=0.2,
        use_potts=True,
        struct_predict=True,
        struct_use_decoder_one_hot=True,
    ).to(device)
    model.eval()

    (
        X,
        S,
        _,
        mask,
        lengths,
        chain_M,
        residue_idx,
        mask_self,
        chain_encoding_all,
        _,
        backbone_4x4,
        _,
    ) = featurize(
        batch,
        device,
        augment_type='atomic',
        augment_eps=0.2,
        replicate=1,
        epoch=0,
        openfold_backbone=False,
    )

    with torch.no_grad():
        log_probs, etab_geom, e_idx, frames, positions, logits = model(
            X,
            S,
            mask,
            chain_M,
            residue_idx,
            chain_encoding_all,
            return_logits=True,
        )

    loss_msa = msa_similarity_loss(log_probs, torch.randint(0, 22, (1, 2, log_probs.shape[1])), torch.ones(1, 2, log_probs.shape[1]), mask)
    loss_struct = structure_consistency_loss(positions, X, mask)
    print(f'Losses -> MSA: {loss_msa.item():.4f}, Struct: {loss_struct.item():.4f}')
except Exception as exc:
    print(f'Full model smoke test skipped: {exc}')
