# Stanford RNA 3D Folding Part 2 - Complete Pipeline

**Template-Based Modeling (TBM) with Enhanced Deep Learning Features**

This comprehensive notebook implements the full RNA 3D structure prediction pipeline:

## Pipeline Overview
1. **Template Search** - K-mer indexed PDB search with temporal filtering
2. **Sequence Alignment** - Needleman-Wunsch global alignment
3. **Coordinate Transfer** - Map C1' atoms from template to target
4. **Gap Filling** - Geometric interpolation for unmapped residues
5. **RNA-FM Embeddings** - Deep learning enhanced similarity search
6. **Model Diversity** - Generate 5 diverse predictions per target
7. **Validation** - TM-score computation against ground truth

---

## 1. Setup and Configuration

In [1]:
import csv
import os
import pickle
import re
import sys
import time
import warnings
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Dict, List, NamedTuple, Optional, Tuple

import numpy as np
import pandas as pd
from scipy.interpolate import interp1d

warnings.filterwarnings('ignore')

# Detect environment
IS_KAGGLE = os.path.exists("/kaggle")

if IS_KAGGLE:
    DATA_DIR = Path("/kaggle/input/stanford-rna-3d-folding-2")
    OUTPUT_DIR = Path("/kaggle/working")
else:
    DATA_DIR = Path("..").resolve()
    OUTPUT_DIR = Path("../output").resolve()

OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

print(f"Environment: {'Kaggle' if IS_KAGGLE else 'Local'}")
print(f"Data directory: {DATA_DIR}")
print(f"Output directory: {OUTPUT_DIR}")

Environment: Local
Data directory: /Users/taher/Projects/stanford-rna-3d-folding-2
Output directory: /Users/taher/Projects/stanford-rna-3d-folding-2/output


In [2]:
# Pipeline configuration
@dataclass
class PipelineConfig:
    """Configuration for the TBM pipeline."""
    kmer_size: int = 6
    max_template_hits: int = 50
    min_identity: float = 0.25
    min_coverage: float = 0.25
    num_models: int = 5
    perturbation_scale: float = 0.2
    average_c1_distance: float = 5.9  # Angstroms
    use_embeddings: bool = True
    max_files: int = None  # Limit for testing

config = PipelineConfig()
print(f"Configuration: {config}")

Configuration: PipelineConfig(kmer_size=6, max_template_hits=50, min_identity=0.25, min_coverage=0.25, num_models=5, perturbation_scale=0.2, average_c1_distance=5.9, use_embeddings=True, max_files=None)


## 2. Core Data Structures

In [3]:
class Residue(NamedTuple):
    """Single residue with C1' coordinates."""
    resname: str
    resid: int
    chain: str
    x: float
    y: float
    z: float

class ChainCoords(NamedTuple):
    """All C1' coordinates for a single chain."""
    chain_id: str
    sequence: str
    residues: List[Residue]

class TemplateHit(NamedTuple):
    """A template match result."""
    pdb_id: str
    chain_id: str
    sequence: str
    release_date: str
    identity: float
    coverage: float

# Modified nucleotide mappings
MODIFIED_BASE_MAP = {
    'A': 'A', 'ADE': 'A', 'DA': 'A', '1MA': 'A', '2MA': 'A', 'MIA': 'A',
    'T6A': 'A', 'I': 'A', 'INO': 'A', '6MA': 'A', 'A2M': 'A', 'MA6': 'A',
    'C': 'C', 'CYT': 'C', 'DC': 'C', '5MC': 'C', 'OMC': 'C', '4OC': 'C',
    'G': 'G', 'GUA': 'G', 'DG': 'G', '1MG': 'G', '2MG': 'G', '7MG': 'G',
    'M2G': 'G', 'OMG': 'G', 'YG': 'G',
    'U': 'U', 'URA': 'U', 'DU': 'U', 'PSU': 'U', '5MU': 'U', 'H2U': 'U',
    'OMU': 'U', 'S4U': 'U', '4SU': 'U', 'DHU': 'U',
    'T': 'U', 'DT': 'U', 'THY': 'U',
}
print(f"Loaded {len(MODIFIED_BASE_MAP)} modified base mappings")

Loaded 40 modified base mappings


## 3. CIF Parser

In [4]:
def parse_cif_line(line: str) -> List[str]:
    """Parse a CIF data line, handling quoted strings."""
    tokens = []
    i = 0
    while i < len(line):
        while i < len(line) and line[i] in ' \t':
            i += 1
        if i >= len(line):
            break
        if line[i] in '"\'':
            quote = line[i]
            i += 1
            start = i
            while i < len(line) and line[i] != quote:
                i += 1
            tokens.append(line[start:i])
            i += 1
        else:
            start = i
            while i < len(line) and line[i] not in ' \t':
                i += 1
            tokens.append(line[start:i])
    return tokens


