In [3]:
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.metrics import mean_absolute_error, r2_score
import numpy as np
import math
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR
import warnings
warnings.filterwarnings('ignore')

# ----------------- Load Dataset -----------------
df = pd.read_csv("new_codon_torsions_freq.csv")
df = df.dropna(subset=["Phi", "Psi", "Omega"]).reset_index(drop=True)

print(f"Dataset shape: {df.shape}")
print(f"Unique amino acids: {df['AA_from_cDNA'].nunique()}")
print(f"Unique codons: {df['Codon'].nunique()}")

# ----------------- Enhanced Data Preprocessing -----------------
# Handle Relative Frequency with better normalization
if "Relative_Frequency" not in df.columns:
    df["Relative_Frequency"] = 0.0

df["Relative_Frequency"] = df["Relative_Frequency"].fillna(0.0)
df["Relative_Frequency"] = df["Relative_Frequency"].clip(0, 1).astype(np.float32)

# Convert angles to radians and use circular encoding
def circular_encode(angles):
    """Convert angles to sin/cos representation for circular data"""
    angles_rad = np.radians(angles)
    return np.cos(angles_rad), np.sin(angles_rad)

# Circular encoding for torsion angles
phi_cos, phi_sin = circular_encode(df["Phi"])
psi_cos, psi_sin = circular_encode(df["Psi"])
omega_cos, omega_sin = circular_encode(df["Omega"])

df["Phi_cos"], df["Phi_sin"] = phi_cos, phi_sin
df["Psi_cos"], df["Psi_sin"] = psi_cos, psi_sin
df["Omega_cos"], df["Omega_sin"] = omega_cos, omega_sin

# Normalize original angles to [-1, 1] for auxiliary loss
df[["Phi_norm", "Psi_norm", "Omega_norm"]] = df[["Phi", "Psi", "Omega"]] / 180.0

# Label encoding with proper handling
AA_encoder = LabelEncoder()
Codon_encoder = LabelEncoder()

df["AA_encoded"] = AA_encoder.fit_transform(df["AA_from_cDNA"])
df["Codon_encoded"] = Codon_encoder.fit_transform(df["Codon"])

# Add codon-AA consistency feature
df["Codon_AA_consistent"] = (df.apply(lambda row: 
    row["AA_from_cDNA"] in ["A", "R", "N", "D", "C", "Q", "E", "G", "H", "I", "L", "K", "M", "F", "P", "S", "T", "W", "Y", "V"], axis=1)).astype(float)

# ----------------- Enhanced Dataset Class -----------------
class EnhancedProteinDataset(Dataset):
    def __init__(self, df, mode="both"):
        self.mode = mode
        self.aa = torch.tensor(df["AA_encoded"].values, dtype=torch.long)
        self.codon = torch.tensor(df["Codon_encoded"].values, dtype=torch.long)
        self.rfreq = torch.tensor(df["Relative_Frequency"].values, dtype=torch.float32).unsqueeze(1)
        self.consistency = torch.tensor(df["Codon_AA_consistent"].values, dtype=torch.float32).unsqueeze(1)
        
        # Circular targets
        self.y_circular = torch.tensor(df[["Phi_cos", "Phi_sin", "Psi_cos", "Psi_sin", "Omega_cos", "Omega_sin"]].values, dtype=torch.float32)
        # Original angles (normalized) for auxiliary loss
        self.y_angles = torch.tensor(df[["Phi_norm", "Psi_norm", "Omega_norm"]].values, dtype=torch.float32)

    def __len__(self):
        return len(self.y_circular)

    def __getitem__(self, idx):
        if self.mode == "aa":
            # AA with padding features to match dimensionality
            return (self.aa[idx], self.rfreq[idx], self.consistency[idx]), (self.y_circular[idx], self.y_angles[idx])
        elif self.mode == "codon":
            # Codon with relative frequency and consistency
            return (self.codon[idx], self.rfreq[idx], self.consistency[idx]), (self.y_circular[idx], self.y_angles[idx])
        else:  # both
            return (self.aa[idx], self.codon[idx], self.rfreq[idx], self.consistency[idx]), (self.y_circular[idx], self.y_angles[idx])

# ----------------- Positional Encoding for better embeddings -----------------
class PositionalEmbedding(nn.Module):
    def __init__(self, vocab_size, embed_dim, max_norm=1.0):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, max_norm=max_norm)
        self.layer_norm = nn.LayerNorm(embed_dim)
        
    def forward(self, x):
        emb = self.embedding(x)
        return self.layer_norm(emb)

