# Stanford RNA 3D Folding Competition - Training Notebook

This notebook trains a model to predict 5 diverse 3D RNA structures per sequence.

## Competition Requirements:
- Predict 5 different structures per target (ensemble predictions)
- Output format: x_1, y_1, z_1, ..., x_5, y_5, z_5 for each residue
- Evaluation: TM-score (best of 5 predictions)
- Coordinates: C1' atom positions in Angstroms
- Runtime limit: 8 hours GPU for inference

## 1. Imports & Config

In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

CONFIG = {
    'data_dir': '../input/stanford-rna-3d-folding-2',
    'max_len': 512,  # Increased to handle longer sequences
    'batch_size': 8,  # Reduced for larger sequences
    'epochs': 15,
    'lr': 5e-4,
    'num_predictions': 5,  # Must predict 5 structures
    'embed_dim': 256,
    'nhead': 8,
    'num_layers': 6,
    'dropout': 0.1,
    'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    'seed': 42
}

# Set seeds for reproducibility
torch.manual_seed(CONFIG['seed'])
np.random.seed(CONFIG['seed'])
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(CONFIG['seed'])

print(f"Running on {CONFIG['device']}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 2. Dataset Implementation

Key features:
- Handles variable sequence lengths (pad/truncate)
- Loads multiple reference structures from train_labels.csv
- Properly maps sequences to their 3D coordinates
- Returns mask for valid positions

In [None]:
class RNADataset(Dataset):
    def __init__(self, data_path, max_len=512, mode='train'):
        self.mode = mode
        self.max_len = max_len
        self.base2int = {'A': 0, 'C': 1, 'G': 2, 'U': 3, 'N': 4}  # N for unknown
        
        # Load sequences
        print(f"Loading {mode} sequences...")
        if mode == 'train':
            self.seq_df = pd.read_csv(os.path.join(data_path, 'train_sequences.csv'))
            lbl_df = pd.read_csv(os.path.join(data_path, 'train_labels.csv'))
        elif mode == 'val':
            self.seq_df = pd.read_csv(os.path.join(data_path, 'validation_sequences.csv'))
            lbl_df = pd.read_csv(os.path.join(data_path, 'validation_labels.csv'))
        else:  # test mode
            self.seq_df = pd.read_csv(os.path.join(data_path, 'test_sequences.csv'))
            lbl_df = None
        
        print(f"Loaded {len(self.seq_df)} sequences")
        
        if lbl_df is not None:
            # Extract target_id from ID column (format: targetid_residuenumber)
            lbl_df['target_id'] = lbl_df['ID'].apply(lambda x: x.rsplit('_', 1)[0])
            
            # Find all coordinate columns (x_1, y_1, z_1, x_2, y_2, z_2, ...)
            coord_cols = [col for col in lbl_df.columns if col.startswith(('x_', 'y_', 'z_'))]
            self.num_structures = len([col for col in coord_cols if col.startswith('x_')])
            print(f"Found {self.num_structures} reference structures in labels")
            
            # Build coordinates dictionary: target_id -> numpy array of shape (num_residues, 3)
            # Note: We only use the first structure for training
            self.coords_map = {}
            print("Building coordinates map...")
            unique_targets = lbl_df['target_id'].unique()
            print(f"Processing {len(unique_targets)} unique targets...")
            
            for idx, target_id in enumerate(unique_targets):
                if idx % 500 == 0:
                    print(f"  Processed {idx}/{len(unique_targets)} targets")
                    
                target_data = lbl_df[lbl_df['target_id'] == target_id].sort_values('resid')
                
                # Extract coordinates for the first structure only
                coords = target_data[['x_1', 'y_1', 'z_1']].values  # shape: (num_residues, 3)
                coords = np.nan_to_num(coords, nan=0.0)
                self.coords_map[target_id] = coords
            
            print(f"Completed coordinates map for {len(self.coords_map)} targets")
            
            # Filter sequences that have labels
            self.seq_df = self.seq_df[self.seq_df['target_id'].isin(self.coords_map.keys())]
            self.seq_df = self.seq_df.reset_index(drop=True)
            print(f"After filtering: {len(self.seq_df)} sequences with labels")
    
    def __len__(self):
        return len(self.seq_df)
    
    def __getitem__(self, idx):
        row = self.seq_df.iloc[idx]
        target_id = row['target_id']
        seq_str = row['sequence']
        
        # Convert sequence to integer IDs
        seq_ids = [self.base2int.get(c, 4) for c in seq_str]  # 4 for unknown
        orig_len = len(seq_ids)
        
        # Truncate or pad to max_len
        if len(seq_ids) > self.max_len:
            seq_ids = seq_ids[:self.max_len]
            actual_len = self.max_len
        else:
            actual_len = len(seq_ids)
            seq_ids = seq_ids + [4] * (self.max_len - actual_len)  # pad with unknown token
        
        input_ids = torch.tensor(seq_ids, dtype=torch.long)
        mask = torch.zeros(self.max_len, dtype=torch.bool)
        mask[:actual_len] = True
        
            # Get coordinates: shape (num_residues, 3)
            # Get coordinates: shape (num_residues, 3, num_structures)
            coords = self.coords_map[target_id]
            
            # Truncate or pad coordinates
            if len(coords) > self.max_len:
                coords = coords[:self.max_len]
            # Create padded tensor: (max_len, 3)
            target_tensor = torch.zeros((self.max_len, 3), dtype=torch.float32)
            target_tensor = torch.zeros((self.max_len, 3, self.num_structures), dtype=torch.float32)
            target_tensor[:len(coords)] = torch.tensor(coords, dtype=torch.float32)
            # Replicate to match num_predictions for training
            # Shape: (max_len, 3, num_predictions) - same target for all predictions
            target_tensor = target_tensor.unsqueeze(-1).repeat(1, 1, 5)




        return input_ids, mask, target_id                    return input_ids, target_tensor, mask, target_id                    return input_ids, mask, target_id

## 3. Model Architecture

**Key Design Decisions:**
- **Transformer Encoder**: Captures long-range dependencies in RNA sequences
- **Multi-head prediction**: Outputs 5 diverse structures using different prediction heads
- **Positional Encoding**: Helps model understand residue positions
- **Ensemble Learning**: Each head learns different structural conformations

In [None]:
class PositionalEncoding(nn.Module):
    """Sinusoidal positional encoding for sequence position information"""
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]


