In [None]:
## script

## installs

In [None]:
!pip install rdkit

Collecting rdkit
  Downloading rdkit-2024.9.6-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (4.0 kB)
Downloading rdkit-2024.9.6-cp311-cp311-manylinux_2_28_x86_64.whl (34.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m34.3/34.3 MB[0m [31m66.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rdkit
Successfully installed rdkit-2024.9.6


# qformer model

## blip2-qformer

In [None]:
%%writefile blip2_polybert_prot5_qformer.py

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import warnings; warnings.filterwarnings("ignore")

import contextlib
import logging
import os
import re
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.cuda.amp import autocast as autocast
from torch.nn import functional as F
from transformers import AutoTokenizer, AutoModel, T5EncoderModel
from transformers import AutoTokenizer, AutoModel, T5Tokenizer, T5EncoderModel
from transformers import BertTokenizer, BertConfig, BertLMHeadModel
from typing import Dict, List, Optional, Tuple, Union

class LayerNorm(nn.LayerNorm):
    def forward(self, x: torch.Tensor, mask=None):
        orig_type = x.dtype
        ret = super().forward(x.type(torch.float32))
        return ret.type(orig_type)

def disabled_train(self, mode=True):
    return self

def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True

class BlipOutput:
    def __init__(self, loss=None, logits=None, similarity=None, intermediate_output=None):
        self.loss = loss
        self.logits = logits
        self.similarity = similarity
        self.intermediate_output = intermediate_output

class PolyBERTEncoder(nn.Module):
    def __init__(self, output_dim):
        super().__init__()
        self.polybert = AutoModel.from_pretrained('kuelumbus/polyBERT')
        self.tokenizer = AutoTokenizer.from_pretrained('kuelumbus/polyBERT')
        self.output_dim = output_dim
        # Project each token embedding to required dimension
        self.projection = nn.Linear(self.polybert.config.hidden_size, output_dim)

    def forward(self, smiles_strings):
        # Tokenize the SMILES strings
        encoded_input = self.tokenizer(smiles_strings,
                                     padding=True,
                                     truncation=True,
                                     return_tensors='pt').to(next(self.polybert.parameters()).device)

        # Get PolyBERT embeddings
        with torch.no_grad():
            model_output = self.polybert(**encoded_input)

        # Get sequence embeddings
        sequence_embeddings = model_output.last_hidden_state

        # Project each token embedding to required dimension
        projected_output = self.projection(sequence_embeddings)  # [batch_size, seq_len, output_dim]

        return projected_output

class BlipBaseQFormer(nn.Module):
    def __init__(self):
        super().__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    @classmethod
    def init_tokenizer(cls, bert_name="bert-base-uncased"):
        tokenizer = BertTokenizer.from_pretrained(bert_name)
        tokenizer.add_special_tokens({"bos_token": "[DEC]"})
        return tokenizer

    def maybe_autocast(self, dtype=torch.float16):
        enable_autocast = self.device != torch.device("cpu")
        if enable_autocast:
            return torch.cuda.amp.autocast(dtype=dtype)
        else:
            return contextlib.nullcontext()

    @classmethod
    def init_Qformer(cls, bert_name, num_query_token, encoder_width, cross_attention_freq=2):
        encoder_config = BertConfig.from_pretrained(bert_name, is_decoder=True)
        encoder_config.encoder_width = encoder_width
        encoder_config.add_cross_attention = True
        encoder_config.cross_attention_freq = cross_attention_freq
        encoder_config.query_length = num_query_token

        Qformer = BertLMHeadModel.from_pretrained(
            bert_name, config=encoder_config
        )
        query_tokens = nn.Parameter(
            torch.zeros(1, num_query_token, encoder_config.hidden_size)
        )
        query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
        return Qformer, query_tokens

    def load_from_pretrained(self, url_or_filename):
        if os.path.isfile(url_or_filename):
            checkpoint = torch.load(url_or_filename, map_location="cpu")
        else:
            raise RuntimeError("checkpoint url or path is invalid")

        state_dict = checkpoint["model"]
        msg = self.load_state_dict(state_dict, strict=False)
        logging.info("load checkpoint from %s" % url_or_filename)
        return msg

class Blip2QformerPolyBERTProtT5(BlipBaseQFormer):
    def __init__(
        self,
        bert_name="bert-base-uncased",
        temperature=0.05,
        freeze_polybert=True,
        freeze_prot5=True,
        polybert_output_dim=768,
        prot5_model_name="Rostlab/ProstT5",
        tune_qformer=True,
        num_query_token=32,
        cross_attention_freq=2,
        embed_dim=1024,
    ):
        super().__init__()

        self.tokenizer = self.init_tokenizer(bert_name)

        # Initialize PolyBERT (SMILES encoder)
        self.polybert_encoder = PolyBERTEncoder(output_dim=polybert_output_dim)

        # Initialize ProtT5 (Protein encoder/decoder)
        self.prot5_tokenizer = T5Tokenizer.from_pretrained(prot5_model_name, do_lower_case=False)
        self.prot5 = T5EncoderModel.from_pretrained(prot5_model_name)

        # Freeze PolyBERT if specified
        self.freeze_polybert = freeze_polybert
        if freeze_polybert:
            for param in self.polybert_encoder.polybert.parameters():
                param.requires_grad = False
            self.polybert_encoder.polybert.eval()
            logging.info("freeze PolyBERT encoder")

        # Freeze ProtT5 if specified
        self.freeze_prot5 = freeze_prot5
        if freeze_prot5:
            for param in self.prot5.parameters():
                param.requires_grad = False
            self.prot5.eval()
            self.prot5.train = disabled_train
            logging.info("freeze ProtT5 encoder")

        # Initialize QFormer
        self.Qformer, self.query_tokens = self.init_Qformer(
            bert_name,
            num_query_token,
            polybert_output_dim,
            cross_attention_freq
        )
        self.Qformer.resize_token_embeddings(len(self.tokenizer))

        # Copy weights for query tokens
        state_dict = self.Qformer.state_dict()
        for name, param in self.Qformer.named_parameters():
            if "_query" in name:
                key_orig = name.replace("_query", "")
                param.data.copy_(state_dict[key_orig])

        # Projection layers
        self.polybert_proj = nn.Sequential(
            nn.Linear(self.Qformer.config.hidden_size, 768),
            nn.LayerNorm(768),
            nn.GELU(),
            nn.Linear(768, 512)
        )

        self.prot5_proj = nn.Sequential(
            nn.Linear(self.prot5.config.d_model, 768),
            nn.LayerNorm(768),
            nn.GELU(),
            nn.Linear(768, 512)
        )

        # Freeze QFormer if specified
        if not tune_qformer:
            for name, param in self.Qformer.named_parameters():
                param.requires_grad = False
            self.Qformer.eval()
            self.Qformer.train = disabled_train
            logging.info("freeze QFormer")

        self.temperature = temperature
        self.embed_dim = embed_dim
        self.polybert_output_dim = polybert_output_dim
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def encode_protein_prot5(self, protein_sequences):
        preprocessed_sequences = []
        for seq in protein_sequences:
            # Replace rare/ambiguous amino acids with X and add space between AAs
            processed_seq = " ".join(list(re.sub(r"[UZOB]", "X", seq.upper())))
            # Add the special token for protein sequences
            processed_seq = "<AA2fold> " + processed_seq
            preprocessed_sequences.append(processed_seq)

        # Tokenize the preprocessed sequences
        prot5_inputs = self.prot5_tokenizer(
            preprocessed_sequences,
            padding=True,
            truncation=True,
            return_tensors="pt"
        ).to(self.device)

        with torch.no_grad():
            try:
                prot5_outputs = self.prot5(
                    input_ids=prot5_inputs.input_ids,
                    attention_mask=prot5_inputs.attention_mask,
                    return_dict=True
                )
                prot5_embeds = prot5_outputs.last_hidden_state

                if torch.isnan(prot5_embeds).any():
                    print("NaN values detected in raw ProtT5 embeddings - applying fix")
                    prot5_embeds = torch.nan_to_num(prot5_embeds, nan=0.0)

                prot5_attention_mask = prot5_inputs.attention_mask
            except Exception as e:
                print(f"Error in ProtT5 encoding: {e}")
                batch_size = len(protein_sequences)
                dim = self.prot5.config.d_model
                seq_len = 30  # Reasonable default
                prot5_embeds = torch.zeros((batch_size, seq_len, dim), device=self.device)
                prot5_attention_mask = torch.ones((batch_size, seq_len), device=self.device)

        prot5_embeds = torch.nan_to_num(prot5_embeds, nan=0.0)

        return prot5_embeds, prot5_attention_mask

    def forward(self, smiles_texts, protein_sequences):
        # Encode SMILES with PolyBERT
        polybert_embeds = self.polybert_encoder(smiles_texts)
        polybert_attention_mask = torch.ones(
            (polybert_embeds.shape[0], polybert_embeds.shape[1]),
            device=polybert_embeds.device
        )

        # Encode protein sequences with ProtT5
        prot5_embeds, prot5_attention_mask = self.encode_protein_prot5(protein_sequences)

        batch_size = polybert_embeds.size(0)
        query_tokens = self.query_tokens.expand(batch_size, -1, -1)

        # Process SMILES through QFormer
        query_polybert_output = self.Qformer.bert(
            inputs_embeds=query_tokens,
            attention_mask=torch.ones(query_tokens.size()[:-1], device=query_tokens.device),
            encoder_hidden_states=polybert_embeds,
            encoder_attention_mask=polybert_attention_mask,
            return_dict=True,
        )
        polybert_feats = self.polybert_proj(query_polybert_output.last_hidden_state)

        # Get protein features by averaging ProtT5 embeddings along sequence dimension
        # with proper masking to ignore padding tokens
        # Ignore the special token position (index 0)
        expanded_attention_mask = prot5_attention_mask[:,1:].unsqueeze(-1).expand_as(prot5_embeds[:,1:,:])
        masked_prot5_embeds = prot5_embeds[:,1:,:] * expanded_attention_mask
        # masked_prot5_embeds = prot5_embeds * expanded_attention_mask
        pooled_prot5_embeds = masked_prot5_embeds.sum(dim=1) / prot5_attention_mask.sum(dim=1, keepdim=True)
        prot5_feats = self.prot5_proj(pooled_prot5_embeds)

        # Normalize features
        polybert_feats = F.normalize(polybert_feats.mean(dim=1), p=2, dim=-1)
        prot5_feats = F.normalize(prot5_feats, p=2, dim=-1)

        # Add small epsilon to prevent division by zero issues
        eps = 1e-8
        polybert_feats = polybert_feats / (polybert_feats.norm(dim=-1, keepdim=True) + eps)
        prot5_feats = prot5_feats / (prot5_feats.norm(dim=-1, keepdim=True) + eps)

        polybert_feats_all = polybert_feats
        prot5_feats_all = prot5_feats

        # Compute contrastive loss
        sim_prot2poly, sim_poly2prot, loss_contrastive = self.contrast_global(
            polybert_feats, prot5_feats, polybert_feats_all, prot5_feats_all, return_sim=True
        )

        return BlipOutput(
            loss=loss_contrastive,
            similarity={"polybert2prot5": sim_poly2prot, "prot52polybert": sim_prot2poly}
        )

    def contrast_global(self, features_polybert, features_prot5, features_polybert_all, features_prot5_all, return_sim=False):
        batch_size = features_polybert.size(0)

        # Debug: Check for NaNs
        if torch.isnan(features_polybert).any():
            print("NaN detected in features_polybert")
        if torch.isnan(features_prot5).any():
            print("NaN detected in features_prot5")

        # Compute similarity scores for polybert->prot5 and prot5->polybert
        sim_poly2prot = torch.matmul(features_polybert, features_prot5_all.transpose(0, 1))
        sim_prot2poly = torch.matmul(features_prot5, features_polybert_all.transpose(0, 1))

        # Debug: Check similarity matrices
        if torch.isnan(sim_poly2prot).any():
            print("NaN detected in sim_poly2prot")
        if torch.isnan(sim_prot2poly).any():
            print("NaN detected in sim_prot2poly")

        # Scale by temperature (with additional safeguard)
        temp = max(self.temperature, 1e-8)  # Prevent division by zero
        logits_prot2poly = sim_prot2poly / temp
        logits_poly2prot = sim_poly2prot / temp

        # Create labels
        labels = torch.arange(batch_size, device=self.device)

        # Compute loss with safeguards
        try:
            loss_prot2poly = F.cross_entropy(logits_prot2poly, labels)
            loss_poly2prot = F.cross_entropy(logits_poly2prot, labels)
            loss = (loss_prot2poly + loss_poly2prot) / 2
        except Exception as e:
            print(f"Error in cross entropy: {e}")
            print(f"logits_prot2poly shape: {logits_prot2poly.shape}, min: {logits_prot2poly.min().item()}, max: {logits_prot2poly.max().item()}")
            print(f"logits_poly2prot shape: {logits_poly2prot.shape}, min: {logits_poly2prot.min().item()}, max: {logits_poly2prot.max().item()}")
            print(f"labels shape: {labels.shape}, values: {labels}")
            return None, None, torch.tensor(float('nan'), device=self.device) if return_sim else torch.tensor(float('nan'), device=self.device)

        if return_sim:
            return sim_prot2poly, sim_poly2prot, loss
        else:
            return loss

    def predict_similarity(self, smiles_text, protein_sequence):
        if isinstance(smiles_text, str):
            smiles_text = [smiles_text]

        if isinstance(protein_sequence, str):
            protein_sequence = [protein_sequence]

        # Encode SMILES with PolyBERT
        polybert_embeds = self.polybert_encoder(smiles_text)
        polybert_attention_mask = torch.ones(
            (polybert_embeds.shape[0], polybert_embeds.shape[1]),
            device=polybert_embeds.device
        )

        # Encode protein sequences with ProtT5
        prot5_embeds, prot5_attention_mask = self.encode_protein_prot5(protein_sequence)

        batch_size = polybert_embeds.size(0)
        query_tokens = self.query_tokens.expand(batch_size, -1, -1)

        # Process SMILES through QFormer
        query_polybert_output = self.Qformer.bert(
            inputs_embeds=query_tokens,
            attention_mask=torch.ones(query_tokens.size()[:-1], device=query_tokens.device),
            encoder_hidden_states=polybert_embeds,
            encoder_attention_mask=polybert_attention_mask,
            return_dict=True,
        )
        polybert_feats = self.polybert_proj(query_polybert_output.last_hidden_state)

        # Get protein features by averaging ProtT5 embeddings along sequence dimension
        expanded_attention_mask = prot5_attention_mask.unsqueeze(-1).expand_as(prot5_embeds)
        masked_prot5_embeds = prot5_embeds * expanded_attention_mask
        pooled_prot5_embeds = masked_prot5_embeds.sum(dim=1) / prot5_attention_mask.sum(dim=1, keepdim=True)
        prot5_feats = self.prot5_proj(pooled_prot5_embeds)

        # Normalize features
        polybert_feats = F.normalize(polybert_feats.mean(dim=1), p=2, dim=-1)
        prot5_feats = F.normalize(prot5_feats, p=2, dim=-1)

        # Compute similarity
        similarity = torch.sum(polybert_feats * prot5_feats, dim=-1)

        return similarity.cpu().tolist()

class MolProtDataset(torch.utils.data.Dataset):
    def __init__(self, pairs):
        self.pairs = pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        pair = self.pairs[idx]
        return {
            "smiles_text": pair.smiles,
            "protein_sequence": pair.protein
        }

class MolProtPair:
    def __init__(self, smiles, protein):
        self.smiles = smiles
        self.protein = protein

def custom_collate(batch):
    # Filter out any samples with invalid data types
    valid_batch = []
    for item in batch:
        if isinstance(item["smiles_text"], str) and isinstance(item["protein_sequence"], str):
            valid_batch.append(item)
        else:
            print(f"Skipping invalid item: {type(item['smiles_text'])}, {type(item['protein_sequence'])}")

    if not valid_batch:
        # Provide a minimal valid batch
        return {
            "smiles_text": ["C"],  # Simplest SMILES string
            "protein_sequence": ["ACDEFGHIKLMNPQRSTVWY"]  # All standard amino acids
        }

    # Standard collation for valid items
    smiles_texts = [item["smiles_text"] for item in valid_batch]
    protein_sequences = [item["protein_sequence"] for item in valid_batch]

    return {
        "smiles_text": smiles_texts,
        "protein_sequence": protein_sequences
    }

def train_blip2_polybert_prot5(
    model,
    train_dataset,
    val_dataset=None,
    batch_size=16,
    num_epochs=10,
    learning_rate=5e-5,
    weight_decay=0.05,
    warmup_steps=5000,
    checkpoint_dir="checkpoints",
    checkpoint_interval=1
):
    os.makedirs(checkpoint_dir, exist_ok=True)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True,
        collate_fn=custom_collate,
        num_workers=4
    )

    if val_dataset:
        val_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=batch_size,
            shuffle=False,
            collate_fn=custom_collate,
            num_workers=4
        )

    optimizer = torch.optim.AdamW(
        [p for p in model.parameters() if p.requires_grad],
        lr=learning_rate,
        weight_decay=weight_decay
    )

    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=num_epochs * len(train_loader)
    )

    device = model.device
    best_val_loss = float('inf')

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        num_batches = 0

        for batch in train_loader:
            try:
                smiles_texts = batch["smiles_text"]
                protein_sequences = batch["protein_sequence"]

                # Create a mask for valid pairs
                valid_indices = [i for i, (s, p) in enumerate(zip(smiles_texts, protein_sequences))
                                if isinstance(s, str) and isinstance(p, str)]

                # Filter using the valid indices
                if len(valid_indices) < len(smiles_texts):
                    smiles_texts = [smiles_texts[i] for i in valid_indices]
                    protein_sequences = [protein_sequences[i] for i in valid_indices]

                if len(smiles_texts) == 0:
                    print("WARNING: No valid SMILES-protein pairs in batch, skipping")
                    continue

                print("SMILES texts length:", len(smiles_texts))
                print("Protein sequences length:", len(protein_sequences))

                # Forward pass with error handling
                try:
                    output = model(smiles_texts, protein_sequences)
                    loss = output.loss if hasattr(output, 'loss') else None
                except Exception as e:
                    print(f"Forward pass error: {e}")
                    continue

                # Handle NaN loss
                if loss is None or torch.isnan(loss):
                    print("WARNING: NaN loss detected, skipping backward pass")
                    continue

                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
                lr_scheduler.step()

                train_loss += loss.item()
                num_batches += 1
            except Exception as e:
                print(f"Unexpected error in training loop: {e}")
                continue

        avg_train_loss = train_loss / num_batches if num_batches > 0 else float('inf')

        val_loss = 0.0
        if val_dataset:
            model.eval()
            val_num_batches = 0

            with torch.no_grad():
                for batch in val_loader:
                    smiles_texts = batch["smiles_text"]
                    protein_sequences = batch["protein_sequence"]

                    try:
                        output = model(smiles_texts, protein_sequences)
                        loss = output.loss

                        if not torch.isnan(loss):
                            val_loss += loss.item()
                            val_num_batches += 1
                    except Exception as e:
                        print(f"Validation error: {e}")
                        continue

            avg_val_loss = val_loss / val_num_batches if val_num_batches > 0 else float('inf')

            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                torch.save(
                    {
                        "model": model.state_dict(),
                        "optimizer": optimizer.state_dict(),
                        "epoch": epoch,
                        "val_loss": best_val_loss
                    },
                    os.path.join(checkpoint_dir, f"best_model.pth")
                )

        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}")
        if val_dataset:
            print(f"Validation Loss: {avg_val_loss:.4f}")

        if (epoch + 1) % checkpoint_interval == 0:
            torch.save(
                {
                    "model": model.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "epoch": epoch
                },
                os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch+1}.pth")
            )

    return model

