In [71]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.data import Data

In [51]:
train_data = pd.read_csv("./data/train_sequences.csv")
train_labels = pd.read_csv("./data/train_labels.csv")
val_data = pd.read_csv("./data/validation_sequences.csv")
val_labels = pd.read_csv("./data/validation_labels.csv")


In [52]:
def preprocess_data(data_df: pd.DataFrame, labels_df: pd.DataFrame):
    """
    Converts RNA sequences from a dataframe (data_df) and their corresponding 3D coordinates 
    from labels_df into a list of PyTorch Geometric Data objects.
    
    - data_df should have columns: target_id, sequence, etc.
    - labels_df should have columns: ID, resname, resid, x_1, y_1, z_1.
    - For each target_id in data_df, rows in labels_df with IDs starting 
      with "target_id_" correspond to that sequence.
      
    Unexpected nucleotides such as '-' or 'X' are mapped to an unknown category.
    Rows with NaN coordinate values are dropped.
    The 3D coordinates are scaled per sequence (zero mean and unit variance),
    and the scaling parameters are stored for later inversion.
    """
    data_list = []
    
    # Define mapping: valid nucleotides plus an unknown category for '-' or 'X'
    nucleotides = {'A': 0, 'U': 1, 'G': 2, 'C': 3, '-': 4, 'X': 4}
    onehot_dim = 5  # 5 categories: A, U, G, C, and unknown
    
    for idx, row in data_df.iterrows():
        target_id = row['target_id']
        sequence = row['sequence'].strip().upper()
        
        # Filter rows from labels_df corresponding to this target_id and drop NaNs in coordinates.
        label_rows = labels_df[labels_df['ID'].str.startswith(target_id + "_")]
        label_rows = label_rows.dropna(subset=['x_1', 'y_1', 'z_1'])
        label_rows = label_rows.sort_values(by='resid')
        
        # Check if the number of labels matches the sequence length
        if len(sequence) != len(label_rows):
            print(f"Warning: length mismatch for {target_id}: sequence length {len(sequence)} vs labels {len(label_rows)}")
            continue  # Skip if there is a mismatch
        
        # One-hot encode the RNA sequence using the 5-dimensional embedding.
        try:
            x_indices = [nucleotides.get(nt, 4) for nt in sequence]
            x = torch.eye(onehot_dim)[x_indices]  # Shape (N, 5)
        except Exception as e:
            print(f"Error encoding sequence for {target_id}: {e}")
            continue

        num_nodes = len(sequence)
        if num_nodes < 2:
            continue  # Skip sequences that are too short to form edges
        
        # Create edge_index: connect each nucleotide to its neighbor (bidirectional)
        edges = []
        for i in range(num_nodes - 1):
            edges.append([i, i+1])
            edges.append([i+1, i])
        edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()

        # Build target 3D coordinates (y) from label_rows.
        coords = label_rows[['x_1', 'y_1', 'z_1']].to_numpy()
        y = torch.tensor(coords, dtype=torch.float32)
        
        # -------------------------------
        # Scale coordinates (per sequence)
        # -------------------------------
        mean = y.mean(dim=0, keepdim=True)
        std = y.std(dim=0, keepdim=True)
        std[std == 0] = 1.0  # Avoid division by zero
        y_scaled = (y - mean) / std

        # Create a PyTorch Geometric Data object
        data_obj = Data(x=x, edge_index=edge_index, y=y_scaled)
        # Store scaling parameters for later inversion
        data_obj.y_mean = mean
        data_obj.y_std = std
        data_obj.target_id = target_id
        data_list.append(data_obj)
    
    return data_list

In [None]:
class RNA_GNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim=3):
        """
        input_dim: Dimension of node features (5 for one-hot encoding with unknown).
        hidden_dim: Hidden layer dimension.
        output_dim: Dimension of output (3 for 3D coordinates).
        """
        super(RNA_GNN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, output_dim)  # Map to 3D coordinates

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = self.fc(x)
        return x


