<a href="https://colab.research.google.com/github/John1495/RNA-3D/blob/main/GVP_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [9]:
!pip install joblib




In [1]:
!pip install torch torch-geometric torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-2.1.0+cpu.html



Looking in links: https://data.pyg.org/whl/torch-2.1.0+cpu.html
Collecting torch-geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-2.1.0%2Bcpu/torch_scatter-2.1.2%2Bpt21cpu-cp311-cp311-linux_x86_64.whl (500 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m500.4/500.4 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-2.1.0%2Bcpu/torch_sparse-0.6.18%2Bpt21cpu-cp311-cp311-linux_x86_64.whl (1.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m15.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch-cluster
  Downloading https://data.pyg.org/whl/torch-2.1.0%2Bcpu/torch_cluster-1.6.3%2Bpt21cpu-cp311-cp311-linux_x86_64.whl (753 kB)
[2K     [9

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, DataLoader
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error
from tqdm import tqdm

# == Load Data ==
seq_df = pd.read_csv('/kaggle/cleaned_train_sequences2 (1).csv')
label_df = pd.read_csv('/kaggle/train_labels1.csv')

label_df['resname'] = label_df['resname'].str.extract(r'([AUGC])')
label_df = label_df.dropna(subset=['resname'])
label_df['target_id'] = label_df['ID'].str.extract(r'(.+)_\d+')

merged = pd.merge(label_df, seq_df[['target_id', 'sequence']], on='target_id', how='left')

# Filter for complete RNAs
valid_ids = merged.groupby('target_id')['resid'].count()
valid_ids = valid_ids[valid_ids > 10].index
merged = merged[merged['target_id'].isin(valid_ids)]

train_ids, val_ids = train_test_split(merged['target_id'].unique(), test_size=0.1, random_state=42)
residue_mapping = {'A': 0, 'U': 1, 'G': 2, 'C': 3}

# == Graph Creator ==
def create_graph(df_group, scaler=None, fit_scaler=False):
    df_group = df_group.sort_values('resid')
    coords = df_group[['x_1', 'y_1', 'z_1']].values

    if scaler:
        coords = scaler.fit_transform(coords) if fit_scaler else scaler.transform(coords)

    node_scalar = torch.eye(4)[[residue_mapping[r] for r in df_group['resname']]]

    # Vector features are placeholder zeros for now (can be enhanced)
    node_vector = torch.zeros((len(df_group), 4))

    node_features = torch.cat([node_scalar, node_vector], dim=1)

    pos = torch.tensor(coords, dtype=torch.float)
    y = pos
    n = len(df_group)

    edge_index = torch.tensor([[i, j] for i in range(n) for j in range(n) if i != j], dtype=torch.long).t().contiguous()
    return Data(x=node_features, edge_index=edge_index, pos=pos, y=y)

scaler = StandardScaler()
train_graphs = [create_graph(merged[merged['target_id'] == tid], scaler, True) for tid in tqdm(train_ids)]
val_graphs = [create_graph(merged[merged['target_id'] == tid], scaler, False) for tid in tqdm(val_ids)]

train_loader = DataLoader(train_graphs, batch_size=1)
val_loader = DataLoader(val_graphs, batch_size=1)

# == GVP Block ==
class GVPBlock(nn.Module):
    def __init__(self, scalar_dim, vector_dim, hidden_dim):
        super().__init__()
        self.scalar_mlp = nn.Sequential(
            nn.Linear(scalar_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.vector_mlp = nn.Sequential(
            nn.Linear(vector_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

    def forward(self, x_scalar, x_vector):
        s_out = self.scalar_mlp(x_scalar)
        v_out = self.vector_mlp(x_vector)
        return s_out, v_out

# == Full GVP Model ==
class PowerfulGVPModel(nn.Module):
    def __init__(self, scalar_dim=4, vector_dim=4, hidden_dim=64):
        super().__init__()
        self.gvp1 = GVPBlock(scalar_dim, vector_dim, hidden_dim)
        self.gvp2 = GVPBlock(hidden_dim, hidden_dim, hidden_dim)
        self.gvp3 = GVPBlock(hidden_dim, hidden_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, 3)

    def forward(self, data):
        x = data.x.float()
        x_scalar = x[:, :4]
        x_vector = x[:, 4:]

        s, v = self.gvp1(x_scalar, x_vector)
        s, v = self.gvp2(s, v)
        s, v = self.gvp3(s, v)

        x_combined = s + v
        out = self.fc(x_combined)
        return out

# == Training ==
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = PowerfulGVPModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()

best_loss = float('inf')
patience = 10
no_improve = 0

for epoch in range(100):
    model.train()
    total_loss = 0
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        pred = model(batch)
        loss = loss_fn(pred, batch.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch}: Train Loss = {avg_loss:.6f}")

    if avg_loss < best_loss:
        best_loss = avg_loss
        no_improve = 0
        torch.save(model.state_dict(), "best_gvp_model.pth")
    else:
        no_improve += 1
        if no_improve >= patience:
            print("Early stopping")
            break

# == Evaluation ==
model.load_state_dict(torch.load("best_gvp_model.pth"))
model.eval()
predictions, targets = [], []

with torch.no_grad():
    for batch in val_loader:
        batch = batch.to(device)
        pred = model(batch)
        predictions.append(pred.cpu().numpy())
        targets.append(batch.y.cpu().numpy())

predictions = np.concatenate(predictions)
targets = np.concatenate(targets)

rmse = np.sqrt(mean_squared_error(targets, predictions))
mae = mean_absolute_error(targets, predictions)

def calculate_tm_score(true, pred):
    d = np.linalg.norm(true - pred, axis=1)
    return np.mean(np.exp(-d / (0.5 * len(d))))

tm_score = calculate_tm_score(targets, predictions)
print(f"\nValidation Results:\nRMSE = {rmse:.4f}, MAE = {mae:.4f}, TM-Score = {tm_score:.4f}")


100%|██████████| 747/747 [03:23<00:00,  3.67it/s]
100%|██████████| 83/83 [00:30<00:00,  2.69it/s]


Epoch 0: Train Loss = 0.947850
Epoch 1: Train Loss = 0.947782
Epoch 2: Train Loss = 0.947724
Epoch 3: Train Loss = 0.947745
Epoch 4: Train Loss = 0.947692
Epoch 5: Train Loss = 0.947705
Epoch 6: Train Loss = 0.947678
Epoch 7: Train Loss = 0.947668
Epoch 8: Train Loss = 0.947661
Epoch 9: Train Loss = 0.947675
Epoch 10: Train Loss = 0.947646
Epoch 11: Train Loss = 0.947617
Epoch 12: Train Loss = 0.947606
Epoch 13: Train Loss = 0.947662
Epoch 14: Train Loss = 0.947625
Epoch 15: Train Loss = 0.947601
Epoch 16: Train Loss = 0.947591
Epoch 17: Train Loss = 0.947593
Epoch 18: Train Loss = 0.947586
Epoch 19: Train Loss = 0.947578
Epoch 20: Train Loss = 0.947577
Epoch 21: Train Loss = 0.947572
Epoch 22: Train Loss = 0.947568
Epoch 23: Train Loss = 0.947601
Epoch 24: Train Loss = 0.947578
Epoch 25: Train Loss = 0.947569
Epoch 26: Train Loss = 0.947565
Epoch 27: Train Loss = 0.947565
Epoch 28: Train Loss = 0.947565
Epoch 29: Train Loss = 0.947567
Epoch 30: Train Loss = 0.947559
Epoch 31: Train Lo

In [11]:
!pip install joblib



In [14]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [16]:
import joblib

# Then save it to your drive
torch.save(model.state_dict(), '/content/drive/MyDrive/GVP_Model.pth')
joblib.dump(scaler, '/content/drive/MyDrive/GVP_Scaler.save')

print("Saved to Google Drive as 'GVP_Model.pth'")


Saved to Google Drive as 'GVP_Model.pth'