def get_top_similar_proteins(model, smiles, protein_list, top_k=5):
    if isinstance(smiles, str):
        smiles = [smiles] * len(protein_list)

    batch_size = 16
    similarities = []

    for i in range(0, len(protein_list), batch_size):
        batch_smiles = smiles[i:i+batch_size]
        batch_proteins = protein_list[i:i+batch_size]

        batch_similarities = model.predict_similarity(batch_smiles, batch_proteins)
        similarities.extend(batch_similarities)

    # Get indices of top_k highest similarities
    top_indices = sorted(range(len(similarities)), key=lambda i: similarities[i], reverse=True)[:top_k]

    # Return top proteins and their similarity scores
    return [(protein_list[i], similarities[i]) for i in top_indices]

def main():
    import argparse

    parser = argparse.ArgumentParser(description='Train BLIP2-PolyBERT-ProtT5-QFormer model')
    parser.add_argument('--data_path', type=str, required=True, help='Path to data CSV file')
    parser.add_argument('--batch_size', type=int, default=16, help='Batch size')
    parser.add_argument('--num_epochs', type=int, default=10, help='Number of epochs')
    parser.add_argument('--learning_rate', type=float, default=5e-5, help='Learning rate')
    parser.add_argument('--weight_decay', type=float, default=0.05, help='Weight decay')
    parser.add_argument('--checkpoint_dir', type=str, default='checkpoints', help='Checkpoint directory')
    parser.add_argument('--embed_dim', type=int, default=1024, help='Embedding dimension')
    parser.add_argument('--polybert_output_dim', type=int, default=768, help='PolyBERT output dimension')
    parser.add_argument('--num_query_token', type=int, default=32, help='Number of query tokens')
    parser.add_argument('--temperature', type=float, default=0.05, help='Temperature for contrastive loss')
    parser.add_argument('--freeze_polybert', action='store_true', help='Freeze PolyBERT parameters')
    parser.add_argument('--freeze_prot5', action='store_true', help='Freeze ProtT5 parameters')
    parser.add_argument('--tune_qformer', action='store_true', help='Tune QFormer parameters')

    args = parser.parse_args()

    # Import data processing functions
    from data_processor import prepare_data_from_snp_data

    train_pairs, val_pairs = prepare_data_from_snp_data(args.data_path)

    train_dataset = MolProtDataset(train_pairs)
    val_dataset = MolProtDataset(val_pairs)

    model = Blip2QformerPolyBERTProtT5(
        temperature=args.temperature,
        freeze_polybert=args.freeze_polybert,
        freeze_prot5=args.freeze_prot5,
        polybert_output_dim=args.polybert_output_dim,
        tune_qformer=args.tune_qformer,
        num_query_token=args.num_query_token,
        embed_dim=args.embed_dim
    )

    model = train_blip2_polybert_prot5(
        model=model,
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        batch_size=args.batch_size,
        num_epochs=args.num_epochs,
        learning_rate=args.learning_rate,
        weight_decay=args.weight_decay,
        checkpoint_dir=args.checkpoint_dir
    )

    torch.save(model.state_dict(), os.path.join(args.checkpoint_dir, "final_model.pth"))

