In [None]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader, random_split, Subset
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import MultiheadAttention, Sequential
import matplotlib.pyplot as plt
from tqdm import tqdm
from Bio import SeqIO
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit import DataStructs
from rdkit.Chem import rdFingerprintGenerator
import numpy as np
import time
import warnings
from sklearn import metrics
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split, StratifiedKFold
import os

In [None]:
class DrugVirusDataset(Dataset):
    def __init__(self, drug1_sim, drug1_tensor, drug2_sim, drug2_tensor, virus_tensor, label):

        self.drug1_tensor = drug1_tensor
        self.drug1_sim = drug1_sim
        self.drug2_tensor = drug2_tensor
        self.drug2_sim = drug2_sim
        self.virus_tensor = virus_tensor
        self.label = label

    def __len__(self):
        
        return len(self.drug1_tensor)

    def __getitem__(self, idx):
       
        drug1 = self.drug1_tensor[idx]
        drug1_sim = self.drug1_sim[idx]
        drug2 = self.drug2_tensor[idx]
        drug2_sim = self.drug2_sim[idx]
        virus = self.virus_tensor[idx]
        label = self.label[idx]

        return drug1, drug1_sim, drug2, drug2_sim, virus, label


In [None]:
def parse_scientific_notation(value):
    try:
        return int(float(value))
    except ValueError:
        return None

In [None]:
def read_combinations_index_file(file_path):
    combinations = []
    with open(file_path, 'r') as file:
        for line in file:
            temp_line = line.strip().split()
            indices = []
            for i in range(len(temp_line)):
                indices.append(parse_scientific_notation(temp_line[i]))
            combinations.append(indices)
    return combinations

In [None]:
combinations = read_combinations_index_file('Data/XIndex_ratio100.txt')

In [None]:
def read_label_file(file_path):
    with open(file_path, 'r') as file:
        lines = file.readlines()
        lines = [parse_scientific_notation(line.strip()) for line in lines]
    return lines

In [None]:
label_file_path = 'Data/Y_ratio100.txt'
labels_list = read_label_file(label_file_path)
label_tensor = torch.tensor(labels_list,dtype = torch.float)

In [None]:
drugs_df = pd.read_excel('Data/small_mulecule_drug_similarity.xlsx')
drugs = drugs_df['smiles']
mols = [Chem.MolFromSmiles(smi) for smi in drugs_df['smiles']]
rdkit_gen = rdFingerprintGenerator.GetRDKitFPGenerator(maxPath=7)
fps = [rdkit_gen.GetFingerprint(mol) for mol in mols]
smile_tensor = torch.tensor(fps,dtype = torch.float)

In [None]:
viruses_df = np.loadtxt('Data/Similarity_Matrix_Viruses.txt')
viruses = torch.tensor(viruses_df, dtype=torch.float)

In [None]:
drugs_sim_df = np.loadtxt('Data/Similarity_Matrix_Drugs.txt')
drugs_sim = torch.tensor(drugs_sim_df, dtype=torch.float)

In [None]:
drug1_tensor = torch.randn(len(combinations),2048)
drug1_sim_tensor = torch.randn(len(combinations),211)
drug2_tensor = torch.randn(len(combinations),2048)
drug2_sim_tensor = torch.randn(len(combinations),211)
virus_tensor = torch.randn(len(combinations),44)
for i in range(len(combinations)):
    drug1_tensor[i] = smile_tensor[combinations[i][0]]
    drug1_sim_tensor[i] = drugs_sim[combinations[i][0]]
    drug2_tensor[i] = smile_tensor[combinations[i][1]]
    drug2_sim_tensor[i] = drugs_sim[combinations[i][1]]
    virus_tensor[i] = viruses[combinations[i][2]]

In [None]:
dataset = DrugVirusDataset(drug1_sim_tensor, drug1_tensor, drug2_sim_tensor, drug2_tensor, virus_tensor, label_tensor)

indices = list(range(len(dataset)))


train_indices, temp_indices, train_labels, temp_labels = train_test_split(indices, label_tensor, test_size=0.2, stratify=label_tensor, random_state=42)

