In [None]:
import pickle 
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader, Subset
import torch.nn as nn
from sklearn.metrics import r2_score, mean_absolute_error, matthews_corrcoef, f1_score
from sklearn.model_selection import KFold
import numpy as np
import pickle
import pandas as pd
import random
import os

from sklearn.metrics import matthews_corrcoef, accuracy_score, f1_score, precision_score, recall_score
import numpy as np
from sklearn.model_selection import KFold


# Function to calculate MCC
def calculate_mcc(y_true, y_pred):
    y_pred = np.round(y_pred)
    return matthews_corrcoef(y_true, y_pred)

# Pearson correlation
def pearson_correlation(x, y):
    vx = x - torch.mean(x)
    vy = y - torch.mean(y)
    return torch.sum(vx * vy) / (torch.sqrt(torch.sum(vx ** 2)) * torch.sqrt(torch.sum(vy ** 2)))

# InteractionDataset class
class InteractionDataset(Dataset):
    def __init__(self, interaction_data, interaction_labels):
        self.interaction_data = interaction_data
        self.interaction_labels = interaction_labels

    def __len__(self):
        return len(self.interaction_labels)

    def __getitem__(self, idx):
        protein_emb, rna_emb = self.interaction_data[idx]
        label = self.interaction_labels[idx]
        label = torch.tensor(label)
        return (protein_emb.squeeze(0), rna_emb[0]), label




class GatedFeatureFusion(nn.Module):
    """Gates to control the contribution of each path"""
    def __init__(self, input_dim):
        super().__init__()
        self.gate = nn.Parameter(torch.rand(input_dim))

    def forward(self, x1, x2):
        # print(x1.shape, x2.shape)
        gate_values = torch.sigmoid(self.gate)
        return x1 * gate_values + x2 * (1 - gate_values)

# DualPathNetworkRegression class
class DualPathNetworkRegression(nn.Module):
    def __init__(self, protein_dim, rna_dim, hidden_dim):
        super().__init__()
        self.protein_path = nn.Sequential(
            nn.Linear(protein_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1)
        )
        self.rna_path = nn.Sequential(
            nn.Linear(rna_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1)
        )
        self.fusion = GatedFeatureFusion(hidden_dim)
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 1)
        )

    def forward(self, protein_emb, rna_emb):
        protein_features = self.protein_path(protein_emb)
        rna_features = self.rna_path(rna_emb)
        combined_features = self.fusion(protein_features, rna_features)
        return torch.sigmoid(self.classifier(combined_features)).squeeze(1)




# Function to load data and start cross-validation
def get_scores(file_rna, file_drug, file_interaction, model_path):

    device = torch.device("cpu")   
    print("Currently working on: ", file_interaction)
    embeddings_rna = pickle.load(open(file_rna, "rb"))
    embeddings_drug = pickle.load(open(file_drug, "rb"))
    interaction = pd.read_csv(file_interaction)

    dataset_rna = []
    dataset_drug = []
    dataset_label = []

    for index, row in interaction.iterrows():
        if row['Compound'] != row['Compound'] or row['Protein'] != row['Protein'] or row['Label'] != row['Label']:
            continue

        if type(row['Label']) == str:
            if row['Label'][0] == '-':
                continue
        label = float(row['Label'])
        dataset_rna.append(embeddings_rna[row['Protein']])
        dataset_drug.append(embeddings_drug[row['Compound']])
        
        if label > 10:
            label = 10
        dataset_label.append(label)

    interaction_dataset = InteractionDataset(list(zip(dataset_drug, dataset_rna)), dataset_label)
    interaction_dataloader= DataLoader(interaction_dataset, batch_size=32, shuffle=True)

    model = DualPathNetworkRegression(768, 768, 256)



    device = torch.device('cpu')
    model = model.to(device)
    # model = DualPathNetworkRegression(768, 768, 256).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.BCELoss()


    for epoch in range(50):
        print(epoch)
        model.train()
        for embeddings, targets in interaction_dataloader:
            protein_emb, rna_emb = embeddings
            # targets_scaled = targets / max_target_value
            protein_emb, rna_emb, targets = protein_emb.to(device), rna_emb.to(device), targets.to(device)
            optimizer.zero_grad()
            predictions = model(protein_emb, rna_emb)
            # print(predictions)
            loss = criterion(predictions, targets)
            loss.backward()
            optimizer.step()

        # Evaluate model
        model.eval()
        with torch.no_grad():
            val_predictions = []
            val_targets = []
            for embeddings, targets in interaction_dataloader:
                protein_emb, rna_emb = embeddings
                # targets_scaled = targets / max_target_value
                protein_emb, rna_emb, targets= protein_emb.to(device), rna_emb.to(device), targets.to(device)
                predictions = model(protein_emb, rna_emb)
                val_predictions.append(predictions)
                val_targets.append(targets)

            val_predictions = torch.cat(val_predictions)
            val_targets = torch.cat(val_targets)
            mcc= calculate_mcc(val_targets,val_predictions)
            
            print("MCC: ",mcc)



In [None]:
rna_file=r"C:\Research\MultiModal_Biological_LLM\Datasets\RNA_Molecule\Testing\Curated_Files\Aptamers_ROBIN_dataset_v1_rna_embeddings_dict.pkl"
molecule_file=r"C:\Research\MultiModal_Biological_LLM\Datasets\RNA_Molecule\Testing\Curated_Files\Aptamers_ROBIN_dataset_v1_drug_embeddings_dict.pkl"
interaction_file=r"C:\Research\MultiModal_Biological_LLM\Datasets\RNA_Molecule\Testing\Curated_Files\Aptamers_ROBIN_dataset_v1.csv"
model_path=r"C:\Research\MultiModal_Biological_LLM\Code\final_results\Aptamers_dataset_v1\best_model.pth"


get_scores(rna_file, molecule_file, interaction_file, model_path)