In [None]:
import os
import hashlib
import time
import torch
import logging
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
from Bio import SeqIO
from tqdm import tqdm
from functools import partial
from datetime import datetime
from torch.utils.data import DataLoader, Dataset, Subset
from transformers import (
    EsmTokenizer,
    EsmForMaskedLM,
    AutoTokenizer,
    AutoModel
)
from peft import PeftModel
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, accuracy_score, classification_report

In [None]:
class ContrastiveLoss:
    def __init__(self, temperature=0.1):
        self.temperature = temperature

    def __call__(self, anchor_projection, positive_projection):
        return self.loss_fn(anchor_projection, positive_projection)

    def loss_fn(self, anchor_projection, positive_projection):
        similarity_matrix = torch.matmul(anchor_projection, positive_projection.T)
        similarity_matrix = similarity_matrix / self.temperature
        
        pos_sim = torch.diag(similarity_matrix)

        lprobs_pocket = F.log_softmax(similarity_matrix, dim=1)
        indices = torch.arange(len(pos_sim))
        L_pocket = -lprobs_pocket[indices, indices].mean()

        lprobs_mol = F.log_softmax(similarity_matrix.T, dim=1)
        L_mol = -lprobs_mol[indices, indices].mean()

        loss = 0.5 * (L_pocket + L_mol)

        return loss

# Distance Definition
class Cosine(nn.Module):
    def forward(self, x1, x2):
        return nn.CosineSimilarity(dim=-1)(x1, x2)

class SquaredCosine(nn.Module):
    def forward(self, x1, x2):
        return nn.CosineSimilarity(dim=-1)(x1, x2) ** 2

class Euclidean(nn.Module):
    def forward(self, x1, x2):
        return torch.cdist(x1, x2, p=2.0)

class SquaredEuclidean(nn.Module):
    def forward(self, x1, x2):
        return torch.cdist(x1, x2, p=2.0) ** 2

DISTANCE_METRICS = {
    "Cosine": Cosine,
    "SquaredCosine": SquaredCosine,
    "Euclidean": Euclidean,
    "SquaredEuclidean": SquaredEuclidean,
}

class Coembedding(nn.Module):
    def __init__(
        self,
        molecule_shape: int = 768,
        protein_shape: int = 1280,
        latent_dimension: int = 1024,
        latent_activation=nn.ReLU,
        latent_distance: str = "Cosine",
        classify: bool = True,
        temperature: float = 0.1
    ):
        super(Coembedding, self).__init__()
        self.molecule_shape = molecule_shape
        self.protein_shape = protein_shape
        self.latent_dimension = latent_dimension
        self.do_classify = classify

        self.temperature = nn.Parameter(torch.tensor(temperature))

        self.molecule_projector = nn.Sequential(
            nn.Linear(self.molecule_shape, latent_dimension),
            latent_activation(),
            nn.Linear(latent_dimension, latent_dimension)
        )
        
        for layer in self.molecule_projector:
            if isinstance(layer, nn.Linear): 
                nn.init.xavier_normal_(layer.weight)

        self.protein_projector = nn.Sequential(
            nn.Linear(self.protein_shape, latent_dimension),
            latent_activation(),
            nn.Linear(latent_dimension, latent_dimension)
        )
        
        for layer in self.protein_projector:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_normal_(layer.weight)

        if self.do_classify:
            if latent_distance not in DISTANCE_METRICS:
                raise ValueError(f"Unsupported distance metric: {latent_distance}")
            self.distance_metric = latent_distance
            self.activator = DISTANCE_METRICS[self.distance_metric]()

    def forward(self, molecule, protein):
        if self.do_classify:
            return self.classify(molecule, protein)
        else:
            return self.regress(molecule, protein)

    def regress(self, molecule, protein):
        molecule_projection = self.molecule_projector(molecule)
        protein_projection = self.protein_projector(protein)

        inner_prod = torch.bmm(
            molecule_projection.view(-1, 1, self.latent_dimension),
            protein_projection.view(-1, self.latent_dimension, 1),
        ).squeeze()
        relu_f = nn.ReLU()
        return relu_f(inner_prod).squeeze()

    def classify(self, molecule, protein):
        molecule_projection = self.molecule_projector(molecule)
        protein_projection = self.protein_projector(protein)

        molecule_projection = molecule_projection.unsqueeze(0) 
        protein_projection = protein_projection.unsqueeze(1) 

        distance = self.activator(molecule_projection, protein_projection)
        
        scaled_distance = distance / self.temperature

        return scaled_distance