if __name__ == "__main__":
    main()

Writing blip2_polybert_prot5_qformer.py


## training and eval script

In [None]:
%%writefile training_script.py

import warnings; warnings.filterwarnings("ignore")

import os
import torch
import argparse
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, random_split
from sklearn.model_selection import train_test_split
from rdkit import Chem
from rdkit.Chem import AllChem, DataStructs
from sklearn.cluster import AgglomerativeClustering


from blip2_polybert_prot5_qformer import (
    Blip2QformerPolyBERTProtT5,
    MolProtDataset,
    MolProtPair,
    custom_collate,
    train_blip2_polybert_prot5,
    get_top_similar_proteins
)

def parse_args():
    parser = argparse.ArgumentParser(description="Train BLIP2-PolyBERT-ProtT5-QFormer model")

    parser.add_argument("--data_path", type=str, required=True,
                        help="Path to data CSV file with SMILES and protein sequences")
    parser.add_argument("--output_dir", type=str, default="results",
                        help="Directory for results and metrics")
    parser.add_argument("--processed_data_dir", type=str, default="processed_data",
                        help="Directory for processed data")
    parser.add_argument("--test_size", type=float, default=0.1,
                        help="Test set size ratio")
    parser.add_argument("--val_size", type=float, default=0.1,
                        help="Validation set size ratio")

    parser.add_argument("--bert_name", type=str, default="bert-base-uncased",
                        help="BERT model name for QFormer")
    parser.add_argument("--polybert_output_dim", type=int, default=768,
                        help="PolyBERT output dimension")
    parser.add_argument("--prot5_model", type=str, default="Rostlab/ProstT5",
                        help="ProtT5 model name")
    parser.add_argument("--embedding_dim", type=int, default=512,
                        help="Embedding dimension")
    parser.add_argument("--num_query_token", type=int, default=256,
                        help="Number of query tokens for QFormer")
    parser.add_argument("--cross_attention_freq", type=int, default=1,
                        help="Cross attention frequency")
    parser.add_argument("--temperature", type=float, default=0.05,
                        help="Temperature for contrastive loss")

    parser.add_argument("--freeze_polybert", action="store_true",
                        help="Freeze PolyBERT parameters")
    parser.add_argument("--freeze_prot5", action="store_true",
                        help="Freeze ProtT5 parameters")
    parser.add_argument("--tune_qformer", action="store_true",
                        help="Tune QFormer parameters")

    parser.add_argument("--batch_size", type=int, default=16,
                        help="Batch size")
    parser.add_argument("--num_epochs", type=int, default=10,
                        help="Number of training epochs")
    parser.add_argument("--learning_rate", type=float, default=5e-5,
                        help="Learning rate")
    parser.add_argument("--weight_decay", type=float, default=0.05,
                        help="Weight decay")
    parser.add_argument("--warmup_steps", type=int, default=1000,
                        help="Warmup steps")

    parser.add_argument("--checkpoint_dir", type=str, default="checkpoints",
                        help="Directory for checkpoints")
    parser.add_argument("--checkpoint_interval", type=int, default=1,
                        help="Checkpoint interval in epochs")
    parser.add_argument("--resume_checkpoint", type=str, default=None,
                        help="Checkpoint to resume training from")

    parser.add_argument("--eval_only", action="store_true",
                        help="Run evaluation only")
    parser.add_argument("--eval_checkpoint", type=str, default=None,
                        help="Checkpoint to use for evaluation")

    parser.add_argument("--seed", type=int, default=42,
                        help="Random seed")

    return parser.parse_args()

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def compute_fingerprint(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol:
        return AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048)
    return None

