# Importing RNA Protein Dataset

In [None]:
import pandas as pd 
import numpy as np
import pickle


protein_1460 = pickle.load(open(r'C:\Research\MultiModal_Biological_LLM\Datasets\Protein_RNA\Processed_Files\1807\protein_embeddings_1807.pkl', 'rb'))
rna_1460 = pickle.load(open(r'C:\Research\MultiModal_Biological_LLM\Datasets\Protein_RNA\Processed_Files\1807\rna_embeddings_1807.pkl', 'rb'))
lebels = pickle.load(open(r'C:\Research\MultiModal_Biological_LLM\Datasets\Protein_RNA\Processed_Files\1807\labels_1807.pkl', 'rb'))



In [None]:
rna_1460[0].shape

# Creating the dataloader

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader

class InteractionDataset(Dataset):
    def __init__(self, interaction_data, interaction_labels):
        """
        Args:
            interaction_data (list): A list where each item is another list containing
                                     the protein embeddings as the first element and
                                     the RNA embeddings as the second element.
            interaction_labels (list): A list of labels (0 or 1).
        """
        self.interaction_data = interaction_data
        self.interaction_labels = interaction_labels

    def __len__(self):
        return len(self.interaction_labels)

    def __getitem__(self, idx):
        protein_emb, rna_emb = self.interaction_data[idx]
        label = self.interaction_labels[idx]
        label = torch.tensor(int(label))    
        return (protein_emb.squeeze(0), rna_emb[0]), label

interaction_dataset = InteractionDataset(list(zip(protein_1460, rna_1460)), lebels)


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from sklearn.metrics import matthews_corrcoef
import numpy as np


dataset_size = len(interaction_dataset)
train_size = int(dataset_size * 1)
val_size = dataset_size - train_size

train_dataset, val_dataset = random_split(interaction_dataset, [train_size, val_size])

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False)


In [None]:
for data, label in train_dataloader:
    print(data[0].shape)
    print(data[1].shape)
    print(label)
    break

In [None]:
protein_emb_size = 1024
rna_emb_size = 768

In [None]:
import torch
import torch.nn as nn

class GatedFeatureFusion(nn.Module):
    """Gates to control the contribution of each path"""
    def __init__(self, input_dim):
        super().__init__()
        self.gate = nn.Parameter(torch.rand(input_dim))

    def forward(self, x1, x2):
        # print(x1.shape, x2.shape)
        gate_values = torch.sigmoid(self.gate)
        return x1 * gate_values + x2 * (1 - gate_values), gate_values

class DualPathNetwork(nn.Module):
    def __init__(self, protein_dim, rna_dim, hidden_dim):
        super().__init__()
        self.protein_path = nn.Sequential(
            nn.Linear(protein_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1)
        )
        self.rna_path = nn.Sequential(
            nn.Linear(rna_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1)
        )
        self.fusion = GatedFeatureFusion(hidden_dim)
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 1)
        )

    def forward(self, protein_emb, rna_emb):
        protein_features = self.protein_path(protein_emb)
        rna_features = self.rna_path(rna_emb)
        combined_features, gate_values = self.fusion(protein_features, rna_features)
        logits = self.classifier(combined_features)
        return torch.sigmoid(logits), gate_values


model_dual_path = DualPathNetwork(protein_dim=1024, rna_dim=768, hidden_dim=512)


In [None]:
class ProteinRNAClassifier(nn.Module):
    def __init__(self, protein_dim, rna_dim):
        super().__init__()
        # Separate fully connected layers for protein and RNA features
        self.protein_fc = nn.Linear(protein_dim, 512)  # Process protein features
        self.rna_fc = nn.Linear(rna_dim, 512)          # Process RNA features
        
        # Combined fully connected layers for the concatenated features
        self.combined_fc1 = nn.Linear(512 * 2, 512)
        self.combined_fc2 = nn.Linear(512, 64)
        self.combined_fc3 = nn.Linear(64, 1)  # Output layer for binary classification

    def forward(self, protein, rna):
        # Process protein and RNA features through their respective fully connected layers
        protein = F.relu(self.protein_fc(protein))
        rna = F.relu(self.rna_fc(rna))
        
        # Concatenate the processed features
        combined = torch.cat((protein, rna), dim=1)
        
        # Fully connected layers with ReLU activations for the combined features
        x = F.relu(self.combined_fc1(combined))
        x = F.relu(self.combined_fc2(x))
        x = torch.sigmoid(self.combined_fc3(x)).squeeze(1)  # Use sigmoid for binary classification
        
        return x