In [None]:
class ContrastiveDataset(Dataset):
    def __init__(self, dataframe, prot_tokenizer, prot_model, mol_tokenizer, mol_model, device):
        self.data = dataframe
        self.prot_tokenizer = prot_tokenizer
        self.prot_model = prot_model.to(device)  # Move model to device
        self.mol_tokenizer = mol_tokenizer
        self.mol_model = mol_model.to(device)    # Move model to device
        self.device = device
        
        # Freeze model parameters to save memory and computation
        self.prot_model.eval()
        self.mol_model.eval()
        for param in self.prot_model.parameters():
            param.requires_grad = False
        for param in self.mol_model.parameters():
            param.requires_grad = False

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        anchor = row['sequence']
        positive = row['canonicalsmiles']
        return anchor, positive

def generate_anchor_embeddings_batch(sequences, tokenizer, lora_model, device):
    lora_model.to(device)
    inputs = tokenizer(sequences, return_tensors="pt", padding=True)
    inputs = {key: value.to(device) for key, value in inputs.items()}
    with torch.no_grad():
        output = lora_model.esm(**inputs).last_hidden_state
        mean_output = output[:, 1:output.size(1)].mean(dim=1)
    return mean_output.cpu()

def generate_mol_embeddings_batch(smiles_list, tokenizer, mol_model, device, target_dim=1280):
    try:
        mol_model.to(device)
        inputs = tokenizer(smiles_list, padding=True, return_tensors="pt")
        inputs = {key: value.to(device) for key, value in inputs.items()}
        with torch.no_grad():
            outputs = mol_model(**inputs)
            mol_embedding = outputs.pooler_output
        return mol_embedding.cpu()  # Move to CPU only after computation
    except Exception as e:
        print(f"Error processing SMILES: {smiles_list}, Error: {e}")
        return torch.zeros((len(smiles_list), target_dim))  # Return zero tensor for invalid SMILES

def contrastive_collate_fn(batch, prot_tokenizer, prot_model, mol_tokenizer, mol_model, device):
    anchors, positives = zip(*batch)
    
    # Batch process anchor sequences
    anchor_embs = generate_anchor_embeddings_batch(
        anchors, prot_tokenizer, prot_model, device
    )
    
    # Batch process positive and negative SMILES
    positive_embs = generate_mol_embeddings_batch(
        positives, mol_tokenizer, mol_model, device
    )


    return {
        'anchorEmb': anchor_embs,
        'positiveEmb': positive_embs,
    }

In [None]:
class ClassificationDataset(Dataset):
    def __init__(self, dataframe, prot_tokenizer, prot_model, device):
        self.data = dataframe
        self.prot_tokenizer = prot_tokenizer
        self.prot_model = prot_model.to(device)
        self.device = device
        
        self.prot_model.eval()
        for param in self.prot_model.parameters():
            param.requires_grad = False
        
        self.index2label = {index: canonicalsmiles for index, canonicalsmiles in enumerate(self.data.canonicalsmiles.unique())}
        self.label2index = {canonicalsmiles: index for index, canonicalsmiles in enumerate(self.data.canonicalsmiles.unique())}
        self.data['num_labels'] = self.data.canonicalsmiles.map(self.label2index).tolist()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        return row['sequence'], row[['num_labels']]
    
def classification_collate_fn(batch, prot_tokenizer, prot_model, device):
    proteins, labels = zip(*batch)

    prot_embs = generate_anchor_embeddings_batch(proteins, prot_tokenizer, prot_model, device)

    return {
        'prot_emb': prot_embs,
        'labels': torch.tensor(labels, dtype=torch.long).squeeze(-1)
    }