class RNAStructurePredictor(nn.Module):
    """
    Model to predict 5 diverse 3D structures for RNA sequences
    
    Architecture:
    - Embedding layer for RNA nucleotides
    - Positional encoding
    - Transformer encoder for sequence understanding
    - 5 separate prediction heads for diverse structures
    """
    def __init__(self, vocab_size=5, embed_dim=256, nhead=8, num_layers=6, 
                 num_predictions=5, dropout=0.1, max_len=512):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_predictions = num_predictions
        
        # Embedding
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=4)
        self.pos_encoder = PositionalEncoding(embed_dim, max_len=max_len)
        
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=nhead,
            dim_feedforward=embed_dim * 4,
            dropout=dropout,
            batch_first=True,
            norm_first=True  # Pre-LN for better training stability
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # Separate prediction heads for each of the 5 structures
        # This encourages diversity in predictions
        self.prediction_heads = nn.ModuleList([
            nn.Sequential(
                nn.Linear(embed_dim, embed_dim // 2),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(embed_dim // 2, 3)  # Output (x, y, z) coordinates
            )
            for _ in range(num_predictions)
        ])
        
        # Initialize weights
        self._init_weights()
    
    def _init_weights(self):
        """Initialize weights for better training"""
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    
    def forward(self, x, mask=None):
        """
        Args:
            x: (batch, seq_len) - RNA sequence IDs
            mask: (batch, seq_len) - True for valid positions, False for padding
        
        Returns:
            coords: (batch, seq_len, 3, num_predictions) - Predicted 3D coordinates
        """
        # Embedding + positional encoding
        x_embed = self.embedding(x)  # (batch, seq_len, embed_dim)
        x_embed = self.pos_encoder(x_embed)
        
        # Transformer expects padding_mask where True = IGNORE
        # Our mask is True = KEEP, so we invert it
        padding_mask = ~mask if mask is not None else None
        
        # Encode sequence
        encoded = self.transformer(x_embed, src_key_padding_mask=padding_mask)
        # (batch, seq_len, embed_dim)
        
        # Generate predictions from each head
        predictions = []
        for head in self.prediction_heads:
            coords = head(encoded)  # (batch, seq_len, 3)
            predictions.append(coords)
        
        # Stack predictions: (batch, seq_len, 3, num_predictions)
        predictions = torch.stack(predictions, dim=3)
        
        return predictions

## 4. Training Loop

**Training Strategy:**
- Train model to predict coordinates for all 5 structures simultaneously
- Use MSE loss averaged across all prediction heads
- Mask out padding positions
- Save best model based on validation loss
- Add gradient clipping for stability

In [None]:
def train_model(model, train_loader, val_loader, config):
    """Train the RNA structure prediction model"""
    
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config['lr'],
        weight_decay=0.01,
        betas=(0.9, 0.999)
    )
    
    # Learning rate scheduler - warmup + cosine decay
    num_training_steps = len(train_loader) * config['epochs']
    num_warmup_steps = num_training_steps // 10
    
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + np.cos(np.pi * progress)))
    
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    
    criterion = nn.MSELoss(reduction='none')
    best_val_loss = float('inf')
    
    print("\\nStarting training...")
    print(f"Total training steps: {num_training_steps}")
    print(f"Warmup steps: {num_warmup_steps}")
    
    for epoch in range(config['epochs']):
        # Training phase
        model.train()
        train_loss = 0
        train_count = 0
        
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']}")
        for batch in progress_bar:
            input_ids, targets, mask, _ = batch
            input_ids = input_ids.to(config['device'])
            targets = targets.to(config['device'])  # (batch, seq_len, 3, num_structures)
            mask = mask.to(config['device'])
            
            optimizer.zero_grad()
            
            # Forward pass
            predictions = model(input_ids, mask)  # (batch, seq_len, 3, num_predictions)
            
            # Compute loss for all predictions
            # Average loss across all structure predictions
            loss = criterion(predictions, targets)  # (batch, seq_len, 3, num_predictions)
            
            # Mask out padding and average
            mask_expanded = mask.unsqueeze(-1).unsqueeze(-1)  # (batch, seq_len, 1, 1)
            loss = (loss * mask_expanded).sum() / (mask.sum() * 3 * config['num_predictions'])
            
            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()
            
            train_loss += loss.item()
            train_count += 1
            progress_bar.set_postfix({
                'loss': train_loss / train_count,
                'lr': scheduler.get_last_lr()[0]
            })
        
        avg_train_loss = train_loss / train_count
        
        # Validation phase
        model.eval()
        val_loss = 0
        val_count = 0
        
        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Validation"):
                input_ids, targets, mask, _ = batch
                input_ids = input_ids.to(config['device'])
                targets = targets.to(config['device'])
                mask = mask.to(config['device'])
                
                predictions = model(input_ids, mask)
                loss = criterion(predictions, targets)
                
                mask_expanded = mask.unsqueeze(-1).unsqueeze(-1)
                loss = (loss * mask_expanded).sum() / (mask.sum() * 3 * config['num_predictions'])
                
                val_loss += loss.item()
                val_count += 1
        
        avg_val_loss = val_loss / val_count
        
        print(f"\\nEpoch {epoch+1} Summary:")
        print(f"  Train Loss: {avg_train_loss:.6f}")
        print(f"  Val Loss: {avg_val_loss:.6f}")
        print(f"  Learning Rate: {scheduler.get_last_lr()[0]:.6f}")
        
        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': avg_val_loss,
                'config': config
            }, 'model.pth')
            print(f"  âœ“ New best model saved! (Val Loss: {avg_val_loss:.6f})")
    
    print(f"\\nTraining complete!")
    print(f"Best validation loss: {best_val_loss:.6f}")
    return model