def cluster_molecules(df, threshold=0.3):
    fps = df['smiles'].apply(compute_fingerprint)
    valid_idx = [i for i, fp in enumerate(fps) if fp is not None]
    fps = [fps[i] for i in valid_idx]

    n = len(fps)
    sim_matrix = np.zeros((n, n))
    for i in range(n):
        for j in range(i + 1, n):
            sim = DataStructs.FingerprintSimilarity(fps[i], fps[j])
            sim_matrix[i, j] = sim
            sim_matrix[j, i] = sim

    clustering = AgglomerativeClustering(n_clusters=None, affinity='precomputed', linkage='complete', distance_threshold=threshold)
    df.loc[df.index[valid_idx], 'cluster'] = clustering.fit_predict(1 - sim_matrix)
    return df.dropna(subset=['cluster'])

# Remove invalid SMILES and protein sequences
def is_valid_smiles(smiles):
    if not isinstance(smiles, str):
        return False
    # Basic check - could be improved with RDKit validation
    return len(smiles) > 0 and all(c in 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789()[]+=-#:.\\/@' for c in smiles)

def is_valid_protein(seq):
    if not isinstance(seq, str):
        return False
    # Check if the sequence consists of valid amino acid letters
    valid_aa = set('ACDEFGHIKLMNPQRSTVWYXBZOU')
    return len(seq) > 0 and all(aa in valid_aa for aa in seq.upper())

def cluster_molecules(df, threshold=0.3):
    df = df.reset_index(drop=True)
    fps = df['smiles'].apply(compute_fingerprint)
    valid_entries = [(i, fp) for i, fp in enumerate(fps) if fp is not None]
    valid_idx, fps = zip(*valid_entries) if valid_entries else ([], [])

    n = len(fps)
    sim_matrix = np.zeros((n, n))
    for i in range(n):
        for j in range(i + 1, n):
            sim = DataStructs.FingerprintSimilarity(fps[i], fps[j])
            sim_matrix[i, j] = sim
            sim_matrix[j, i] = sim

    clustering = AgglomerativeClustering(n_clusters=None, metric='precomputed', linkage='complete', distance_threshold=threshold)
    df.loc[df.index[list(valid_idx)], 'cluster'] = clustering.fit_predict(1 - sim_matrix)
    return df.dropna(subset=['cluster'])


def is_valid_protein(sequence):
    return isinstance(sequence, str) and len(sequence) > 0

def cluster_molecules(df):
    from rdkit import Chem
    from rdkit.Chem import AllChem
    from sklearn.cluster import KMeans
    import numpy as np

    def smiles_to_morgan_fp(smiles, radius=2, nBits=2048):
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return None
        fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits)
        return np.array(fp)

    df['smiles'] = df['smiles'].apply(lambda x: f'[*]{x}[*]' if '[*]' not in x else x)

    fingerprints = []
    valid_indices = []

    for i, smiles in enumerate(df['smiles']):
        fp = smiles_to_morgan_fp(smiles)
        if fp is not None:
            fingerprints.append(fp)
            valid_indices.append(i)

    X = np.array(fingerprints)

    if len(X) > 1:
        n_clusters = min(40, len(X))
        kmeans = KMeans(n_clusters=n_clusters, random_state=42)
        cluster_labels = kmeans.fit_predict(X)
        df_valid = df.iloc[valid_indices].copy()
        df_valid['cluster'] = cluster_labels
        return df_valid
    else:
        df['cluster'] = -1
        return df