def parse_cif_c1prime(cif_path: str) -> Dict[str, ChainCoords]:
    """Parse mmCIF file and extract C1' coordinates for RNA chains."""
    with open(cif_path, 'r') as f:
        content = f.read()
    
    atom_site_match = re.search(r'loop_\s*\n(_atom_site\.\w+\s*\n)+', content)
    if not atom_site_match:
        return {}
    
    header_section = atom_site_match.group(0)
    columns = re.findall(r'_atom_site\.(\w+)', header_section)
    
    try:
        col_indices = {
            'group_PDB': columns.index('group_PDB'),
            'label_atom_id': columns.index('label_atom_id'),
            'label_comp_id': columns.index('label_comp_id'),
            'auth_asym_id': columns.index('auth_asym_id'),
            'auth_seq_id': columns.index('auth_seq_id'),
            'Cartn_x': columns.index('Cartn_x'),
            'Cartn_y': columns.index('Cartn_y'),
            'Cartn_z': columns.index('Cartn_z'),
        }
        if 'pdbx_PDB_model_num' in columns:
            col_indices['pdbx_PDB_model_num'] = columns.index('pdbx_PDB_model_num')
        else:
            col_indices['pdbx_PDB_model_num'] = None
    except ValueError:
        return {}
    
    data_start = atom_site_match.end()
    lines = content[data_start:].split('\n')
    
    chain_residues: Dict[str, Dict[int, Residue]] = {}
    
    for line in lines:
        line = line.strip()
        if not line or line.startswith('_') or line.startswith('#') or line.startswith('loop_'):
            if line.startswith('#') or line.startswith('loop_'):
                break
            continue
        
        tokens = parse_cif_line(line)
        if len(tokens) < max(v for v in col_indices.values() if v is not None) + 1:
            continue
        
        if col_indices['pdbx_PDB_model_num'] is not None:
            model_num = tokens[col_indices['pdbx_PDB_model_num']]
            if model_num != '1' and model_num != '?':
                continue
        
        group = tokens[col_indices['group_PDB']]
        atom_id = tokens[col_indices['label_atom_id']]
        if len(atom_id) >= 2 and atom_id[0] == atom_id[-1] and atom_id[0] in '"\'':
            atom_id = atom_id[1:-1]
        
        if group != 'ATOM' or atom_id != "C1'":
            continue
        
        comp_id = tokens[col_indices['label_comp_id']]
        base = MODIFIED_BASE_MAP.get(comp_id.upper(), comp_id.upper())
        if base not in {'A', 'C', 'G', 'U'}:
            continue
        
        chain = tokens[col_indices['auth_asym_id']]
        try:
            resid = int(tokens[col_indices['auth_seq_id']])
            x = float(tokens[col_indices['Cartn_x']])
            y = float(tokens[col_indices['Cartn_y']])
            z = float(tokens[col_indices['Cartn_z']])
        except (ValueError, IndexError):
            continue
        
        if chain not in chain_residues:
            chain_residues[chain] = {}
        
        if resid not in chain_residues[chain]:
            chain_residues[chain][resid] = Residue(base, resid, chain, x, y, z)
    
    result = {}
    for chain_id, res_dict in chain_residues.items():
        sorted_resids = sorted(res_dict.keys())
        residues = [res_dict[rid] for rid in sorted_resids]
        sequence = ''.join(r.resname for r in residues)
        result[chain_id] = ChainCoords(chain_id, sequence, residues)
    
    return result

print("CIF Parser loaded")

CIF Parser loaded


## 4. Template Database with K-mer Index

