# Stanford RNA 3D Folding - Training v2

**Key Fixes**:
- Coordinate normalization (prevents exploding loss)
- num_workers=0 (prevents multiprocessing errors)
- NaN/Inf checks in loss
- Faster preprocessing

In [None]:
import os, gc, time, random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
import warnings
warnings.filterwarnings('ignore')

print(f'[{time.strftime("%H:%M:%S")}] Starting...')

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'[{time.strftime("%H:%M:%S")}] Device: {device}')
if torch.cuda.is_available():
    print(f'[{time.strftime("%H:%M:%S")}] GPU: {torch.cuda.get_device_name(0)}')

In [None]:
CONFIG = {
    'data_dir': '../input/stanford-rna-3d-folding-2',
    'max_len': 384,
    'batch_size': 16,
    'epochs': 30,
    'lr': 1e-3,
    'min_lr': 1e-5,
    'weight_decay': 0.01,
    'gradient_clip': 1.0,
    'warmup_epochs': 3,
    'vocab_size': 5,
    'embed_dim': 256,
    'nhead': 8,
    'num_layers': 6,
    'num_predictions': 5,
    'dropout': 0.1,
    'num_workers': 0,
    'save_path': './model.pth',
}
print(f'[{time.strftime("%H:%M:%S")}] Config loaded')
for k, v in CONFIG.items(): print(f'  {k}: {v}')

In [None]:
print(f'\n[{time.strftime("%H:%M:%S")}] === LOADING DATA ===')
train_seq = pd.read_csv(os.path.join(CONFIG['data_dir'], 'train_sequences.csv'))
val_seq = pd.read_csv(os.path.join(CONFIG['data_dir'], 'validation_sequences.csv'))
print(f'[{time.strftime("%H:%M:%S")}] Train: {len(train_seq)}, Val: {len(val_seq)}')

train_labels = pd.read_csv(os.path.join(CONFIG['data_dir'], 'train_labels.csv'))
val_labels = pd.read_csv(os.path.join(CONFIG['data_dir'], 'validation_labels.csv'))
print(f'[{time.strftime("%H:%M:%S")}] Labels: train={len(train_labels)}, val={len(val_labels)}')
print(f'Columns: {train_labels.columns.tolist()[:8]}')

In [None]:
print(f'\n[{time.strftime("%H:%M:%S")}] === PREPROCESSING ===')

def preprocess_labels(df, name):
    start = time.time()
    df = df.copy()
    df['target_id'] = df['ID'].str.rsplit('_', n=1).str[0]
    
    x_col = 'x_1' if 'x_1' in df.columns else 'x'
    y_col = 'y_1' if 'y_1' in df.columns else 'y'
    z_col = 'z_1' if 'z_1' in df.columns else 'z'
    
    coords_dict = {}
    grouped = df.groupby('target_id')
    total = len(grouped)
    
    for i, (tid, grp) in enumerate(grouped):
        if i % 1000 == 0:
            print(f'[{time.strftime("%H:%M:%S")}] {name}: {i}/{total}')
        grp = grp.sort_values('ID')
        coords = np.stack([grp[x_col].values, grp[y_col].values, grp[z_col].values], axis=1).astype(np.float32)
        coords = np.nan_to_num(coords, nan=0.0)
        coords_dict[tid] = coords
    
    all_c = np.concatenate(list(coords_dict.values()), axis=0)
    mean = np.mean(all_c, axis=0)
    std = np.std(all_c, axis=0) + 1e-8
    print(f'[{time.strftime("%H:%M:%S")}] {name}: {len(coords_dict)} targets in {time.time()-start:.1f}s')
    print(f'Mean: {mean}, Std: {std}')
    return coords_dict, mean, std

train_coords, COORD_MEAN, COORD_STD = preprocess_labels(train_labels, 'train')
val_coords, _, _ = preprocess_labels(val_labels, 'val')
del train_labels, val_labels
gc.collect()
print(f'[{time.strftime("%H:%M:%S")}] Preprocessing done!')