# ----------------- Enhanced Model with Attention -----------------
class EnhancedAnglePredictor(nn.Module):
    def __init__(self, n_aa, n_codon, embed_dim=64, hidden_dim=256, n_heads=8, n_layers=3, dropout=0.1, mode="both"):
        super().__init__()
        self.mode = mode
        self.embed_dim = embed_dim
        
        # Always create both embeddings to ensure same parameter count
        self.aa_embed = PositionalEmbedding(n_aa, embed_dim)
        self.codon_embed = PositionalEmbedding(n_codon, embed_dim)
        
        if mode == "aa":
            # AA embedding + relative frequency + consistency
            self.feature_fusion = nn.Sequential(
                nn.Linear(embed_dim + 2, embed_dim),  # +2 for rfreq and consistency
                nn.LayerNorm(embed_dim),
                nn.ReLU(),
                nn.Dropout(dropout)
            )
            in_dim = embed_dim
            
        elif mode == "codon":
            # Codon embedding + relative frequency + consistency
            self.feature_fusion = nn.Sequential(
                nn.Linear(embed_dim + 2, embed_dim),  # +2 for rfreq and consistency
                nn.LayerNorm(embed_dim),
                nn.ReLU(),
                nn.Dropout(dropout)
            )
            in_dim = embed_dim
            
        else:  # both
            # Cross-attention between AA and codon embeddings
            self.cross_attention = nn.MultiheadAttention(embed_dim, n_heads, dropout=dropout, batch_first=True)
            
            # Feature fusion for all inputs
            self.feature_fusion = nn.Sequential(
                nn.Linear(embed_dim * 2 + 2, embed_dim),  # +2 for rfreq and consistency
                nn.LayerNorm(embed_dim),
                nn.ReLU(),
                nn.Dropout(dropout)
            )
            in_dim = embed_dim

        # Enhanced MLP with residual connections (same for all modes)
        self.layers = nn.ModuleList()
        current_dim = in_dim
        
        for i in range(n_layers):
            self.layers.append(nn.Sequential(
                nn.Linear(current_dim, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(hidden_dim, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.ReLU(),
                nn.Dropout(dropout)
            ))
            
            if i == 0:
                self.input_projection = nn.Linear(current_dim, hidden_dim) if current_dim != hidden_dim else nn.Identity()
            current_dim = hidden_dim

        # Output heads (same for all modes)
        self.circular_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, 6)  # 2 per angle (cos, sin)
        )
        
        self.angle_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, 3)  # Direct angle prediction
        )

    def forward(self, x):
        if self.mode == "aa":
            aa, rfreq, consistency = x
            aa_emb = self.aa_embed(aa)
            # Combine AA embedding with additional features
            combined = torch.cat([aa_emb, rfreq, consistency], dim=1)
            emb = self.feature_fusion(combined)
            
        elif self.mode == "codon":
            codon, rfreq, consistency = x
            codon_emb = self.codon_embed(codon)
            # Combine codon embedding with additional features
            combined = torch.cat([codon_emb, rfreq, consistency], dim=1)
            emb = self.feature_fusion(combined)
            
        else:  # both
            aa, codon, rfreq, consistency = x
            aa_emb = self.aa_embed(aa).unsqueeze(1)  # Add sequence dimension
            codon_emb = self.codon_embed(codon).unsqueeze(1)
            
            # Cross-attention between AA and codon
            attended_aa, _ = self.cross_attention(aa_emb, codon_emb, codon_emb)
            attended_aa = attended_aa.squeeze(1)
            codon_emb = codon_emb.squeeze(1)
            
            # Combine all features
            combined = torch.cat([attended_aa, codon_emb, rfreq, consistency], dim=1)
            emb = self.feature_fusion(combined)

        # Pass through enhanced MLP with residual connections (same for all modes)
        x = emb
        for i, layer in enumerate(self.layers):
            if i == 0:
                residual = self.input_projection(x)
            else:
                residual = x
            x = layer(x) + residual

        # Dual outputs (same for all modes)
        circular_out = self.circular_head(x)
        angle_out = self.angle_head(x)
        
        # Normalize circular outputs to unit vectors
        circular_out = circular_out.view(-1, 3, 2)  # Reshape to (batch, 3_angles, 2_components)
        circular_out = F.normalize(circular_out, p=2, dim=2)
        circular_out = circular_out.view(-1, 6)  # Flatten back
        
        return circular_out, angle_out

# ----------------- Custom Loss Function -----------------
class CircularAngleLoss(nn.Module):
    def __init__(self, circular_weight=1.0, angle_weight=0.5):
        super().__init__()
        self.circular_weight = circular_weight
        self.angle_weight = angle_weight
        self.mse_loss = nn.MSELoss()
        
    def circular_loss(self, pred_circular, target_circular):
        """Loss for circular representations"""
        return self.mse_loss(pred_circular, target_circular)
    
    def angle_loss(self, pred_angles, target_angles):
        """Auxiliary loss for direct angle prediction"""
        return self.mse_loss(pred_angles, target_angles)
    
    def forward(self, pred_circular, pred_angles, target_circular, target_angles):
        circ_loss = self.circular_loss(pred_circular, target_circular)
        ang_loss = self.angle_loss(pred_angles, target_angles)
        return self.circular_weight * circ_loss + self.angle_weight * ang_loss