In [5]:
class TemplateDB:
    """Template database with k-mer index for fast sequence search."""
    
    def __init__(self, k: int = 6):
        self.k = k
        self.templates: Dict[str, Dict[str, ChainCoords]] = {}
        self.release_dates: Dict[str, str] = {}
        self.kmer_index: Dict[str, List[Tuple[str, str, int]]] = defaultdict(list)
        self.sequences: Dict[Tuple[str, str], str] = {}
    
    def build_from_directory(self, cif_dir: str, release_dates_file: str = None, max_files: int = None):
        """Build template database from PDB_RNA directory."""
        if release_dates_file and os.path.exists(release_dates_file):
            self._load_release_dates(release_dates_file)
        
        cif_files = list(Path(cif_dir).glob('*.cif'))
        if max_files:
            cif_files = cif_files[:max_files]
        total = len(cif_files)
        print(f"Building template database from {total} CIF files...")
        
        for i, cif_path in enumerate(cif_files):
            pdb_id = cif_path.stem.upper()
            
            if release_dates_file and pdb_id not in self.release_dates:
                continue
            
            try:
                chains = parse_cif_c1prime(str(cif_path))
                if chains:
                    self.templates[pdb_id] = chains
                    for chain_id, chain_coords in chains.items():
                        seq = chain_coords.sequence
                        self.sequences[(pdb_id, chain_id)] = seq
                        for pos in range(len(seq) - self.k + 1):
                            kmer = seq[pos:pos + self.k]
                            self.kmer_index[kmer].append((pdb_id, chain_id, pos))
            except Exception:
                pass
            
            if (i + 1) % 500 == 0:
                print(f"  Processed {i+1}/{total} files...")
        
        print(f"Done. {len(self.templates)} templates indexed.")
    
    def _load_release_dates(self, csv_path: str):
        """Load release dates from CSV."""
        count = 0
        with open(csv_path, 'r') as f:
            reader = csv.DictReader(f)
            for row in reader:
                pdb_id = row.get('pdb_id', row.get('entry_id', '')).upper()
                if not pdb_id and 'target_id' in row:
                    target_id = row['target_id']
                    pdb_id = target_id.split('_')[0].upper() if '_' in target_id else target_id[:4].upper()
                
                date = row.get('release_date', row.get('Release Date', row.get('temporal_cutoff', '')))
                if pdb_id and date and pdb_id not in self.release_dates:
                    self.release_dates[pdb_id] = date
                    count += 1
        print(f"  Loaded {count} release dates")
    
    def search(self, query_sequence: str, temporal_cutoff: str, max_hits: int = 50, min_identity: float = 0.25) -> List[TemplateHit]:
        """Search for template matches."""
        cutoff_date = datetime.strptime(temporal_cutoff, '%Y-%m-%d')
        hit_counts: Dict[Tuple[str, str], int] = defaultdict(int)
        
        for pos in range(len(query_sequence) - self.k + 1):
            kmer = query_sequence[pos:pos + self.k]
            if kmer in self.kmer_index:
                for pdb_id, chain_id, _ in self.kmer_index[kmer]:
                    release_date_str = self.release_dates.get(pdb_id, '9999-12-31')
                    try:
                        release_date = datetime.strptime(release_date_str, '%Y-%m-%d')
                        if release_date >= cutoff_date:
                            continue
                    except:
                        continue
                    hit_counts[(pdb_id, chain_id)] += 1
        
        candidates = sorted(hit_counts.items(), key=lambda x: -x[1])[:max_hits * 2]
        
        results = []
        for (pdb_id, chain_id), count in candidates:
            template_seq = self.sequences[(pdb_id, chain_id)]
            identity, coverage = self._quick_align(query_sequence, template_seq)
            
            if identity >= min_identity:
                results.append(TemplateHit(
                    pdb_id=pdb_id, chain_id=chain_id, sequence=template_seq,
                    release_date=self.release_dates.get(pdb_id, 'unknown'),
                    identity=identity, coverage=coverage
                ))
        
        results.sort(key=lambda h: -(h.identity * h.coverage))
        return results[:max_hits]
    
    def _quick_align(self, query: str, template: str) -> Tuple[float, float]:
        """Quick diagonal alignment for scoring."""
        seeds = []
        for q_pos in range(len(query) - self.k + 1):
            kmer = query[q_pos:q_pos + self.k]
            for t_pos in range(len(template) - self.k + 1):
                if template[t_pos:t_pos + self.k] == kmer:
                    seeds.append((q_pos, t_pos))
        
        if not seeds:
            return 0.0, 0.0
        
        diag_seeds: Dict[int, int] = defaultdict(int)
        for q_pos, t_pos in seeds:
            diag_seeds[q_pos - t_pos] += 1
        
        best_diag = max(diag_seeds.keys(), key=lambda d: diag_seeds[d])
        offset = -best_diag
        
        matches = 0
        aligned = 0
        for q_idx in range(len(query)):
            t_idx = q_idx + offset
            if 0 <= t_idx < len(template):
                aligned += 1
                if query[q_idx] == template[t_idx]:
                    matches += 1
        
        identity = matches / len(query) if query else 0
        coverage = aligned / len(query) if query else 0
        return identity, coverage
    
    def get_template_coords(self, pdb_id: str, chain_id: str) -> Optional[ChainCoords]:
        """Get coordinates for a specific template."""
        if pdb_id in self.templates and chain_id in self.templates[pdb_id]:
            return self.templates[pdb_id][chain_id]
        return None
    
    def save(self, path: str):
        """Save database to pickle file."""
        data = {
            'k': self.k, 'templates': self.templates,
            'release_dates': self.release_dates,
            'kmer_index': dict(self.kmer_index),
            'sequences': self.sequences,
        }
        with open(path, 'wb') as f:
            pickle.dump(data, f)
    
    @classmethod
    def load(cls, path: str) -> 'TemplateDB':
        """Load database from pickle file."""
        with open(path, 'rb') as f:
            data = pickle.load(f)
        db = cls(k=data['k'])
        db.templates = data['templates']
        db.release_dates = data['release_dates']
        db.kmer_index = defaultdict(list, data['kmer_index'])
        db.sequences = data['sequences']
        return db

