# 02_add_crosslinks.ipynb

Load the baseline model, load crosslinks, and run guided refinement (toy optimizer).

In [None]:
from pathlib import Path
import pandas as pd
import torch
from main import run_af3
from restraints import crosslink_loss

# Load sequences and baseline coords
fasta = '../data/sequences/barnase_barstar.fasta'
coords = run_af3(fasta)

# Load crosslinks
df = pd.read_csv('../data/crosslinks/barnase_barstar.tsv', sep='\t')
crosslinks = df.to_records(index=False)

# Prepare optimizable coords
coords_opt = {c: coords[c].clone().requires_grad_(True) for c in coords}
optimizer = torch.optim.Adam([coords_opt[c] for c in coords_opt], lr=1e-3)

# Run a short guided optimization and print loss
for step in range(100):
    optimizer.zero_grad()
    loss = crosslink_loss(coords_opt, crosslinks, weight=10.0)
    loss.backward()
    optimizer.step()
    if step % 20 == 0:
        print(f'Step {step} loss {loss.item():.3f}')

# Save guided output
out_pdb = '../outputs/run_crosslink_guided.pdb'
with open(out_pdb, 'w') as f:
    atom_id = 1
    for chain, xyz in coords_opt.items():
        for i, (x,y,z) in enumerate(xyz.detach().numpy(), start=1):
            f.write(f"ATOM  {atom_id:5d}  CA  ALA {chain}{i:4d}    {x:8.3f}{y:8.3f}{z:8.3f}\n")
            atom_id += 1

print('Saved crosslink-guided PDB to', out_pdb)