# Building Struct2Seq from Scratch: A Step-by-Step Guide

This notebook teaches you how to build the Struct2Seq model from the ground up. We'll construct each component step by step, explaining the algorithms, mathematics, and implementation details.

**Learning Objectives:**
- Understand protein structure representation as graphs
- Learn how to extract geometric features from 3D coordinates
- Build graph attention mechanisms from scratch
- Construct the full autoregressive sequence generation model
- Understand how each component works together

**Table of Contents:**
1. [Introduction: Proteins as Graphs](#section1)
2. [Part 1: Protein Featurization](#section2)
3. [Part 2: Graph Operations](#section3)
4. [Part 3: Attention Mechanisms](#section4)
5. [Part 4: Transformer Layers](#section5)
6. [Part 5: The Complete Struct2Seq Model](#section6)
7. [Part 6: Training and Inference](#section7)

In [None]:
# Import required libraries
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from matplotlib import pyplot as plt
import json

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

---
<a id='section1'></a>
## 1. Introduction: Proteins as Graphs

### What is a Protein?

A protein is a chain of amino acids that folds into a 3D structure. Each amino acid has a backbone consisting of 4 atoms:
- **N** (Nitrogen)
- **Cα** (Alpha Carbon) 
- **C** (Carbon)
- **O** (Oxygen)

### Graph Representation

We represent proteins as graphs:
- **Nodes (V)**: Each residue (amino acid) is a node
- **Edges (E)**: Connect k-nearest neighbors in 3D space
- **Node Features**: Geometric properties (angles, positions)
- **Edge Features**: Pairwise relationships (distances, orientations)

```
Mathematical Notation:
G = (V, E)  where:
- V = {v₁, v₂, ..., vₙ}  (n residues)
- E ⊆ V × V               (edges between neighbors)
```

In [None]:
# Example: Load a protein structure
def load_example_protein():
    """Load a simple protein example from the dataset"""
    try:
        with open('../data/chain_set.jsonl', 'r') as f:
            # Load first protein
            protein = json.loads(f.readline())
        return protein
    except:
        print("Dataset not found. Creating synthetic example...")
        # Create a simple synthetic protein
        n_residues = 10
        return {
            'name': 'example',
            'seq': 'ACDEFGHIKL',
            'coords': {
                'N': np.random.randn(n_residues, 3).tolist(),
                'CA': np.random.randn(n_residues, 3).tolist(),
                'C': np.random.randn(n_residues, 3).tolist(),
                'O': np.random.randn(n_residues, 3).tolist()
            }
        }

protein = load_example_protein()
print(f"Protein: {protein['name']}")
print(f"Sequence: {protein['seq']}")
print(f"Length: {len(protein['seq'])} residues")
print(f"\nCoordinate shape: {len(protein['coords']['CA'])} x 3")

In [None]:
# Visualize the protein backbone
def visualize_backbone(coords_dict, title="Protein Backbone"):
    """Simple 3D visualization of protein backbone"""
    from mpl_toolkits.mplot3d import Axes3D
    
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')
    
    # Extract CA coordinates
    ca_coords = np.array(coords_dict['CA'])
    
    # Plot backbone trace
    ax.plot(ca_coords[:, 0], ca_coords[:, 1], ca_coords[:, 2], 
            'b-', linewidth=2, label='Backbone')
    ax.scatter(ca_coords[:, 0], ca_coords[:, 1], ca_coords[:, 2], 
               c='red', s=50, label='Cα atoms')
    
    ax.set_xlabel('X (Å)')
    ax.set_ylabel('Y (Å)')
    ax.set_zlabel('Z (Å)')
    ax.set_title(title)
    ax.legend()
    plt.tight_layout()
    plt.show()

visualize_backbone(protein['coords'])

---
<a id='section2'></a>
## 2. Part 1: Protein Featurization

The first step is converting 3D coordinates into meaningful features. We'll implement:

### 2.1 Distance Calculations and k-NN Graph
### 2.2 Radial Basis Functions (RBF)
### 2.3 Dihedral Angles (φ, ψ, ω)
### 2.4 Orientations and Quaternions

**Reference:** See `struct2seq/protein_features.py`

### 2.1 Building the k-NN Graph

**Algorithm:**
1. Compute pairwise Euclidean distances between all Cα atoms
2. For each residue, find k nearest neighbors
3. Store neighbor indices for graph construction

**Mathematical Formulation:**
```
D[i,j] = ||X[i] - X[j]||₂  (Euclidean distance)
E_idx[i] = argsort(D[i])[:k]  (k nearest neighbors)
```

In [None]:
def compute_pairwise_distances(X, mask, eps=1e-6):
    """
    Compute pairwise Euclidean distances between all residues.
    
    Args:
        X: Coordinates [batch_size, n_residues, 3]
        mask: Valid residue mask [batch_size, n_residues]
        eps: Small constant for numerical stability
    
    Returns:
        D: Distance matrix [batch_size, n_residues, n_residues]
    """
    # Create 2D mask: valid if both residues are valid
    mask_2D = mask.unsqueeze(1) * mask.unsqueeze(2)  # [B, N, N]
    
    # Compute pairwise differences: X[i] - X[j]
    dX = X.unsqueeze(1) - X.unsqueeze(2)  # [B, N, N, 3]
    
    # Compute distances: ||X[i] - X[j]||
    D = torch.sqrt(torch.sum(dX**2, dim=-1) + eps)  # [B, N, N]
    
    # Mask invalid distances
    D = mask_2D * D
    
    return D, mask_2D

# Test the function
# Create example coordinates
X_example = torch.tensor(protein['coords']['CA']).unsqueeze(0).float()  # [1, N, 3]
mask_example = torch.ones(1, len(protein['seq']))  # [1, N]

D, mask_2D = compute_pairwise_distances(X_example, mask_example)
print(f"Distance matrix shape: {D.shape}")
print(f"\nDistance matrix (first 5x5):")
print(D[0, :5, :5])

In [None]:
def build_knn_graph(D, mask_2D, k=30):
    """
    Build k-nearest neighbors graph.
    
    Args:
        D: Distance matrix [batch_size, n_residues, n_residues]
        mask_2D: Valid pair mask [batch_size, n_residues, n_residues]
        k: Number of nearest neighbors
    
    Returns:
        D_neighbors: Distances to k neighbors [B, N, k]
        E_idx: Indices of k neighbors [B, N, k]
    """
    # Set invalid distances to large value so they won't be selected
    D_max = torch.max(D, dim=-1, keepdim=True)[0]
    D_adjusted = D + (1.0 - mask_2D) * D_max
    
    # Find k nearest neighbors (smallest distances)
    D_neighbors, E_idx = torch.topk(D_adjusted, k, dim=-1, largest=False)
    
    return D_neighbors, E_idx

# Test k-NN graph construction
k = 5  # Use smaller k for visualization
D_neighbors, E_idx = build_knn_graph(D, mask_2D, k=k)

print(f"Neighbor distances shape: {D_neighbors.shape}")
print(f"Neighbor indices shape: {E_idx.shape}")
print(f"\nNeighbors of residue 0: {E_idx[0, 0]}")
print(f"Distances to neighbors: {D_neighbors[0, 0]}")

In [None]:
# Visualize the k-NN graph
def visualize_knn_graph(E_idx, title="k-NN Graph Connectivity"):
    """
    Visualize the k-NN graph as an adjacency matrix.
    """
    n_residues = E_idx.shape[1]
    
    # Create adjacency matrix
    adj_matrix = torch.zeros(n_residues, n_residues)
    for i in range(n_residues):
        for j in E_idx[0, i]:
            adj_matrix[i, j] = 1
    
    plt.figure(figsize=(8, 8))
    plt.imshow(adj_matrix.numpy(), cmap='Blues', interpolation='nearest')
    plt.colorbar(label='Connected')
    plt.xlabel('Residue j')
    plt.ylabel('Residue i')
    plt.title(title)
    plt.tight_layout()
    plt.show()

visualize_knn_graph(E_idx)

### 2.2 Radial Basis Functions (RBF)

RBF encoding represents distances as smooth features:

**Mathematical Formulation:**
```
RBF(d) = exp(-((d - μᵢ) / σ)²)  for i = 1..num_rbf
μᵢ = linspace(0, 20Å, num_rbf)  (Gaussian centers)
σ = 20 / num_rbf                (Gaussian width)
```

This creates a smooth, differentiable encoding of distance information.

In [None]:
def rbf_encoding(D, num_rbf=16, d_min=0.0, d_max=20.0):
    """
    Encode distances using Radial Basis Functions.
    
    Args:
        D: Distances [batch_size, n_residues, k_neighbors]
        num_rbf: Number of RBF kernels
        d_min: Minimum distance for RBF centers
        d_max: Maximum distance for RBF centers
    
    Returns:
        RBF features [batch_size, n_residues, k_neighbors, num_rbf]
    """
    # Create RBF centers (μ)
    rbf_centers = torch.linspace(d_min, d_max, num_rbf)  # [num_rbf]
    rbf_centers = rbf_centers.view(1, 1, 1, -1)  # [1, 1, 1, num_rbf]
    
    # RBF width (σ)
    rbf_width = (d_max - d_min) / num_rbf
    
    # Expand distances for broadcasting
    D_expanded = D.unsqueeze(-1)  # [B, N, K, 1]
    
    # Compute RBF: exp(-((d - μ) / σ)²)
    RBF = torch.exp(-((D_expanded - rbf_centers) / rbf_width) ** 2)
    
    return RBF

# Test RBF encoding
RBF_features = rbf_encoding(D_neighbors, num_rbf=16)
print(f"RBF features shape: {RBF_features.shape}")
print(f"RBF features for first neighbor of residue 0:")
print(RBF_features[0, 0, 0, :])

In [None]:
# Visualize RBF encoding
def visualize_rbf():
    """Visualize how RBF encoding works"""
    distances = torch.linspace(0, 20, 100)
    rbf_features = rbf_encoding(distances.unsqueeze(0).unsqueeze(0), num_rbf=8)
    
    plt.figure(figsize=(12, 5))
    for i in range(8):
        plt.plot(distances.numpy(), rbf_features[0, 0, :, i].numpy(), 
                label=f'RBF {i+1}')
    
    plt.xlabel('Distance (Å)')
    plt.ylabel('RBF Activation')
    plt.title('Radial Basis Function Encoding')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

visualize_rbf()

### 2.3 Dihedral Angles

Dihedral angles (φ, ψ, ω) describe the backbone geometry:

**Mathematical Formulation:**
```
Given 4 consecutive atoms A-B-C-D:
1. Compute unit vectors: u₁ = (B-A)/||B-A||, u₂ = (C-B)/||C-B||, u₃ = (D-C)/||D-C||
2. Compute normals: n₁ = u₁ × u₂, n₂ = u₂ × u₃
3. Dihedral angle: θ = sign(u₁·n₂) × arccos(n₁·n₂)
4. Encode as: [cos(θ), sin(θ)] (circular representation)
```

**In proteins:**
- φ (phi): N-Cα-C-N dihedral
- ψ (psi): Cα-C-N-Cα dihedral  
- ω (omega): C-N-Cα-C dihedral

**Reference:** `protein_features.py:294-337`

In [None]:
def compute_dihedrals(X, eps=1e-7):
    """
    Compute backbone dihedral angles (φ, ψ, ω).
    
    Args:
        X: Backbone coordinates [batch_size, n_residues, 4, 3]
           where 4 atoms are [N, CA, C, O]
    
    Returns:
        Dihedral features [batch_size, n_residues, 6]
        (cos φ, sin φ, cos ψ, sin ψ, cos ω, sin ω)
    """
    # Take only N, CA, C (first 3 atoms)
    X = X[:, :, :3, :].reshape(X.shape[0], 3 * X.shape[1], 3)
    
    # Compute unit vectors between consecutive atoms
    dX = X[:, 1:, :] - X[:, :-1, :]  # [B, 3N-1, 3]
    U = F.normalize(dX, dim=-1)  # Unit vectors
    
    # Get consecutive triplets of unit vectors
    u_0 = U[:, 2:, :]   # Third vector
    u_1 = U[:, 1:-1, :] # Middle vector
    u_2 = U[:, :-2, :]  # First vector
    
    # Compute normal vectors (perpendicular to planes)
    n_1 = F.normalize(torch.cross(u_1, u_0, dim=-1), dim=-1)
    n_2 = F.normalize(torch.cross(u_2, u_1, dim=-1), dim=-1)
    
    # Compute dihedral angle
    # cos(θ) = n₁ · n₂
    cos_D = (n_2 * n_1).sum(-1)
    cos_D = torch.clamp(cos_D, -1 + eps, 1 - eps)
    
    # θ = sign(u₂ · n₁) × arccos(cos(θ))
    D = torch.sign((u_2 * n_1).sum(-1)) * torch.acos(cos_D)
    
    # Pad to account for boundary conditions
    D = F.pad(D, (1, 2), 'constant', 0)
    
    # Reshape into [phi, psi, omega] per residue
    D = D.view(D.size(0), D.size(1) // 3, 3)
    
    # Encode as circular features: [cos(θ), sin(θ)]
    D_features = torch.cat([torch.cos(D), torch.sin(D)], dim=-1)
    
    return D_features

# Test dihedral computation
# Create backbone coordinate tensor
X_backbone = torch.zeros(1, len(protein['seq']), 4, 3)
for i, atom in enumerate(['N', 'CA', 'C', 'O']):
    X_backbone[0, :, i, :] = torch.tensor(protein['coords'][atom])

dihedral_features = compute_dihedrals(X_backbone)
print(f"Dihedral features shape: {dihedral_features.shape}")
print(f"Features per residue: 6 (cos φ, sin φ, cos ψ, sin ψ, cos ω, sin ω)")
print(f"\nFirst residue dihedrals:")
print(dihedral_features[0, 0, :])

In [None]:
# Visualize Ramachandran plot (if we have enough data)
def plot_ramachandran(dihedral_features):
    """Plot phi-psi angles (Ramachandran plot)"""
    # Extract phi and psi angles
    phi = torch.atan2(dihedral_features[:, :, 1], dihedral_features[:, :, 0])  # arctan(sin/cos)
    psi = torch.atan2(dihedral_features[:, :, 3], dihedral_features[:, :, 2])
    
    phi_deg = phi.numpy().flatten() * 180 / np.pi
    psi_deg = psi.numpy().flatten() * 180 / np.pi
    
    plt.figure(figsize=(8, 8))
    plt.scatter(phi_deg, psi_deg, alpha=0.6, s=30)
    plt.xlabel('φ (phi) [degrees]')
    plt.ylabel('ψ (psi) [degrees]')
    plt.title('Ramachandran Plot')
    plt.grid(True, alpha=0.3)
    plt.xlim(-180, 180)
    plt.ylim(-180, 180)
    plt.axhline(y=0, color='k', linestyle='-', alpha=0.3)
    plt.axvline(x=0, color='k', linestyle='-', alpha=0.3)
    plt.tight_layout()
    plt.show()

plot_ramachandran(dihedral_features)

### 2.4 Positional Encodings

Positional encodings help the model understand sequence order:

**Mathematical Formulation (from Transformer):**
```
PE(pos, 2i)   = sin(pos / 10000^(2i/d))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d))
```

For edges, we use the relative position: `pos = j - i` (neighbor index - current index)

**Reference:** `protein_features.py:14-42`

In [None]:
def positional_encodings(E_idx, num_embeddings=16):
    """
    Compute positional encodings for edges based on sequence distance.
    
    Args:
        E_idx: Neighbor indices [batch_size, n_residues, k_neighbors]
        num_embeddings: Dimensionality of encoding
    
    Returns:
        Positional encodings [B, N, K, num_embeddings]
    """
    batch_size = E_idx.size(0)
    n_nodes = E_idx.size(1)
    n_neighbors = E_idx.size(2)
    
    # Current position indices
    ii = torch.arange(n_nodes, dtype=torch.float32).view(1, -1, 1)
    
    # Relative position: j - i
    d = (E_idx.float() - ii).unsqueeze(-1)  # [B, N, K, 1]
    
    # Compute frequencies (from original Transformer)
    frequency = torch.exp(
        torch.arange(0, num_embeddings, 2, dtype=torch.float32)
        * -(np.log(10000.0) / num_embeddings)
    )
    
    # Compute angles
    angles = d * frequency.view(1, 1, 1, -1)  # [B, N, K, num_embeddings/2]
    
    # Concatenate sin and cos
    E_pos = torch.cat([torch.cos(angles), torch.sin(angles)], dim=-1)
    
    return E_pos

# Test positional encodings
pos_encodings = positional_encodings(E_idx, num_embeddings=16)
print(f"Positional encodings shape: {pos_encodings.shape}")
print(f"\nPositional encoding for first neighbor of residue 5:")
print(pos_encodings[0, 5, 0, :])

In [None]:
# Visualize positional encodings
def visualize_positional_encodings():
    """Visualize positional encoding patterns"""
    # Create sequence of relative positions
    positions = torch.arange(-20, 21).view(1, 41, 1)
    encodings = positional_encodings(positions, num_embeddings=16)
    
    plt.figure(figsize=(12, 6))
    plt.imshow(encodings[0, :, 0, :].T.numpy(), aspect='auto', cmap='RdBu', 
               interpolation='nearest')
    plt.colorbar(label='Encoding value')
    plt.xlabel('Relative Position (j - i)')
    plt.ylabel('Encoding Dimension')
    plt.title('Positional Encoding Pattern')
    plt.xticks(range(0, 41, 5), range(-20, 21, 5))
    plt.tight_layout()
    plt.show()

visualize_positional_encodings()

---
<a id='section3'></a>
## 3. Part 2: Graph Operations

Graph neural networks require specialized operations to gather information from neighbors.

**Key Operations:**
1. `gather_nodes`: Gather node features at specified indices
2. `gather_edges`: Gather edge features at specified indices
3. `cat_neighbors_nodes`: Concatenate neighbor and node features

**Reference:** `self_attention.py:11-36`

In [None]:
def gather_nodes(nodes, neighbor_idx):
    """
    Gather node features at neighbor indices.
    
    Given node features and neighbor indices, extract features
    of neighboring nodes for each node.
    
    Args:
        nodes: Node features [batch_size, n_nodes, feature_dim]
        neighbor_idx: Neighbor indices [batch_size, n_nodes, k_neighbors]
    
    Returns:
        Neighbor features [batch_size, n_nodes, k_neighbors, feature_dim]
    
    Example:
        nodes = [[h₀, h₁, h₂, h₃], ...]  # Node features
        neighbor_idx = [[1, 2], [0, 3], [1, 3], [0, 2]]  # Who are my neighbors?
        output[0] = [h₁, h₂]  # Features of node 0's neighbors
    """
    # Flatten neighbor indices for gathering
    neighbors_flat = neighbor_idx.view(neighbor_idx.shape[0], -1)  # [B, N*K]
    neighbors_flat = neighbors_flat.unsqueeze(-1).expand(-1, -1, nodes.size(2))  # [B, N*K, C]
    
    # Gather features
    neighbor_features = torch.gather(nodes, 1, neighbors_flat)  # [B, N*K, C]
    
    # Reshape to [B, N, K, C]
    neighbor_features = neighbor_features.view(
        list(neighbor_idx.shape)[:3] + [-1]
    )
    
    return neighbor_features

# Test gather_nodes
# Create dummy node features
n_nodes = 5
feature_dim = 4
dummy_nodes = torch.arange(n_nodes * feature_dim).view(1, n_nodes, feature_dim).float()
print("Node features:")
print(dummy_nodes[0])

# Create neighbor indices
dummy_neighbors = torch.tensor([[[1, 2], [0, 3], [1, 4], [0, 2], [1, 3]]])
print("\nNeighbor indices:")
print(dummy_neighbors[0])

# Gather neighbor features
neighbor_features = gather_nodes(dummy_nodes, dummy_neighbors)
print("\nGathered neighbor features for node 0:")
print(neighbor_features[0, 0])  # Should be features of nodes 1 and 2

In [None]:
def cat_neighbors_nodes(h_nodes, h_edges, E_idx):
    """
    Concatenate node features with edge features.
    
    For each edge (i,j), concatenate:
    - Edge features h_edges[i,j]
    - Destination node features h_nodes[j]
    
    Args:
        h_nodes: Node features [B, N, C_node]
        h_edges: Edge features [B, N, K, C_edge]
        E_idx: Edge indices [B, N, K]
    
    Returns:
        Combined features [B, N, K, C_edge + C_node]
    """
    # Gather node features at neighbor positions
    h_nodes_neighbors = gather_nodes(h_nodes, E_idx)
    
    # Concatenate edge and node features
    h_combined = torch.cat([h_edges, h_nodes_neighbors], dim=-1)
    
    return h_combined

# Test cat_neighbors_nodes
dummy_edges = torch.randn(1, n_nodes, 2, 3)  # Edge features
combined = cat_neighbors_nodes(dummy_nodes, dummy_edges, dummy_neighbors)
print(f"Combined features shape: {combined.shape}")
print(f"Expected: [1, {n_nodes}, 2, {3 + feature_dim}] (edge_dim + node_dim)")

---
<a id='section4'></a>
## 4. Part 3: Attention Mechanisms

Attention allows the model to focus on relevant parts of the graph.

### Multi-Head Attention

**Mathematical Formulation:**
```
Q = W_Q × h_V         (Query from current node)
K = W_K × h_E         (Keys from edges/neighbors)
V = W_V × h_E         (Values from edges/neighbors)

Attention(Q,K,V) = softmax(QK^T / √d) × V

Multi-Head: Run multiple attention heads in parallel
```

**Reference:** `self_attention.py:155-210`

In [None]:
class NeighborAttention(nn.Module):
    """
    Multi-head attention over graph neighbors.
    
    This is the core mechanism that allows nodes to aggregate
    information from their neighbors adaptively.
    """
    def __init__(self, num_hidden, num_in, num_heads=4):
        super().__init__()
        self.num_heads = num_heads
        self.num_hidden = num_hidden
        
        # Linear transformations for Q, K, V
        self.W_Q = nn.Linear(num_hidden, num_hidden, bias=False)  # Query
        self.W_K = nn.Linear(num_in, num_hidden, bias=False)      # Key
        self.W_V = nn.Linear(num_in, num_hidden, bias=False)      # Value
        self.W_O = nn.Linear(num_hidden, num_hidden, bias=False)  # Output
    
    def forward(self, h_V, h_E, mask_attend=None):
        """
        Args:
            h_V: Node features [B, N, num_hidden]
            h_E: Edge features [B, N, K, num_in]
            mask_attend: Attention mask [B, N, K] (optional)
        
        Returns:
            Updated node features [B, N, num_hidden]
        """
        batch_size, n_nodes, n_neighbors = h_E.shape[:3]
        n_heads = self.num_heads
        d = self.num_hidden // n_heads  # Dimension per head
        
        # === Step 1: Compute Q, K, V ===
        # Query: from current node (broadcast to all neighbors)
        Q = self.W_Q(h_V)  # [B, N, num_hidden]
        Q = Q.view(batch_size, n_nodes, 1, n_heads, 1, d)
        
        # Keys: from edge features
        K = self.W_K(h_E)  # [B, N, K, num_hidden]
        K = K.view(batch_size, n_nodes, n_neighbors, n_heads, d, 1)
        
        # Values: from edge features
        V = self.W_V(h_E)  # [B, N, K, num_hidden]
        V = V.view(batch_size, n_nodes, n_neighbors, n_heads, d)
        
        # === Step 2: Compute attention scores ===
        # Attention logits: QK^T
        attend_logits = torch.matmul(Q, K)  # [B, N, K, n_heads, 1, 1]
        attend_logits = attend_logits.view(batch_size, n_nodes, n_neighbors, n_heads)
        attend_logits = attend_logits.transpose(-2, -1)  # [B, N, n_heads, K]
        
        # Scale by √d (for numerical stability)
        attend_logits = attend_logits / np.sqrt(d)
        
        # === Step 3: Apply attention mask (if provided) ===
        if mask_attend is not None:
            # Expand mask for multiple heads
            mask = mask_attend.unsqueeze(2).expand(-1, -1, n_heads, -1)
            # Set masked positions to -inf before softmax
            attend_logits = torch.where(
                mask > 0, 
                attend_logits, 
                torch.tensor(float('-inf'))
            )
        
        # === Step 4: Softmax to get attention weights ===
        attend = F.softmax(attend_logits, dim=-1)  # [B, N, n_heads, K]
        
        # === Step 5: Weighted sum of values ===
        # attend: [B, N, n_heads, K] → [B, N, n_heads, 1, K]
        # V: [B, N, K, n_heads, d] → [B, N, n_heads, K, d]
        h_V_update = torch.matmul(
            attend.unsqueeze(-2),  # [B, N, n_heads, 1, K]
            V.transpose(2, 3)       # [B, N, n_heads, K, d]
        )  # [B, N, n_heads, 1, d]
        
        # Reshape back to [B, N, num_hidden]
        h_V_update = h_V_update.view(batch_size, n_nodes, self.num_hidden)
        
        # === Step 6: Output projection ===
        h_V_update = self.W_O(h_V_update)
        
        return h_V_update

# Test the attention mechanism
attention = NeighborAttention(num_hidden=32, num_in=48, num_heads=4)

# Create dummy inputs
h_V_test = torch.randn(1, 10, 32)   # 10 nodes, 32 features
h_E_test = torch.randn(1, 10, 5, 48)  # Each node has 5 neighbors, 48 edge features

# Apply attention
h_V_updated = attention(h_V_test, h_E_test)
print(f"Input node features shape: {h_V_test.shape}")
print(f"Edge features shape: {h_E_test.shape}")
print(f"Output node features shape: {h_V_updated.shape}")
print(f"\nAttention successfully updated node features!")

### Understanding Multi-Head Attention

**Why multiple heads?**
- Different heads can attend to different aspects of the structure
- Head 1 might focus on nearby residues
- Head 2 might focus on specific structural motifs
- Head 3 might focus on long-range contacts

**Visualization:**

In [None]:
def visualize_attention_pattern(attention_weights, head_idx=0):
    """
    Visualize attention pattern for a specific head.
    
    Args:
        attention_weights: [B, N, num_heads, K]
        head_idx: Which attention head to visualize
    """
    # Extract attention for specific head
    attn = attention_weights[0, :, head_idx, :].detach().numpy()
    
    plt.figure(figsize=(10, 6))
    plt.imshow(attn, aspect='auto', cmap='viridis', interpolation='nearest')
    plt.colorbar(label='Attention Weight')
    plt.xlabel('Neighbor Index')
    plt.ylabel('Residue')
    plt.title(f'Attention Pattern (Head {head_idx})')
    plt.tight_layout()
    plt.show()

print("Note: Actual attention weights are computed inside the forward pass.")
print("In practice, you would extract them during model execution.")

---
<a id='section5'></a>
## 5. Part 4: Transformer Layers

A complete Transformer layer combines attention with feed-forward networks and normalization.

**Architecture:**
```
Input → Attention → Add & Norm → Feed-Forward → Add & Norm → Output
        ↓                           ↓
     Residual                   Residual
```

**Reference:** `self_attention.py:60-101`

In [None]:
class Normalize(nn.Module):
    """Layer normalization."""
    def __init__(self, features, epsilon=1e-6):
        super().__init__()
        self.gain = nn.Parameter(torch.ones(features))
        self.bias = nn.Parameter(torch.zeros(features))
        self.epsilon = epsilon
    
    def forward(self, x, dim=-1):
        # Compute mean and std
        mu = x.mean(dim, keepdim=True)
        sigma = torch.sqrt(x.var(dim, keepdim=True) + self.epsilon)
        
        # Normalize
        return self.gain * (x - mu) / sigma + self.bias

class PositionWiseFeedForward(nn.Module):
    """Feed-forward network applied to each position."""
    def __init__(self, num_hidden, num_ff):
        super().__init__()
        self.W_in = nn.Linear(num_hidden, num_ff)
        self.W_out = nn.Linear(num_ff, num_hidden)
    
    def forward(self, h):
        return self.W_out(F.relu(self.W_in(h)))

class TransformerLayer(nn.Module):
    """
    Complete Transformer layer with attention and feed-forward.
    
    Architecture:
    1. Multi-head attention
    2. Residual connection + Layer norm
    3. Position-wise feed-forward
    4. Residual connection + Layer norm
    """
    def __init__(self, num_hidden, num_in, num_heads=4, dropout=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.num_hidden = num_hidden
        self.dropout = nn.Dropout(dropout)
        
        # Two normalization layers
        self.norm = nn.ModuleList([Normalize(num_hidden) for _ in range(2)])
        
        # Attention and feed-forward
        self.attention = NeighborAttention(num_hidden, num_in, num_heads)
        self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4)
    
    def forward(self, h_V, h_E, mask_V=None, mask_attend=None):
        """
        Args:
            h_V: Node features [B, N, num_hidden]
            h_E: Edge features [B, N, K, num_in]
            mask_V: Node mask [B, N] (optional)
            mask_attend: Attention mask [B, N, K] (optional)
        
        Returns:
            Updated node features [B, N, num_hidden]
        """
        # === Step 1: Self-attention ===
        dh = self.attention(h_V, h_E, mask_attend)
        
        # Residual connection + normalization
        h_V = self.norm[0](h_V + self.dropout(dh))
        
        # === Step 2: Feed-forward ===
        dh = self.dense(h_V)
        
        # Residual connection + normalization
        h_V = self.norm[1](h_V + self.dropout(dh))
        
        # === Step 3: Apply node mask ===
        if mask_V is not None:
            mask_V = mask_V.unsqueeze(-1)
            h_V = mask_V * h_V
        
        return h_V

# Test the transformer layer
transformer = TransformerLayer(num_hidden=32, num_in=48, num_heads=4, dropout=0.1)

h_V_test = torch.randn(2, 10, 32)     # 2 proteins, 10 residues each
h_E_test = torch.randn(2, 10, 5, 48)  # 5 neighbors per residue

h_V_out = transformer(h_V_test, h_E_test)
print(f"Input shape: {h_V_test.shape}")
print(f"Output shape: {h_V_out.shape}")
print(f"\nTransformer layer executed successfully!")

### Why Residual Connections?

Residual connections (skip connections) are crucial:

**Without residual:** `h_new = Transform(h_old)`
**With residual:** `h_new = h_old + Transform(h_old)`

**Benefits:**
1. Easier gradient flow (helps training deep networks)
2. Allows the model to learn incremental refinements
3. Prevents vanishing gradients

---
<a id='section6'></a>
## 6. Part 5: The Complete Struct2Seq Model

Now we assemble everything into the complete model!

**Architecture Overview:**
```
Input: 3D Structure (X) and Sequence (S)
    ↓
1. Featurization → Node features (V), Edge features (E)
    ↓
2. Encoder (3 layers) → Encode structure
    ↓
3. Decoder (3 layers) → Generate sequence autoregressively
    ↓
Output: Log probabilities over amino acids
```

**Reference:** `struct2seq/struct2seq.py:14-219`

In [None]:
class SimpleProteinFeatures(nn.Module):
    """
    Simplified protein featurization for demonstration.
    
    In the real implementation, this is much more sophisticated
    (see protein_features.py).
    """
    def __init__(self, node_features, edge_features, top_k=30):
        super().__init__()
        self.top_k = top_k
        self.node_embedding = nn.Linear(6, node_features)  # 6D dihedral features
        self.edge_embedding = nn.Linear(32, edge_features)  # RBF + positional
    
    def forward(self, X, mask):
        """
        Args:
            X: Coordinates [B, N, 4, 3] (N, CA, C, O)
            mask: Valid residues [B, N]
        
        Returns:
            V: Node features [B, N, node_features]
            E: Edge features [B, N, K, edge_features]
            E_idx: Edge indices [B, N, K]
        """
        # Extract CA coordinates
        X_ca = X[:, :, 1, :]  # [B, N, 3]
        
        # Compute distances and build k-NN graph
        D, mask_2D = compute_pairwise_distances(X_ca, mask)
        D_neighbors, E_idx = build_knn_graph(D, mask_2D, k=self.top_k)
        
        # Node features: dihedral angles
        V = compute_dihedrals(X)
        V = self.node_embedding(V)
        
        # Edge features: RBF + positional encodings
        rbf = rbf_encoding(D_neighbors, num_rbf=16)
        pos = positional_encodings(E_idx, num_embeddings=16)
        E = torch.cat([rbf, pos], dim=-1)
        E = self.edge_embedding(E)
        
        return V, E, E_idx

# Test featurization
featurizer = SimpleProteinFeatures(node_features=32, edge_features=48, top_k=10)
V, E, E_idx = featurizer(X_backbone, mask_example)
print(f"Node features: {V.shape}")
print(f"Edge features: {E.shape}")
print(f"Edge indices: {E_idx.shape}")

In [None]:
class Struct2Seq(nn.Module):
    """
    Complete Struct2Seq model for protein design.
    
    The model consists of:
    1. Featurization: Convert 3D structure to graph
    2. Encoder: Process structure information
    3. Decoder: Generate sequence autoregressively
    """
    def __init__(self, 
                 num_letters=20,      # Amino acid vocabulary size
                 node_features=128,   # Node feature dimension
                 edge_features=128,   # Edge feature dimension
                 hidden_dim=128,      # Hidden dimension
                 num_encoder_layers=3,
                 num_decoder_layers=3,
                 num_heads=4,
                 k_neighbors=30,
                 dropout=0.1):
        super().__init__()
        
        self.hidden_dim = hidden_dim
        
        # === 1. Featurization ===
        self.features = SimpleProteinFeatures(
            node_features, edge_features, top_k=k_neighbors
        )
        
        # === 2. Embeddings ===
        # Node and edge embeddings
        self.W_v = nn.Linear(node_features, hidden_dim)
        self.W_e = nn.Linear(edge_features, hidden_dim)
        # Sequence embedding
        self.W_s = nn.Embedding(num_letters, hidden_dim)
        
        # === 3. Encoder layers ===
        # Process structural information (unmasked)
        self.encoder_layers = nn.ModuleList([
            TransformerLayer(hidden_dim, hidden_dim * 2, num_heads, dropout)
            for _ in range(num_encoder_layers)
        ])
        
        # === 4. Decoder layers ===
        # Generate sequence autoregressively (masked)
        self.decoder_layers = nn.ModuleList([
            TransformerLayer(hidden_dim, hidden_dim * 3, num_heads, dropout)
            for _ in range(num_decoder_layers)
        ])
        
        # === 5. Output projection ===
        self.W_out = nn.Linear(hidden_dim, num_letters)
        
        # Initialize parameters
        self._init_params()
    
    def _init_params(self):
        """Xavier initialization for better training."""
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    
    def _autoregressive_mask(self, E_idx):
        """
        Create mask for autoregressive decoding.
        
        Position i can only attend to positions < i.
        
        Args:
            E_idx: Neighbor indices [B, N, K]
        
        Returns:
            Mask [B, N, K] where mask[i,j] = 1 if j < i
        """
        n_nodes = E_idx.size(1)
        ii = torch.arange(n_nodes).view(1, -1, 1)
        
        # mask = 1 if neighbor_idx < current_idx
        mask = (E_idx - ii < 0).float()
        
        return mask
    
    def forward(self, X, S, mask):
        """
        Forward pass of Struct2Seq.
        
        Args:
            X: Coordinates [B, N, 4, 3]
            S: Sequence [B, N] (amino acid indices)
            mask: Valid residues [B, N]
        
        Returns:
            Log probabilities [B, N, 20]
        """
        # === Step 1: Featurization ===
        V, E, E_idx = self.features(X, mask)
        h_V = self.W_v(V)
        h_E = self.W_e(E)
        
        # === Step 2: Encoder (unmasked attention) ===
        # Build attention mask (attend to all valid neighbors)
        mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1)
        mask_attend = mask.unsqueeze(-1) * mask_attend
        
        for layer in self.encoder_layers:
            # Combine edge features with node features
            h_EV = cat_neighbors_nodes(h_V, h_E, E_idx)
            h_V = layer(h_V, h_EV, mask_V=mask, mask_attend=mask_attend)
        
        # === Step 3: Prepare decoder ===
        # Embed the sequence
        h_S = self.W_s(S)
        h_ES = cat_neighbors_nodes(h_S, h_E, E_idx)
        
        # Encoder features (for positions we haven't generated yet)
        h_ES_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx)
        h_ESV_encoder = cat_neighbors_nodes(h_V, h_ES_encoder, E_idx)
        
        # === Step 4: Decoder (autoregressive) ===
        # Create autoregressive mask
        mask_ar = self._autoregressive_mask(E_idx).unsqueeze(-1)
        mask_1D = mask.view(mask.size(0), mask.size(1), 1, 1)
        mask_bw = mask_1D * mask_ar  # Backward (already generated)
        mask_fw = mask_1D * (1 - mask_ar)  # Forward (encoder info)
        
        h_ESV_encoder_fw = mask_fw * h_ESV_encoder
        
        for layer in self.decoder_layers:
            # Combine structure + sequence information
            h_ESV = cat_neighbors_nodes(h_V, h_ES, E_idx)
            # Mask: attend backward to generated sequence,
            # forward to structure
            h_ESV = mask_bw * h_ESV + h_ESV_encoder_fw
            h_V = layer(h_V, h_ESV, mask_V=mask)
        
        # === Step 5: Output projection ===
        logits = self.W_out(h_V)
        log_probs = F.log_softmax(logits, dim=-1)
        
        return log_probs

# Create the model
model = Struct2Seq(
    num_letters=20,
    node_features=32,
    edge_features=48,
    hidden_dim=32,
    num_encoder_layers=2,
    num_decoder_layers=2,
    num_heads=4,
    k_neighbors=10,
    dropout=0.1
)

print("Struct2Seq model created!")
print(f"Number of parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Test the complete model
# Create dummy input
batch_size = 2
n_residues = 20

X_test = torch.randn(batch_size, n_residues, 4, 3)  # Coordinates
S_test = torch.randint(0, 20, (batch_size, n_residues))  # Sequence
mask_test = torch.ones(batch_size, n_residues)  # All valid

# Forward pass
with torch.no_grad():
    log_probs = model(X_test, S_test, mask_test)

print(f"Input shape: {X_test.shape}")
print(f"Output log probabilities shape: {log_probs.shape}")
print(f"Expected: [batch_size={batch_size}, n_residues={n_residues}, vocab_size=20]")
print(f"\nModel forward pass successful!")

### Understanding the Autoregressive Mask

The decoder uses **autoregressive masking** to generate sequences:

```
Position:  0  1  2  3  4
           ↓  ↓  ↓  ↓  ↓
Gen 0:    [?  ?  ?  ?  ?]  → Predict position 0
Gen 1:    [A  ?  ?  ?  ?]  → Predict position 1 (seeing A)
Gen 2:    [A  C  ?  ?  ?]  → Predict position 2 (seeing A, C)
Gen 3:    [A  C  D  ?  ?]  → Predict position 3 (seeing A, C, D)
...
```

**Key insight:** Position i can only see positions < i. This prevents "cheating" during training.

In [None]:
# Visualize autoregressive mask
def visualize_ar_mask():
    """Show the autoregressive masking pattern."""
    n = 15
    E_idx = torch.arange(n).unsqueeze(0).unsqueeze(0).expand(1, n, n)
    mask = model._autoregressive_mask(E_idx)
    
    plt.figure(figsize=(8, 8))
    plt.imshow(mask[0].numpy(), cmap='Blues', interpolation='nearest')
    plt.colorbar(label='Can Attend')
    plt.xlabel('Neighbor Position')
    plt.ylabel('Current Position')
    plt.title('Autoregressive Attention Mask\n(1 = can attend, 0 = masked)')
    
    # Add text annotations
    for i in range(min(5, n)):
        for j in range(min(5, n)):
            text = plt.text(j, i, int(mask[0, i, j].item()),
                          ha="center", va="center", color="black", fontsize=10)
    
    plt.tight_layout()
    plt.show()

visualize_ar_mask()

---
<a id='section7'></a>
## 7. Part 6: Training and Inference

### Training the Model

**Loss Function:** Negative log-likelihood
```
L = -Σᵢ log P(sᵢ | structure, s₁, ..., sᵢ₋₁)
```

Where sᵢ is the true amino acid at position i.

In [None]:
def compute_loss(log_probs, S_true, mask):
    """
    Compute negative log-likelihood loss.
    
    Args:
        log_probs: Model predictions [B, N, 20]
        S_true: True sequence [B, N]
        mask: Valid positions [B, N]
    
    Returns:
        loss: Scalar loss value
        perplexity: Perplexity metric
    """
    criterion = nn.NLLLoss(reduction='none')
    
    # Compute loss per position
    loss = criterion(
        log_probs.contiguous().view(-1, 20),
        S_true.contiguous().view(-1)
    ).view(S_true.size())
    
    # Mask invalid positions
    loss = loss * mask
    
    # Average over valid positions
    loss_avg = loss.sum() / mask.sum()
    
    # Compute perplexity
    perplexity = torch.exp(loss_avg)
    
    return loss_avg, perplexity

# Test loss computation
loss, perplexity = compute_loss(log_probs, S_test, mask_test)
print(f"Loss: {loss.item():.4f}")
print(f"Perplexity: {perplexity.item():.4f}")
print(f"\nPerplexity interpretation:")
print(f"- Random baseline: 20.0 (uniform over 20 amino acids)")
print(f"- Good model: 4-7")
print(f"- Perfect model: 1.0")

In [None]:
def training_step_example(model, X, S, mask, optimizer):
    """
    Example training step.
    
    In practice, you would:
    1. Load batches from DataLoader
    2. Repeat for many epochs
    3. Validate on held-out set
    4. Save best model
    """
    # Forward pass
    log_probs = model(X, S, mask)
    
    # Compute loss
    loss, perplexity = compute_loss(log_probs, S, mask)
    
    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    return loss.item(), perplexity.item()

# Example: one training step
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss, perplexity = training_step_example(model, X_test, S_test, mask_test, optimizer)
print(f"After one training step:")
print(f"Loss: {loss:.4f}")
print(f"Perplexity: {perplexity:.4f}")

### Sequence Generation (Sampling)

To design new sequences, we sample from the model autoregressively:

```python
for i in range(n_residues):
    # Get probabilities for position i
    probs = model.predict_position(i, structure, sequence[:i])
    
    # Sample amino acid
    aa = sample(probs, temperature)
    
    # Add to sequence
    sequence[i] = aa
```

In [None]:
def sample_sequence(model, X, mask, temperature=1.0):
    """
    Sample a protein sequence for a given structure.
    
    Args:
        model: Trained Struct2Seq model
        X: Structure coordinates [B, N, 4, 3]
        mask: Valid residues [B, N]
        temperature: Sampling temperature (higher = more random)
    
    Returns:
        S: Sampled sequence [B, N]
    """
    model.eval()
    
    batch_size = X.size(0)
    n_residues = X.size(1)
    
    # Initialize empty sequence
    S = torch.zeros(batch_size, n_residues, dtype=torch.long)
    
    with torch.no_grad():
        # Featurize structure (done once)
        V, E, E_idx = model.features(X, mask)
        h_V_encoder = model.W_v(V)
        h_E = model.W_e(E)
        
        # Encode structure
        mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1)
        mask_attend = mask.unsqueeze(-1) * mask_attend
        
        for layer in model.encoder_layers:
            h_EV = cat_neighbors_nodes(h_V_encoder, h_E, E_idx)
            h_V_encoder = layer(h_V_encoder, h_EV, mask_V=mask, 
                               mask_attend=mask_attend)
        
        # Sample sequence position by position
        h_S = torch.zeros_like(h_V_encoder)
        
        for t in range(n_residues):
            # Update sequence embedding up to position t
            h_S_t = model.W_s(S[:, :t+1])
            h_S[:, :t+1, :] = h_S_t
            
            # Decode
            h_V = h_V_encoder.clone()
            h_ES = cat_neighbors_nodes(h_S, h_E, E_idx)
            h_ES_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx)
            h_ESV_encoder = cat_neighbors_nodes(h_V_encoder, h_ES_encoder, E_idx)
            
            mask_ar = model._autoregressive_mask(E_idx).unsqueeze(-1)
            mask_1D = mask.view(mask.size(0), mask.size(1), 1, 1)
            mask_bw = mask_1D * mask_ar
            mask_fw = mask_1D * (1 - mask_ar)
            h_ESV_encoder_fw = mask_fw * h_ESV_encoder
            
            for layer in model.decoder_layers:
                h_ESV = cat_neighbors_nodes(h_V, h_ES, E_idx)
                h_ESV = mask_bw * h_ESV + h_ESV_encoder_fw
                h_V = layer(h_V, h_ESV, mask_V=mask)
            
            # Predict position t
            logits_t = model.W_out(h_V[:, t, :])
            probs_t = F.softmax(logits_t / temperature, dim=-1)
            
            # Sample
            S[:, t] = torch.multinomial(probs_t, 1).squeeze(-1)
    
    return S

# Example: Sample a sequence
amino_acids = 'ACDEFGHIKLMNPQRSTVWY'

with torch.no_grad():
    S_sampled = sample_sequence(model, X_test[:1], mask_test[:1], temperature=1.0)

# Convert to string
seq_str = ''.join([amino_acids[s] for s in S_sampled[0]])
print(f"Sampled sequence: {seq_str}")
print(f"\nNote: This is from an untrained model, so it's random!")

### Temperature Sampling

**Temperature** controls randomness:

```
probs = softmax(logits / T)
```

- **T = 0.1**: Very confident (argmax)
- **T = 1.0**: Normal sampling
- **T = 2.0**: More random/diverse

In [None]:
def demonstrate_temperature():
    """Show effect of temperature on sampling."""
    # Create example logits
    logits = torch.tensor([2.0, 1.0, 0.5, 0.3, 0.1])
    
    temperatures = [0.1, 0.5, 1.0, 2.0]
    
    plt.figure(figsize=(12, 4))
    for i, T in enumerate(temperatures):
        probs = F.softmax(logits / T, dim=0)
        
        plt.subplot(1, 4, i+1)
        plt.bar(range(5), probs.numpy())
        plt.title(f'Temperature = {T}')
        plt.xlabel('Amino Acid')
        plt.ylabel('Probability')
        plt.ylim(0, 1)
    
    plt.tight_layout()
    plt.show()

demonstrate_temperature()

---
## Summary and Key Takeaways

**What we learned:**

1. **Protein Representation**
   - Proteins as graphs with nodes (residues) and edges (spatial neighbors)
   - Features: distances, angles, orientations

2. **Model Architecture**
   - **Encoder**: Process structure with unmasked attention
   - **Decoder**: Generate sequence with autoregressive masking
   - **Attention**: Aggregate information from neighbors adaptively

3. **Key Components**
   - k-NN graph construction
   - RBF and positional encodings
   - Multi-head attention
   - Transformer layers with residual connections
   - Autoregressive generation

4. **Training and Inference**
   - Loss: Negative log-likelihood
   - Metric: Perplexity
   - Sampling: Temperature-controlled generation

**Next Steps:**
1. Train on real protein data
2. Evaluate on test structures
3. Generate diverse sequences
4. Validate designs experimentally

**References:**
- Paper: Ingraham et al., NeurIPS 2019
- Code: `struct2seq/` directory
- Tutorial: TUTORIAL.md

---
## Exercises

**Try these to deepen your understanding:**

1. **Modify k in k-NN graph**
   - How does changing k affect the graph connectivity?
   - Plot connectivity matrices for k=5, 10, 20, 30

2. **Visualize attention weights**
   - Extract attention weights during forward pass
   - Visualize which residues attend to which
   - Compare different attention heads

3. **Implement MPNN variant**
   - Replace TransformerLayer with MPNNLayer
   - Compare performance

4. **Sequence recovery analysis**
   - Load real protein structure
   - Generate sequences
   - Compute % recovery vs native

5. **Feature ablation study**
   - Remove different features (RBF, dihedrals, etc.)
   - Measure impact on perplexity

---
## Additional Resources

**Papers:**
- Ingraham et al., "Generative Models for Graph-Based Protein Design" (NeurIPS 2019)
- Vaswani et al., "Attention Is All You Need" (NeurIPS 2017)
- Gilmer et al., "Neural Message Passing for Quantum Chemistry" (ICML 2017)

**Code:**
- `struct2seq/struct2seq.py` - Complete model implementation
- `struct2seq/protein_features.py` - Full featurization code
- `struct2seq/self_attention.py` - Attention mechanisms
- `experiments/train_s2s.py` - Training script

**Datasets:**
- CATH: Protein structure classification
- SPIN2: Short protein sequences benchmark

**Tools:**
- PyTorch: Deep learning framework
- PyMOL: Protein visualization
- BioPython: Protein structure parsing