# RNA Structure Visualization with TM-score Analysis

This notebook provides tools for visualizing RNA 3D structures predicted by our models and analyzing their quality using TM-score metrics.

In [None]:
import os
import sys
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
try:
    import py3Dmol
except ImportError:
    !pip install py3Dmol
    import py3Dmol

# Add the src directory to the path to import our modules
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

# Import our modules
from src.stanford_rna_folding.models.rna_folding_model import RNAFoldingModel
from src.stanford_rna_folding.data.data_processing import StanfordRNADataset, rna_collate_fn
from src.stanford_rna_folding.evaluation.metrics import batch_rmsd, batch_tm_score, tm_score, structure_core_identification
from src.stanford_rna_folding.evaluation.alignment import save_pdb_from_coords, tm_score_with_us_align

In [None]:
# Set up device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## Load Model and Data

In [None]:
def load_config(config_path):
    """Load configuration from a YAML file."""
    import yaml
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    return config

# Choose a config file
config_path = "../configs/biophysics_config.yaml"
config = load_config(config_path)
print(f"Loaded configuration from {config_path}")

In [None]:
def load_model(checkpoint_path, config):
    """Load a trained model from a checkpoint."""
    # Initialize the model
    model_config = config['model']['params']
    model = RNAFoldingModel(**model_config)
    
    # Load the checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    model.eval()
    
    print(f"Loaded model from {checkpoint_path}")
    print(f"Best validation RMSD: {checkpoint.get('val_rmsd', 'N/A')}")
    print(f"Best validation TM-score: {checkpoint.get('val_tm_score', 'N/A')}")
    
    return model

# Choose a checkpoint file
checkpoint_path = "../models/best_model_tm_score.pt"
try:
    model = load_model(checkpoint_path, config)
except FileNotFoundError:
    print(f"Checkpoint file {checkpoint_path} not found. You can specify the correct path above.")

In [None]:
def load_dataset(data_dir, split):
    """Load a dataset for the specified split."""
    from src.stanford_rna_folding.data.transforms import RNADataTransform
    
    # Create transform
    transform = RNADataTransform(normalize_coords=True)
    
    # Create dataset
    dataset = StanfordRNADataset(data_dir=data_dir, split=split, transform=transform)
    
    print(f"Loaded {split} dataset with {len(dataset)} examples")
    return dataset

# Load the validation dataset
data_dir = "../datasets/stanford-rna-3d-folding"
try:
    val_dataset = load_dataset(data_dir, "validation")
except FileNotFoundError:
    print(f"Data directory {data_dir} not found. You can specify the correct path above.")

## Predict Structures

In [None]:
def predict_structure(model, sequence):
    """Predict the 3D structure for a RNA sequence."""
    model.eval()
    with torch.no_grad():
        # Encode sequence
        if isinstance(sequence, str):
            # Map nucleotides to indices (A=0, U=1, G=2, C=3, N=4)
            nucleotide_map = {'A': 0, 'U': 1, 'G': 2, 'C': 3, 'N': 4}
            seq_tensor = torch.tensor([nucleotide_map.get(n, 4) for n in sequence], dtype=torch.long)
            seq_tensor = seq_tensor.unsqueeze(0)  # Add batch dimension
        else:
            # Assume it's already a tensor
            seq_tensor = sequence.unsqueeze(0) if sequence.dim() == 1 else sequence
            
        seq_tensor = seq_tensor.to(device)
        seq_length = torch.tensor([seq_tensor.size(1)], dtype=torch.long, device=device)
        
        # Forward pass
        pred_coords = model(seq_tensor, seq_length)
        
    return pred_coords[0]  # Remove batch dimension

In [None]:
# Choose a random example from the validation set
try:
    example_idx = np.random.randint(0, len(val_dataset))
    example = val_dataset[example_idx]
    
    # Get the sequence and true coordinates
    sequence = example['sequence']
    true_coords = example['coordinates']
    target_id = example['target_id']
    
    # Print some information about the example
    print(f"Example {example_idx}: {target_id}")
    print(f"Sequence length: {len(sequence)}")
    print(f"Coordinates shape: {true_coords.shape}")
    
    # Predict the structure
    pred_coords = predict_structure(model, sequence)
    print(f"Predicted coordinates shape: {pred_coords.shape}")
except NameError:
    print("Dataset or model not loaded. Please run the previous cells first.")

## Calculate TM-scores