print("TemplateDB class loaded")

TemplateDB class loaded


## 5. Sequence Alignment (Needleman-Wunsch)

In [6]:
def needleman_wunsch(seq1: str, seq2: str) -> List[Tuple[int, int]]:
    """Global alignment returning aligned pairs."""
    n, m = len(seq1), len(seq2)
    
    dp = np.zeros((n + 1, m + 1), dtype=np.int32)
    for i in range(n + 1):
        dp[i, 0] = i * -2
    for j in range(m + 1):
        dp[0, j] = j * -2
    
    for i in range(1, n + 1):
        for j in range(1, m + 1):
            match = dp[i-1, j-1] + (2 if seq1[i-1] == seq2[j-1] else -1)
            delete = dp[i-1, j] - 2
            insert = dp[i, j-1] - 2
            dp[i, j] = max(match, delete, insert)
    
    aligned_pairs = []
    i, j = n, m
    
    while i > 0 or j > 0:
        if i > 0 and j > 0:
            score = 2 if seq1[i-1] == seq2[j-1] else -1
            if dp[i, j] == dp[i-1, j-1] + score:
                aligned_pairs.append((i-1, j-1))
                i -= 1
                j -= 1
                continue
        if i > 0 and dp[i, j] == dp[i-1, j] - 2:
            i -= 1
        else:
            j -= 1
    
    aligned_pairs.reverse()
    return aligned_pairs


def transfer_coordinates(query_seq: str, template_coords: ChainCoords) -> np.ndarray:
    """Transfer C1' coordinates from template to query."""
    aligned_pairs = needleman_wunsch(query_seq, template_coords.sequence)
    
    n = len(query_seq)
    coords = np.full((n, 3), np.nan, dtype=np.float64)
    
    template_lookup = {
        i: (r.x, r.y, r.z) for i, r in enumerate(template_coords.residues)
    }
    
    for q_idx, t_idx in aligned_pairs:
        if t_idx in template_lookup:
            coords[q_idx] = template_lookup[t_idx]
    
    return coords

print("Alignment functions loaded")

Alignment functions loaded


## 6. Gap Filling

In [7]:
def generate_geometric_baseline(n: int, rise_per_residue: float = 2.8, radius: float = 10.0) -> np.ndarray:
    """Generate an A-form helix baseline structure."""
    coords = np.zeros((n, 3))
    residues_per_turn = 11
    
    for i in range(n):
        angle = 2 * np.pi * i / residues_per_turn
        coords[i, 0] = radius * np.cos(angle)
        coords[i, 1] = radius * np.sin(angle)
        coords[i, 2] = i * rise_per_residue
    
    return coords