In [None]:
# Test Evaluation Function #

def evaluate_model(model, mol_embs, test_loader, loss_fn, device):
    model.eval() 
    test_loss = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():  # Disable gradient calculation
        for batch_idx, batch in enumerate(test_loader):
            prot_embs = batch['prot_emb'].to(device)
            labels = batch['labels'].to(device)

            logits = model(mol_embs, prot_embs)
            loss = loss_fn(logits, labels)
            test_loss += loss.item()

            # Get predictions and save
            preds = torch.argmax(F.softmax(logits, dim=-1), dim=-1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Calculate average test loss
    avg_test_loss = test_loss / len(test_loader)
    
    # Calculate accuracy
    accuracy = accuracy_score(all_labels, all_preds)

    return avg_test_loss, accuracy

In [None]:
# Logging Setup
def setup_logging():
    log_dir = 'training_log'
    os.makedirs(log_dir, exist_ok=True)
    log_filename = os.path.join(log_dir, f"training_log_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
    logging.basicConfig(
        filename=log_filename,
        level=logging.INFO,
        format="%(asctime)s - %(levelname)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S"
    )

    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    formatter = logging.Formatter(
        "%(asctime)s - %(levelname)s - %(message)s", 
        "%Y-%m-%d %H:%M:%S"
    )
    console.setFormatter(formatter)
    logging.getLogger().addHandler(console)


In [None]:
# Main Training Script
def main():
    # Set up logging
    setup_logging()
    logging.info("Starting training process.")

    # Configuration parameters
    config = {
        'epochs': 100,
        'T_0': 10,
        'lr_triplet': 1e-4,  
        'lr_ce': 1e-4,         
        'latent_distance': 'Cosine'  
    }

    # Load classification data
    classification_data_path = "T2_data_normalized.xlsx"
    if not os.path.exists(classification_data_path):
        logging.error(f"Data file not found: {classification_data_path}")
        return
    classification_df = pd.read_excel(classification_data_path)
    classification_df = classification_df.drop_duplicates(subset = ['canonicalsmiles'])
    
    logging.info(f"Successfully loaded classification data, sample size: {len(classification_df)}")

    # Load contrastive learning data
    contrastive_data_path = "T2_data_normalized.xlsx"
    if not os.path.exists(contrastive_data_path):
        logging.error(f"Data file not found: {contrastive_data_path}")
        return
    contrastive_df = pd.read_excel(contrastive_data_path)
    contrastive_df = contrastive_df.drop_duplicates(subset = ['canonicalsmiles'])
    logging.info(f"Successfully loaded contrastive learning data, sample size: {len(contrastive_df)}")

    # Load protein model and tokenizer
    model_name = 'esm2/esm2_t33_650M_UR50D'
    prot_tokenizer = EsmTokenizer.from_pretrained(model_name)
    base_model = EsmForMaskedLM.from_pretrained(model_name)
    prot_model = PeftModel.from_pretrained(base_model, './plm')
    logging.info("Successfully loaded protein model and tokenizer.")

    # Load molecule model and tokenizer
    mol_model_path = "./ibm/MoLFormer-XL-both-10pct"
    mol_tokenizer = AutoTokenizer.from_pretrained(mol_model_path, trust_remote_code=True)
    mol_model = AutoModel.from_pretrained(mol_model_path, deterministic_eval=True, trust_remote_code=True)
    logging.info("Successfully loaded molecule model and tokenizer.")

    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f"Using device: {device}")

    # Initialize model
    protein_embedding_dim = 1280
    molecule_embedding_dim = 768
    projection_dim = 1024

    model = Coembedding(
        molecule_shape=molecule_embedding_dim,
        protein_shape=protein_embedding_dim,
        latent_dimension=projection_dim,
        latent_activation=nn.ReLU,
        latent_distance=config.get('latent_distance', "Cosine"),
        classify=True
    ).to(device)
    logging.info("Successfully initialized co-embedding model.")

    # Create ClassificationDataset and ContrastiveDataset instances
    classification_dataset = ClassificationDataset(
        dataframe=classification_df,
        prot_tokenizer=prot_tokenizer,
        prot_model=prot_model,
        device=device
    )
    
    contrastive_dataset = ContrastiveDataset(
        dataframe=contrastive_df,
        prot_tokenizer=prot_tokenizer,
        prot_model=prot_model,
        mol_tokenizer=mol_tokenizer,
        mol_model=mol_model,
        device=device
    )

    logging.info(f"ClassificationDataset and ContrastiveDataset creation completed")

    # Split classification dataset into training and test sets
    train_indices, test_indices = train_test_split(
        list(range(len(classification_dataset))),
        test_size=0.2,
        random_state=42
    )

    train_classification_dataset = Subset(classification_dataset, train_indices)
    test_classification_dataset = Subset(classification_dataset, test_indices)

    # Create DataLoaders
    batch_size = 32

    # DataLoader for classification task
    train_classification_loader = DataLoader(
        train_classification_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=lambda batch: classification_collate_fn(batch, prot_tokenizer, prot_model, device),
        pin_memory=True
    )

    test_classification_loader = DataLoader(
        test_classification_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=lambda batch: classification_collate_fn(batch, prot_tokenizer, prot_model, device),
        pin_memory=True
    )
    
    mol_embs = generate_mol_embeddings_batch(list(classification_dataset.index2label.values()), mol_tokenizer, mol_model, device).to(device)

    # DataLoader for contrastive learning task
    contrastive_loader = DataLoader(
        contrastive_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=lambda batch: contrastive_collate_fn(batch, prot_tokenizer, prot_model, mol_tokenizer, mol_model, device),
        pin_memory=True
    )

    logging.info(f"DataLoader creation completed, batch size: {batch_size}")

    # Define optimizers and schedulers for both loss functions
    contrastive_opt = torch.optim.AdamW(model.parameters(), lr=config['lr_triplet'])
    contrastive_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(contrastive_opt, T_0=config['T_0'])

    classfication_opt = torch.optim.AdamW(model.parameters(), lr=config['lr_ce'])
    classfication_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(classfication_opt, T_0=config['T_0'])

    # Define loss functions
    loss_fn = ContrastiveLoss()
    classfication_fn = nn.CrossEntropyLoss()
    
    logging.info("Optimizers, schedulers and loss functions initialization completed.")

    os.makedirs('model_weight', exist_ok=True)

    best_test_acc = 0.0 
    best_epoch = 0 

    for epo in range(config['epochs']):
        model.train()
        total_loss = 0

        # Contrastive learning training
        for batch_idx, batch in enumerate(tqdm(contrastive_loader, total=len(contrastive_loader), desc=f"Contrastive Epoch {epo+1}/{config['epochs']}")):
            anchor = batch['anchorEmb'].to(device)
            positive = batch['positiveEmb'].to(device)
        
            anchor_projection = F.normalize(model.protein_projector(anchor), p=2, dim=1)
            positive_projection = F.normalize(model.molecule_projector(positive), p=2, dim=1)
        
            loss = loss_fn(anchor_projection, positive_projection)

            contrastive_opt.zero_grad()
            loss.backward()
            contrastive_opt.step()

            total_loss += loss.item()

        avg_contrastive_loss = total_loss / len(contrastive_loader)
        contrastive_scheduler.step()

        logging.info(f"Contrastive Epoch {epo+1}/{config['epochs']}, Loss: {avg_contrastive_loss:.4f}")

        # Classification task training
        model.train()
        total_loss = 0

        for batch_idx, batch in enumerate(tqdm(train_classification_loader, total=len(train_classification_loader), desc=f"Classification Epoch {epo+1}/{config['epochs']}")):
            prot_embs = batch['prot_emb'].to(device)
            labels = batch['labels'].to(device)
            
            logits = model(mol_embs, prot_embs)    
            loss = classfication_fn(logits, labels)

            classfication_opt.zero_grad()
            loss.backward()
            classfication_opt.step()

            total_loss += loss.item()

        avg_classification_loss = total_loss / len(train_classification_loader)
        classfication_scheduler.step()

        logging.info(f"Classification Epoch {epo+1}/{config['epochs']}, Loss: {avg_classification_loss:.4f}")
        
        # Evaluate the model on the training set
        avg_train_loss, train_acc = evaluate_model(model, mol_embs, train_classification_loader, classfication_fn, device)
        logging.info(f"Train Loss: {avg_train_loss:.4f}, Train Accuracy: {train_acc:.4f}")

        # Evaluate the model on the test set
        avg_test_loss, test_acc = evaluate_model(model, mol_embs, test_classification_loader, classfication_fn, device)
        logging.info(f"Test Loss: {avg_test_loss:.4f}, Test Accuracy: {test_acc:.4f}")

        # Save the model if test accuracy improves
        if test_acc > best_test_acc:
            best_test_acc = test_acc
            best_epoch = epo + 1
            
            torch.save({
                'epoch': epo + 1,
                'model_state_dict': model.state_dict(),
                'contrastive_opt_state_dict': contrastive_opt.state_dict(),
                'contrastive_scheduler_state_dict': contrastive_scheduler.state_dict(),
                'classfication_opt_state_dict': classfication_opt.state_dict(),
                'classfication_scheduler_state_dict': classfication_scheduler.state_dict(),
                'loss': avg_test_loss,
                'accuracy': best_test_acc,
            }, 'model_weight/best_model.pth')

            logging.info(f"New best model saved at epoch {best_epoch} with test accuracy: {best_test_acc:.4f}")

    logging.info(f"Training completed. Best test accuracy: {best_test_acc:.4f} at epoch {best_epoch}")

if __name__ == "__main__":
    main()

In [None]:
model_name = 'esm2/esm2_t33_650M_UR50D'
prot_tokenizer = EsmTokenizer.from_pretrained(model_name)
base_model = EsmForMaskedLM.from_pretrained(model_name)
prot_model = PeftModel.from_pretrained(base_model, './plm')

mol_model_path = "./ibm/MoLFormer-XL-both-10pct"
mol_tokenizer = AutoTokenizer.from_pretrained(mol_model_path, trust_remote_code=True)
mol_model = AutoModel.from_pretrained(mol_model_path, deterministic_eval=True, trust_remote_code=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

data = pd.read_excel("T2_data_normalized.xlsx")
data = data.drop_duplicates(subset = ['canonicalsmiles'])

# prot_seq = data.sequence.tolist()
prot_seq = test_data.sequence.tolist()
mol_smiles = data.canonicalsmiles.tolist()

unique_labels = data.T2PKproductsname.tolist()
label_to_index = {productsname: idx for idx, productsname in enumerate(unique_labels)}
index_to_label = {idx: productsname for productsname, idx in label_to_index.items()}
# true_labels = data['T2PKproductsname'].map(label_to_index).tolist()
true_labels = test_data['T2PKproductsname'].map(label_to_index).tolist()

prot_emb = generate_anchor_embeddings_batch(
        prot_seq, prot_tokenizer, prot_model, device
    )

mol_emb = generate_mol_embeddings_batch(
        mol_smiles, mol_tokenizer, mol_model, device
    )

model = Coembedding().to(device)
model.load_state_dict(torch.load('model_weight/best_model_1024_ce_triplet_1022_final.pth')['model_state_dict'])
model.eval()

with torch.no_grad():
    prot_emb = prot_emb.to(device)
    mol_emb = mol_emb.to(device)
    prediction = torch.argmax(model(mol_emb, prot_emb), dim=-1) 

In [None]:
pred_labels = prediction.cpu().numpy()

cm = confusion_matrix(true_labels, pred_labels)

print("Confusion Matrix:")
print(cm)

precision = precision_score(true_labels, pred_labels, average='weighted')
recall = recall_score(true_labels, pred_labels, average='weighted')
f1 = f1_score(true_labels, pred_labels, average='weighted')
accuracy = accuracy_score(true_labels, pred_labels)

print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1-Score: {f1:.4f}")
print(f"Accuracy: {accuracy:.4f}")