In [None]:
import os
import hashlib
import time
import torch
import logging
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
from Bio import SeqIO
from tqdm import tqdm
from functools import partial
from datetime import datetime
from torch.utils.data import DataLoader, Dataset, Subset
from transformers import (
    EsmTokenizer,
    EsmForMaskedLM,
    AutoTokenizer,
    AutoModel
)
from peft import PeftModel
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, accuracy_score, classification_report

In [None]:
# Distance Definition
class Cosine(nn.Module):
    def forward(self, x1, x2):
        return nn.CosineSimilarity(dim=-1)(x1, x2)

class SquaredCosine(nn.Module):
    def forward(self, x1, x2):
        return nn.CosineSimilarity(dim=-1)(x1, x2) ** 2

class Euclidean(nn.Module):
    def forward(self, x1, x2):
        return torch.cdist(x1, x2, p=2.0)

class SquaredEuclidean(nn.Module):
    def forward(self, x1, x2):
        return torch.cdist(x1, x2, p=2.0) ** 2

DISTANCE_METRICS = {
    "Cosine": Cosine,
    "SquaredCosine": SquaredCosine,
    "Euclidean": Euclidean,
    "SquaredEuclidean": SquaredEuclidean,
}

class Coembedding(nn.Module):
    def __init__(
        self,
        molecule_shape: int = 768,
        protein_shape: int = 1280,
        latent_dimension: int = 1024,
        latent_activation=nn.ReLU,
        latent_distance: str = "Cosine",
        classify: bool = True,
        temperature: float = 0.1
    ):
        super(Coembedding, self).__init__()
        self.molecule_shape = molecule_shape
        self.protein_shape = protein_shape
        self.latent_dimension = latent_dimension
        self.do_classify = classify

        self.temperature = nn.Parameter(torch.tensor(temperature))

        self.molecule_projector = nn.Sequential(
            nn.Linear(self.molecule_shape, latent_dimension),
            latent_activation(),
            nn.Linear(latent_dimension, latent_dimension)
        )
        
        for layer in self.molecule_projector:
            if isinstance(layer, nn.Linear): 
                nn.init.xavier_normal_(layer.weight)

        self.protein_projector = nn.Sequential(
            nn.Linear(self.protein_shape, latent_dimension),
            latent_activation(),
            nn.Linear(latent_dimension, latent_dimension)
        )
        
        for layer in self.protein_projector:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_normal_(layer.weight)

        if self.do_classify:
            if latent_distance not in DISTANCE_METRICS:
                raise ValueError(f"Unsupported distance metric: {latent_distance}")
            self.distance_metric = latent_distance
            self.activator = DISTANCE_METRICS[self.distance_metric]()

    def forward(self, molecule, protein):
        if self.do_classify:
            return self.classify(molecule, protein)
        else:
            return self.regress(molecule, protein)

    def regress(self, molecule, protein):
        molecule_projection = self.molecule_projector(molecule)
        protein_projection = self.protein_projector(protein)

        inner_prod = torch.bmm(
            molecule_projection.view(-1, 1, self.latent_dimension),
            protein_projection.view(-1, self.latent_dimension, 1),
        ).squeeze()
        relu_f = nn.ReLU()
        return relu_f(inner_prod).squeeze()

    def classify(self, molecule, protein):
        molecule_projection = self.molecule_projector(molecule)
        protein_projection = self.protein_projector(protein)

        molecule_projection = molecule_projection.unsqueeze(0) 
        protein_projection = protein_projection.unsqueeze(1) 

        distance = self.activator(molecule_projection, protein_projection)
        
        scaled_distance = distance / self.temperature

        return scaled_distance

In [None]:
def generate_anchor_embeddings_batch(sequences, tokenizer, lora_model, device):
    lora_model.to(device)
    inputs = tokenizer(sequences, return_tensors="pt", padding=True)
    inputs = {key: value.to(device) for key, value in inputs.items()}
    with torch.no_grad():
        output = lora_model.esm(**inputs).last_hidden_state
        mean_output = output[:, 1:output.size(1)].mean(dim=1)
    return mean_output.cpu()