def main():
    """Main training function"""
    print("="*60)
    print("RNA 3D Structure Prediction - Training")
    print("="*60)
    
    # Create datasets
    print("\\nLoading datasets...")
    train_dataset = RNADataset(
        CONFIG['data_dir'],
        max_len=CONFIG['max_len'],
        mode='train'
    )
    val_dataset = RNADataset(
        CONFIG['data_dir'],
        max_len=CONFIG['max_len'],
        mode='val'
    )
    
    print(f"Train samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=CONFIG['batch_size'],
        shuffle=True,
        num_workers=2,
        pin_memory=True if torch.cuda.is_available() else False
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=CONFIG['batch_size'],
        shuffle=False,
        num_workers=2,
        pin_memory=True if torch.cuda.is_available() else False
    )
    
    # Create model
    print("\\nInitializing model...")
    model = RNAStructurePredictor(
        vocab_size=5,
        embed_dim=CONFIG['embed_dim'],
        nhead=CONFIG['nhead'],
        num_layers=CONFIG['num_layers'],
        num_predictions=CONFIG['num_predictions'],
        dropout=CONFIG['dropout'],
        max_len=CONFIG['max_len']
    ).to(CONFIG['device'])
    
    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    
    # Train model
    model = train_model(model, train_loader, val_loader, CONFIG)
    
    print("\\n" + "="*60)
    print("Training pipeline completed successfully!")
    print("="*60)


if __name__ == '__main__':
    main()