# Stanford RNA 3D Folding - Training Notebook

**Competition**: Stanford RNA 3D Folding Part 2

**Goal**: Predict 5 diverse 3D structures (C1' coordinates) for each RNA sequence

**Metric**: TM-score (best of 5 predictions averaged across targets)

In [None]:
import os
import gc
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Running on {device}')
if torch.cuda.is_available():
    print(f'CUDA available: {torch.cuda.is_available()}')
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB')

In [None]:
# Configuration
CONFIG = {
    'data_dir': '../input/stanford-rna-3d-folding-2',
    'max_len': 512,
    'batch_size': 8,
    'epochs': 20,
    'lr': 3e-4,
    'min_lr': 1e-6,
    'weight_decay': 0.01,
    'gradient_clip': 1.0,
    'warmup_epochs': 2,
    'vocab_size': 5,
    'embed_dim': 256,
    'nhead': 8,
    'num_layers': 6,
    'num_predictions': 5,
    'dropout': 0.1,
    'seed': 42,
    'num_workers': 2,
    'save_path': './model.pth',
}

print('Configuration:')
for k, v in CONFIG.items():
    print(f'  {k}: {v}')

In [None]:
# Load data
print('Loading data...')

train_seq = pd.read_csv(os.path.join(CONFIG['data_dir'], 'train_sequences.csv'))
val_seq = pd.read_csv(os.path.join(CONFIG['data_dir'], 'validation_sequences.csv'))
print(f'Train sequences: {len(train_seq)}')
print(f'Validation sequences: {len(val_seq)}')

train_labels = pd.read_csv(os.path.join(CONFIG['data_dir'], 'train_labels.csv'))
val_labels = pd.read_csv(os.path.join(CONFIG['data_dir'], 'validation_labels.csv'))
print(f'Train label rows: {len(train_labels)}')
print(f'Validation label rows: {len(val_labels)}')

In [None]:
# Preprocess labels
print('Preprocessing labels...')

def preprocess_labels(labels_df):
    coords_dict = {}
    labels_df = labels_df.copy()
    labels_df['target_id'] = labels_df['ID'].apply(lambda x: '_'.join(x.split('_')[:-1]))
    target_ids = labels_df['target_id'].unique()
    print(f'Processing {len(target_ids)} targets...')
    
    for target_id in tqdm(target_ids, desc='Building coords'):
        target_data = labels_df[labels_df['target_id'] == target_id].sort_values('resid')
        if 'x_1' in target_data.columns:
            x = target_data['x_1'].values
            y = target_data['y_1'].values
            z = target_data['z_1'].values
        else:
            continue
        coords = np.stack([x, y, z], axis=1)
        coords = np.nan_to_num(coords, nan=0.0)
        coords_dict[target_id] = coords.astype(np.float32)
    return coords_dict

train_coords = preprocess_labels(train_labels)
val_coords = preprocess_labels(val_labels)

print(f'Train: {len(train_coords)}, Val: {len(val_coords)}')

del train_labels, val_labels
gc.collect()

In [None]:
# Dataset
class RNADataset(Dataset):
    def __init__(self, seq_df, coords_dict, max_len=512, augment=False):
        self.seq_df = seq_df.reset_index(drop=True)
        self.coords_dict = coords_dict
        self.max_len = max_len
        self.augment = augment
        self.base2int = {'A': 0, 'C': 1, 'G': 2, 'U': 3, 'N': 4}
        valid_ids = set(coords_dict.keys())
        self.seq_df = self.seq_df[self.seq_df['target_id'].isin(valid_ids)].reset_index(drop=True)
        print(f'Dataset: {len(self.seq_df)}')
    
    def __len__(self):
        return len(self.seq_df)
    
    def __getitem__(self, idx):
        row = self.seq_df.iloc[idx]
        target_id = row['target_id']
        sequence = row['sequence']
        seq_ids = [self.base2int.get(c.upper(), 4) for c in sequence]
        coords = self.coords_dict[target_id][:self.max_len].copy()
        orig_len = min(len(seq_ids), len(coords), self.max_len)
        
        if self.augment and np.random.random() > 0.5:
            angle = np.random.uniform(0, 2 * np.pi)
            c, s = np.cos(angle), np.sin(angle)
            R = np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]], dtype=np.float32)
            coords = coords @ R.T
        
        if len(seq_ids) > self.max_len:
            seq_ids = seq_ids[:self.max_len]
        else:
            seq_ids = seq_ids + [4] * (self.max_len - len(seq_ids))
        
        if len(coords) < self.max_len:
            coords = np.pad(coords, ((0, self.max_len - len(coords)), (0, 0)))
        else:
            coords = coords[:self.max_len]
        
        mask = np.zeros(self.max_len, dtype=bool)
        mask[:orig_len] = True
        
        return (
            torch.tensor(seq_ids, dtype=torch.long),
            torch.tensor(coords, dtype=torch.float32),
            torch.tensor(mask, dtype=torch.bool),
            orig_len
        )

train_dataset = RNADataset(train_seq, train_coords, CONFIG['max_len'], True)
val_dataset = RNADataset(val_seq, val_coords, CONFIG['max_len'], False)