def generate_mol_embeddings_batch(smiles_list, tokenizer, mol_model, device, target_dim=1280):
    try:
        mol_model.to(device)
        inputs = tokenizer(smiles_list, padding=True, return_tensors="pt")
        inputs = {key: value.to(device) for key, value in inputs.items()}
        with torch.no_grad():
            outputs = mol_model(**inputs)
            mol_embedding = outputs.pooler_output
        return mol_embedding.cpu()  # Move to CPU only after computation
    except Exception as e:
        print(f"Error processing SMILES: {smiles_list}, Error: {e}")
        return torch.zeros((len(smiles_list), target_dim))  # Return zero tensor for invalid SMILES

In [None]:
model_name = 'esm2/esm2_t33_650M_UR50D'
prot_tokenizer = EsmTokenizer.from_pretrained(model_name)
base_model = EsmForMaskedLM.from_pretrained(model_name)
prot_model = PeftModel.from_pretrained(base_model, './plm')

mol_model_path = "./ibm/MoLFormer-XL-both-10pct"
mol_tokenizer = AutoTokenizer.from_pretrained(mol_model_path, trust_remote_code=True)
mol_model = AutoModel.from_pretrained(mol_model_path, deterministic_eval=True, trust_remote_code=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = Coembedding().to(device)
model.load_state_dict(torch.load('model_weight/best_model.pth')['model_state_dict'])
model.eval()

data = pd.read_excel("T2_data_normalized.xlsx")
data = data.drop_duplicates(subset = ['canonicalsmiles'])

# Dataset split
train_data, test_data = train_test_split(data, test_size=0.2, random_state=42)

# Extract sequences and SMILES from the split test set
test_prot_seq = test_data.sequence.tolist()
mol_smiles = data.canonicalsmiles.tolist()

unique_labels = data.T2PKproductsname.tolist()
label_to_index = {productsname: idx for idx, productsname in enumerate(unique_labels)}
index_to_label = {idx: productsname for productsname, idx in label_to_index.items()}
true_labels = test_data['T2PKproductsname'].map(label_to_index).tolist()

# Function to calculate top-k accuracy
def calculate_topk_accuracy(model, mol_emb, prot_emb, true_labels, k=3):
    with torch.no_grad():
        outputs = model(mol_emb, prot_emb)
        
        # Get top-k predictions
        _, top_k_preds = torch.topk(outputs, k, dim=1)
        top_k_preds = top_k_preds.cpu().numpy()
        
        # Initialize accuracy counter
        accuracy_count = 0
        
        # Iterate through each sample's predictions and true labels
        for i in range(len(true_labels)):
            if true_labels[i] in top_k_preds[i]:  # If true label is in top-k predictions
                accuracy_count += 1
        
        # Calculate accuracy
        accuracy = accuracy_count / len(true_labels)
        return accuracy

# Generate embeddings for training and test sets
test_prot_emb = generate_anchor_embeddings_batch(test_prot_seq, prot_tokenizer, prot_model, device).to(device)
mol_emb = generate_mol_embeddings_batch(mol_smiles, mol_tokenizer, mol_model, device).to(device)

# Calculate accuracy for different k values
test_top1_accuracy = calculate_topk_accuracy(model, mol_emb, test_prot_emb, true_labels, k=1)
test_top3_accuracy = calculate_topk_accuracy(model, mol_emb, test_prot_emb, true_labels, k=3)
test_top5_accuracy = calculate_topk_accuracy(model, mol_emb, test_prot_emb, true_labels, k=5)

print(f"Test Top-1 Accuracy: {test_top1_accuracy:.4f}")
print(f"Test Top-3 Accuracy: {test_top3_accuracy:.4f}") 
print(f"Test Top-5 Accuracy: {test_top5_accuracy:.4f}")