# ----------------- Enhanced Training with Better Techniques -----------------
def train_enhanced_model(df, mode="both", epochs=100, batch_size=64, patience=15):
    dataset = EnhancedProteinDataset(df, mode=mode)
    
    # Stratified split to ensure balanced distribution
    train_idx, val_idx = train_test_split(
        np.arange(len(dataset)), 
        test_size=0.2, 
        random_state=42,
        stratify=df["AA_encoded"]  # Stratify by amino acid
    )
    
    train_ds = torch.utils.data.Subset(dataset, train_idx)
    val_ds = torch.utils.data.Subset(dataset, val_idx)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, num_workers=2, pin_memory=True)

    model = EnhancedAnglePredictor(
        n_aa=len(AA_encoder.classes_),
        n_codon=len(Codon_encoder.classes_),
        mode=mode
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    print(f"Using device: {device}")
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

    criterion = CircularAngleLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
    
    # Learning rate scheduling
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=8, verbose=True)
    
    best_val_loss = float('inf')
    patience_counter = 0
    
    train_losses, val_losses = [], []

    for epoch in range(epochs):
        # Training
        model.train()
        total_loss = 0
        num_batches = 0
        
        for batch in train_loader:
            optimizer.zero_grad()

            # Consistent input handling for all modes
            if mode == "both":
                (aa, codon, rfreq, consistency), (y_circular, y_angles) = batch
                inputs = (aa.to(device), codon.to(device), rfreq.to(device), consistency.to(device))
            else:  # aa or codon mode
                (x, rfreq, consistency), (y_circular, y_angles) = batch
                inputs = (x.to(device), rfreq.to(device), consistency.to(device))
            
            y_circular, y_angles = y_circular.to(device), y_angles.to(device)
            
            pred_circular, pred_angles = model(inputs)
            loss = criterion(pred_circular, pred_angles, y_circular, y_angles)
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            total_loss += loss.item()
            num_batches += 1

        avg_train_loss = total_loss / num_batches
        train_losses.append(avg_train_loss)

        # Validation
        model.eval()
        val_loss = 0
        num_val_batches = 0
        
        with torch.no_grad():
            for batch in val_loader:
                # Consistent input handling for all modes
                if mode == "both":
                    (aa, codon, rfreq, consistency), (y_circular, y_angles) = batch
                    inputs = (aa.to(device), codon.to(device), rfreq.to(device), consistency.to(device))
                else:  # aa or codon mode
                    (x, rfreq, consistency), (y_circular, y_angles) = batch
                    inputs = (x.to(device), rfreq.to(device), consistency.to(device))
                
                y_circular, y_angles = y_circular.to(device), y_angles.to(device)
                
                pred_circular, pred_angles = model(inputs)
                loss = criterion(pred_circular, pred_angles, y_circular, y_angles)
                
                val_loss += loss.item()
                num_val_batches += 1

        avg_val_loss = val_loss / num_val_batches
        val_losses.append(avg_val_loss)
        
        scheduler.step(avg_val_loss)
        
        # # Early stopping
        # if avg_val_loss < best_val_loss:
        #     best_val_loss = avg_val_loss
        #     patience_counter = 0
        #     # Save best model
        #     torch.save(model.state_dict(), f'best_model_{mode}.pth')
        # else:
        #     patience_counter += 1
            
        # if patience_counter >= patience:
        #     print(f"Early stopping at epoch {epoch+1}")
        #     break

        if epoch % 10 == 0 or epoch == epochs - 1:
            print(f"Epoch {epoch+1}: Train Loss = {avg_train_loss:.4f}, Val Loss = {avg_val_loss:.4f}")

    # Load best model
    model.load_state_dict(torch.load(f'best_model_{mode}.pth'))
    return model, train_losses, val_losses

