<a href="https://colab.research.google.com/github/John1495/RNA-3D/blob/main/GCNConv_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install optuna
!pip install joblib


Collecting optuna
  Downloading optuna-4.3.0-py3-none-any.whl.metadata (17 kB)
Collecting alembic>=1.5.0 (from optuna)
  Downloading alembic-1.15.2-py3-none-any.whl.metadata (7.3 kB)
Collecting colorlog (from optuna)
  Downloading colorlog-6.9.0-py3-none-any.whl.metadata (10 kB)
Collecting sqlalchemy>=1.4.2 (from optuna)
  Downloading sqlalchemy-2.0.40-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.6 kB)
Collecting greenlet>=1 (from sqlalchemy>=1.4.2->optuna)
  Downloading greenlet-3.2.1-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.metadata (4.1 kB)
Downloading optuna-4.3.0-py3-none-any.whl (386 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m386.6/386.6 kB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading alembic-1.15.2-py3-none-any.whl (231 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m231.9/231.9 kB[0m [31m16.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading sqlalchemy-2.0.40-cp311-cp311-manyli

In [2]:
!pip install torch torch-geometric torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-2.1.0+cpu.html


Looking in links: https://data.pyg.org/whl/torch-2.1.0+cpu.html
Collecting torch-geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-2.1.0%2Bcpu/torch_scatter-2.1.2%2Bpt21cpu-cp311-cp311-linux_x86_64.whl (500 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m500.4/500.4 kB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-2.1.0%2Bcpu/torch_sparse-0.6.18%2Bpt21cpu-cp311-cp311-linux_x86_64.whl (1.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m31.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch-cluster
  Downloading https://data.pyg.org/whl/torch-2.1.0%2Bcpu/torch_cluster-1.6.3%2Bpt21cpu-cp311-cp311-linux_x86_64.whl (753 kB)
[2K     [9

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import seaborn as sns
import joblib
from tqdm import tqdm

# Load data
seq_df = pd.read_csv('/kaggle/cleaned_train_sequences2 (1).csv')
label_df = pd.read_csv('/kaggle/train_labels1.csv')
label_df['resname'] = label_df['resname'].str.extract(r'([AUGC])')
label_df = label_df.dropna(subset=['resname'])
label_df['target_id'] = label_df['ID'].str.extract(r'(.+)_\d+')
merged = pd.merge(label_df, seq_df[['target_id', 'sequence']], on='target_id', how='left')

# Filter complete cases only
complete_ids = merged.groupby('target_id')['resid'].count()
valid_ids = complete_ids[complete_ids > 10].index
merged = merged[merged['target_id'].isin(valid_ids)]

# Train-validation split
all_ids = merged['target_id'].unique()
train_ids, val_ids = train_test_split(all_ids, test_size=0.1, random_state=42)

residue_mapping = {'A': 0, 'U': 1, 'G': 2, 'C': 3}

def create_graph_from_group(df_group, scaler=None, fit_scaler=False):
    sequence = df_group['sequence'].values[0]
    coords = df_group[['x_1', 'y_1', 'z_1']].values
    df_group = df_group.sort_values('resid')
    node_features = torch.eye(4)[[residue_mapping[r] for r in df_group['resname']]]

    if scaler is not None:
        if fit_scaler:
            coords = scaler.fit_transform(coords)
        else:
            coords = scaler.transform(coords)

    pos = torch.tensor(coords, dtype=torch.float)
    y = pos  # predict absolute coordinates

    n = len(df_group)
    edge_index = torch.tensor(
        [[i, i+1] for i in range(n-1)] + [[i+1, i] for i in range(n-1)],
        dtype=torch.long
    ).t().contiguous()

    return Data(x=node_features, edge_index=edge_index, pos=pos, y=y)

scaler = StandardScaler()
train_graphs, val_graphs = [], []
for tid in tqdm(train_ids):
    g = create_graph_from_group(merged[merged['target_id'] == tid], scaler, fit_scaler=True)
    train_graphs.append(g)

for tid in tqdm(val_ids):
    g = create_graph_from_group(merged[merged['target_id'] == tid], scaler, fit_scaler=False)
    val_graphs.append(g)

train_loader = DataLoader(train_graphs, batch_size=1)
val_loader = DataLoader(val_graphs, batch_size=1)

class GCNNet(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GCNNet, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, output_dim)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return x

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCNNet(input_dim=4, hidden_dim=32, output_dim=3).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.MSELoss()

best_val_loss = float('inf')
patience = 10
epochs_without_improvement = 0

model.train()
for epoch in range(100):
    total_loss = 0
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        pred = model(batch)
        loss = loss_fn(pred, batch.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch}, Loss: {avg_loss:.6f}")

    if avg_loss < best_val_loss:
        best_val_loss = avg_loss
        epochs_without_improvement = 0
        torch.save(model.state_dict(), "MODEL_GCN.pth")
    else:
        epochs_without_improvement += 1
        if epochs_without_improvement >= patience:
            print(f"Early stopping at epoch {epoch}")
            break

model.eval()
predictions, true = [], []
with torch.no_grad():
    for batch in val_loader:
        batch = batch.to(device)
        pred = model(batch)
        predictions.append(pred.cpu().numpy())
        true.append(batch.y.cpu().numpy())

predictions = np.concatenate(predictions, axis=0)
true = np.concatenate(true, axis=0)

rmse = np.sqrt(mean_squared_error(true, predictions))
mae = mean_absolute_error(true, predictions)

def calculate_tm_score(true, pred):
    d = np.linalg.norm(true - pred, axis=1)
    N = true.shape[0]
    return np.mean(np.exp(-d / (0.5 * N)))

def calculate_rmsd(true, pred):
    return np.sqrt(np.mean((true - pred) ** 2))

tm_score = calculate_tm_score(true, predictions)
rmsd = calculate_rmsd(true, predictions)

print("\nValidation Metrics:")
print(f"RMSE: {rmse:.6f}")
print(f"MAE: {mae:.6f}")
print(f"TM-Score: {tm_score:.6f}")
print(f"RMSD: {rmsd:.6f}")

# Save scaler
joblib.dump(scaler, "scaler_gcn.pkl")


100%|██████████| 747/747 [00:07<00:00, 96.00it/s]
100%|██████████| 83/83 [00:00<00:00, 97.12it/s]


Epoch 0, Loss: 0.950368
Epoch 1, Loss: 0.948379
Epoch 2, Loss: 0.948051
Epoch 3, Loss: 0.947850
Epoch 4, Loss: 0.947716
Epoch 5, Loss: 0.947637
Epoch 6, Loss: 0.947593
Epoch 7, Loss: 0.947563
Epoch 8, Loss: 0.947533
Epoch 9, Loss: 0.947516
Epoch 10, Loss: 0.947484
Epoch 11, Loss: 0.947457
Epoch 12, Loss: 0.947419
Epoch 13, Loss: 0.947412
Epoch 14, Loss: 0.947384
Epoch 15, Loss: 0.947383
Epoch 16, Loss: 0.947354
Epoch 17, Loss: 0.947365
Epoch 18, Loss: 0.947334
Epoch 19, Loss: 0.947318
Epoch 20, Loss: 0.947293
Epoch 21, Loss: 0.947303
Epoch 22, Loss: 0.947270
Epoch 23, Loss: 0.947272
Epoch 24, Loss: 0.947252
Epoch 25, Loss: 0.947234
Epoch 26, Loss: 0.947231
Epoch 27, Loss: 0.947212
Epoch 28, Loss: 0.947209
Epoch 29, Loss: 0.947191
Epoch 30, Loss: 0.947196
Epoch 31, Loss: 0.947171
Epoch 32, Loss: 0.947168
Epoch 33, Loss: 0.947153
Epoch 34, Loss: 0.947139
Epoch 35, Loss: 0.947145
Epoch 36, Loss: 0.947125
Epoch 37, Loss: 0.947118
Epoch 38, Loss: 0.947110
Epoch 39, Loss: 0.947101
Epoch 40, 

['scaler_gcn.pkl']

In [5]:
from google.colab import drive
drive.mount('/content/drive')

# Then save it to your drive
torch.save(model.state_dict(), '/content/drive/MyDrive/GCNConv_Model.pth')
joblib.dump(scaler, '/content/drive/MyDrive/GCNConc_Scaler.save')

print("Saved to Google Drive as 'GCNConv_Model.pth'")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Saved to Google Drive as 'GCNConv_Model.pth'