protein_dim = 1024
rna_dim = 768
# Initialize the model
model = ProteinRNAClassifier(protein_dim, rna_dim)

# count the number of trainable parameters in the model
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

In [None]:


def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False  

set_seed(42)

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, Subset
from sklearn.metrics import matthews_corrcoef, accuracy_score, f1_score, precision_score, recall_score
import numpy as np
from sklearn.model_selection import KFold


batch_size = 32



# Function to calculate MCC
def calculate_mcc(y_true, y_pred):
    y_pred = np.round(y_pred)
    return matthews_corrcoef(y_true, y_pred)

# Function to calculate additional metrics
def calculate_metrics(y_true, y_pred):
    y_pred = np.round(y_pred)
    acc = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred)
    recall = recall_score(y_true, y_pred)
    specificity = recall_score(y_true, y_pred, pos_label=0)
    return acc, f1, precision, recall, specificity

# Training and validation function
def train_model(train_loader, val_loader, epochs):
    best_metrics = {
        'mcc': 0,
        'acc': 0,
        'f1': 0,
        'precision': 0,
        'recall': 0,
        'specificity': 0
    }
    model = DualPathNetwork(1024, 768, hidden_dim=512)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.BCELoss()

    # Device configuration
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    for epoch in range(epochs):
        model.train()
        train_labels = []
        train_preds = []
        for embeddings, labels in train_loader:
            protein_emb, rna_emb = embeddings
            protein_emb, rna_emb, labels = protein_emb.to(device), rna_emb.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs, _ = model(protein_emb, rna_emb)
            outputs = outputs.squeeze()
            loss = criterion(outputs, labels.float())
            loss.backward()
            optimizer.step()
            train_labels.extend(labels.tolist())
            train_preds.extend(outputs.detach().cpu().numpy())

        train_mcc = calculate_mcc(train_labels, train_preds)
        
        # Validation loop
        model.eval()
        with torch.no_grad():
            val_labels = []
            val_preds = []
            for embeddings, labels in val_loader:
                protein_emb, rna_emb = embeddings
                protein_emb, rna_emb, labels = protein_emb.to(device), rna_emb.to(device), labels.to(device)
                outputs, _ = model(protein_emb, rna_emb)
                outputs = outputs.squeeze()
                val_labels.extend(labels.tolist())
                val_preds.extend(outputs.cpu().numpy())
        
        val_mcc = calculate_mcc(val_labels, val_preds)
        if val_mcc > best_metrics['mcc']:
            best_metrics['mcc'] = val_mcc
            best_metrics['acc'], best_metrics['f1'], best_metrics['precision'], best_metrics['recall'], best_metrics['specificity'] = calculate_metrics(val_labels, val_preds)
            print(f'Epoch {epoch+1}: Train MCC: {train_mcc:.4f}, Val MCC: {val_mcc:.4f}')

       
    
    return best_metrics

# 10-fold cross-validation
kf = KFold(n_splits=10, shuffle=True)
cv_results = {}

for fold, (train_idx, val_idx) in enumerate(kf.split(interaction_dataset)):
    train_subset = Subset(interaction_dataset, train_idx)
    val_subset = Subset(interaction_dataset, val_idx)
    
    train_dataloader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
    val_dataloader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)
    
    print(f'Fold {fold+1}')
    best_metrics = train_model(train_dataloader, val_dataloader, epochs=100)
    cv_results[fold + 1] = best_metrics

print(cv_results)