test_indices, val_indices, _, _ = train_test_split(temp_indices, temp_labels, test_size=0.5, stratify=temp_labels, random_state=42)


train_set = Subset(dataset, train_indices)
test_set = Subset(dataset, test_indices)
val_set = Subset(dataset, val_indices)

batch_size = 32
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True)

In [None]:
class Self_Attention(nn.Module):
    def __init__(self,input_dim, embedding_dim, num_heads):
        super(Self_Attention, self).__init__()
        self.projection = nn.Linear(input_dim, embedding_dim)
        self.self_attention = MultiheadAttention(embed_dim=embedding_dim, num_heads=num_heads,  batch_first =True, dropout = 0.2)
        self.layerNorm = nn.LayerNorm(embedding_dim)

    def forward(self, x):
        
        proj = self.projection(x).unsqueeze(1)

        attn_output, _ = self.self_attention(proj, proj, proj)
        output = self.layerNorm(attn_output+proj)
        return output.squeeze(1)

In [None]:
class Cross_Attention(nn.Module):
    def __init__(self,query_dim,key_dim, embed_dim, num_heads):
        super(Cross_Attention, self).__init__()
        self.cross_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first =True, dropout = 0.2 )
        self.embed_dim = embed_dim
        self.query_projection = nn.Linear(query_dim, embed_dim)
        self.key_projection = nn.Linear(key_dim, embed_dim)
        self.layerNorm = nn.LayerNorm(embed_dim)

    def forward(self, query, key):
        
        query_proj = self.query_projection(query).unsqueeze(1) 
        key_proj = self.key_projection(key).unsqueeze(1)  

        
        attn_output, _ = self.cross_attn(query_proj, key_proj, key_proj)
        output = self.layerNorm(attn_output+query_proj)

        return output.squeeze(1) 

In [None]:
class attentionCancat(nn.Module):
    def __init__(self,embeddding_dim, drug1_dim: int, drug2_dim: int, virus_sim_dim: int, drug_sim_dim: int, num_heads: int):
        super().__init__()
        
        self.v = Self_Attention(virus_sim_dim, embeddding_dim, num_heads)
        self.d1d2 = Cross_Attention(drug_sim_dim, drug_sim_dim, embeddding_dim, num_heads)
        self.d2d1 = Cross_Attention(drug_sim_dim, drug_sim_dim, embeddding_dim, num_heads)
    
    def forward(self, drug1, drug2, virus):
        v = self.v(virus)
        d1d2 = self.d1d2(drug1,drug2)
        d2d1 = self.d2d1(drug2,drug1)
        x = [v,d1d2,d2d1]
        return x

In [None]:
class ClassificatioLayer(nn.Module):
    def __init__(self, drug1_dim: int, drug2_dim: int, virus_sim_dim: int, drug_sim_dim: int, embedding_dim: int, num_heads: int):
        super().__init__()
        self.att_cat = attentionCancat(embedding_dim, drug1_dim, drug2_dim, virus_sim_dim, drug_sim_dim, num_heads)

        
    def forward(self, d1, d2, v, d1_sim, d2_sim):
        attn_cat = self.att_cat(d1_sim, d2_sim, v)
        # Calculate dot product for classification
        dot_product = torch.sum(attn_cat[0] * attn_cat[1] *  attn_cat[2], dim=-1)

        output = dot_product.squeeze()
        return output

In [None]:
def evaluate(ground_truth, predictions, prediction_scores):
    set_of_labels = set(ground_truth)
    assert False not in [label in set_of_labels for label in predictions],\
           'Predicted labels must be valid'      
    accuracy = metrics.accuracy_score(ground_truth, predictions)
    precision = metrics.precision_score(ground_truth, predictions)
    recall = metrics.recall_score(ground_truth, predictions)
    f1_score = metrics.f1_score(ground_truth, predictions)
    MCC = metrics.matthews_corrcoef(ground_truth, predictions)
    
    # Calculate confusion matrix
    tn, fp, fn, tp = confusion_matrix(ground_truth, predictions).ravel()

    # Calculate specificity
    specificity = tn / (tn + fp)
    
    AUROC = metrics.roc_auc_score(ground_truth, prediction_scores)
    AUPR = metrics.average_precision_score(ground_truth, prediction_scores)

    return [accuracy, precision, recall, specificity, f1_score, MCC, AUROC, AUPR]

