# Struct2Seq with PyTorch Geometric Tutorial

This notebook demonstrates the refactored Struct2Seq model using PyTorch Geometric.

**Struct2Seq** is a graph neural network model for protein sequence design. Given a protein backbone structure (3D coordinates), it generates amino acid sequences that are likely to fold into that structure.

## What's New in This Refactored Version?

1. **PyTorch Geometric Integration**: Uses PyG's efficient graph operations and batching
2. **Cleaner Architecture**: Modular design with separate files for data, features, layers, and models
3. **Better Documentation**: Clear docstrings and comments throughout
4. **Modern Practices**: Uses PyG's `MessagePassing` API for custom GNN layers

## Table of Contents

1. [Setup and Imports](#setup)
2. [Understanding the Data Format](#data-format)
3. [Creating Protein Graphs](#protein-graphs)
4. [Model Architecture](#model-architecture)
5. [Training the Model](#training)
6. [Sampling Sequences](#sampling)
7. [Evaluation](#evaluation)

## 1. Setup and Imports <a name="setup"></a>

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
import numpy as np
import json
import matplotlib.pyplot as plt

# Import our refactored modules
from struct2seq_pyg import (
    ProteinGraphDataset,
    create_protein_graph,
    Struct2SeqPyG,
    ProteinFeaturizer
)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Understanding the Data Format <a name="data-format"></a>

Protein structures are stored in JSONL format with the following structure:

```json
{
    "name": "protein_id",
    "seq": "ACDEFGH...",
    "coords": {
        "N": [[x, y, z], ...],
        "CA": [[x, y, z], ...],
        "C": [[x, y, z], ...],
        "O": [[x, y, z], ...]
    }
}
```

Each protein has:
- `name`: Unique identifier
- `seq`: Amino acid sequence (single-letter codes)
- `coords`: 3D coordinates for backbone atoms (N, CA, C, O)

In [None]:
# Example: Creating a synthetic protein for demonstration
# In practice, you would load real data from a JSONL file

def create_example_protein():
    """Create a small synthetic protein for demonstration."""
    length = 10
    
    # Simple linear backbone
    coords = {
        'N': np.random.randn(length, 3),
        'CA': np.random.randn(length, 3),
        'C': np.random.randn(length, 3),
        'O': np.random.randn(length, 3),
    }
    
    # Random sequence
    AA_ALPHABET = 'ACDEFGHIKLMNPQRSTVWY'
    sequence = ''.join(np.random.choice(list(AA_ALPHABET), length))
    
    return coords, sequence

coords, sequence = create_example_protein()
print(f"Example sequence: {sequence}")
print(f"Sequence length: {len(sequence)}")
print(f"CA coordinates shape: {coords['CA'].shape}")

## 3. Creating Protein Graphs <a name="protein-graphs"></a>

We convert protein structures into graphs where:
- **Nodes** represent residues (amino acids)
- **Edges** connect k-nearest neighbor residues in 3D space
- **Node features** encode backbone geometry (dihedral angles)
- **Edge features** encode pairwise relationships (distance, orientation)

In [None]:
# Create a protein graph
featurizer = ProteinFeaturizer()
data = create_protein_graph(
    coords=coords,
    sequence=sequence,
    k_neighbors=5,  # Use 5 nearest neighbors (normally 30)
    featurizer=featurizer
)

print("\nProtein Graph:")
print(f"  Number of nodes: {data.num_nodes}")
print(f"  Number of edges: {data.edge_index.size(1)}")
print(f"  Node features shape: {data.x.shape}")
print(f"  Edge features shape: {data.edge_attr.shape}")
print(f"  Sequence shape: {data.seq.shape}")

In [None]:
# Visualize the graph structure
def visualize_protein_graph(data, max_nodes=50):
    """Visualize the graph connectivity as an adjacency matrix."""
    num_nodes = min(data.num_nodes, max_nodes)
    
    # Create adjacency matrix
    adj = torch.zeros(num_nodes, num_nodes)
    edge_index = data.edge_index[:, data.edge_index[0] < num_nodes]
    edge_index = edge_index[:, edge_index[1] < num_nodes]
    adj[edge_index[0], edge_index[1]] = 1
    
    plt.figure(figsize=(8, 8))
    plt.imshow(adj.numpy(), cmap='Blues', interpolation='nearest')
    plt.colorbar(label='Edge')
    plt.xlabel('Residue Index')
    plt.ylabel('Residue Index')
    plt.title('k-NN Graph Structure')
    plt.tight_layout()
    plt.show()

visualize_protein_graph(data)

## 4. Model Architecture <a name="model-architecture"></a>

The Struct2Seq model has an encoder-decoder architecture:

### Encoder (Structure Processing)
- Takes protein backbone structure as input
- Uses graph attention layers (unmasked)
- Learns structural representations

### Decoder (Sequence Generation)
- Generates amino acid sequence autoregressively
- Uses masked graph attention (only attends to past positions)
- Conditioned on structure encoding

### Features
- Multi-head attention over k-NN neighbors
- Rich edge features (positional encodings, RBF distances, orientations)
- Layer normalization and residual connections

In [None]:
# Initialize the model
model = Struct2SeqPyG(
    node_feature_dim=6,      # Dihedral angles (phi, psi, omega) as sin/cos
    edge_feature_dim=39,     # Positional encodings (16) + RBF (16) + Orientations (7)
    hidden_dim=128,          # Hidden dimension
    num_encoder_layers=3,    # Number of encoder layers
    num_decoder_layers=3,    # Number of decoder layers
    num_letters=20,          # 20 amino acids
    num_heads=4,             # Attention heads
    dropout=0.1,             # Dropout probability
    use_mpnn=False           # Use GAT instead of MPNN
).to(device)

# Count parameters
num_params = sum(p.numel() for p in model.parameters())
print(f"\nModel has {num_params:,} parameters")

# Show model structure
print("\nModel Architecture:")
print(model)

## 5. Training the Model <a name="training"></a>

Training uses teacher forcing:
1. Feed the structure and ground-truth sequence to the model
2. Model predicts amino acids at each position
3. Compute cross-entropy loss with ground truth
4. Backpropagate and update weights

In [None]:
# Training function
def train_epoch(model, loader, optimizer, device):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    num_batches = 0
    
    for batch in loader:
        batch = batch.to(device)
        
        # Forward pass
        log_probs = model(batch)
        
        # Compute loss
        loss = F.nll_loss(log_probs, batch.seq)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        num_batches += 1
    
    return total_loss / num_batches


# Evaluation function
@torch.no_grad()
def evaluate(model, loader, device):
    """Evaluate the model."""
    model.eval()
    total_loss = 0
    total_correct = 0
    total_samples = 0
    
    for batch in loader:
        batch = batch.to(device)
        
        # Forward pass
        log_probs = model(batch)
        
        # Compute metrics
        loss = F.nll_loss(log_probs, batch.seq)
        pred = log_probs.argmax(dim=-1)
        correct = (pred == batch.seq).sum().item()
        
        total_loss += loss.item() * batch.num_nodes
        total_correct += correct
        total_samples += batch.num_nodes
    
    avg_loss = total_loss / total_samples
    accuracy = total_correct / total_samples
    perplexity = np.exp(avg_loss)
    
    return avg_loss, accuracy, perplexity

In [None]:
# Example training loop (with synthetic data)
# In practice, you would use a real dataset

def create_synthetic_dataset(num_proteins=100, length_range=(20, 100)):
    """Create a synthetic dataset for demonstration."""
    data_list = []
    AA_ALPHABET = 'ACDEFGHIKLMNPQRSTVWY'
    
    for i in range(num_proteins):
        length = np.random.randint(*length_range)
        
        # Create random coordinates
        coords = {
            'N': np.random.randn(length, 3),
            'CA': np.random.randn(length, 3),
            'C': np.random.randn(length, 3),
            'O': np.random.randn(length, 3),
        }
        
        # Random sequence
        sequence = ''.join(np.random.choice(list(AA_ALPHABET), length))
        
        # Create graph
        data = create_protein_graph(coords, sequence, k_neighbors=10)
        data_list.append(data)
    
    return data_list

# Create synthetic datasets
train_data = create_synthetic_dataset(100)
val_data = create_synthetic_dataset(20)

# Create data loaders
train_loader = DataLoader(train_data, batch_size=4, shuffle=True)
val_loader = DataLoader(val_data, batch_size=4, shuffle=False)

print(f"Training set: {len(train_data)} proteins")
print(f"Validation set: {len(val_data)} proteins")

In [None]:
# Training loop
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
num_epochs = 5

train_losses = []
val_losses = []
val_accuracies = []

for epoch in range(num_epochs):
    # Train
    train_loss = train_epoch(model, train_loader, optimizer, device)
    
    # Validate
    val_loss, val_acc, val_perplexity = evaluate(model, val_loader, device)
    
    # Record metrics
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    val_accuracies.append(val_acc)
    
    print(f"Epoch {epoch+1}/{num_epochs}:")
    print(f"  Train Loss: {train_loss:.4f}")
    print(f"  Val Loss: {val_loss:.4f}")
    print(f"  Val Accuracy: {val_acc:.4f}")
    print(f"  Val Perplexity: {val_perplexity:.4f}")

In [None]:
# Plot training curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

# Loss curves
ax1.plot(train_losses, label='Train Loss')
ax1.plot(val_losses, label='Val Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss')
ax1.legend()
ax1.grid(True)

# Accuracy curve
ax2.plot(val_accuracies, label='Val Accuracy', color='green')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.set_title('Validation Accuracy')
ax2.legend()
ax2.grid(True)

plt.tight_layout()
plt.show()

## 6. Sampling Sequences <a name="sampling"></a>

Once trained, the model can generate novel sequences for a given structure using autoregressive sampling:
1. Start with empty sequence
2. At each position, predict amino acid distribution
3. Sample from the distribution
4. Add sampled amino acid and continue

Temperature controls randomness:
- **Low temperature** (0.1): More deterministic, picks most likely amino acids
- **High temperature** (1.0+): More random, explores diverse sequences

In [None]:
# Sample sequences for a test protein
test_protein = val_data[0].to(device)
AA_ALPHABET = 'ACDEFGHIKLMNPQRSTVWY'

# Sample with different temperatures
temperatures = [0.1, 0.5, 1.0]

print("Sampling sequences with different temperatures:\n")

for temp in temperatures:
    sampled_seq = model.sample(test_protein, temperature=temp)
    sampled_seq_str = ''.join([AA_ALPHABET[i] for i in sampled_seq[0].cpu().numpy()])
    
    print(f"Temperature {temp}:")
    print(f"  {sampled_seq_str}")
    print()

In [None]:
# Analyze amino acid composition
def plot_aa_composition(sequences, title="Amino Acid Composition"):
    """Plot amino acid composition of sequences."""
    AA_ALPHABET = 'ACDEFGHIKLMNPQRSTVWY'
    
    # Count amino acids
    counts = {aa: 0 for aa in AA_ALPHABET}
    total = 0
    
    for seq in sequences:
        for aa_idx in seq:
            aa = AA_ALPHABET[aa_idx]
            counts[aa] += 1
            total += 1
    
    # Convert to frequencies
    freqs = {aa: counts[aa] / total for aa in AA_ALPHABET}
    
    # Plot
    plt.figure(figsize=(12, 4))
    plt.bar(freqs.keys(), freqs.values())
    plt.xlabel('Amino Acid')
    plt.ylabel('Frequency')
    plt.title(title)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

# Sample multiple sequences and analyze composition
sampled_sequences = []
for _ in range(10):
    seq = model.sample(test_protein, temperature=1.0)
    sampled_sequences.append(seq[0].cpu())

plot_aa_composition(sampled_sequences, "Sampled Sequences - AA Composition")

## 7. Evaluation <a name="evaluation"></a>

We evaluate the model using:
- **Perplexity**: Lower is better, measures how well the model predicts sequences
- **Recovery**: Percentage of native amino acids recovered when redesigning
- **Amino acid composition**: Should match natural protein statistics

In [None]:
# Compute recovery rate (native sequence recovery)
@torch.no_grad()
def compute_recovery(model, data, num_samples=10):
    """Compute sequence recovery rate."""
    model.eval()
    data = data.to(device)
    
    # Ground truth sequence
    true_seq = data.seq.cpu().numpy()
    
    # Sample sequences
    recoveries = []
    for _ in range(num_samples):
        sampled = model.sample(data, temperature=0.1)  # Low temp for deterministic
        sampled_seq = sampled[0].cpu().numpy()
        
        # Compute recovery
        recovery = (sampled_seq == true_seq).mean()
        recoveries.append(recovery)
    
    return np.mean(recoveries), np.std(recoveries)

# Evaluate recovery on validation set
print("Evaluating sequence recovery on validation set:\n")

for i in range(min(5, len(val_data))):
    mean_recovery, std_recovery = compute_recovery(model, val_data[i], num_samples=5)
    print(f"Protein {i+1}: {mean_recovery:.2%} Â± {std_recovery:.2%}")

In [None]:
# Final evaluation metrics
val_loss, val_acc, val_perplexity = evaluate(model, val_loader, device)

print("\n" + "="*50)
print("Final Model Performance")
print("="*50)
print(f"Validation Loss: {val_loss:.4f}")
print(f"Validation Accuracy: {val_acc:.4f}")
print(f"Validation Perplexity: {val_perplexity:.4f}")
print("="*50)

## Summary

This tutorial demonstrated:

1. **Data Processing**: Converting protein structures to PyG graphs
2. **Model Architecture**: Encoder-decoder GNN with attention
3. **Training**: Teacher forcing with cross-entropy loss
4. **Sampling**: Autoregressive sequence generation
5. **Evaluation**: Perplexity and sequence recovery metrics

## Next Steps

- Train on real protein datasets (CATH, PDB)
- Experiment with different architectures (MPNN vs GAT)
- Try different hyperparameters (hidden dim, num layers, etc.)
- Evaluate on benchmark datasets (SPIN2, Ollikainen)
- Compare with baseline models

## References

- Ingraham et al. "Generative Models for Graph-Based Protein Design" (NeurIPS 2019)
- PyTorch Geometric Documentation: https://pytorch-geometric.readthedocs.io/