In [None]:
class RNADataset(Dataset):
    def __init__(self, seq_df, coords_dict, max_len, coord_mean, coord_std, aug=False):
        self.max_len = max_len
        self.aug = aug
        self.base2int = {'A': 0, 'C': 1, 'G': 2, 'U': 3, 'N': 4}
        self.coord_mean = coord_mean
        self.coord_std = coord_std
        valid_ids = set(coords_dict.keys())
        self.seq_df = seq_df[seq_df['target_id'].isin(valid_ids)].reset_index(drop=True)
        self.coords_dict = coords_dict
        print(f'[{time.strftime("%H:%M:%S")}] Dataset: {len(self.seq_df)}')
    
    def __len__(self): return len(self.seq_df)
    
    def __getitem__(self, idx):
        row = self.seq_df.iloc[idx]
        seq = row['sequence']
        seq_ids = [self.base2int.get(c.upper(), 4) for c in seq[:self.max_len]]
        orig_len = len(seq_ids)
        
        coords = self.coords_dict[row['target_id']][:self.max_len].copy()
        coords = (coords - self.coord_mean) / self.coord_std  # NORMALIZE
        coord_len = len(coords)
        orig_len = min(orig_len, coord_len)
        
        if self.aug and random.random() > 0.5:
            a = random.uniform(0, 2*np.pi)
            R = np.array([[np.cos(a), -np.sin(a), 0], [np.sin(a), np.cos(a), 0], [0, 0, 1]], dtype=np.float32)
            coords = coords @ R.T
        
        if len(seq_ids) < self.max_len: seq_ids += [4] * (self.max_len - len(seq_ids))
        if len(coords) < self.max_len: coords = np.pad(coords, ((0, self.max_len-len(coords)), (0,0)))
        
        mask = np.zeros(self.max_len, dtype=bool)
        mask[:orig_len] = True
        return torch.tensor(seq_ids, dtype=torch.long), torch.tensor(coords, dtype=torch.float32), torch.tensor(mask, dtype=torch.bool), orig_len

print(f'\n[{time.strftime("%H:%M:%S")}] === DATASETS ===')
train_ds = RNADataset(train_seq, train_coords, CONFIG['max_len'], COORD_MEAN, COORD_STD, aug=True)
val_ds = RNADataset(val_seq, val_coords, CONFIG['max_len'], COORD_MEAN, COORD_STD, aug=False)
train_loader = DataLoader(train_ds, batch_size=CONFIG['batch_size'], shuffle=True, num_workers=0, pin_memory=True, drop_last=True)
val_loader = DataLoader(val_ds, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=0, pin_memory=True)
print(f'[{time.strftime("%H:%M:%S")}] Train batches: {len(train_loader)}, Val batches: {len(val_loader)}')

In [None]:
print(f'\n[{time.strftime("%H:%M:%S")}] === DATA CHECK ===')
batch = next(iter(train_loader))
seq, coords, mask, lens = batch
print(f'Shapes: seq={seq.shape}, coords={coords.shape}')
print(f'Coords normalized: min={coords.min():.3f}, max={coords.max():.3f}, mean={coords.mean():.3f}')
assert not torch.isnan(coords).any() and not torch.isinf(coords).any(), 'Bad data!'
print('Data OK!')

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe.unsqueeze(0))
    def forward(self, x): return x + self.pe[:, :x.size(1)]

class RNAModel(nn.Module):
    def __init__(self, vocab=5, dim=256, heads=8, layers=6, n_pred=5, drop=0.1, max_len=512):
        super().__init__()
        self.n_pred = n_pred
        self.emb = nn.Embedding(vocab, dim, padding_idx=4)
        self.pos = PositionalEncoding(dim, max_len)
        self.drop = nn.Dropout(drop)
        enc = nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=dim*4, dropout=drop, batch_first=True, norm_first=True, activation='gelu')
        self.transformer = nn.TransformerEncoder(enc, num_layers=layers)
        self.heads = nn.ModuleList([nn.Sequential(nn.Linear(dim, dim), nn.GELU(), nn.Dropout(drop), nn.Linear(dim, 3)) for _ in range(n_pred)])
        self._init()
    
    def _init(self):
        for p in self.parameters():
            if p.dim() > 1: nn.init.xavier_uniform_(p)
    
    def forward(self, x, mask=None):
        x = self.drop(self.pos(self.emb(x)))
        x = self.transformer(x, src_key_padding_mask=~mask if mask is not None else None)
        return torch.stack([h(x) for h in self.heads], dim=3)

print(f'\n[{time.strftime("%H:%M:%S")}] === MODEL ===')
model = RNAModel(CONFIG['vocab_size'], CONFIG['embed_dim'], CONFIG['nhead'], CONFIG['num_layers'], CONFIG['num_predictions'], CONFIG['dropout'], CONFIG['max_len']).to(device)
print(f'[{time.strftime("%H:%M:%S")}] Params: {sum(p.numel() for p in model.parameters()):,}')

In [None]:
class SafeLoss(nn.Module):
    def forward(self, pred, target, mask):
        b, s, _, n = pred.shape
        t_exp = target.unsqueeze(3).expand(-1, -1, -1, n)
        mse = ((pred - t_exp) ** 2).sum(dim=2)
        mse_masked = mse * mask.unsqueeze(2).float()
        valid = mask.sum(dim=1, keepdim=True).clamp(min=1).float()
        per_pred = mse_masked.sum(dim=1) / valid
        best = per_pred.min(dim=1)[0].mean()
        avg = per_pred.mean()
        return torch.clamp(0.5 * best + 0.5 * avg, max=1e6), best, avg

