# Building Struct2Seq from Scratch

This notebook walks through every step required to construct, train, evaluate, and sample from the Struct2Seq model using only the reusable modules provided in this repository. Rather than relying on the pre-written training script, we assemble the full pipeline manually to illustrate how each component fits together.

## 1. Environment Setup

We begin by importing core libraries, enabling deterministic behavior, and configuring the compute device.

In [None]:
import math
import os
from pathlib import Path

import numpy as np
import torch
from torch import nn
from torch.utils.data import Subset

from struct2seq import data as s2s_data
from struct2seq import struct2seq as s2s_model
from struct2seq import noam_opt
from experiments import utils as exp_utils

print(f'PyTorch version: {torch.__version__}')
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {DEVICE}')

def set_seed(seed: int = 7):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed()

## 2. Data Preparation

`StructureDataset` expects a JSONL file where each line contains the amino acid sequence and 3D coordinates for a single protein chain. If you do not have the NeurIPS 2019 dataset locally, point `dataset_path` to your own JSONL file produced with `data/build_chain_dataset.py`.

In [None]:
# Path to your JSONL dataset. Replace this with a real file before running.
dataset_path = Path('data/SPIN2/train.jsonl')
assert dataset_path.exists(), ("Dataset not found. Please set 'dataset_path' to your JSONL file.")

full_dataset = s2s_data.StructureDataset(str(dataset_path))
print(f'Total chains available: {len(full_dataset)}')
print('Example entry keys:', full_dataset[0].keys())

### Train/Validation Split

For reproducibility we perform a simple split. The built-in dataset already provides length-aware batching through `StructureLoader`, which keeps sequences of similar length together.

In [None]:
num_train = int(0.9 * len(full_dataset))
train_indices = np.arange(num_train)
valid_indices = np.arange(num_train, len(full_dataset))

train_dataset = Subset(full_dataset, train_indices)
valid_dataset = Subset(full_dataset, valid_indices)

batch_tokens = 2000  # approximate number of amino acids per batch
train_loader = s2s_data.StructureLoader(train_dataset, batch_tokens=batch_tokens, shuffle=True)
valid_loader = s2s_data.StructureLoader(valid_dataset, batch_tokens=batch_tokens, shuffle=False)

print(f'Train batches: {len(train_loader)}, Validation batches: {len(valid_loader)}')

## 3. Batch Featurization

`experiments.utils.featurize` pads variable-length chains and assembles the coordinate tensor `(X)`, amino acid indices `(S)`, valid-mask, and chain lengths.

In [None]:
def next_batch(loader):
    for batch in loader:
        X, S, mask, lengths = exp_utils.featurize(batch, device=DEVICE)
        yield X, S, mask, lengths

example_X, example_S, example_mask, example_lengths = next(next_batch(train_loader))
print('Coordinate tensor:', example_X.shape)
print('Sequence tensor:', example_S.shape)
print('Mask tensor:', example_mask.shape)
print('Lengths:', example_lengths)

## 4. Model Construction

We instantiate `Struct2Seq` by specifying the feature sizes, number of attention layers, neighborhood size, and other architectural details.

In [None]:
hidden_dim = 128
model = s2s_model.Struct2Seq(
    num_letters=20,
    node_features=hidden_dim,
    edge_features=hidden_dim,
    hidden_dim=hidden_dim,
    k_neighbors=30,
    protein_features='full',
    dropout=0.1,
    augment_eps=0.0,
    num_encoder_layers=3,
    num_decoder_layers=3,
    forward_attention_decoder=True,
    use_mpnn=False
).to(DEVICE)

print(model)
print(f'Total parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M')

## 5. Loss Function and Optimizer

Struct2Seq outputs log-probabilities over amino acid types. We therefore use negative log-likelihood with optional label smoothing. The `NoamOpt` wrapper reproduces the learning-rate schedule from the original Transformer paper.

In [None]:
optimizer = noam_opt.get_std_opt(model.parameters(), hidden_dim, factor=1.0, warmup=4000)
label_smoothing = 0.1