In [72]:
class RNA_GAT(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim=3, heads=4):
        """
        input_dim: Dimension of node features (5 for one-hot with unknown).
        hidden_dim: Hidden layer dimension.
        output_dim: Dimension of output (3 for 3D coordinates).
        heads: Number of attention heads in the GAT layers.
        """
        super(RNA_GAT, self).__init__()
        # First GAT layer; using multi-head attention (output will be concatenated)
        self.gat1 = GATConv(input_dim, hidden_dim, heads=heads, concat=True)
        # Second GAT layer; we can average the heads by setting concat=False
        self.gat2 = GATConv(hidden_dim * heads, hidden_dim, heads=1, concat=False)
        self.fc = nn.Linear(hidden_dim, output_dim)  # Map to 3D coordinates

    def forward(self, x, edge_index):
        x = F.elu(self.gat1(x, edge_index))
        x = F.elu(self.gat2(x, edge_index))
        x = self.fc(x)
        return x

In [74]:
def train_model(model, data_list, epochs=100, lr=0.001):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.MSELoss()
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for data in data_list:
            optimizer.zero_grad()
            pred = model(data.x, data.edge_index)
            loss = loss_fn(pred, data.y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        avg_loss = total_loss / len(data_list)
        if epoch % 10 == 0:
            print(f"Epoch {epoch:03d}, Loss: {avg_loss:.4f}")

In [76]:
def compute_tm_score(pred_coords, true_coords):
    """
    Compute a TM-score between two structures based on the provided formula.
    
    Parameters:
        pred_coords (np.ndarray): Predicted 3D coordinates of shape (L, 3).
        true_coords (np.ndarray): True 3D coordinates of shape (L, 3).
        
    Returns:
        float: TM-score (higher is better).
    """
    L_ref = true_coords.shape[0]  # Number of residues in the reference (ground truth)
    L_align = pred_coords.shape[0]  # Number of aligned residues (same as L_ref for this task)
    
    # Compute d0 (scaling factor)
    if L_ref >= 30:
        d0 = 0.6 * np.sqrt(L_ref - 0.5) - 2.5
    else:
        d0_values = {12: 0.3, 15: 0.4, 19: 0.5, 23: 0.6, 29: 0.7}
        d0 = d0_values.get(L_ref, 0.7)  # Default to 0.7 for any unknown values

    # Compute pairwise distances between corresponding residues
    distances = np.linalg.norm(pred_coords - true_coords, axis=1)

    # Compute the TM-score
    tm_score = np.sum(1 / (1 + (distances / d0)**2)) / L_ref

    return tm_score

def evaluate_tm_score(model, data_list):
    """
    Evaluate the model on a list of Data objects by computing the TM-score
    for each structure and returning the average TM-score.
    
    Parameters:
        model (nn.Module): The trained GNN model.
        data_list (list): List of PyTorch Geometric Data objects representing the validation data.
        
    Returns:
        float: Average TM-score for all validation examples.
    """
    model.eval()
    tm_scores = []
    
    with torch.no_grad():
        for data in data_list:
            # Predict scaled coordinates
            pred_scaled = model(data.x, data.edge_index)
            
            # Invert scaling to get predicted coordinates in original scale
            pred_unscaled = pred_scaled * data.y_std + data.y_mean
            
            # True 3D coordinates in original scale
            true_unscaled = data.y * data.y_std + data.y_mean
            
            # Convert predicted and true coordinates to numpy for distance calculation
            pred_coords = pred_unscaled.cpu().numpy()
            true_coords = true_unscaled.cpu().numpy()
            
            # Compute TM-score for this structure
            tm = compute_tm_score(pred_coords, true_coords)
            tm_scores.append(tm)
    
    # Return the average TM-score across all samples
    avg_tm_score = np.mean(tm_scores)
    return avg_tm_score

In [56]:
train_data_list = preprocess_data(train_data, train_labels)
val_data_list = preprocess_data(val_data, val_labels)



In [58]:
len(train_data_list)


606

In [59]:
len(val_data_list)

12

In [68]:
model = RNA_GNN(input_dim=5, hidden_dim=32)

In [86]:
model = RNA_GAT(input_dim=5, hidden_dim=64, output_dim=3, heads=4)

In [89]:
train_data_list[0].edge_index

tensor([[ 0,  1,  1,  2,  2,  3,  3,  4,  4,  5,  5,  6,  6,  7,  7,  8,  8,  9,
          9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15, 16, 16, 17, 17, 18,
         18, 19, 19, 20, 20, 21, 21, 22, 22, 23, 23, 24, 24, 25, 25, 26, 26, 27,
         27, 28],
        [ 1,  0,  2,  1,  3,  2,  4,  3,  5,  4,  6,  5,  7,  6,  8,  7,  9,  8,
         10,  9, 11, 10, 12, 11, 13, 12, 14, 13, 15, 14, 16, 15, 17, 16, 18, 17,
         19, 18, 20, 19, 21, 20, 22, 21, 23, 22, 24, 23, 25, 24, 26, 25, 27, 26,
         28, 27]])

In [87]:
train_model(model, train_data_list, epochs=200, lr=0.005)

Epoch 000, Loss: 0.9690
Epoch 010, Loss: 0.9696
Epoch 020, Loss: 0.9691
Epoch 030, Loss: 0.9688
Epoch 040, Loss: 0.9677
Epoch 050, Loss: 0.9680
Epoch 060, Loss: 0.9682
Epoch 070, Loss: 0.9676
Epoch 080, Loss: 0.9674
Epoch 090, Loss: 0.9685
Epoch 100, Loss: 0.9683
Epoch 110, Loss: 0.9690
Epoch 120, Loss: 0.9689


KeyboardInterrupt: 

In [85]:
train_avg_tm_score = evaluate_tm_score(model, train_data_list)
val_avg_tm_score = evaluate_tm_score(model, val_data_list)
print(f"Average TM-score on training set: {train_avg_tm_score:.4f}")
print(f"Average TM-score on validation set: {val_avg_tm_score:.4f}")

Average TM-score on training set: 0.0247
Average TM-score on validation set: 0.0406


In [81]:
model.eval()
with torch.no_grad():
    # For demonstration, predict for the first RNA target in data_list.
    pred = model(val_data_list[0].x, val_data_list[0].edge_index)
    print(f"Predicted 3D coordinates for target {val_data_list[0].target_id}:\n", pred)
    print(f"Original 3D coordinates for the target:\n  {val_data_list[0].y}")

Predicted 3D coordinates for target R1107:
 tensor([[-1.2690e-01, -1.1975e-01,  3.6927e-02],
        [-1.2690e-01, -1.1975e-01,  3.6927e-02],
        [-1.2690e-01, -1.1975e-01,  3.6927e-02],
        [-5.2769e-02, -8.2068e-02, -3.4334e-03],
        [-4.8493e-02, -7.8047e-02, -6.6388e-03],
        [-4.5756e-02, -7.6823e-02, -8.0445e-03],
        [-5.9105e-05, -7.6527e-03, -5.1695e-02],
        [ 4.6751e-02, -2.7077e-03,  2.3361e-02],
        [ 7.9938e-03, -3.3667e-03, -6.3310e-03],
        [ 7.2516e-02,  2.0408e-02, -1.8242e-03],
        [ 3.7805e-02,  1.1794e-02, -2.1151e-03],
        [ 3.7805e-02,  1.1794e-02, -2.1151e-03],
        [ 2.4206e-02,  4.3488e-02,  1.9653e-02],
        [ 2.2825e-02,  4.3290e-02,  1.9575e-02],
        [ 2.1471e-02,  4.3045e-02,  1.9480e-02],
        [-2.1789e-02,  3.2168e-02,  1.8806e-02],
        [-1.7712e-02,  3.1819e-02,  1.7591e-02],
        [-3.9298e-02, -9.4295e-05,  6.1142e-03],
        [-4.0341e-02, -4.8301e-04,  6.6571e-03],
        [-4.0341e-02, -4.

In [34]:
train_labels.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 137095 entries, 0 to 137094
Data columns (total 6 columns):
 #   Column   Non-Null Count   Dtype  
---  ------   --------------   -----  
 0   ID       137095 non-null  object 
 1   resname  137095 non-null  object 
 2   resid    137095 non-null  int64  
 3   x_1      130950 non-null  float64
 4   y_1      130950 non-null  float64
 5   z_1      130950 non-null  float64
dtypes: float64(3), int64(1), object(2)
memory usage: 6.3+ MB