def preprocess_data(data_path, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    df = pd.read_csv(data_path)
    df = df.dropna(subset=['smiles', 'protein_sequence'])
    df['valid_protein'] = df['protein_sequence'].apply(is_valid_protein)
    df = df[df['valid_protein']]
    df['smiles_len'] = df['smiles'].apply(len)
    df['protein_len'] = df['protein_sequence'].apply(len)
    df = cluster_molecules(df)
    df.to_csv(os.path.join(output_dir, "preprocessed_data.csv"), index=False)
    return df



def create_datasets(df, val_size=0.15, test_size=0.15, seed=42):
    clusters = df['cluster'].unique()
    np.random.seed(seed)
    np.random.shuffle(clusters)
    train_cutoff = int(0.7 * len(clusters))
    val_cutoff = int(0.85 * len(clusters))
    train_clusters, val_clusters, test_clusters = clusters[:train_cutoff], clusters[train_cutoff:val_cutoff], clusters[val_cutoff:]
    train_df = df[df['cluster'].isin(train_clusters)]
    val_df = df[df['cluster'].isin(val_clusters)]
    test_df = df[df['cluster'].isin(test_clusters)]
    print(f"Train set: {len(train_df)} samples from {len(train_clusters)} clusters")
    print(f"Validation set: {len(val_df)} samples from {len(val_clusters)} clusters")
    print(f"Test set: {len(test_df)} samples from {len(test_clusters)} clusters")
    train_pairs = [MolProtPair(row['smiles'], row['protein_sequence']) for _, row in train_df.iterrows()]
    val_pairs = [MolProtPair(row['smiles'], row['protein_sequence']) for _, row in val_df.iterrows()]
    test_pairs = [MolProtPair(row['smiles'], row['protein_sequence']) for _, row in test_df.iterrows()]
    return MolProtDataset(train_pairs), MolProtDataset(val_pairs), MolProtDataset(test_pairs)


def create_dataloaders(train_dataset, val_dataset, test_dataset, batch_size=16, num_workers=4):
    train_loader = None
    if train_dataset:
        train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            collate_fn=custom_collate,
            num_workers=num_workers,
            drop_last=True
        )

    val_loader = None
    if val_dataset:
        val_loader = DataLoader(
            val_dataset,
            batch_size=batch_size,
            shuffle=False,
            collate_fn=custom_collate,
            num_workers=num_workers
        )

    test_loader = None
    if test_dataset:
        test_loader = DataLoader(
            test_dataset,
            batch_size=batch_size,
            shuffle=False,
            collate_fn=custom_collate,
            num_workers=num_workers
        )

    return train_loader, val_loader, test_loader

def evaluate_model(model, loader):
    model.eval()
    total_loss = 0.0

    with torch.no_grad():
        for batch in loader:
            smiles_texts = batch["smiles_text"]
            protein_sequences = batch["protein_sequence"]

            try:
                output = model(smiles_texts, protein_sequences)
                loss = output.loss

                total_loss += loss.item()
            except Exception as e:
                print(f"Error during evaluation: {e}")
                continue

    avg_loss = total_loss / len(loader)
    return avg_loss

