In [2]:
import RNA
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt

from main import sequences
from src.dataset import load_benchmark_dataset, Species, Modification

In [3]:
test_dataset = load_benchmark_dataset(Species.human, Modification.psi, True)
train_dataset = load_benchmark_dataset(Species.human, Modification.psi)

In [6]:
import torch
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
import numpy as np
from sklearn.metrics import classification_report
from src.dataset import load_benchmark_dataset, Species, Modification
import RNA

# Step 1: Data Preparation
def sequence_to_graph(sequence):
    # Generate dot-bracket notation
    (dot_bracket, _) = RNA.fold(sequence)
    
    # Create node features (one-hot encoding of bases)
    bases = ['A', 'U', 'G', 'C']
    node_features = np.zeros((len(sequence), len(bases) + 1))
    for i, base in enumerate(sequence):
        node_features[i, bases.index(base)] = 1
        node_features[i, -1] = int(i == len(sequence) // 2)  # Mark central nucleotide
    
    # Create edges
    edges = []
    stack = []
    for i, char in enumerate(dot_bracket):
        if i > 0:
            edges.append((i-1, i))  # Backbone connections
        if char == '(':
            stack.append(i)
        elif char == ')':
            if stack:
                edges.append((stack.pop(), i))  # Base pairs
    
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    
    return torch.tensor(node_features, dtype=torch.float), edge_index

class RNADataset(torch.utils.data.Dataset):
    def __init__(self, sequences, labels):
        self.sequences = sequences
        self.labels = labels

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence = self.sequences.iloc[idx]
        label = self.labels.iloc[idx]
        x, edge_index = sequence_to_graph(sequence[0])
        return Data(x=x, edge_index=edge_index, y=torch.tensor([label], dtype=torch.long))

# Step 2: Model Definitions
class ShallowGCN(torch.nn.Module):
    def __init__(self, num_node_features):
        super(ShallowGCN, self).__init__()
        self.conv1 = GCNConv(num_node_features, 16)
        self.conv2 = GCNConv(16, 2)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return F.log_softmax(x[len(x)//2], dim=0)  # Only central node

class ComplexGCN(torch.nn.Module):
    def __init__(self, num_node_features):
        super(ComplexGCN, self).__init__()
        self.conv1 = GCNConv(num_node_features, 64)
        self.conv2 = GCNConv(64, 32)
        self.conv3 = GCNConv(32, 16)
        self.fc = torch.nn.Linear(16, 2)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = F.relu(self.conv3(x, edge_index))
        x = self.fc(x[len(x)//2])  # Only central node
        return F.log_softmax(x, dim=0)

# Step 3: Training Function
def train(model, loader, optimizer, device):
    model.train()
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, data.y)
        loss.backward()
        optimizer.step()

# Step 4: Evaluation Function
def evaluate(model, loader, device):
    model.eval()
    correct = 0
    for data in loader:
        data = data.to(device)
        with torch.no_grad():
            pred = model(data).max(1)[1]
        correct += pred.eq(data.y).sum().item()
    return correct / len(loader.dataset)

# Step 5: Main Pipeline
def run_pipeline():
    # Load datasets
    train_dataset = load_benchmark_dataset(Species.human, Modification.psi)
    test_dataset = load_benchmark_dataset(Species.human, Modification.psi, True)

    # Create PyTorch Geometric datasets
    train_data = RNADataset(train_dataset.samples, train_dataset.targets)
    test_data = RNADataset(test_dataset.samples, test_dataset.targets)

    # Create data loaders
    train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=32)

    # Set up device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Initialize models
    num_node_features = train_data[0].num_node_features
    shallow_gcn = ShallowGCN(num_node_features).to(device)
    complex_gcn = ComplexGCN(num_node_features).to(device)

    # Set up optimizers
    shallow_optimizer = torch.optim.Adam(shallow_gcn.parameters(), lr=0.01)
    complex_optimizer = torch.optim.Adam(complex_gcn.parameters(), lr=0.01)

    # Training loop
    for epoch in range(100):
        train(shallow_gcn, train_loader, shallow_optimizer, device)
        train(complex_gcn, train_loader, complex_optimizer, device)
        
        if (epoch + 1) % 10 == 0:
            shallow_acc = evaluate(shallow_gcn, test_loader, device)
            complex_acc = evaluate(complex_gcn, test_loader, device)
            print(f'Epoch {epoch+1}: Shallow GCN Acc: {shallow_acc:.4f}, Complex GCN Acc: {complex_acc:.4f}')

    # Final evaluation and classification report
    def get_predictions(model, loader, device):
        model.eval()
        predictions = []
        true_labels = []
        for data in loader:
            data = data.to(device)
            with torch.no_grad():
                pred = model(data).max(1)[1]
            predictions.extend(pred.cpu().numpy())
            true_labels.extend(data.y.cpu().numpy())
        return np.array(predictions), np.array(true_labels)

    shallow_preds, true_labels = get_predictions(shallow_gcn, test_loader, device)
    complex_preds, _ = get_predictions(complex_gcn, test_loader, device)

    print("Shallow GCN Classification Report:")
    print(classification_report(true_labels, shallow_preds))
    
    print("\nComplex GCN Classification Report:")
    print(classification_report(true_labels, complex_preds))

if __name__ == "__main__":
    run_pipeline()

  x, edge_index = sequence_to_graph(sequence[0])


RuntimeError: size mismatch (got input: [2], target: [32])

In [5]:
train_sequences = train_dataset.samples['sequence'].values