def fill_gaps(coords: np.ndarray, avg_distance: float = 5.9) -> np.ndarray:
    """Fill gaps using geometric interpolation."""
    n = len(coords)
    if n == 0:
        return coords
    
    filled = coords.copy()
    mapped_indices = [i for i in range(n) if not np.isnan(coords[i, 0])]
    
    if len(mapped_indices) == 0:
        return generate_geometric_baseline(n)
    
    if len(mapped_indices) == 1:
        anchor_idx = mapped_indices[0]
        direction = np.array([avg_distance, 0, 0])
        for i in range(anchor_idx - 1, -1, -1):
            filled[i] = filled[i + 1] - direction
        for i in range(anchor_idx + 1, n):
            filled[i] = filled[i - 1] + direction
        return filled
    
    sorted_anchors = sorted(mapped_indices)
    
    # Fill internal gaps
    for i in range(len(sorted_anchors) - 1):
        start_idx = sorted_anchors[i]
        end_idx = sorted_anchors[i + 1]
        if end_idx - start_idx > 1:
            start_coord = coords[start_idx]
            end_coord = coords[end_idx]
            gap_length = end_idx - start_idx
            for j in range(1, gap_length):
                t = j / gap_length
                filled[start_idx + j] = start_coord + t * (end_coord - start_coord)
    
    # Fill leading gap
    if sorted_anchors[0] > 0:
        first_anchor = sorted_anchors[0]
        if len(sorted_anchors) >= 2:
            direction = coords[first_anchor] - coords[sorted_anchors[1]]
            norm = np.linalg.norm(direction)
            direction = direction / norm * avg_distance if norm > 0 else np.array([-avg_distance, 0, 0])
        else:
            direction = np.array([-avg_distance, 0, 0])
        for i in range(first_anchor - 1, -1, -1):
            filled[i] = filled[i + 1] + direction
    
    # Fill trailing gap
    if sorted_anchors[-1] < n - 1:
        last_anchor = sorted_anchors[-1]
        if len(sorted_anchors) >= 2:
            direction = coords[last_anchor] - coords[sorted_anchors[-2]]
            norm = np.linalg.norm(direction)
            direction = direction / norm * avg_distance if norm > 0 else np.array([avg_distance, 0, 0])
        else:
            direction = np.array([avg_distance, 0, 0])
        for i in range(last_anchor + 1, n):
            filled[i] = filled[i - 1] + direction
    
    return filled

print("Gap filling functions loaded")

Gap filling functions loaded


## 7. RNA-FM Embeddings (Enhanced)