def plot_training_metrics(train_losses, val_losses, output_dir):
    plt.figure(figsize=(10, 6))
    epochs = range(1, len(train_losses) + 1)

    plt.plot(epochs, train_losses, 'b-', label='Training Loss')
    plt.plot(epochs, val_losses, 'r-', label='Validation Loss')

    plt.title('Training and Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    os.makedirs(output_dir, exist_ok=True)
    plt.savefig(os.path.join(output_dir, 'training_metrics.png'))
    plt.close()

def calculate_retrieval_metrics(model, test_dataset, k_values=(1, 5, 10)):
    device = next(model.parameters()).device
    model.eval()

    all_smiles = []
    all_proteins = []

    for pair in test_dataset.pairs:
        all_smiles.append(pair.smiles)
        all_proteins.append(pair.protein)

    metrics = {f"R@{k}": 0.0 for k in k_values}
    metrics["MRR"] = 0.0

    n_samples = min(len(all_smiles), 1000)  # Limit to 1000 samples for evaluation speed
    indices = np.random.choice(len(all_smiles), n_samples, replace=False)

    for i in indices:
        query_smiles = all_smiles[i]
        true_protein = all_proteins[i]

        # Get similarity scores between query and all proteins
        similarities = []
        batch_size = 16

        for j in range(0, len(all_proteins), batch_size):
            batch_proteins = all_proteins[j:j+batch_size]
            batch_smiles = [query_smiles] * len(batch_proteins)

            try:
                batch_similarities = model.predict_similarity(batch_smiles, batch_proteins)
                similarities.extend(batch_similarities)
            except Exception as e:
                print(f"Error during similarity prediction: {e}")
                similarities.extend([0.0] * len(batch_proteins))

        # Find the rank of the true protein
        true_idx = all_proteins.index(true_protein)
        sorted_indices = np.argsort(similarities)[::-1]
        rank = np.where(sorted_indices == true_idx)[0][0] + 1  # Convert to 1-indexed

        # Calculate metrics
        metrics["MRR"] += 1.0 / rank

        for k in k_values:
            if rank <= k:
                metrics[f"R@{k}"] += 1.0

    # Normalize metrics
    for key in metrics:
        metrics[key] /= n_samples

    return metrics

def main():
    args = parse_args()
    set_seed(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Create directories
    os.makedirs(args.output_dir, exist_ok=True)
    os.makedirs(args.checkpoint_dir, exist_ok=True)
    os.makedirs(args.processed_data_dir, exist_ok=True)

    # Preprocess data
    print("Preprocessing data...")
    df = preprocess_data(args.data_path, args.processed_data_dir)

    # Create datasets
    print("Creating datasets...")
    train_dataset, val_dataset, test_dataset = create_datasets(
        df, val_size=args.val_size, test_size=args.test_size, seed=args.seed
    )

    # Create dataloaders
    print("Creating dataloaders...")
    train_loader, val_loader, test_loader = create_dataloaders(
        train_dataset, val_dataset, test_dataset, batch_size=args.batch_size
    )

    # Initialize model
    print("Initializing model...")
    model = Blip2QformerPolyBERTProtT5(
        bert_name=args.bert_name,
        temperature=args.temperature,
        freeze_polybert=args.freeze_polybert,
        freeze_prot5=args.freeze_prot5,
        polybert_output_dim=args.polybert_output_dim,
        prot5_model_name=args.prot5_model,
        tune_qformer=args.tune_qformer,
        num_query_token=args.num_query_token,
        cross_attention_freq=args.cross_attention_freq,
        embed_dim=args.embedding_dim
    ).to(device)

    if args.eval_only:
        if args.eval_checkpoint:
            print(f"Loading checkpoint {args.eval_checkpoint} for evaluation...")
            checkpoint = torch.load(args.eval_checkpoint, map_location=device)
            model.load_state_dict(checkpoint["model"])

        print("Evaluating model...")
        test_loss = evaluate_model(model, test_loader)
        print(f"Test Loss: {test_loss:.4f}")

        print("Calculating retrieval metrics...")
        retrieval_metrics = calculate_retrieval_metrics(model, test_dataset)
        for metric, value in retrieval_metrics.items():
            print(f"{metric}: {value:.4f}")

        # Save metrics
        metrics = {"test_loss": test_loss, **retrieval_metrics}
        pd.DataFrame([metrics]).to_csv(os.path.join(args.output_dir, "eval_metrics.csv"), index=False)
    else:
        if args.resume_checkpoint:
            print(f"Resuming from checkpoint {args.resume_checkpoint}...")
            checkpoint = torch.load(args.resume_checkpoint, map_location=device)
            model.load_state_dict(checkpoint["model"])
            start_epoch = checkpoint.get("epoch", 0) + 1
            print(f"Resuming from epoch {start_epoch}")
        else:
            start_epoch = 0

        print("Training model...")
        train_losses = []
        val_losses = []
        best_val_loss = float('inf')

        for epoch in range(start_epoch, args.num_epochs):
            model.train()
            epoch_loss = 0.0
            batch_count = 0

            for batch in train_loader:
                try:
                    smiles_texts = batch["smiles_text"]
                    protein_sequences = batch["protein_sequence"]

                    output = model(smiles_texts, protein_sequences)
                    loss = output.loss

                    optimizer = torch.optim.AdamW(
                        [p for p in model.parameters() if p.requires_grad],
                        lr=args.learning_rate,
                        weight_decay=args.weight_decay
                    )

                    optimizer.zero_grad()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                    optimizer.step()

                    epoch_loss += loss.item()
                    batch_count += 1

                    if batch_count % 10 == 0:
                        print(f"Epoch {epoch+1}/{args.num_epochs}, Batch {batch_count}, Loss: {loss.item():.4f}")

                except Exception as e:
                    print(f"Error during training: {e}")
                    continue

            avg_train_loss = epoch_loss / batch_count if batch_count > 0 else float('inf')
            train_losses.append(avg_train_loss)

            # Validation
            val_loss = evaluate_model(model, val_loader)
            val_losses.append(val_loss)

            print(f"Epoch {epoch+1}/{args.num_epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {val_loss:.4f}")

            # Save checkpoint
            if (epoch + 1) % args.checkpoint_interval == 0:
                checkpoint_path = os.path.join(args.checkpoint_dir, f"checkpoint_epoch_{epoch+1}.pth")
                torch.save({
                    "model": model.state_dict(),
                    "epoch": epoch,
                    "train_loss": avg_train_loss,
                    "val_loss": val_loss
                }, checkpoint_path)
                print(f"Saved checkpoint to {checkpoint_path}")

            # Save best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_model_path = os.path.join(args.checkpoint_dir, "best_model.pth")
                torch.save({
                    "model": model.state_dict(),
                    "epoch": epoch,
                    "train_loss": avg_train_loss,
                    "val_loss": val_loss
                }, best_model_path)
                print(f"New best model saved with validation loss: {val_loss:.4f}")

        # Plot training metrics
        plot_training_metrics(train_losses, val_losses, args.output_dir)

        # Save final model
        final_model_path = os.path.join(args.checkpoint_dir, "final_model.pth")
        torch.save({
            "model": model.state_dict(),
            "train_losses": train_losses,
            "val_losses": val_losses
        }, final_model_path)
        print(f"Saved final model to {final_model_path}")

        # Load best model for evaluation
        best_model_path = os.path.join(args.checkpoint_dir, "best_model.pth")
        checkpoint = torch.load(best_model_path, map_location=device)
        model.load_state_dict(checkpoint["model"])

        # Evaluate on test set
        test_loss = evaluate_model(model, test_loader)
        print(f"Test Loss: {test_loss:.4f}")

        # Calculate retrieval metrics
        retrieval_metrics = calculate_retrieval_metrics(model, test_dataset)
        for metric, value in retrieval_metrics.items():
            print(f"{metric}: {value:.4f}")

        # Save metrics
        metrics = {
            "test_loss": test_loss,
            "best_val_loss": best_val_loss,
            **retrieval_metrics
        }
        pd.DataFrame([metrics]).to_csv(os.path.join(args.output_dir, "final_metrics.csv"), index=False)

if __name__ == "__main__":
    main()

Writing training_script.py


## data processing and dataset creation

In [None]:
%%writefile data_processor.py

import warnings; warnings.filterwarnings("ignore")

import os
import torch
import argparse
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, random_split
from sklearn.model_selection import train_test_split
from rdkit.Chem import AllChem, DataStructs
from sklearn.cluster import AgglomerativeClustering



from blip2_molt5_prot5_qformer import (
    Blip2QformerMolT5ProtT5,
    MolProtDataset,
    MolProtPair,
    custom_collate,
    train_blip2_molt5_prot5,
    get_top_similar_proteins
)

def parse_args():
    parser = argparse.ArgumentParser(description="Train BLIP2-MolT5-ProtT5-QFormer model")

    parser.add_argument("--data_path", type=str, required=True,
                        help="Path to data CSV file with SMILES and protein sequences")
    parser.add_argument("--output_dir", type=str, default="results",
                        help="Directory for results and metrics")
    parser.add_argument("--processed_data_dir", type=str, default="processed_data",
                        help="Directory for processed data")
    parser.add_argument("--test_size", type=float, default=0.1,
                        help="Test set size ratio")
    parser.add_argument("--val_size", type=float, default=0.1,
                        help="Validation set size ratio")

    parser.add_argument("--bert_name", type=str, default="bert-base-uncased",
                        help="BERT model name for QFormer")
    parser.add_argument("--molt5_model", type=str, default="laituan245/molt5-base",
                        help="MolT5 model name")
    parser.add_argument("--prot5_model", type=str, default="Rostlab/prot_t5_xl_uniref50",
                        help="ProtT5 model name")
    parser.add_argument("--embedding_dim", type=int, default=1024,
                        help="Embedding dimension")
    parser.add_argument("--num_query_token", type=int, default=32,
                        help="Number of query tokens for QFormer")
    parser.add_argument("--cross_attention_freq", type=int, default=2,
                        help="Cross attention frequency")
    parser.add_argument("--temperature", type=float, default=0.05,
                        help="Temperature for contrastive loss")

    parser.add_argument("--freeze_molt5", action="store_true",
                        help="Freeze MolT5 parameters")
    parser.add_argument("--freeze_prot5", action="store_true",
                        help="Freeze ProtT5 parameters")
    parser.add_argument("--tune_qformer", action="store_true",
                        help="Tune QFormer parameters")

    parser.add_argument("--batch_size", type=int, default=16,
                        help="Batch size")
    parser.add_argument("--num_epochs", type=int, default=10,
                        help="Number of training epochs")
    parser.add_argument("--learning_rate", type=float, default=5e-5,
                        help="Learning rate")
    parser.add_argument("--weight_decay", type=float, default=0.05,
                        help="Weight decay")
    parser.add_argument("--warmup_steps", type=int, default=1000,
                        help="Warmup steps")

    parser.add_argument("--checkpoint_dir", type=str, default="checkpoints",
                        help="Directory for checkpoints")
    parser.add_argument("--checkpoint_interval", type=int, default=1,
                        help="Checkpoint interval in epochs")
    parser.add_argument("--resume_checkpoint", type=str, default=None,
                        help="Checkpoint to resume training from")

    parser.add_argument("--eval_only", action="store_true",
                        help="Run evaluation only")
    parser.add_argument("--eval_checkpoint", type=str, default=None,
                        help="Checkpoint to use for evaluation")

    parser.add_argument("--seed", type=int, default=42,
                        help="Random seed")

    return parser.parse_args()

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def compute_fingerprint(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol:
        return AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048)
    return None

def cluster_molecules(df, threshold=0.3):
    fps = df['smiles'].apply(compute_fingerprint)
    valid_idx = [i for i, fp in enumerate(fps) if fp is not None]
    fps = [fps[i] for i in valid_idx]

    n = len(fps)
    sim_matrix = np.zeros((n, n))
    for i in range(n):
        for j in range(i + 1, n):
            sim = DataStructs.FingerprintSimilarity(fps[i], fps[j])
            sim_matrix[i, j] = sim
            sim_matrix[j, i] = sim

    clustering = AgglomerativeClustering(n_clusters=None, affinity='precomputed', linkage='complete', distance_threshold=threshold)
    df.loc[df.index[valid_idx], 'cluster'] = clustering.fit_predict(1 - sim_matrix)
    return df.dropna(subset=['cluster'])

# Remove invalid SMILES and protein sequences
def is_valid_smiles(smiles):
    if not isinstance(smiles, str):
        return False
    # Basic check - could be improved with RDKit validation
    return len(smiles) > 0 and all(c in 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789()[]+=-#:.\\/@' for c in smiles)

def is_valid_protein(seq):
    if not isinstance(seq, str):
        return False
    # Check if the sequence consists of valid amino acid letters
    valid_aa = set('ACDEFGHIKLMNPQRSTVWYXBZOU')
    return len(seq) > 0 and all(aa in valid_aa for aa in seq.upper())

def cluster_molecules(df, threshold=0.3):
    df = df.reset_index(drop=True)
    fps = df['smiles'].apply(compute_fingerprint)
    valid_entries = [(i, fp) for i, fp in enumerate(fps) if fp is not None]
    valid_idx, fps = zip(*valid_entries) if valid_entries else ([], [])

    n = len(fps)
    sim_matrix = np.zeros((n, n))
    for i in range(n):
        for j in range(i + 1, n):
            sim = DataStructs.FingerprintSimilarity(fps[i], fps[j])
            sim_matrix[i, j] = sim
            sim_matrix[j, i] = sim

    clustering = AgglomerativeClustering(n_clusters=None, metric='precomputed', linkage='complete', distance_threshold=threshold)
    df.loc[df.index[list(valid_idx)], 'cluster'] = clustering.fit_predict(1 - sim_matrix)
    return df.dropna(subset=['cluster'])

def preprocess_data(data_path, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    df = pd.read_csv(data_path)
    print(f"Original data shape: {df.shape}")
    df = df.dropna(subset=['smiles', 'protein_sequence'])
    df['valid_smiles'] = df['smiles']
    df['valid_protein'] = df['protein_sequence'].apply(is_valid_protein)
    df = df[df['valid_smiles'] & df['valid_protein']]
    df['smiles_len'] = df['smiles'].apply(len)
    df['protein_len'] = df['protein_sequence'].apply(len)
    df = cluster_molecules(df)
    print(f"Preprocessed data shape: {df.shape}")
    df.to_csv(os.path.join(output_dir, "preprocessed_data.csv"), index=False)
    return df

def create_datasets(df, val_size=0.15, test_size=0.15, seed=42):
    clusters = df['cluster'].unique()
    np.random.seed(seed)
    np.random.shuffle(clusters)
    train_cutoff = int(0.7 * len(clusters))
    val_cutoff = int(0.85 * len(clusters))
    train_clusters, val_clusters, test_clusters = clusters[:train_cutoff], clusters[train_cutoff:val_cutoff], clusters[val_cutoff:]
    train_df = df[df['cluster'].isin(train_clusters)]
    val_df = df[df['cluster'].isin(val_clusters)]
    test_df = df[df['cluster'].isin(test_clusters)]
    print(f"Train set: {len(train_df)} samples from {len(train_clusters)} clusters")
    print(f"Validation set: {len(val_df)} samples from {len(val_clusters)} clusters")
    print(f"Test set: {len(test_df)} samples from {len(test_clusters)} clusters")
    train_pairs = [MolProtPair(row['smiles'], row['protein_sequence']) for _, row in train_df.iterrows()]
    val_pairs = [MolProtPair(row['smiles'], row['protein_sequence']) for _, row in val_df.iterrows()]
    test_pairs = [MolProtPair(row['smiles'], row['protein_sequence']) for _, row in test_df.iterrows()]
    return MolProtDataset(train_pairs), MolProtDataset(val_pairs), MolProtDataset(test_pairs)


def create_dataloaders(train_dataset, val_dataset, test_dataset, batch_size=16, num_workers=4):
    train_loader = None
    if train_dataset:
        train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            collate_fn=custom_collate,
            num_workers=num_workers,
            drop_last=True
        )

    val_loader = None
    if val_dataset:
        val_loader = DataLoader(
            val_dataset,
            batch_size=batch_size,
            shuffle=False,
            collate_fn=custom_collate,
            num_workers=num_workers
        )

    test_loader = None
    if test_dataset:
        test_loader = DataLoader(
            test_dataset,
            batch_size=batch_size,
            shuffle=False,
            collate_fn=custom_collate,
            num_workers=num_workers
        )

    return train_loader, val_loader, test_loader

def evaluate_model(model, loader):
    model.eval()
    total_loss = 0.0

    with torch.no_grad():
        for batch in loader:
            smiles_texts = batch["smiles_text"]
            protein_sequences = batch["protein_sequence"]

            try:
                output = model(smiles_texts, protein_sequences)
                loss = output.loss

                total_loss += loss.item()
            except Exception as e:
                print(f"Error during evaluation: {e}")
                continue

    avg_loss = total_loss / len(loader)
    return avg_loss

def plot_training_metrics(train_losses, val_losses, output_dir):
    plt.figure(figsize=(10, 6))
    epochs = range(1, len(train_losses) + 1)

    plt.plot(epochs, train_losses, 'b-', label='Training Loss')
    plt.plot(epochs, val_losses, 'r-', label='Validation Loss')

    plt.title('Training and Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    os.makedirs(output_dir, exist_ok=True)
    plt.savefig(os.path.join(output_dir, 'training_metrics.png'))
    plt.close()

def calculate_retrieval_metrics(model, test_dataset, k_values=(1, 5, 10)):
    device = next(model.parameters()).device
    model.eval()

    all_smiles = []
    all_proteins = []

    for pair in test_dataset.pairs:
        all_smiles.append(pair.smiles)
        all_proteins.append(pair.protein)

    metrics = {f"R@{k}": 0.0 for k in k_values}
    metrics["MRR"] = 0.0

    n_samples = min(len(all_smiles), 1000)  # Limit to 1000 samples for evaluation speed
    indices = np.random.choice(len(all_smiles), n_samples, replace=False)

    for i in indices:
        query_smiles = all_smiles[i]
        true_protein = all_proteins[i]

        # Get similarity scores between query and all proteins
        similarities = []
        batch_size = 16

        for j in range(0, len(all_proteins), batch_size):
            batch_proteins = all_proteins[j:j+batch_size]
            batch_smiles = [query_smiles] * len(batch_proteins)

            try:
                batch_similarities = model.predict_similarity(batch_smiles, batch_proteins)
                similarities.extend(batch_similarities)
            except Exception as e:
                print(f"Error during similarity prediction: {e}")
                similarities.extend([0.0] * len(batch_proteins))

        # Find the rank of the true protein
        true_idx = all_proteins.index(true_protein)
        sorted_indices = np.argsort(similarities)[::-1]
        rank = np.where(sorted_indices == true_idx)[0][0] + 1  # Convert to 1-indexed

        # Calculate metrics
        metrics["MRR"] += 1.0 / rank

        for k in k_values:
            if rank <= k:
                metrics[f"R@{k}"] += 1.0

    # Normalize metrics
    for key in metrics:
        metrics[key] /= n_samples

    return metrics

def main():
    args = parse_args()
    set_seed(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Create directories
    os.makedirs(args.output_dir, exist_ok=True)
    os.makedirs(args.checkpoint_dir, exist_ok=True)
    os.makedirs(args.processed_data_dir, exist_ok=True)

    # Preprocess data
    print("Preprocessing data...")
    df = preprocess_data(args.data_path, args.processed_data_dir)

    # Create datasets
    print("Creating datasets...")
    train_dataset, val_dataset, test_dataset = create_datasets(
        df, val_size=args.val_size, test_size=args.test_size, seed=args.seed
    )

    # Create dataloaders
    print("Creating dataloaders...")
    train_loader, val_loader, test_loader = create_dataloaders(
        train_dataset, val_dataset, test_dataset, batch_size=args.batch_size
    )

    # Initialize model
    print("Initializing model...")
    model = Blip2QformerMolT5ProtT5(
        bert_name=args.bert_name,
        temperature=args.temperature,
        freeze_molt5=args.freeze_molt5,
        freeze_prot5=args.freeze_prot5,
        molt5_model_name=args.molt5_model,
        prot5_model_name=args.prot5_model,
        tune_qformer=args.tune_qformer,
        num_query_token=args.num_query_token,
        cross_attention_freq=args.cross_attention_freq,
        embed_dim=args.embedding_dim
    ).to(device)

    if args.eval_only:
        if args.eval_checkpoint:
            print(f"Loading checkpoint {args.eval_checkpoint} for evaluation...")
            checkpoint = torch.load(args.eval_checkpoint, map_location=device)
            model.load_state_dict(checkpoint["model"])

        print("Evaluating model...")
        test_loss = evaluate_model(model, test_loader)
        print(f"Test Loss: {test_loss:.4f}")

        print("Calculating retrieval metrics...")
        retrieval_metrics = calculate_retrieval_metrics(model, test_dataset)
        for metric, value in retrieval_metrics.items():
            print(f"{metric}: {value:.4f}")

        # Save metrics
        metrics = {"test_loss": test_loss, **retrieval_metrics}
        pd.DataFrame([metrics]).to_csv(os.path.join(args.output_dir, "eval_metrics.csv"), index=False)
    else:
        if args.resume_checkpoint:
            print(f"Resuming from checkpoint {args.resume_checkpoint}...")
            checkpoint = torch.load(args.resume_checkpoint, map_location=device)
            model.load_state_dict(checkpoint["model"])
            start_epoch = checkpoint.get("epoch", 0) + 1
            print(f"Resuming from epoch {start_epoch}")
        else:
            start_epoch = 0

        print("Training model...")
        train_losses = []
        val_losses = []
        best_val_loss = float('inf')

        for epoch in range(start_epoch, args.num_epochs):
            model.train()
            epoch_loss = 0.0
            batch_count = 0

            for batch in train_loader:
                try:
                    smiles_texts = batch["smiles_text"]
                    protein_sequences = batch["protein_sequence"]

                    output = model(smiles_texts, protein_sequences)
                    loss = output.loss

                    optimizer = torch.optim.AdamW(
                        [p for p in model.parameters() if p.requires_grad],
                        lr=args.learning_rate,
                        weight_decay=args.weight_decay
                    )

                    optimizer.zero_grad()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                    optimizer.step()

                    epoch_loss += loss.item()
                    batch_count += 1

                    if batch_count % 10 == 0:
                        print(f"Epoch {epoch+1}/{args.num_epochs}, Batch {batch_count}, Loss: {loss.item():.4f}")

                except Exception as e:
                    print(f"Error during training: {e}")
                    continue

            avg_train_loss = epoch_loss / batch_count if batch_count > 0 else float('inf')
            train_losses.append(avg_train_loss)

            # Validation
            val_loss = evaluate_model(model, val_loader)
            val_losses.append(val_loss)

            print(f"Epoch {epoch+1}/{args.num_epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {val_loss:.4f}")

            # Save checkpoint
            if (epoch + 1) % args.checkpoint_interval == 0:
                checkpoint_path = os.path.join(args.checkpoint_dir, f"checkpoint_epoch_{epoch+1}.pth")
                torch.save({
                    "model": model.state_dict(),
                    "epoch": epoch,
                    "train_loss": avg_train_loss,
                    "val_loss": val_loss
                }, checkpoint_path)
                print(f"Saved checkpoint to {checkpoint_path}")

            # Save best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_model_path = os.path.join(args.checkpoint_dir, "best_model.pth")
                torch.save({
                    "model": model.state_dict(),
                    "epoch": epoch,
                    "train_loss": avg_train_loss,
                    "val_loss": val_loss
                }, best_model_path)
                print(f"New best model saved with validation loss: {val_loss:.4f}")

        # Plot training metrics
        plot_training_metrics(train_losses, val_losses, args.output_dir)

        # Save final model
        final_model_path = os.path.join(args.checkpoint_dir, "final_model.pth")
        torch.save({
            "model": model.state_dict(),
            "train_losses": train_losses,
            "val_losses": val_losses
        }, final_model_path)
        print(f"Saved final model to {final_model_path}")

        # Load best model for evaluation
        best_model_path = os.path.join(args.checkpoint_dir, "best_model.pth")
        checkpoint = torch.load(best_model_path, map_location=device)
        model.load_state_dict(checkpoint["model"])

        # Evaluate on test set
        test_loss = evaluate_model(model, test_loader)
        print(f"Test Loss: {test_loss:.4f}")

        # Calculate retrieval metrics
        retrieval_metrics = calculate_retrieval_metrics(model, test_dataset)
        for metric, value in retrieval_metrics.items():
            print(f"{metric}: {value:.4f}")

        # Save metrics
        metrics = {
            "test_loss": test_loss,
            "best_val_loss": best_val_loss,
            **retrieval_metrics
        }
        pd.DataFrame([metrics]).to_csv(os.path.join(args.output_dir, "final_metrics.csv"), index=False)

if __name__ == "__main__":
    main()

Writing data_processor.py


# run

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# !python training_script.py --data_path "/content/drive/MyDrive/enzymes + plastics/2024/sigma_data.csv" --batch_size 16 --num_epochs 20 --learning_rate 5e-5 --freeze_polybert --freeze_prot5 --checkpoint_dir "/content/drive/MyDrive/enzymes + plastics/2024/checkpoints/" --output_dir "/content/drive/MyDrive/enzymes + plastics/2024/results"

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Creating datasets...
Train set: 4896 samples from 16 clusters
Validation set: 3366 samples from 3 clusters
Test set: 765 samples from 4 clusters
Creating dataloaders...
Initializing model...
You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Some weights of BertLMHeadModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.encoder.layer.0.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.0.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.0.crossatt

In [None]:
!python training_script.py --data_path "/content/drive/MyDrive/enzymes + plastics/2024/sigma_data.csv" --batch_size 16 --num_epochs 20 --learning_rate 5e-5 --checkpoint_dir "/content/drive/MyDrive/enzymes + plastics/2024/checkpoints/" --output_dir "/content/drive/MyDrive/enzymes + plastics/2024/results"

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Creating datasets...
Train set: 4896 samples from 16 clusters
Validation set: 3366 samples from 3 clusters
Test set: 765 samples from 4 clusters
Creating dataloaders...
Initializing model...
tokenizer_config.json: 100% 48.0/48.0 [00:00<00:00, 334kB/s]
vocab.txt: 100% 232k/232k [00:00<00:00, 71.0MB/s]
tokenizer.json: 100% 466k/466k [00:00<00:00, 1.04MB/s]
config.json: 100% 570/570 [00:00<00:00, 4.39MB/s]
config.json: 100% 756/756 [00:00<00:00, 6.82MB/s]
pytorch_model.bin: 100% 101M/101M [00:00<00:00, 226MB/s] 
tokenizer_config.json: 100% 382/382 [00:00<00:00, 3.84MB/s]
spm.model: 100% 242k/242k [00:00<00:00, 345MB/s]
tokenizer.json: 100% 331k/331k [00:00<00:00, 41.1MB/s]
model.safetensors:  31% 31.5M/101M [00:00<00:00, 229MB/s]
added_tokens.json: 100% 84.0/84.0 [00:00<00:00, 788kB/s]
model.safetensors: 100% 101M/101M [00:00<00:00, 238MB/s] 
special_tokens_map.json: 100% 173/173 [00:00<00:00, 1.53MB/s]
tokenizer_config.json