criterion = SafeLoss()
optimizer = AdamW(model.parameters(), lr=CONFIG['lr'], weight_decay=CONFIG['weight_decay'])
scheduler = CosineAnnealingLR(optimizer, T_max=CONFIG['epochs']-CONFIG['warmup_epochs'], eta_min=CONFIG['min_lr'])
print(f'[{time.strftime("%H:%M:%S")}] Loss & optimizer ready')

In [None]:
def train_epoch(model, loader, crit, opt, dev, epoch, cfg):
    model.train()
    total = 0
    n = len(loader)
    if epoch < cfg['warmup_epochs']:
        lr = cfg['lr'] * (epoch + 1) / cfg['warmup_epochs']
        for pg in opt.param_groups: pg['lr'] = lr
    for i, (seq, coords, mask, _) in enumerate(loader):
        seq, coords, mask = seq.to(dev), coords.to(dev), mask.to(dev)
        opt.zero_grad()
        pred = model(seq, mask)
        loss, _, _ = crit(pred, coords, mask)
        if torch.isnan(loss) or torch.isinf(loss):
            print(f'[{time.strftime("%H:%M:%S")}] WARN: NaN/Inf batch {i}')
            continue
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), cfg['gradient_clip'])
        opt.step()
        total += loss.item()
        if i % 50 == 0: print(f'[{time.strftime("%H:%M:%S")}] E{epoch+1} B{i}/{n}: {loss.item():.4f}')
    return total / n

@torch.no_grad()
def validate(model, loader, crit, dev):
    model.eval()
    total = 0
    for seq, coords, mask, _ in loader:
        seq, coords, mask = seq.to(dev), coords.to(dev), mask.to(dev)
        loss, _, _ = crit(model(seq, mask), coords, mask)
        if not (torch.isnan(loss) or torch.isinf(loss)): total += loss.item()
    return total / max(len(loader), 1)

print(f'[{time.strftime("%H:%M:%S")}] Training functions ready')

In [None]:
print(f'\n[{time.strftime("%H:%M:%S")}] ' + '='*50)
print(f'[{time.strftime("%H:%M:%S")}] STARTING TRAINING')
print(f'[{time.strftime("%H:%M:%S")}] ' + '='*50)

best_val = float('inf')
history = {'train': [], 'val': []}

for epoch in range(CONFIG['epochs']):
    t0 = time.time()
    print(f'\n[{time.strftime("%H:%M:%S")}] === EPOCH {epoch+1}/{CONFIG["epochs"]} ===')
    
    train_loss = train_epoch(model, train_loader, criterion, optimizer, device, epoch, CONFIG)
    val_loss = validate(model, val_loader, criterion, device)
    
    if epoch >= CONFIG['warmup_epochs']: scheduler.step()
    
    history['train'].append(train_loss)
    history['val'].append(val_loss)
    
    print(f'[{time.strftime("%H:%M:%S")}] Epoch {epoch+1}: train={train_loss:.4f}, val={val_loss:.4f}, lr={optimizer.param_groups[0]["lr"]:.2e}, time={time.time()-t0:.1f}s')
    
    if val_loss < best_val:
        best_val = val_loss
        print(f'[{time.strftime("%H:%M:%S")}] NEW BEST! Saving...')
        torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(), 'val_loss': val_loss, 'config': CONFIG, 'coord_mean': COORD_MEAN, 'coord_std': COORD_STD}, CONFIG['save_path'])

print(f'\n[{time.strftime("%H:%M:%S")}] ' + '='*50)
print(f'[{time.strftime("%H:%M:%S")}] DONE! Best val: {best_val:.6f}')
print(f'[{time.strftime("%H:%M:%S")}] ' + '='*50)

In [None]:
print(f'\n[{time.strftime("%H:%M:%S")}] === FINAL CHECK ===')
ckpt = torch.load(CONFIG['save_path'], map_location=device)
model.load_state_dict(ckpt['model_state_dict'])
model.eval()
print(f'Loaded epoch {ckpt["epoch"]+1}, val_loss={ckpt["val_loss"]:.6f}')

with torch.no_grad():
    batch = next(iter(val_loader))
    seq, coords, mask, lens = batch
    pred = model(seq.to(device), mask.to(device))
    print(f'Pred shape: {pred.shape}')
    print(f'Pred range: [{pred.min():.3f}, {pred.max():.3f}]')
    
    # Denormalized RMSD
    p = pred[0, :, :, 0].cpu().numpy() * COORD_STD + COORD_MEAN
    t = coords[0].numpy() * COORD_STD + COORD_MEAN
    L = lens[0].item()
    rmsd = np.sqrt(((p[:L] - t[:L])**2).mean())
    print(f'Sample RMSD: {rmsd:.2f} Angstroms')

print(f'\n[{time.strftime("%H:%M:%S")}] Training history:')
for i, (tr, va) in enumerate(zip(history['train'], history['val'])):
    print(f'  Epoch {i+1}: train={tr:.4f}, val={va:.4f}')

print(f'\n[{time.strftime("%H:%M:%S")}] NOTEBOOK COMPLETE!')