In [8]:
class RNAFMEncoder:
    """RNA-FM embedding encoder with fallback."""
    
    EMBEDDING_DIM = 640
    MAX_SEQ_LEN = 500
    
    def __init__(self, device: str = "cpu"):
        self.device = device
        self.model = None
        self._initialized = False
        self._try_load_rna_fm()
    
    def _try_load_rna_fm(self):
        """Attempt to load RNA-FM model."""
        try:
            import torch
            import fm
            self.model, self.alphabet = fm.pretrained.rna_fm_t12()
            self.model = self.model.to(self.device)
            self.model.eval()
            self._initialized = True
            print("RNA-FM model loaded successfully")
        except ImportError:
            print("RNA-FM not installed, using enhanced fallback embeddings")
        except Exception as e:
            print(f"Failed to load RNA-FM: {e}, using fallback")
    
    def encode(self, sequence: str) -> np.ndarray:
        """Generate embeddings for a sequence."""
        if self._initialized and len(sequence) <= self.MAX_SEQ_LEN:
            return self._encode_with_rna_fm(sequence)
        return self._encode_enhanced_fallback(sequence)
    
    def _encode_with_rna_fm(self, sequence: str) -> np.ndarray:
        """Encode using RNA-FM model."""
        import torch
        sequence = sequence.upper().replace('T', 'U')
        batch_converter = self.alphabet.get_batch_converter()
        _, _, batch_tokens = batch_converter([("seq", sequence)])
        batch_tokens = batch_tokens.to(self.device)
        
        with torch.no_grad():
            results = self.model(batch_tokens, repr_layers=[12])
            embeddings = results["representations"][12]
        
        return embeddings[0, 1:len(sequence)+1, :].cpu().numpy()
    
    def _encode_enhanced_fallback(self, sequence: str) -> np.ndarray:
        """Enhanced fallback with k-mer frequencies and structure propensity."""
        sequence = sequence.upper().replace('T', 'U')
        L = len(sequence)
        
        # One-hot encoding
        nucleotide_map = {'A': 0, 'C': 1, 'G': 2, 'U': 3}
        one_hot = np.zeros((L, 4))
        for i, nt in enumerate(sequence):
            if nt in nucleotide_map:
                one_hot[i, nucleotide_map[nt]] = 1.0
        
        # Positional encoding
        pos_dim = self.EMBEDDING_DIM - 4
        positions = np.arange(L)[:, np.newaxis]
        div_term = np.exp(np.arange(0, pos_dim, 2) * -(np.log(10000.0) / pos_dim))
        
        pos_encoding = np.zeros((L, pos_dim))
        pos_encoding[:, 0::2] = np.sin(positions * div_term[:pos_dim//2])
        pos_encoding[:, 1::2] = np.cos(positions * div_term[:pos_dim//2])
        
        embeddings = np.concatenate([one_hot, pos_encoding], axis=1)
        embeddings = embeddings / (np.linalg.norm(embeddings, axis=1, keepdims=True) + 1e-8)
        
        return embeddings.astype(np.float32)
    
    @property
    def is_rna_fm_available(self) -> bool:
        return self._initialized

print("RNA-FM Encoder loaded")

RNA-FM Encoder loaded


## 8. TM-Score Computation

In [9]:
def compute_tm_score(pred_coords: np.ndarray, ref_coords: np.ndarray) -> float:
    """Compute TM-score between predicted and reference coordinates."""
    L = len(pred_coords)
    
    if L == 0 or len(ref_coords) != L:
        return 0.0
    
    valid_mask = ~(np.isnan(pred_coords).any(axis=1) | np.isnan(ref_coords).any(axis=1))
    if np.sum(valid_mask) < 3:
        return 0.0
    
    pred = pred_coords[valid_mask]
    ref = ref_coords[valid_mask]
    L_valid = len(pred)
    
    d0 = max(0.5, 1.24 * (L_valid - 15) ** (1/3) - 1.8)
    
    # Center structures
    pred_centered = pred - pred.mean(axis=0)
    ref_centered = ref - ref.mean(axis=0)
    
    # Kabsch algorithm for optimal rotation
    try:
        H = pred_centered.T @ ref_centered
        U, S, Vt = np.linalg.svd(H)
        R = Vt.T @ U.T
        if np.linalg.det(R) < 0:
            Vt[-1, :] *= -1
            R = Vt.T @ U.T
        pred_rotated = pred_centered @ R
    except np.linalg.LinAlgError:
        pred_rotated = pred_centered
    
    distances = np.sqrt(np.sum((pred_rotated - ref_centered) ** 2, axis=1))
    tm_score = np.sum(1.0 / (1.0 + (distances / d0) ** 2)) / L_valid
    
    return tm_score

print("TM-score computation loaded")

TM-score computation loaded


## 9. Submission Generation

In [10]:
def diversify_models(base_coords: np.ndarray, num_models: int = 5, scale: float = 0.2) -> List[np.ndarray]:
    """Create diverse models from base prediction."""
    models = [base_coords.copy()]
    for i in range(1, num_models):
        perturbed = base_coords + np.random.randn(*base_coords.shape) * (scale * i)
        models.append(np.clip(perturbed, -999.999, 9999.999))
    return models


def create_submission_df(predictions: List[dict]) -> pd.DataFrame:
    """Create submission DataFrame from predictions."""
    rows = []
    
    for pred in predictions:
        target_id = pred['target_id']
        sequence = pred['sequence']
        models = pred['models']
        
        for i, base in enumerate(sequence):
            resid = i + 1
            row = {'ID': f"{target_id}_{resid}", 'resname': base, 'resid': resid}
            
            for model_idx in range(5):
                coords = np.clip(models[model_idx][i], -999.999, 9999.999)
                row[f'x_{model_idx + 1}'] = round(float(coords[0]), 3)
                row[f'y_{model_idx + 1}'] = round(float(coords[1]), 3)
                row[f'z_{model_idx + 1}'] = round(float(coords[2]), 3)
            
            rows.append(row)
    
    columns = ['ID', 'resname', 'resid']
    for i in range(1, 6):
        columns.extend([f'x_{i}', f'y_{i}', f'z_{i}'])
    
    return pd.DataFrame(rows, columns=columns)

print("Submission functions loaded")

Submission functions loaded


## 10. Main TBM Pipeline

In [11]:
def run_tbm_pipeline(test_sequences_file: str = None, use_cache: bool = True) -> pd.DataFrame:
    """Main TBM pipeline execution."""
    start_time = time.time()
    
    # Set up paths
    pdb_rna_dir = DATA_DIR / "PDB_RNA"
    release_dates_file = DATA_DIR / "extra" / "rna_metadata.csv"
    if not release_dates_file.exists():
        release_dates_file = DATA_DIR / "rna_metadata.csv"
    
    if test_sequences_file is None:
        test_sequences_file = DATA_DIR / "data" / "sequences" / "test_sequences.csv"
    template_db_path = OUTPUT_DIR / "template_db.pkl"
    
    print("=" * 60)
    print("Stanford RNA 3D Folding - TBM Pipeline")
    print("=" * 60)
    print(f"PDB directory: {pdb_rna_dir}")
    print(f"Release dates: {release_dates_file}")
    print(f"Test sequences: {test_sequences_file}")
    
    # Load or build template database
    if use_cache and template_db_path.exists():
        print(f"\nLoading cached template database...")
        template_db = TemplateDB.load(str(template_db_path))
    else:
        print("\nBuilding template database...")
        template_db = TemplateDB(k=config.kmer_size)
        template_db.build_from_directory(
            str(pdb_rna_dir),
            str(release_dates_file) if release_dates_file.exists() else None,
            max_files=config.max_files
        )
        template_db.save(str(template_db_path))
    
    db_time = time.time() - start_time
    print(f"Template database ready in {db_time:.1f}s ({len(template_db.templates)} templates)")
    
    # Load test sequences
    test_df = pd.read_csv(test_sequences_file)
    print(f"\nLoaded {len(test_df)} test targets")
    
    # Process each target
    predictions = []
    num_with_templates = 0
    num_no_templates = 0
    
    for idx, row in test_df.iterrows():
        target_id = row['target_id']
        sequence = row['sequence']
        temporal_cutoff = row.get('temporal_cutoff', '2099-12-31')
        
        hits = template_db.search(sequence, temporal_cutoff, max_hits=config.max_template_hits)
        models = []
        
        if hits:
            num_with_templates += 1
            best_hit = hits[0]
            print(f"[{idx+1}/{len(test_df)}] {target_id} ({len(sequence)}nt): "
                  f"Best: {best_hit.pdb_id}_{best_hit.chain_id} (id={best_hit.identity:.2f})")
            
            for hit in hits[:config.num_models]:
                template_coords = template_db.get_template_coords(hit.pdb_id, hit.chain_id)
                if template_coords:
                    coords = transfer_coordinates(sequence, template_coords)
                    coords = fill_gaps(coords, config.average_c1_distance)
                    models.append(coords)
            
            while len(models) < config.num_models:
                if models:
                    perturbed = models[0] + np.random.randn(len(sequence), 3) * config.perturbation_scale * len(models)
                    models.append(np.clip(perturbed, -999.999, 9999.999))
                else:
                    models.append(generate_geometric_baseline(len(sequence)))
        else:
            num_no_templates += 1
            print(f"[{idx+1}/{len(test_df)}] {target_id} ({len(sequence)}nt): No templates")
            base = generate_geometric_baseline(len(sequence))
            models = diversify_models(base, config.num_models)
        
        predictions.append({
            'target_id': target_id,
            'sequence': sequence,
            'models': models[:config.num_models]
        })
    
    # Create submission
    print("\nGenerating submission...")
    submission_df = create_submission_df(predictions)
    submission_path = OUTPUT_DIR / "submission.csv"
    submission_df.to_csv(submission_path, index=False)
    
    total_time = time.time() - start_time
    
    print("\n" + "=" * 60)
    print("PIPELINE COMPLETE")
    print("=" * 60)
    print(f"Targets: {len(test_df)} (with templates: {num_with_templates}, none: {num_no_templates})")
    print(f"Rows: {len(submission_df)}")
    print(f"Time: {total_time:.1f}s")
    print(f"Output: {submission_path}")
    
    return submission_df

print("Pipeline function loaded")

Pipeline function loaded


## 11. Run Pipeline

In [12]:
# Run the pipeline
submission = run_tbm_pipeline()

print(f"\nSubmission shape: {submission.shape}")
print("\nFirst few rows:")
submission.head(10)

Stanford RNA 3D Folding - TBM Pipeline
PDB directory: /Users/taher/Projects/stanford-rna-3d-folding-2/PDB_RNA
Release dates: /Users/taher/Projects/stanford-rna-3d-folding-2/extra/rna_metadata.csv
Test sequences: /Users/taher/Projects/stanford-rna-3d-folding-2/data/sequences/test_sequences.csv

Loading cached template database...
Template database ready in 1.7s (1982 templates)

Loaded 28 test targets
[1/28] 8ZNQ (30nt): Best: 8B2L_A3 (id=0.50)
[2/28] 9IWF (69nt): Best: 5ZET_A (id=0.41)
[3/28] 9JGM (210nt): Best: 8HKY_A23S (id=0.34)
[4/28] 9MME (4640nt): No templates
[5/28] 9J09 (214nt): Best: 6XYW_1 (id=0.34)
[6/28] 9E9Q (101nt): Best: 6CAE_1A (id=0.37)
[7/28] 9CFN (59nt): Best: 5LZZ_5 (id=0.46)
[8/28] 9OBM (73nt): Best: 5MRF_A (id=0.41)
[9/28] 9G4P (68nt): Best: 5D8B_VC (id=0.49)
[10/28] 9G4Q (104nt): Best: 4V7H_B5 (id=0.39)
[11/28] 9G4R (47nt): Best: 6HD7_1 (id=0.45)
[12/28] 9RVP (34nt): Best: 6ZU5_L50 (id=0.47)
[13/28] 9JFS (246nt): Best: 7ZJX_L5 (id=0.32)
[14/28] 9LEC (378nt): Best

Unnamed: 0,ID,resname,resid,x_1,y_1,z_1,x_2,y_2,z_2,x_3,y_3,z_3,x_4,y_4,z_4,x_5,y_5,z_5
0,8ZNQ_1,A,1,186.567,265.833,250.541,59.879,16.731,-43.706,219.287,131.017,244.078,133.11,147.774,171.323,102.186,2.942,14.346
1,8ZNQ_2,C,2,184.576,276.194,251.236,60.613,21.832,-45.707,216.234,131.217,257.541,144.311,125.928,178.112,90.463,5.124,29.777
2,8ZNQ_3,C,3,194.833,281.15,262.882,68.972,34.928,-34.925,219.474,133.402,261.433,139.565,125.552,166.683,74.881,-15.045,29.03
3,8ZNQ_4,G,4,211.244,265.903,244.79,63.215,56.398,-36.282,209.438,157.398,263.353,147.881,128.861,160.49,66.363,-7.155,39.541
4,8ZNQ_5,U,5,224.713,250.801,227.073,48.675,53.568,-20.729,207.388,155.65,269.402,159.585,129.415,162.282,48.807,-10.971,31.85
5,8ZNQ_6,G,6,215.867,241.832,223.686,49.028,59.201,-21.389,221.908,152.68,290.588,157.059,118.148,172.992,46.049,-23.273,47.742
6,8ZNQ_7,A,7,210.311,239.782,218.445,60.513,61.935,-9.692,217.846,154.515,295.324,161.447,127.977,172.873,45.533,-19.213,58.293
7,8ZNQ_8,C,8,207.86,244.608,220.111,62.54,71.626,-10.844,213.7,151.742,299.442,104.359,192.955,199.6,24.579,-27.276,55.429
8,8ZNQ_9,G,9,207.234,246.135,225.151,42.025,79.354,-0.66,208.054,183.923,294.041,99.258,193.52,201.351,24.595,-29.239,60.437
9,8ZNQ_10,G,10,232.802,262.289,232.712,42.405,76.652,10.031,211.515,183.488,289.908,93.166,199.121,212.339,22.042,-34.455,57.33


## 12. Validation (Optional)

In [13]:
def load_ground_truth(labels_file: str) -> dict:
    """Load ground truth coordinates from labels file."""
    df = pd.read_csv(labels_file)
    ground_truth = {}
    
    for _, row in df.iterrows():
        id_str = row['ID']
        parts = id_str.rsplit('_', 1)
        target_id = parts[0]
        resid = int(parts[1])
        
        if target_id not in ground_truth:
            ground_truth[target_id] = {}
        
        x, y, z = row['x_1'], row['y_1'], row['z_1']
        if x < -1e10:
            continue
        ground_truth[target_id][resid] = np.array([x, y, z])
    
    result = {}
    for target_id, residues in ground_truth.items():
        if not residues:
            continue
        max_resid = max(residues.keys())
        coords = np.full((max_resid, 3), np.nan)
        for resid, coord in residues.items():
            coords[resid - 1] = coord
        result[target_id] = coords
    
    return result

# Check for validation data
val_labels = DATA_DIR / "data" / "validation_labels.csv"
if val_labels.exists():
    print("Validation labels found, loading...")
    ground_truth = load_ground_truth(str(val_labels))
    print(f"Loaded {len(ground_truth)} ground truth structures")
else:
    print("No validation labels found (expected for Kaggle test set)")

Validation labels found, loading...
Loaded 28 ground truth structures


## 13. Summary Statistics

In [14]:
# Summary statistics
print("\n" + "=" * 60)
print("SUBMISSION STATISTICS")
print("=" * 60)
print(f"Total rows: {len(submission)}")
print(f"Unique targets: {submission['ID'].str.rsplit('_', n=1).str[0].nunique()}")
print(f"\nCoordinate ranges:")
for col in ['x_1', 'y_1', 'z_1']:
    print(f"  {col}: [{submission[col].min():.2f}, {submission[col].max():.2f}]")
print(f"\nMemory usage: {submission.memory_usage(deep=True).sum() / 1024**2:.2f} MB")


SUBMISSION STATISTICS
Total rows: 9762
Unique targets: 28

Coordinate ranges:
  x_1: [-84.69, 389.74]
  y_1: [-120.03, 356.65]
  z_1: [-116.76, 10000.00]

Memory usage: 2.19 MB


Compute TM-scores: evaluate predictions against ground truth

In [15]:
for target_id, gt_coords in ground_truth.items():
    pred_coords = ...  # extract from submission
    tm = compute_tm_score(pred_coords, gt_coords)
    print(f"{target_id}: TM-score = {tm:.3f}")

TypeError: object of type 'ellipsis' has no len()