In [None]:
#Model, Loss, Optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ClassificatioLayer(2048, 2048, 44, 211, 1024 ,4).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

In [None]:
num_epochs = 200
train_losses = []
best_aupr = 0
best_model_path = 'best_model.pth'
for epoch in range(num_epochs):
    #print('-'*100)
    model.train()
    running_loss = 0.0
    for d1, d1_sim ,d2, d2_sim,v, labels in tqdm(train_loader):
        optimizer.zero_grad()
        outputs = model(d1.to(device),d2.to(device),v.to(device), d1_sim.to(device), d2_sim.to(device))
        loss = criterion(outputs.to(device), labels.to(device))
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    
    train_loss = running_loss / len(train_loader)
    train_losses.append(train_loss)
    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}')
    
    model.eval()
    y_true = []
    y_pred = []
    y_prob = []
    
    with torch.no_grad():
         for d1, d1_sim ,d2, d2_sim,v, labels in val_loader:
            outputs = model(d1.to(device),d2.to(device),v.to(device), d1_sim.to(device), d2_sim.to(device))
            predicted = (outputs.view(-1) > 0.5).float()
            probs = (outputs.view(-1)).float()
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(predicted.cpu().numpy())
            y_prob.extend(probs.cpu().numpy())
    # Compute evaluation metrics
    evaluation_metrics = evaluate(y_true, y_pred, y_prob)
    accuracy = evaluation_metrics[0]
    precision = evaluation_metrics[1]
    recall = evaluation_metrics[2]
    specificity = evaluation_metrics[3]
    f1_score = evaluation_metrics[4]
    MCC = evaluation_metrics[5]
    AUROC = evaluation_metrics[6]
    AUPR = evaluation_metrics[7]
    print(f'Accuracy: {accuracy:.4f}')
    print(f'Precision: {precision:.4f}')
    print(f'Recall: {recall:.4f}')
    print(f'specificity: {specificity:.4f}')
    print(f'F1 Score: {f1_score:.4f}')
    print(f'MCC: {MCC:.4f}')
    print(f'AUROC: {AUROC:.4f}')
    print(f'AUPR: {AUPR:.4f}')
    if (AUPR > best_aupr):
        best_aupr = AUPR
        print(epoch, AUPR)
        torch.save(model.state_dict(), best_model_path)
        
    
    # # Plot the training loss
plt.plot(range(num_epochs), train_losses, label='Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

y_true = []
y_pred = []
y_prob = []

# Load the best model for testing
best_model = ClassificatioLayer(2048, 2048, 44, 211, 1024 ,4).to(device)
best_model.to(device)
best_model.load_state_dict(torch.load(best_model_path))
best_model.eval()

with torch.no_grad():
     for d1, d1_sim ,d2, d2_sim,v, labels in test_loader:
        labels.to(device)

        outputs = best_model(d1.to(device),d2.to(device),v.to(device), d1_sim.to(device), d2_sim.to(device))

        predicted = (outputs.view(-1) > 0.5).float()
        probs = (outputs.view(-1)).float()
        y_true.extend(labels.cpu().numpy())
        y_pred.extend(predicted.cpu().numpy())
        y_prob.extend(probs.cpu().numpy())

# Compute evaluation metrics
evaluation_metrics = evaluate(y_true, y_pred, y_prob)
accuracy = evaluation_metrics[0]
precision = evaluation_metrics[1]
recall = evaluation_metrics[2]
specificity = evaluation_metrics[3]
f1_score = evaluation_metrics[4]
MCC = evaluation_metrics[5]
AUROC = evaluation_metrics[6]
AUPR = evaluation_metrics[7]

print(f'Accuracy: {accuracy:.4f}')
print(f'Precision: {precision:.4f}')
print(f'Recall: {recall:.4f}')
print(f'specificity: {specificity:.4f}')
print(f'F1 Score: {f1_score:.4f}')
print(f'MCC: {MCC:.4f}')
print(f'AUROC: {AUROC:.4f}')
print(f'AUPR: {AUPR:.4f}')