In [None]:
def calculate_metrics(pred_coords, true_coords):
    """Calculate RMSD and TM-score for a predicted structure."""
    # Convert to CPU tensors if needed
    if pred_coords.device.type == 'cuda':
        pred_coords = pred_coords.cpu()
    if true_coords.device.type == 'cuda':
        true_coords = true_coords.cpu()
    
    # Calculate RMSD
    from src.stanford_rna_folding.evaluation.metrics import rmsd
    rmsd_value = rmsd(pred_coords, true_coords, align=True).item()
    
    # Calculate TM-score
    from src.stanford_rna_folding.evaluation.metrics import tm_score
    tm_value = tm_score(pred_coords, true_coords, align=True).item()
    
    # Identify core structure regions
    # Flatten coordinates for structure core identification
    pred_flat = pred_coords.reshape(-1, 3)
    true_flat = true_coords.reshape(-1, 3)
    
    # Calculate d0 parameter for TM-score
    Lref = true_flat.shape[0]
    d0 = 1.24 * (Lref - 15) ** (1/3) - 1.8
    d0 = max(d0, 0.5)
    
    # Identify core regions based on TM-score contributions
    core_mask = structure_core_identification(pred_flat, true_flat, d0)
    
    return {
        'rmsd': rmsd_value,
        'tm_score': tm_value,
        'core_mask': core_mask
    }

In [None]:
try:
    # Calculate metrics for the predicted structure
    metrics = calculate_metrics(pred_coords, true_coords)
    
    print(f"RMSD: {metrics['rmsd']:.4f}")
    print(f"TM-score: {metrics['tm_score']:.4f}")
    print(f"Number of core residues: {metrics['core_mask'].sum().item()}/{len(metrics['core_mask'])}")
    print(f"Core percentage: {metrics['core_mask'].sum().item() / len(metrics['core_mask']) * 100:.1f}%")
except NameError:
    print("Predicted coordinates not available. Please run the previous cells first.")

## Visualize Structures

In [None]:
def save_structures_as_pdb(pred_coords, true_coords, prefix="structure"):
    """Save predicted and true structures as PDB files."""
    # Create an output directory if it doesn't exist
    output_dir = "output_structures"
    os.makedirs(output_dir, exist_ok=True)
    
    # Save predicted structure
    pred_pdb_path = os.path.join(output_dir, f"{prefix}_pred.pdb")
    save_pdb_from_coords(pred_coords, pred_pdb_path)
    
    # Save true structure
    true_pdb_path = os.path.join(output_dir, f"{prefix}_true.pdb")
    save_pdb_from_coords(true_coords, true_pdb_path)
    
    return pred_pdb_path, true_pdb_path

In [None]:
def visualize_structures(pred_coords, true_coords):
    """Visualize predicted and true structures with py3Dmol."""
    # Convert to numpy arrays if needed
    if isinstance(pred_coords, torch.Tensor):
        pred_coords = pred_coords.detach().cpu().numpy()
    if isinstance(true_coords, torch.Tensor):
        true_coords = true_coords.detach().cpu().numpy()
    
    # Save structures to PDB files
    pred_path, true_path = save_structures_as_pdb(pred_coords, true_coords)
    
    # Read PDB files
    with open(pred_path, 'r') as f:
        pred_pdb = f.read()
    with open(true_path, 'r') as f:
        true_pdb = f.read()
    
    # Create a py3Dmol viewer
    view = py3Dmol.view(width=800, height=400)
    
    # Add the structures to separate models
    view.addModel(pred_pdb, 'pdb', {'model': 'pred'})
    view.addModel(true_pdb, 'pdb', {'model': 'true'})
    
    # Style the prediction (blue)
    view.setStyle({'model': 'pred'}, {'cartoon': {'color': 'blue', 'thickness': 0.8}})
    
    # Style the true structure (green)
    view.setStyle({'model': 'true'}, {'cartoon': {'color': 'green', 'thickness': 0.8}})
    
    # Set up the viewer
    view.zoomTo()
    
    return view

In [None]:
try:
    # Visualize the structures
    view = visualize_structures(pred_coords, true_coords)
    view
except NameError:
    print("Predicted coordinates not available. Please run the previous cells first.")

## Conclusion and Next Steps

In this notebook, we've demonstrated how to visualize RNA structures and analyze their quality using TM-score. The TM-score provides a more robust metric than RMSD for assessing the overall structural similarity between predicted and reference RNA structures.

Next steps could include:
1. Analyzing the performance of our model on different RNA families
2. Enhancing the visualization to highlight regions with high/low TM-score contributions
3. Using these insights to improve the model's performance on challenging structures