# ----------------- Evaluation Function -----------------
def evaluate_model(model, df, mode="both"):
    dataset = EnhancedProteinDataset(df, mode=mode)
    loader = DataLoader(dataset, batch_size=64, shuffle=False)
    
    device = next(model.parameters()).device
    model.eval()
    
    all_pred_angles = []
    all_true_angles = []
    
    with torch.no_grad():
        for batch in loader:
            # Consistent input handling for all modes
            if mode == "both":
                (aa, codon, rfreq, consistency), (y_circular, y_angles) = batch
                inputs = (aa.to(device), codon.to(device), rfreq.to(device), consistency.to(device))
            else:  # aa or codon mode
                (x, rfreq, consistency), (y_circular, y_angles) = batch
                inputs = (x.to(device), rfreq.to(device), consistency.to(device))
            
            pred_circular, pred_angles = model(inputs)
            
            # Convert circular predictions back to angles
            pred_circular = pred_circular.view(-1, 3, 2)
            pred_angles_from_circular = torch.atan2(pred_circular[:, :, 1], pred_circular[:, :, 0]) * 180 / math.pi
            
            all_pred_angles.append(pred_angles_from_circular.cpu())
            all_true_angles.append((y_angles * 180).cpu())  # Convert back to degrees
    
    pred_angles = torch.cat(all_pred_angles, dim=0).numpy()
    true_angles = torch.cat(all_true_angles, dim=0).numpy()
    
    # Calculate metrics
    mae_phi = mean_absolute_error(true_angles[:, 0], pred_angles[:, 0])
    mae_psi = mean_absolute_error(true_angles[:, 1], pred_angles[:, 1])
    mae_omega = mean_absolute_error(true_angles[:, 2], pred_angles[:, 2])
    
    r2_phi = r2_score(true_angles[:, 0], pred_angles[:, 0])
    r2_psi = r2_score(true_angles[:, 1], pred_angles[:, 1])
    r2_omega = r2_score(true_angles[:, 2], pred_angles[:, 2])
    
    print(f"\nEvaluation Results for {mode.upper()} model:")
    print(f"Phi   - MAE: {mae_phi:.2f}°, R²: {r2_phi:.3f}")
    print(f"Psi   - MAE: {mae_psi:.2f}°, R²: {r2_psi:.3f}")
    print(f"Omega - MAE: {mae_omega:.2f}°, R²: {r2_omega:.3f}")
    print(f"Average MAE: {(mae_phi + mae_psi + mae_omega)/3:.2f}°")
    
    return {
        'mae': [mae_phi, mae_psi, mae_omega],
        'r2': [r2_phi, r2_psi, r2_omega],
        'predictions': pred_angles,
        'true_values': true_angles
    }

# ----------------- Run Enhanced Experiments -----------------
if __name__ == "__main__":
    print("="*60)
    print("ENHANCED PROTEIN TORSION ANGLE PREDICTION")
    print("="*60)
    
    # Train models with different configurations
    print("\n1. Training with Amino Acids only...")
    model_aa, train_loss_aa, val_loss_aa = train_enhanced_model(df, mode="aa", epochs=100)
    results_aa = evaluate_model(model_aa, df, mode="aa")
    
    print("\n2. Training with Codons only...")
    model_codon, train_loss_codon, val_loss_codon = train_enhanced_model(df, mode="codon", epochs=100)
    results_codon = evaluate_model(model_codon, df, mode="codon")
    
    print("\n3. Training with Both (Enhanced Architecture)...")
    model_both, train_loss_both, val_loss_both = train_enhanced_model(df, mode="both", epochs=100)
    results_both = evaluate_model(model_both, df, mode="both")
    
    print("\n" + "="*60)
    print("FINAL COMPARISON:")
    print("="*60)
    print(f"AA Only    - Average MAE: {sum(results_aa['mae'])/3:.2f}°")
    print(f"Codon Only - Average MAE: {sum(results_codon['mae'])/3:.2f}°")
    print(f"Both       - Average MAE: {sum(results_both['mae'])/3:.2f}°")

Dataset shape: (6635, 12)
Unique amino acids: 20
Unique codons: 61
ENHANCED PROTEIN TORSION ANGLE PREDICTION

1. Training with Amino Acids only...
Using device: cuda
Model parameters: 442,121
Epoch 1: Train Loss = 0.5179, Val Loss = 0.4896
Epoch 11: Train Loss = 0.4839, Val Loss = 0.4782
Epoch 21: Train Loss = 0.4818, Val Loss = 0.4782
Epoch 31: Train Loss = 0.4817, Val Loss = 0.4783
Epoch 41: Train Loss = 0.4810, Val Loss = 0.4782
Epoch 51: Train Loss = 0.4811, Val Loss = 0.4780
Epoch 61: Train Loss = 0.4803, Val Loss = 0.4778
Epoch 71: Train Loss = 0.4797, Val Loss = 0.4779
Epoch 81: Train Loss = 0.4805, Val Loss = 0.4778
Epoch 91: Train Loss = 0.4796, Val Loss = 0.4778
Epoch 100: Train Loss = 0.4800, Val Loss = 0.4778

Evaluation Results for AA model:
Phi   - MAE: 31.26°, R²: -0.409
Psi   - MAE: 67.41°, R²: -0.196
Omega - MAE: 132.11°, R²: -0.602
Average MAE: 76.93°

2. Training with Codons only...
Using device: cuda
Model parameters: 442,121
Epoch 1: Train Loss = 0.5382, Val Loss =