<a href="https://colab.research.google.com/github/John1495/RNA-3-D-1/blob/main/GVP_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install joblib




In [None]:
!pip install torch torch-geometric torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-2.1.0+cpu.html



Looking in links: https://data.pyg.org/whl/torch-2.1.0+cpu.html
Collecting torch-geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-2.1.0%2Bcpu/torch_scatter-2.1.2%2Bpt21cpu-cp311-cp311-linux_x86_64.whl (500 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m500.4/500.4 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-2.1.0%2Bcpu/torch_sparse-0.6.18%2Bpt21cpu-cp311-cp311-linux_x86_64.whl (1.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m15.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch-cluster
  Downloading https://data.pyg.org/whl/torch-2.1.0%2Bcpu/torch_cluster-1.6.3%2Bpt21cpu-cp311-cp311-linux_x86_64.whl (753 kB)
[2K     [9

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, DataLoader
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error
from tqdm import tqdm

# == Load Data ==
seq_df = pd.read_csv('/kaggle/cleaned_train_sequences2 (1).csv')
label_df = pd.read_csv('/kaggle/train_labels1.csv')

label_df['resname'] = label_df['resname'].str.extract(r'([AUGC])')
label_df = label_df.dropna(subset=['resname'])
label_df['target_id'] = label_df['ID'].str.extract(r'(.+)_\d+')

merged = pd.merge(label_df, seq_df[['target_id', 'sequence']], on='target_id', how='left')

# Filter for complete RNAs
valid_ids = merged.groupby('target_id')['resid'].count()
valid_ids = valid_ids[valid_ids > 10].index
merged = merged[merged['target_id'].isin(valid_ids)]

train_ids, val_ids = train_test_split(merged['target_id'].unique(), test_size=0.1, random_state=42)
residue_mapping = {'A': 0, 'U': 1, 'G': 2, 'C': 3}

# == Graph Creator ==
def create_graph(df_group, scaler=None, fit_scaler=False):
    df_group = df_group.sort_values('resid')
    coords = df_group[['x_1', 'y_1', 'z_1']].values

    if scaler:
        coords = scaler.fit_transform(coords) if fit_scaler else scaler.transform(coords)

    node_scalar = torch.eye(4)[[residue_mapping[r] for r in df_group['resname']]]

    # Vector features are placeholder zeros for now (can be enhanced)
    node_vector = torch.zeros((len(df_group), 4))

    node_features = torch.cat([node_scalar, node_vector], dim=1)

    pos = torch.tensor(coords, dtype=torch.float)
    y = pos
    n = len(df_group)

    edge_index = torch.tensor([[i, j] for i in range(n) for j in range(n) if i != j], dtype=torch.long).t().contiguous()
    return Data(x=node_features, edge_index=edge_index, pos=pos, y=y)

scaler = StandardScaler()
train_graphs = [create_graph(merged[merged['target_id'] == tid], scaler, True) for tid in tqdm(train_ids)]
val_graphs = [create_graph(merged[merged['target_id'] == tid], scaler, False) for tid in tqdm(val_ids)]

train_loader = DataLoader(train_graphs, batch_size=1)
val_loader = DataLoader(val_graphs, batch_size=1)

# == GVP Block ==
class GVPBlock(nn.Module):
    def __init__(self, scalar_dim, vector_dim, hidden_dim):
        super().__init__()
        self.scalar_mlp = nn.Sequential(
            nn.Linear(scalar_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.vector_mlp = nn.Sequential(
            nn.Linear(vector_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

    def forward(self, x_scalar, x_vector):
        s_out = self.scalar_mlp(x_scalar)
        v_out = self.vector_mlp(x_vector)
        return s_out, v_out

# == Full GVP Model ==
class PowerfulGVPModel(nn.Module):
    def __init__(self, scalar_dim=4, vector_dim=4, hidden_dim=64):
        super().__init__()
        self.gvp1 = GVPBlock(scalar_dim, vector_dim, hidden_dim)
        self.gvp2 = GVPBlock(hidden_dim, hidden_dim, hidden_dim)
        self.gvp3 = GVPBlock(hidden_dim, hidden_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, 3)

    def forward(self, data):
        x = data.x.float()
        x_scalar = x[:, :4]
        x_vector = x[:, 4:]

        s, v = self.gvp1(x_scalar, x_vector)
        s, v = self.gvp2(s, v)
        s, v = self.gvp3(s, v)

        x_combined = s + v
        out = self.fc(x_combined)
        return out

# == Training ==
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = PowerfulGVPModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()

best_loss = float('inf')
patience = 10
no_improve = 0

for epoch in range(100):
    model.train()
    total_loss = 0
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        pred = model(batch)
        loss = loss_fn(pred, batch.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch}: Train Loss = {avg_loss:.6f}")

    if avg_loss < best_loss:
        best_loss = avg_loss
        no_improve = 0
        torch.save(model.state_dict(), "best_gvp_model.pth")
    else:
        no_improve += 1
        if no_improve >= patience:
            print("Early stopping")
            break

# == Evaluation ==
model.load_state_dict(torch.load("best_gvp_model.pth"))
model.eval()
predictions, targets = [], []

with torch.no_grad():
    for batch in val_loader:
        batch = batch.to(device)
        pred = model(batch)
        predictions.append(pred.cpu().numpy())
        targets.append(batch.y.cpu().numpy())

predictions = np.concatenate(predictions)
targets = np.concatenate(targets)

rmse = np.sqrt(mean_squared_error(targets, predictions))
mae = mean_absolute_error(targets, predictions)

def calculate_tm_score(true, pred):
    d = np.linalg.norm(true - pred, axis=1)
    return np.mean(np.exp(-d / (0.5 * len(d))))

tm_score = calculate_tm_score(targets, predictions)
print(f"\nValidation Results:\nRMSE = {rmse:.4f}, MAE = {mae:.4f}, TM-Score = {tm_score:.4f}")


100%|██████████| 747/747 [03:23<00:00,  3.67it/s]
100%|██████████| 83/83 [00:30<00:00,  2.69it/s]


Epoch 0: Train Loss = 0.947850
Epoch 1: Train Loss = 0.947782
Epoch 2: Train Loss = 0.947724
Epoch 3: Train Loss = 0.947745
Epoch 4: Train Loss = 0.947692
Epoch 5: Train Loss = 0.947705
Epoch 6: Train Loss = 0.947678
Epoch 7: Train Loss = 0.947668
Epoch 8: Train Loss = 0.947661
Epoch 9: Train Loss = 0.947675
Epoch 10: Train Loss = 0.947646
Epoch 11: Train Loss = 0.947617
Epoch 12: Train Loss = 0.947606
Epoch 13: Train Loss = 0.947662
Epoch 14: Train Loss = 0.947625
Epoch 15: Train Loss = 0.947601
Epoch 16: Train Loss = 0.947591
Epoch 17: Train Loss = 0.947593
Epoch 18: Train Loss = 0.947586
Epoch 19: Train Loss = 0.947578
Epoch 20: Train Loss = 0.947577
Epoch 21: Train Loss = 0.947572
Epoch 22: Train Loss = 0.947568
Epoch 23: Train Loss = 0.947601
Epoch 24: Train Loss = 0.947578
Epoch 25: Train Loss = 0.947569
Epoch 26: Train Loss = 0.947565
Epoch 27: Train Loss = 0.947565
Epoch 28: Train Loss = 0.947565
Epoch 29: Train Loss = 0.947567
Epoch 30: Train Loss = 0.947559
Epoch 31: Train Lo

In [25]:
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data, Batch
import numpy as np
from torch.optim.lr_scheduler import OneCycleLR
import pandas as pd
from torch.utils.data import DataLoader, random_split
from sklearn.model_selection import KFold
import copy
import time

class RNAGraphDataset(torch.utils.data.Dataset):
    def __init__(self, sequences_file, labels_file, validation_mode=False):
        self.sequences = pd.read_csv(sequences_file)
        self.labels = pd.read_csv(labels_file)
        self.mean = None
        self.std = None
        self.validation_mode = validation_mode
        self.graphs = self._process_data()
        self.lengths = [len(seq) for seq in self.sequences['sequence']]
        self.max_length = max(self.lengths) if self.lengths else 1

    def _process_data(self):
        graphs = []
        all_coords = []

        if not self.validation_mode:
            for _, row in self.sequences.iterrows():
                target_id = row['target_id']
                matched_labels = self.labels[self.labels['ID'].str.startswith(target_id)]
                if not matched_labels.empty and len(matched_labels) == len(row['sequence']):
                    coords = matched_labels[['x_1', 'y_1', 'z_1']].values.astype(np.float32)
                    all_coords.append(coords)

            all_coords = np.concatenate(all_coords, axis=0) if all_coords else np.zeros((1,3))
            self.mean = all_coords.mean(axis=0)
            self.std = all_coords.std(axis=0) + 1e-8
        else:
            self.mean = np.zeros(3)
            self.std = np.ones(3)

        for _, row in self.sequences.iterrows():
            target_id = row['target_id']
            sequence = row['sequence']
            matched_labels = self.labels[self.labels['ID'].str.startswith(target_id)]

            if matched_labels.empty or len(matched_labels) != len(sequence):
                continue

            node_feats = self._encode_sequence(sequence)
            coords = matched_labels[['x_1', 'y_1', 'z_1']].values.astype(np.float32)
            coords = (coords - self.mean) / self.std

            edge_index, edge_attr = self._build_edges(len(sequence), coords)

            graphs.append(Data(
                x=torch.tensor(node_feats, dtype=torch.float32),
                edge_index=edge_index,
                edge_attr=edge_attr,
                y=torch.tensor(coords, dtype=torch.float32),
                pos=torch.tensor(coords, dtype=torch.float32),
                length=torch.tensor(len(sequence), dtype=torch.long)
            ))
        return graphs

    def _encode_sequence(self, sequence):
        mapping = {
            'A': [1,0,0,0, 0.12, 0.89],
            'U': [0,1,0,0, 0.23, 0.76],
            'G': [0,0,1,0, 0.34, 0.65],
            'C': [0,0,0,1, 0.45, 0.54]
        }
        return [mapping.get(nt, [0]*6) for nt in sequence]

    def _build_edges(self, length, coords):
        edge_index = []
        edge_attr = []
        for i in range(length - 1):
            vec = coords[i+1] - coords[i]
            dist = np.linalg.norm(vec)
            edge_index.append([i, i+1])
            edge_index.append([i+1, i])
            edge_attr.extend([
                [dist, vec[0], vec[1], vec[2], 1.0],
                [dist, -vec[0], -vec[1], -vec[2], 0.0]
            ])
        return (torch.tensor(edge_index).t().contiguous(),
                torch.tensor(np.array(edge_attr), dtype=torch.float32))

    def __len__(self):
        return len(self.graphs)

    def __getitem__(self, idx):
        data = self.graphs[idx]

        if not self.validation_mode and torch.rand(1) < 0.5:
            noise_level = 0.02 * (self.lengths[idx]/self.max_length)
            data.y += torch.randn_like(data.y) * noise_level
            data.pos += torch.randn_like(data.pos) * noise_level
            if hasattr(data, 'edge_attr'):
                data.edge_attr[:, :4] += torch.randn_like(data.edge_attr[:, :4]) * (noise_level/2)

        return data

def compute_tm_score(pred_coords, true_coords, lengths):
    tm_scores = []
    pred_coords = pred_coords.detach().cpu().numpy()
    true_coords = true_coords.detach().cpu().numpy()

    ptr = 0
    for L in lengths:
        if isinstance(L, torch.Tensor):
            L = L.item()

        pred = pred_coords[ptr:ptr+L]
        true = true_coords[ptr:ptr+L]
        ptr += L

        d0 = max(1.24 * (L - 15) ** (1/3) - 1.8, 0.5)
        diff = pred - true
        dist_sq = np.sum(diff**2, axis=1)
        tm_components = 1 / (1 + (dist_sq / (d0**2)))
        tm_scores.append(np.sum(tm_components) / L)

    return np.mean(tm_scores) if tm_scores else 0.0

class RNAD33(nn.Module):
    def __init__(self, in_channels=6, hidden_channels=512, num_layers=6, dropout=0.2):
        super().__init__()
        self.dropout = dropout

        self.input_proj = nn.Sequential(
            nn.Linear(in_channels, hidden_channels),
            nn.LayerNorm(hidden_channels),
            nn.LeakyReLU(0.1),
            nn.Dropout(dropout)
        )

        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()
        for _ in range(num_layers):
            conv = GCNConv(hidden_channels, hidden_channels, improved=True)
            bn = nn.BatchNorm1d(hidden_channels)
            self.convs.append(conv)
            self.bns.append(bn)

        self.output_net = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels*2),
            nn.LeakyReLU(0.1),
            nn.Dropout(dropout),
            nn.Linear(hidden_channels*2, hidden_channels),
            nn.LeakyReLU(0.1),
            nn.Linear(hidden_channels, 3)
        )

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, GCNConv):
                nn.init.kaiming_normal_(m.lin.weight, mode='fan_out', nonlinearity='leaky_relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.input_proj(x)
        x_skip = x.clone()

        for conv, bn in zip(self.convs, self.bns):
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = conv(x, edge_index) + x
            x = bn(x)
            x = F.leaky_relu(x, 0.1)

        return self.output_net(x + x_skip)

def validate(model, loader, device):
    model.eval()
    total_loss = 0
    tm_scores = []
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            pred = model(data)
            loss = F.mse_loss(pred, data.y)
            total_loss += loss.item()
            batch_tm = compute_tm_score(pred, data.y, data.length)
            tm_scores.append(batch_tm)

    return total_loss / len(loader), np.mean(tm_scores) if tm_scores else 0.0

def kfold_validation(train_dataset, config, device):
    kfold = KFold(n_splits=config["k_folds"], shuffle=True, random_state=42)
    fold_results = []

    for fold, (train_idx, val_idx) in enumerate(kfold.split(train_dataset)):
        print(f"\n=== Fold {fold + 1}/{config['k_folds']} ===")

        train_subset = torch.utils.data.Subset(train_dataset, train_idx)
        val_subset = torch.utils.data.Subset(train_dataset, val_idx)

        train_loader = DataLoader(
            train_subset,
            batch_size=config["batch_size"],
            shuffle=True,
            collate_fn=lambda x: Batch.from_data_list(x),
            num_workers=4 if str(device) == 'cuda' else 0
        )
        val_loader = DataLoader(
            val_subset,
            batch_size=config["batch_size"],
            collate_fn=lambda x: Batch.from_data_list(x),
            num_workers=4 if str(device) == 'cuda' else 0
        )

        model = RNAD33(
            hidden_channels=config["hidden_dim"],
            num_layers=config["num_layers"],
            dropout=config["dropout"]
        ).to(device)

        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=config["lr"],
            weight_decay=config["weight_decay"]
        )

        scheduler = OneCycleLR(
            optimizer,
            max_lr=config["max_lr"],
            steps_per_epoch=len(train_loader),
            epochs=config["max_epochs"],
            pct_start=0.3
        )

        best_val = float('inf')
        for epoch in range(1, config["max_epochs"] + 1):
            model.train()
            train_loss = 0

            for data in train_loader:
                data = data.to(device)
                optimizer.zero_grad()

                pred = model(data)
                loss = F.mse_loss(pred, data.y)
                loss += config["l2_lambda"] * sum(p.norm(2) for p in model.parameters())

                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), config["grad_clip"])
                optimizer.step()
                scheduler.step()
                train_loss += loss.item()

            val_loss, val_tm = validate(model, val_loader, device)

            if val_loss < best_val:
                best_val = val_loss
                best_model = copy.deepcopy(model.state_dict())

            print(f"Epoch {epoch:03d} | Train Loss: {train_loss/len(train_loader):.6f} | "
                  f"Val Loss: {val_loss:.6f} | TM-score: {val_tm:.4f}")

        fold_results.append((best_val, best_model))
        print(f"Fold {fold + 1} completed. Best Val Loss: {best_val:.6f}")

    return fold_results

def main():
    config = {
        "train_sequences": "/kaggle/cleaned_train_sequences2 (1).csv",
        "train_labels": "/kaggle/train_labels1.csv",
        "validation_sequences": "/kaggle/validation_sequences.csv",
        "validation_labels": "/kaggle/validation_labels.csv",
        "batch_size": 32 if torch.cuda.is_available() else 8,
        "hidden_dim": 512,
        "num_layers": 6,
        "max_epochs": 200,
        "patience": 30,
        "min_delta": 0.0001,
        "grad_clip": 0.3,
        "lr": 5e-4,
        "max_lr": 3e-4,
        "weight_decay": 1e-5,
        "l2_lambda": 0.01,
        "dropout": 0.2,
        "k_folds": 5
    }

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    print("Loading full training dataset...")
    full_train_dataset = RNAGraphDataset(config["train_sequences"], config["train_labels"], validation_mode=False)
    print(f"Loaded {len(full_train_dataset)} training samples")

    print("\nStarting k-fold cross-validation...")
    fold_results = kfold_validation(full_train_dataset, config, device)

    avg_val_loss = np.mean([res[0] for res in fold_results])
    print(f"\nK-fold validation complete. Average Val Loss: {avg_val_loss:.6f}")

    print("\nTraining final model on full dataset...")
    train_loader = DataLoader(
        full_train_dataset,
        batch_size=config["batch_size"],
        shuffle=True,
        collate_fn=lambda x: Batch.from_data_list(x),
        num_workers=4 if str(device) == 'cuda' else 0
    )

    val_dataset = RNAGraphDataset(
        config["validation_sequences"],
        config["validation_labels"],
        validation_mode=True
    )
    val_dataset.mean = full_train_dataset.mean
    val_dataset.std = full_train_dataset.std

    val_loader = DataLoader(
        val_dataset,
        batch_size=config["batch_size"],
        collate_fn=lambda x: Batch.from_data_list(x),
        num_workers=4 if str(device) == 'cuda' else 0
    )

    model = RNAD33(
        hidden_channels=config["hidden_dim"],
        num_layers=config["num_layers"],
        dropout=config["dropout"]
    ).to(device)

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config["lr"],
        weight_decay=config["weight_decay"]
    )

    scheduler = OneCycleLR(
        optimizer,
        max_lr=config["max_lr"],
        steps_per_epoch=len(train_loader),
        epochs=config["max_epochs"],
        pct_start=0.3
    )

    best_val = float('inf')
    best_tm = 0
    best_epoch = 0

    for epoch in range(1, config["max_epochs"] + 1):
        model.train()
        train_loss = 0
        epoch_start = time.time()

        for data in train_loader:
            data = data.to(device)
            optimizer.zero_grad()

            pred = model(data)
            loss = F.mse_loss(pred, data.y)
            loss += config["l2_lambda"] * sum(p.norm(2) for p in model.parameters())
            loss += 0.001 * sum(p.abs().sum() for p in model.parameters())

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), config["grad_clip"])
            optimizer.step()
            scheduler.step()
            train_loss += loss.item()

        val_loss, val_tm = validate(model, val_loader, device)
        epoch_time = time.time() - epoch_start

        print(f"Epoch {epoch:03d} | Time: {epoch_time:.1f}s | "
              f"Train Loss: {train_loss/len(train_loader):.6f} | "
              f"Val Loss: {val_loss:.6f} | TM-score: {val_tm:.4f}")

        if val_loss < best_val - config["min_delta"]:
            best_val = val_loss
            best_tm = val_tm
            best_epoch = epoch
            torch.save(model.state_dict(), "best_model.pt")
            print(f"New best model saved (Val Loss: {val_loss:.6f}, TM-score: {val_tm:.4f})")

        if epoch - best_epoch > config["patience"]:
            print(f"\nEarly stopping triggered at epoch {epoch}")
            break

    print("\nTraining completed!")
    print(f"Best validation loss: {best_val:.6f}")
    print(f"Best TM-score: {best_tm:.4f}")

if __name__ == "__main__":
    main()

Using device: cpu
Loading full training dataset...
Loaded 844 training samples

Starting k-fold cross-validation...

=== Fold 1/5 ===


TypeError: '>' not supported between instances of 'float' and 'complex'

In [None]:
!pip install joblib



In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import joblib

# Then save it to your drive
torch.save(model.state_dict(), '/content/drive/MyDrive/GVP_Model.pth')
joblib.dump(scaler, '/content/drive/MyDrive/GVP_Scaler.save')

print("Saved to Google Drive as 'GVP_Model.pth'")


Saved to Google Drive as 'GVP_Model.pth'