def compute_loss(S_true, log_probs, mask):
    loss, loss_av = exp_utils.loss_smoothed(S_true, log_probs, mask, weight=label_smoothing)
    return loss, loss_av

# quick sanity check
with torch.no_grad():
    log_probs = model(example_X, example_S, example_lengths, example_mask)
    _, example_loss = compute_loss(example_S, log_probs, example_mask)
print(f'Initial loss (random weights): {example_loss.item():.3f}')

## 6. Training Loop

Below is a minimalist training loop. Adjust `num_epochs`, gradient clipping, and mixed precision as needed. To keep the notebook lightweight, you may reduce the number of steps or batches when experimenting.

In [None]:
num_epochs = 20
log_every = 10

def run_epoch(loader, model, optimizer=None):
    is_train = optimizer is not None
    model.train(is_train)
    total_loss = 0.0
    total_tokens = 0
    for step, batch in enumerate(loader, start=1):
        X, S, mask, lengths = exp_utils.featurize(batch, device=DEVICE)
        if is_train:
            optimizer.zero_grad()
        log_probs = model(X, S, lengths, mask)
        loss, loss_av = compute_loss(S, log_probs, mask)
        if is_train:
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
        tokens = mask.sum().item()
        total_loss += loss_av.item() * tokens
        total_tokens += tokens
        if is_train and step % log_every == 0:
            print(f'[train] step={step} loss={loss_av.item():.3f}')
    return total_loss / max(total_tokens, 1)

for epoch in range(1, num_epochs + 1):
    train_loss = run_epoch(train_loader, model, optimizer)
    with torch.no_grad():
        valid_loss = run_epoch(valid_loader, model, optimizer=None)
    print(f'Epoch {epoch:02d}: train NLL={train_loss:.4f}, valid NLL={valid_loss:.4f}')

## 7. Evaluation Metrics

Besides negative log-likelihood, we can compute accuracy or perplexity. This cell illustrates a masked accuracy computation.

In [None]:
def evaluate_accuracy(loader, model):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in loader:
            X, S, mask, lengths = exp_utils.featurize(batch, device=DEVICE)
            log_probs = model(X, S, lengths, mask)
            preds = log_probs.argmax(dim=-1)
            correct += ((preds == S) * mask).sum().item()
            total += mask.sum().item()
    return correct / max(total, 1)

val_accuracy = evaluate_accuracy(valid_loader, model)
print(f'Validation accuracy: {val_accuracy:.3%}')

## 8. Autoregressive Sampling

After training, we can generate sequences conditioned on the backbone coordinates. The sampler predicts one residue at a time using the autoregressive decoder.

In [None]:
AMINO_ACIDS = np.array(list(exp_utils.AMINO_ACIDS))

def sample_sequences(loader, model, temperature=1.0):
    model.eval()
    sequences = []
    with torch.no_grad():
        for batch in loader:
            X, _, mask, lengths = exp_utils.featurize(batch, device=DEVICE)
            sampled = model.sample(X, lengths, mask, temperature=temperature)
            sampled = sampled.cpu().numpy()
            for seq_arr, seq_mask in zip(sampled, mask.cpu().numpy()):
                length = int(seq_mask.sum())
                sequences.append(''.join(AMINO_ACIDS[seq_arr[:length]]))
    return sequences

generated_sequences = sample_sequences(valid_loader, model, temperature=1.0)
print('
'.join(generated_sequences[:3]))

## 9. Checkpointing

Finally, save the trained parameters to disk for later reuse. The `experiments.utils` module already contains helper routines for restoring checkpoints if needed.

In [None]:
checkpoint_path = Path('struct2seq_from_scratch.pt')
torch.save({'model_state_dict': model.state_dict()}, checkpoint_path)
print(f'Saved checkpoint to {checkpoint_path.resolve()}')

# Example of loading the checkpoint back
state = torch.load(checkpoint_path, map_location=DEVICE)
model.load_state_dict(state['model_state_dict'])
print('Checkpoint restored successfully.')