train_loader = DataLoader(train_dataset, CONFIG['batch_size'], shuffle=True, num_workers=2, pin_memory=True, drop_last=True)
val_loader = DataLoader(val_dataset, CONFIG['batch_size'], shuffle=False, num_workers=2, pin_memory=True)

print(f'Train batches: {len(train_loader)}, Val batches: {len(val_loader)}')

In [None]:
# Model
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))
    
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class RNAStructurePredictor(nn.Module):
    def __init__(self, vocab_size=5, embed_dim=256, nhead=8, num_layers=6, num_predictions=5, dropout=0.1, max_len=512):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=4)
        self.pos_encoder = PositionalEncoding(embed_dim, max_len)
        self.dropout = nn.Dropout(dropout)
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=nhead, dim_feedforward=embed_dim*4, dropout=dropout, batch_first=True, norm_first=True, activation='gelu')
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        self.heads = nn.ModuleList([nn.Sequential(nn.Linear(embed_dim, embed_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(embed_dim, embed_dim//2), nn.GELU(), nn.Linear(embed_dim//2, 3)) for _ in range(num_predictions)])
    
    def forward(self, x, mask=None):
        x = self.dropout(self.pos_encoder(self.embedding(x)))
        x = self.transformer(x, src_key_padding_mask=~mask if mask is not None else None)
        return torch.stack([h(x) for h in self.heads], dim=3)

model = RNAStructurePredictor(CONFIG['vocab_size'], CONFIG['embed_dim'], CONFIG['nhead'], CONFIG['num_layers'], CONFIG['num_predictions'], CONFIG['dropout'], CONFIG['max_len']).to(device)
print(f'Parameters: {sum(p.numel() for p in model.parameters()):,}')

In [None]:
# Loss and optimizer
class RNALoss(nn.Module):
    def forward(self, pred, target, mask):
        target_exp = target.unsqueeze(3).expand_as(pred)
        mse = ((pred - target_exp) ** 2).sum(dim=2)
        mse_masked = mse * mask.unsqueeze(2).float()
        mse_per_pred = mse_masked.sum(dim=1) / mask.sum(dim=1, keepdim=True).clamp(min=1)
        best_loss = mse_per_pred.min(dim=1)[0].mean()
        avg_loss = mse_per_pred.mean()
        return 0.5 * best_loss + 0.5 * avg_loss, best_loss

criterion = RNALoss()
optimizer = AdamW(model.parameters(), lr=CONFIG['lr'], weight_decay=CONFIG['weight_decay'])
scheduler = CosineAnnealingLR(optimizer, CONFIG['epochs'] - CONFIG['warmup_epochs'], CONFIG['min_lr'])

In [None]:
# Training
def train_epoch(model, loader, criterion, optimizer, epoch):
    model.train()
    total_loss, n = 0, 0
    if epoch < CONFIG['warmup_epochs']:
        for g in optimizer.param_groups:
            g['lr'] = CONFIG['lr'] * (epoch + 1) / CONFIG['warmup_epochs']
    for seq, coords, mask, _ in tqdm(loader, desc=f'Train {epoch+1}'):
        seq, coords, mask = seq.to(device), coords.to(device), mask.to(device)
        optimizer.zero_grad()
        loss, _ = criterion(model(seq, mask), coords, mask)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG['gradient_clip'])
        optimizer.step()
        total_loss += loss.item()
        n += 1
    return total_loss / n

@torch.no_grad()
def validate(model, loader, criterion):
    model.eval()
    total_loss, n = 0, 0
    for seq, coords, mask, _ in tqdm(loader, desc='Val'):
        seq, coords, mask = seq.to(device), coords.to(device), mask.to(device)
        loss, _ = criterion(model(seq, mask), coords, mask)
        total_loss += loss.item()
        n += 1
    return total_loss / n

In [None]:
# Main training loop
print('Starting training...')
best_val = float('inf')
for epoch in range(CONFIG['epochs']):
    train_loss = train_epoch(model, train_loader, criterion, optimizer, epoch)
    val_loss = validate(model, val_loader, criterion)
    if epoch >= CONFIG['warmup_epochs']:
        scheduler.step()
    print(f'Epoch {epoch+1}: train={train_loss:.4f}, val={val_loss:.4f}, lr={optimizer.param_groups[0]["lr"]:.2e}')
    if val_loss < best_val:
        best_val = val_loss
        torch.save({'model': model.state_dict(), 'config': CONFIG, 'val_loss': val_loss}, CONFIG['save_path'])
        print(f'  Saved best model!')

print(f'Training complete! Best val loss: {best_val:.4f}')

In [None]:
# Verify saved model
print('Verifying model...')
ckpt = torch.load(CONFIG['save_path'])
model.load_state_dict(ckpt['model'])
model.eval()
with torch.no_grad():
    for seq, coords, mask, lens in val_loader:
        pred = model(seq.to(device), mask.to(device))[0].cpu().numpy()
        tgt = coords[0].numpy()
        l = lens[0].item()
        rmsd = np.sqrt(((pred[:l,:,0] - tgt[:l])**2).mean())
        print(f'Sample RMSD: {rmsd:.2f} A')
        break
print(f'Model size: {os.path.getsize(CONFIG["save_path"])/1e6:.1f} MB')
print('Done! Download model.pth for inference.')