<a href="https://colab.research.google.com/github/adavy/EvoML/blob/main/deepTrainRNA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os

num_cores = os.cpu_count()
print(f"Number of CPU cores: {num_cores}")

Number of CPU cores: 12


In [None]:
!pip install torch-geometric torch_scatter torch_sparse torch_cluster torch_spline_conv

Collecting torch-geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch_scatter
  Downloading torch_scatter-2.1.2.tar.gz (108 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m108.0/108.0 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting torch_sparse
  Downloading torch_sparse-0.6.18.tar.gz (209 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m210.0/210.0 kB[0m [31m9.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting torch_cluster
  Downloading torch_cluster-1.6.3.tar.gz (54 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.5/54.5 kB[

In [None]:
import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torch_geometric.nn import GCNConv
from scipy.spatial.distance import pdist, squareform
from torch.cuda.amp import GradScaler, autocast

# -------------------- Hyperparameters --------------------

CONFIG = {
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "vocab_size": 4,  # A, U, G, C
    "embedding_dim": 32,
    "hidden_dim": 64,
    "max_graph_size": None,  # Will be dynamically set based on dataset
    "subgraph_size": 20,  # Size of subgraphs for hierarchical modeling
    "distance_threshold": 6.0,  # Distance threshold for adjacency matrix (in Ångströms)
    "batch_size": 32,
    "learning_rate": 0.001,
    "epochs": 10,
    "num_node_features": 4,  # One-hot encoding of nucleotides
    "gnn_hidden_channels": 64,
    "num_coordinate_dims": 3,  # x, y, z coordinates
    "train_json_path": "rna_train_data.json",
    "val_json_path": "rna_validation_data.json",
    "test_csv_path": "test_sequences.csv",
    "submission_csv_path": "submission.csv",
}

# -------------------- Define Models --------------------

class HierarchicalGraphRNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, max_graph_size, subgraph_size):
        super(HierarchicalGraphRNN, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.subgraph_fc = nn.Linear(hidden_dim, subgraph_size * subgraph_size)
        self.max_graph_size = max_graph_size
        self.subgraph_size = subgraph_size

    def forward(self, x):
        x = self.embedding(x)
        out, _ = self.rnn(x)

        # Predict subgraph adjacency matrices
        subgraph_values = self.subgraph_fc(out[:, -1, :])  # Predict subgraph adjacency matrix
        subgraph_values = subgraph_values.view(-1, self.subgraph_size, self.subgraph_size)

        # Combine subgraphs to form the full adjacency matrix
        batch_size = x.size(0)
        num_subgraphs = self.max_graph_size // self.subgraph_size
        adj_matrices = torch.zeros(batch_size, self.max_graph_size, self.max_graph_size, device=x.device)

        for b in range(batch_size):
            for i in range(num_subgraphs):
                for j in range(num_subgraphs):
                    if i == j:  # Diagonal subgraphs
                        adj_matrices[b, i * self.subgraph_size:(i + 1) * self.subgraph_size,
                                     j * self.subgraph_size:(j + 1) * self.subgraph_size] = subgraph_values[b]

        return adj_matrices


class RNA3DGN(nn.Module):
    def __init__(self, num_node_features, hidden_channels, num_coordinate_dims):
        super(RNA3DGN, self).__init__()
        self.conv1 = GCNConv(num_node_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.linear = nn.Linear(hidden_channels, num_coordinate_dims)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index).relu()
        x = self.conv3(x, edge_index).relu()
        x = self.linear(x)
        return x


# -------------------- Dataset Classes --------------------

class RNASequenceSubgraphDataset(Dataset):
    def __init__(self, data, subgraph_size, nucleotide_to_index):
        self.subgraphs = []
        self.nucleotide_to_index = nucleotide_to_index

        for key, entry in data.items():
            sequence = entry["sequence"]
            sequence_encoded = [self.nucleotide_to_index[n] for n in sequence]
            coordinates = np.array([[n["x_1"], n["y_1"], n["z_1"]] for n in entry["nucleotides"].values()])
            adj_matrix = generate_adjacency_matrix(sequence, coordinates)

            # Extract subgraphs
            subgraphs = extract_subgraphs(sequence_encoded, adj_matrix, subgraph_size)
            self.subgraphs.extend(subgraphs)

    def __len__(self):
        return len(self.subgraphs)

    def __getitem__(self, idx):
        sub_seq, sub_adj = self.subgraphs[idx]
        sub_seq = torch.tensor(sub_seq, dtype=torch.long)
        sub_adj = torch.tensor(sub_adj, dtype=torch.float32)
        return sub_seq, sub_adj


# -------------------- Helper Functions --------------------

def extract_subgraphs(sequence, adj_matrix, subgraph_size):
    num_nucleotides = len(sequence)
    subgraphs = []
    for i in range(0, num_nucleotides, subgraph_size):
        sub_adj = adj_matrix[i:i + subgraph_size, i:i + subgraph_size]
        sub_seq = sequence[i:i + subgraph_size]
        subgraphs.append((sub_seq, sub_adj))
    return subgraphs


def reconstruct_full_graph(model, sequence, subgraph_size, device):
    sequence_tensor = torch.tensor(sequence, dtype=torch.long).unsqueeze(0).to(device)
    subgraph_predictions = model(sequence_tensor)

    # Combine subgraphs into full adjacency matrix
    num_subgraphs = len(sequence) // subgraph_size
    full_adj_matrix = torch.zeros(len(sequence), len(sequence), device=device)
    for i in range(num_subgraphs):
        for j in range(num_subgraphs):
            if i == j:  # Diagonal subgraphs
                full_adj_matrix[i * subgraph_size:(i + 1) * subgraph_size,
                                j * subgraph_size:(j + 1) * subgraph_size] = subgraph_predictions[i]
    return full_adj_matrix


def generate_adjacency_matrix(sequence, coordinates, distance_threshold=6.0):
    num_nucleotides = len(sequence)
    adj_matrix = np.zeros((num_nucleotides, num_nucleotides), dtype=np.float32)

    # Compute pairwise distances between nucleotides
    pairwise_distances = squareform(pdist(coordinates))

    # Define valid base pairs
    valid_pairs = {('A', 'U'), ('U', 'A'), ('G', 'C'), ('C', 'G'), ('G', 'U'), ('U', 'G')}

    # Generate adjacency matrix
    for i in range(num_nucleotides):
        for j in range(i + 1, num_nucleotides):
            # Check distance threshold
            if pairwise_distances[i, j] < distance_threshold:
                adj_matrix[i, j] = 1
                adj_matrix[j, i] = 1

            # Check base pairing rules
            if (sequence[i], sequence[j]) in valid_pairs:
                adj_matrix[i, j] = 1
                adj_matrix[j, i] = 1

    return adj_matrix


# -------------------- Training Functions --------------------

def train_subgraph_model(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    scaler = GradScaler()

    for sub_seq, sub_adj in dataloader:
        sub_seq = sub_seq.to(device)
        sub_adj = sub_adj.to(device)

        optimizer.zero_grad()
        with autocast():
            outputs = model(sub_seq)
            loss = criterion(outputs, sub_adj)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()
    return total_loss / len(dataloader)


def validate_graph_reconstruction(model, dataloader, criterion, device, subgraph_size):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for sequences, ground_truth_adjs in dataloader:
            sequences = sequences.to(device)
            ground_truth_adjs = ground_truth_adjs.to(device)

            # Predict subgraphs and reconstruct full graph
            predicted_full_adjs = []
            for sequence in sequences:
                predicted_full_adj = reconstruct_full_graph(model, sequence, subgraph_size, device)
                predicted_full_adjs.append(predicted_full_adj)
            predicted_full_adjs = torch.stack(predicted_full_adjs)

            # Compute loss
            loss = criterion(predicted_full_adjs, ground_truth_adjs)
            total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    return avg_loss


# -------------------- Main Workflow --------------------

def main():
    device = torch.device(CONFIG["device"])
    print(f"Using device: {device}")

    # Load datasets
    print("Loading datasets...")
    with open(CONFIG["train_json_path"], "r") as f:
        train_data = json.load(f)
    with open(CONFIG["val_json_path"], "r") as f:
        val_data = json.load(f)
    test_data = pd.read_csv(CONFIG["test_csv_path"])
    print("Datasets loaded successfully.")

    nucleotide_to_index = {"A": 0, "U": 1, "G": 2, "C": 3}
    max_graph_size = max(len(entry["sequence"]) for entry in train_data.values())
    CONFIG["max_graph_size"] = max_graph_size  # Dynamically set max_graph_size
    print(f"Max graph size set to {max_graph_size}.")

    print("Preparing datasets...")
    train_dataset = RNASequenceSubgraphDataset(train_data, CONFIG["subgraph_size"], nucleotide_to_index)
    val_dataset = RNASequenceSubgraphDataset(val_data, CONFIG["subgraph_size"], nucleotide_to_index)
    print("Datasets prepared.")

    print("Initializing DataLoaders...")
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=CONFIG["batch_size"],
        shuffle=True,
        num_workers=12,  # Use all 12 CPU cores
        pin_memory=True  # Recommended when using a GPU
    )

    val_dataloader = DataLoader(
        val_dataset,
        batch_size=CONFIG["batch_size"],
        shuffle=False,
        num_workers=12,  # Use all 12 CPU cores
        pin_memory=True  # Recommended when using a GPU
    )
    print("DataLoaders initialized.")

    # Initialize models
    print("Initializing models...")
    graphrnn_model = HierarchicalGraphRNN(
        vocab_size=CONFIG["vocab_size"],
        embedding_dim=CONFIG["embedding_dim"],
        hidden_dim=CONFIG["hidden_dim"],
        max_graph_size=max_graph_size,
        subgraph_size=CONFIG["subgraph_size"],
    ).to(device)

    gnn_model = RNA3DGN(
        num_node_features=CONFIG["num_node_features"],
        hidden_channels=CONFIG["gnn_hidden_channels"],
        num_coordinate_dims=CONFIG["num_coordinate_dims"],
    ).to(device)
    print("Models initialized.")

    # Define optimizers and loss functions
    print("Setting up optimizers and loss functions...")
    graphrnn_optimizer = optim.Adam(graphrnn_model.parameters(), lr=CONFIG["learning_rate"])
    gnn_optimizer = optim.Adam(gnn_model.parameters(), lr=CONFIG["learning_rate"])
    graphrnn_criterion = nn.MSELoss()
    gnn_criterion = nn.SmoothL1Loss()
    print("Optimizers and loss functions set up.")

    # Train GraphRNN
    print("Starting training for GraphRNN...")
    for epoch in range(CONFIG["epochs"]):
        print(f"Epoch {epoch + 1}/{CONFIG['epochs']} - Training...")
        train_loss = train_subgraph_model(graphrnn_model, train_dataloader, graphrnn_optimizer, graphrnn_criterion, device)
        print(f"Epoch {epoch + 1}/{CONFIG['epochs']} - Validation...")
        val_loss = validate_graph_reconstruction(graphrnn_model, val_dataloader, graphrnn_criterion, device, CONFIG["subgraph_size"])
        print(f"Epoch {epoch + 1}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")

    print("GraphRNN training completed.")

    # Generate predictions for test data
    print("Generating predictions for test data...")
    submission_data = []
    for idx, row in enumerate(test_data.iterrows(), start=1):
        sequence = row[1]["sequence"]
        sequence_encoded = [nucleotide_to_index[n] for n in sequence]
        sequence_tensor = torch.tensor(sequence_encoded, dtype=torch.long).unsqueeze(0).to(device)

        with torch.no_grad():
            predicted_adj_matrix = graphrnn_model(sequence_tensor)
            predicted_coordinates = gnn_model(predicted_adj_matrix)

        for i, coords in enumerate(predicted_coordinates.cpu().numpy()):
            row_data = [f"{row[1]['target_id']}_{i + 1}", sequence[i], i + 1]
            row_data.extend(coords.flatten())
            submission_data.append(row_data)

        if idx % 10 == 0:  # Print progress every 10 sequences
            print(f"Processed {idx}/{len(test_data)} test sequences...")

    print("Test data predictions completed.")

    # Save predictions to submission.csv
    print(f"Saving predictions to {CONFIG['submission_csv_path']}...")
    columns = ["ID", "resname", "resid"] + [f"{axis}_{i + 1}" for i in range(5) for axis in ["x", "y", "z"]]
    submission_df = pd.DataFrame(submission_data, columns=columns)
    submission_df.to_csv(CONFIG["submission_csv_path"], index=False)
    print(f"Predictions saved to {CONFIG['submission_csv_path']}.")


if __name__ == "__main__":
    main()