In [None]:
## script

#### installs

In [None]:
!pip install scanpy

Collecting scanpy
  Downloading scanpy-1.11.0-py3-none-any.whl.metadata (9.5 kB)
Collecting anndata>=0.8 (from scanpy)
  Downloading anndata-0.11.3-py3-none-any.whl.metadata (8.2 kB)
Collecting legacy-api-wrap>=1.4 (from scanpy)
  Downloading legacy_api_wrap-1.4.1-py3-none-any.whl.metadata (2.1 kB)
Collecting scikit-learn<1.6.0,>=1.1 (from scanpy)
  Downloading scikit_learn-1.5.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)
Collecting session-info2 (from scanpy)
  Downloading session_info2-0.1.2-py3-none-any.whl.metadata (2.5 kB)
Collecting array-api-compat!=1.5,>1.4 (from anndata>=0.8->scanpy)
  Downloading array_api_compat-1.11.1-py3-none-any.whl.metadata (1.8 kB)
Downloading scanpy-1.11.0-py3-none-any.whl (2.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m42.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading anndata-0.11.3-py3-none-any.whl (142 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m142.

In [None]:
!pip install pytorch_lightning

Collecting pytorch_lightning
  Downloading pytorch_lightning-2.5.0.post0-py3-none-any.whl.metadata (21 kB)
Collecting torchmetrics>=0.7.0 (from pytorch_lightning)
  Downloading torchmetrics-1.6.3-py3-none-any.whl.metadata (20 kB)
Collecting lightning-utilities>=0.10.0 (from pytorch_lightning)
  Downloading lightning_utilities-0.14.0-py3-none-any.whl.metadata (5.6 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.1.0->pytorch_lightning)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.1.0->pytorch_lightning)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=2.1.0->pytorch_lightning)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=2.1.0->pytorch_lightning)
  Dow

## blip2_molt5_qformer.py

In [23]:
%%writefile blip2_molt5_qformer.py
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import warnings; warnings.filterwarnings("ignore")

import contextlib
import logging
import os
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.cuda.amp import autocast as autocast
from torch.nn import functional as F
from transformers import AutoTokenizer, T5ForConditionalGeneration
from transformers import BertTokenizer, BertConfig, BertLMHeadModel
# from pytorch_lightning.utilities import distributed
from typing import Dict, List, Optional, Tuple, Union

class LayerNorm(nn.LayerNorm):
    def forward(self, x: torch.Tensor, mask=None):
        orig_type = x.dtype
        ret = super().forward(x.type(torch.float32))
        return ret.type(orig_type)

def disabled_train(self, mode=True):
    return self

# @torch.no_grad()
# def pl_concat_all_gather(tensor, cat=True):
#     if not is_dist_avail_and_initialized():
#         if not cat:
#             return [tensor]
#         return tensor

#     output = distributed.gather_all_tensors(tensor)
#     if cat:
#         output = torch.cat(output, dim=0)
#     return output

def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True

class BlipOutput:
    def __init__(self, loss=None, logits=None, similarity=None, intermediate_output=None):
        self.loss = loss
        self.logits = logits
        self.similarity = similarity
        self.intermediate_output = intermediate_output

class BlipBaseQFormer(nn.Module):  # Add nn.Module as parent class
    def __init__(self):
        super().__init__()  # Call nn.Module's __init__
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def init_tokenizer(cls, bert_name="bert-base-uncased"):
        tokenizer = BertTokenizer.from_pretrained(bert_name)
        tokenizer.add_special_tokens({"bos_token": "[DEC]"})
        return tokenizer

    def maybe_autocast(self, dtype=torch.float16):
        enable_autocast = self.device != torch.device("cpu")
        if enable_autocast:
            return torch.cuda.amp.autocast(dtype=dtype)
        else:
            return contextlib.nullcontext()

    @classmethod
    def init_Qformer(cls, bert_name, num_query_token, encoder_width, cross_attention_freq=2):
        # encoder_config = BertConfig.from_pretrained(bert_name)
        encoder_config = BertConfig.from_pretrained(bert_name, is_decoder=True)
        encoder_config.encoder_width = encoder_width
        encoder_config.add_cross_attention = True
        encoder_config.cross_attention_freq = cross_attention_freq
        encoder_config.query_length = num_query_token

        Qformer = BertLMHeadModel.from_pretrained(
            bert_name, config=encoder_config
        )
        query_tokens = nn.Parameter(
            torch.zeros(1, num_query_token, encoder_config.hidden_size)
        )
        query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
        return Qformer, query_tokens

    def load_from_pretrained(self, url_or_filename):
        if os.path.isfile(url_or_filename):
            checkpoint = torch.load(url_or_filename, map_location="cpu")
        else:
            raise RuntimeError("checkpoint url or path is invalid")

        state_dict = checkpoint["model"]
        msg = self.load_state_dict(state_dict, strict=False)
        logging.info("load checkpoint from %s" % url_or_filename)
        return msg

class Blip2QformerMolT5(BlipBaseQFormer):
    def __init__(
        self,
        bert_name="bert-base-uncased",
        temperature=0.05,
        freeze_molt5=True,
        molt5_model_name="laituan245/molt5-base",
        tune_qformer=True,
        num_query_token=32,
        cross_attention_freq=2,
        embed_dim=256,
    ):
        super().__init__()

        self.tokenizer = self.init_tokenizer(bert_name)
        self.molt5_tokenizer = AutoTokenizer.from_pretrained(molt5_model_name, model_max_length=512)
        self.molt5 = T5ForConditionalGeneration.from_pretrained(molt5_model_name)

        self.freeze_molt5 = freeze_molt5
        if freeze_molt5:
            for param in self.molt5.parameters():
                param.requires_grad = False
            self.molt5.eval()
            self.molt5.train = disabled_train
            logging.info("freeze MolT5 encoder and decoder")

        self.Qformer, self.query_tokens = self.init_Qformer(
            bert_name,
            num_query_token,
            self.molt5.config.d_model,
            cross_attention_freq
        )
        self.Qformer.resize_token_embeddings(len(self.tokenizer))

        state_dict = self.Qformer.state_dict()
        for name, param in self.Qformer.named_parameters():
            if "_query" in name:
                key_orig = name.replace("_query", "")
                param.data.copy_(state_dict[key_orig])

        self.cell_proj = nn.Linear(embed_dim, self.Qformer.config.hidden_size)
        self.molt5_proj = nn.Linear(self.Qformer.config.hidden_size, embed_dim)
        self.text_proj = nn.Linear(self.Qformer.config.hidden_size, embed_dim)

        if not tune_qformer:
            for name, param in self.Qformer.named_parameters():
                param.requires_grad = False
            self.Qformer.eval()
            self.Qformer.train = disabled_train
            logging.info("freeze QFormer")

        self.temperature = temperature
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def encode_molt5(self, smiles_texts):
        molt5_inputs = self.molt5_tokenizer(
            smiles_texts,
            padding=True,
            truncation=True,
            return_tensors="pt"
        ).to(self.device)

        # Debug input tokenization
        print(f"MolT5 input shape: {molt5_inputs.input_ids.shape}")

        with torch.no_grad():  # Use no_grad to ensure stability
            try:
                molt5_outputs = self.molt5.encoder(
                    input_ids=molt5_inputs.input_ids,
                    attention_mask=molt5_inputs.attention_mask,
                    return_dict=True
                )
                molt5_embeds = molt5_outputs.last_hidden_state

                # Check for NaNs right after model output
                if torch.isnan(molt5_embeds).any():
                    print("NaN values detected in raw MolT5 embeddings - applying fix")
                    molt5_embeds = torch.nan_to_num(molt5_embeds, nan=0.0)

                molt5_attention_mask = molt5_inputs.attention_mask
            except Exception as e:
                print(f"Error in MolT5 encoding: {e}")
                # Provide fallback embeddings if there's an error
                batch_size = len(smiles_texts)
                dim = self.molt5.config.d_model
                seq_len = 20  # Reasonable default
                molt5_embeds = torch.zeros((batch_size, seq_len, dim), device=self.device)
                molt5_attention_mask = torch.ones((batch_size, seq_len), device=self.device)

        # Additional safeguard to ensure no NaNs propagate
        molt5_embeds = torch.nan_to_num(molt5_embeds, nan=0.0)

        return molt5_embeds, molt5_attention_mask

    def encode_cell(self, cell_embeddings):
        cell_embeddings = self.cell_proj(cell_embeddings)
        return cell_embeddings

    def forward(self, cell_embeddings, smiles_texts):
        # Add sequence dimension to cell embeddings
        cell_embeds = self.encode_cell(cell_embeddings)
        cell_embeds = cell_embeds.unsqueeze(1)  # Add sequence dimension [batch, 1, hidden_size]

        molt5_embeds, molt5_attention_mask = self.encode_molt5(smiles_texts)

        batch_size = cell_embeds.size(0)
        query_tokens = self.query_tokens.expand(batch_size, -1, -1)

        # Get cell features via QFormer
        query_output = self.Qformer.bert(
            inputs_embeds=query_tokens,
            attention_mask=torch.ones(query_tokens.size()[:-1], device=query_tokens.device),
            encoder_hidden_states=cell_embeds,
            encoder_attention_mask=None,
            return_dict=True,
        )
        cell_feats = self.text_proj(query_output.last_hidden_state)

        # Same fix for the molt5 part
        query_molt5_output = self.Qformer.bert(
            inputs_embeds=query_tokens,
            attention_mask=torch.ones(query_tokens.size()[:-1], device=query_tokens.device),
            encoder_hidden_states=molt5_embeds,
            encoder_attention_mask=molt5_attention_mask,
            return_dict=True,
        )
        molt5_feats = self.molt5_proj(query_molt5_output.last_hidden_state)

        # Normalize features
        cell_feats = F.normalize(cell_feats, p=2, dim=-1)
        molt5_feats = F.normalize(molt5_feats, p=2, dim=-1)

        # Add small epsilon to prevent division by zero
        eps = 1e-8
        cell_feats = cell_feats / (cell_feats + eps)
        molt5_feats = molt5_feats / (molt5_feats + eps)

        cell_feats_all = cell_feats
        molt5_feats_all = molt5_feats

        # Compute contrastive loss
        sim_c2m, sim_m2c, loss_contrastive = self.contrast_global(
            cell_feats, molt5_feats, cell_feats_all, molt5_feats_all, return_sim=True
        )

        return BlipOutput(
            loss=loss_contrastive,
            similarity={"cell2molt5": sim_c2m, "molt52cell": sim_m2c}
        )

    def contrast_global(self, features_cell, features_molt5, features_cell_all, features_molt5_all, return_sim=False):
        batch_size = features_cell.size(0)

        # Debug: Check for NaNs
        if torch.isnan(features_cell).any():
            print("NaN detected in features_cell")
        if torch.isnan(features_molt5).any():
            print("NaN detected in features_molt5")

        # Compute similarity scores for cell->molt5 and molt5->cell
        sim_c2m = torch.matmul(features_cell.view(batch_size, -1), features_molt5_all.view(batch_size, -1).transpose(0, 1))
        sim_m2c = torch.matmul(features_molt5.view(batch_size, -1), features_cell_all.view(batch_size, -1).transpose(0, 1))

        # Debug: Check similarity matrices
        if torch.isnan(sim_c2m).any():
            print("NaN detected in sim_c2m")
        if torch.isnan(sim_m2c).any():
            print("NaN detected in sim_m2c")

        # Scale by temperature (with additional safeguard)
        temp = max(self.temperature, 1e-8)  # Prevent division by zero
        logits_c2m = sim_c2m / temp
        logits_m2c = sim_m2c / temp

        # Create labels
        labels = torch.arange(batch_size, device=self.device)

        # Compute loss with safeguards
        try:
            loss_c2m = F.cross_entropy(logits_c2m, labels)
            loss_m2c = F.cross_entropy(logits_m2c, labels)
            loss = (loss_c2m + loss_m2c) / 2
        except Exception as e:
            print(f"Error in cross entropy: {e}")
            print(f"logits_c2m shape: {logits_c2m.shape}, min: {logits_c2m.min().item()}, max: {logits_c2m.max().item()}")
            print(f"logits_m2c shape: {logits_m2c.shape}, min: {logits_m2c.min().item()}, max: {logits_m2c.max().item()}")
            print(f"labels shape: {labels.shape}, values: {labels}")
            return None, None, torch.tensor(float('nan'), device=self.device) if return_sim else torch.tensor(float('nan'), device=self.device)

        if return_sim:
            return sim_c2m, sim_m2c, loss
        else:
            return loss

    def predict_similarity(self, cell_embedding, smiles_text):
        if cell_embedding.dim() == 1:
            cell_embedding = cell_embedding.unsqueeze(0)

        if isinstance(smiles_text, str):
            smiles_text = [smiles_text]

        # Add sequence dimension to cell embeddings
        cell_embeds = self.encode_cell(cell_embedding)
        cell_embeds = cell_embeds.unsqueeze(1)  # Add sequence dimension [batch, 1, hidden_size]

        molt5_embeds, molt5_attention_mask = self.encode_molt5(smiles_text)

        batch_size = cell_embeds.size(0)
        query_tokens = self.query_tokens.expand(batch_size, -1, -1)

        # Get cell features via QFormer
        query_output = self.Qformer.bert(
            inputs_embeds=query_tokens,
            attention_mask=torch.ones(query_tokens.size()[:-1], device=query_tokens.device),
            encoder_hidden_states=cell_embeds,
            encoder_attention_mask=None,
            return_dict=True,
        )
        cell_feats = self.text_proj(query_output.last_hidden_state)

        # Same fix for the molt5 part
        query_molt5_output = self.Qformer.bert(
            inputs_embeds=query_tokens,
            attention_mask=torch.ones(query_tokens.size()[:-1], device=query_tokens.device),
            encoder_hidden_states=molt5_embeds,
            encoder_attention_mask=molt5_attention_mask,
            return_dict=True,
        )
        molt5_feats = self.molt5_proj(query_molt5_output.last_hidden_state)

        # Normalize features
        cell_feats = F.normalize(cell_feats, p=2, dim=-1)
        molt5_feats = F.normalize(molt5_feats, p=2, dim=-1)

        # Compute similarity
        similarity = torch.sum(cell_feats * molt5_feats, dim=-1)

        return similarity.cpu().tolist()

class CellMolDataset(torch.utils.data.Dataset):
    def __init__(self, pairs):
        self.pairs = pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        pair = self.pairs[idx]
        # Ensure cell_embedding is a tensor of floats
        if isinstance(pair.cell_embedding, str):
            print(f"Warning: Found string in cell_embedding at index {idx}")
            # Provide a fallback embedding or skip this sample
            return {
                "cell_embedding": torch.zeros(512, dtype=torch.float32),
                "smiles_text": pair.smiles
            }
        return {
            "cell_embedding": pair.cell_embedding,
            "smiles_text": pair.smiles
        }


def custom_collate(batch):
    # Filter out any samples with invalid data types
    valid_batch = []
    for item in batch:
        if torch.is_tensor(item["cell_embedding"]) and isinstance(item["smiles_text"], str):
            valid_batch.append(item)
        else:
            print(f"Skipping invalid item: {type(item['cell_embedding'])}, {type(item['smiles_text'])}")

    if not valid_batch:
        # Provide a minimal valid batch
        return {
            "cell_embedding": torch.zeros((1, 512), dtype=torch.float32),
            "smiles_text": ["C"]  # Simplest SMILES string
        }

    # Standard collation for valid items
    cell_embeddings = torch.stack([item["cell_embedding"] for item in valid_batch])
    smiles_texts = [item["smiles_text"] for item in valid_batch]

    return {
        "cell_embedding": cell_embeddings,
        "smiles_text": smiles_texts
    }

def train_blip2_molt5(
    model,
    train_dataset,
    val_dataset=None,
    batch_size=32,
    num_epochs=10,
    learning_rate=5e-5,
    weight_decay=0.05,
    warmup_steps=5000,
    checkpoint_dir="checkpoints",
    checkpoint_interval=1
):
    os.makedirs(checkpoint_dir, exist_ok=True)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True,
        collate_fn=custom_collate,
        num_workers=4
    )

    if val_dataset:
        val_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=batch_size,
            shuffle=False,
            collate_fn=custom_collate,
            num_workers=4
        )

    optimizer = torch.optim.AdamW(
        [p for p in model.parameters() if p.requires_grad],
        lr=learning_rate,
        weight_decay=weight_decay
    )

    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=num_epochs * len(train_loader)
    )

    device = model.device
    best_val_loss = float('inf')

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        num_batches = 0

        for batch in train_loader:
            try:
                cell_embeddings = batch["cell_embedding"].to(device)
                smiles_texts = batch["smiles_text"]

                # Create a mask for valid SMILES
                valid_indices = [i for i, s in enumerate(smiles_texts) if isinstance(s, str)]

                # Filter both cell embeddings and smiles texts using the same indices
                if len(valid_indices) < len(smiles_texts):
                    cell_embeddings = cell_embeddings[valid_indices]
                    smiles_texts = [smiles_texts[i] for i in valid_indices]

                if len(smiles_texts) == 0:
                    print("WARNING: No valid SMILES in batch, skipping")
                    continue

                print("Cell embeddings shape:", cell_embeddings.shape)
                print("SMILES texts length:", len(smiles_texts))

                # Handle NaN values
                if torch.isnan(cell_embeddings).any():
                    print("WARNING: NaN values in cell embeddings")
                    cell_embeddings = torch.nan_to_num(cell_embeddings, nan=0.0)

                # Forward pass with error handling
                try:
                    output = model(cell_embeddings, smiles_texts)
                    loss = output.loss if hasattr(output, 'loss') else None
                except Exception as e:
                    print(f"Forward pass error: {e}")
                    continue

                # Handle NaN loss
                if loss is None or torch.isnan(loss):
                    print("WARNING: NaN loss detected, skipping backward pass")
                    continue

                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()

                train_loss += loss.item()
            except Exception as e:
                print(f"Unexpected error in training loop: {e}")
                continue

        avg_train_loss = train_loss / num_batches

        val_loss = 0.0
        if val_dataset:
            model.eval()
            val_num_batches = 0

            with torch.no_grad():
                for batch in val_loader:
                    cell_embeddings = batch["cell_embedding"].to(device)
                    smiles_texts = batch["smiles_text"]

                    output = model(cell_embeddings, smiles_texts)
                    loss = output.loss

                    val_loss += loss.item()
                    val_num_batches += 1

            avg_val_loss = val_loss / val_num_batches

            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                torch.save(
                    {
                        "model": model.state_dict(),
                        "optimizer": optimizer.state_dict(),
                        "epoch": epoch,
                        "val_loss": best_val_loss
                    },
                    os.path.join(checkpoint_dir, f"best_model.pth")
                )

        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}")
        if val_dataset:
            print(f"Validation Loss: {avg_val_loss:.4f}")

        if (epoch + 1) % checkpoint_interval == 0:
            torch.save(
                {
                    "model": model.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "epoch": epoch
                },
                os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch+1}.pth")
            )

    return model

def prepare_data_from_scgpt_embeddings(adata_dir, smiles_path):
    import pandas as pd
    import scanpy as sc

    smiles_df = pd.read_csv(smiles_path)

    cell_embeddings = []
    smiles_texts = []

    for file in os.listdir(adata_dir):
        if file.endswith('.h5ad'):
            try:
                match = re.search(r'([A-Za-z0-9-]+)_block', file)
                if match:
                    drug_name = match.group(1)
                else:
                    continue

                if not any(smiles_df['drug'] == drug_name):
                    continue

                adata = sc.read_h5ad(os.path.join(adata_dir, file))
                diff_embedding = torch.tensor(adata.obsm["X_scGPT"].mean(axis=0))

                drug_smiles = smiles_df[smiles_df['drug'] == drug_name]['SMILES'].iloc[0]

                cell_embeddings.append(diff_embedding)
                smiles_texts.append(drug_smiles)

            except Exception as e:
                print(f"Error processing {file}: {e}")


    # Create pairs object
    pairs = []
    for i in range(len(cell_embeddings)):
        pairs.append(CellSMILESPair(
            cell_embedding=cell_embeddings[i],
            smiles=smiles_texts[i]
        ))

    # Filter out invalid pairs
    valid_pairs = []
    for pair in pairs:
        if torch.is_tensor(pair.cell_embedding) and isinstance(pair.smiles, str):
            valid_pairs.append(pair)
        elif not isinstance(pair.smiles, str):
            print(f"Removing pair with non-string SMILES: {pair.smiles}")
        else:
            print(f"Removing pair with invalid cell embedding type: {type(pair.cell_embedding)}")

    print(f"Kept {len(valid_pairs)} valid pairs out of {len(pairs)} total")

    return valid_pairs

def get_similarity(model, cell_embedding, smiles_list):
    similarity_scores = model.predict_similarity(cell_embedding, smiles_list)

    sorted_indices = torch.argsort(torch.tensor(similarity_scores), descending=True)
    sorted_smiles = [smiles_list[idx] for idx in sorted_indices]
    sorted_scores = [similarity_scores[idx] for idx in sorted_indices]

    return sorted_smiles, sorted_scores

def main():
    import argparse

    parser = argparse.ArgumentParser(description='Train BLIP2-MolT5-QFormer model')
    parser.add_argument('--adata_dir', type=str, required=True, help='Directory with anndata files')
    parser.add_argument('--smiles_path', type=str, required=True, help='Path to SMILES CSV file')
    parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
    parser.add_argument('--num_epochs', type=int, default=10, help='Number of epochs')
    parser.add_argument('--learning_rate', type=float, default=5e-5, help='Learning rate')
    parser.add_argument('--weight_decay', type=float, default=0.05, help='Weight decay')
    parser.add_argument('--checkpoint_dir', type=str, default='checkpoints', help='Checkpoint directory')
    parser.add_argument('--embed_dim', type=int, default=512, help='Cell embedding dimension')
    parser.add_argument('--num_query_token', type=int, default=32, help='Number of query tokens')
    parser.add_argument('--temperature', type=float, default=0.05, help='Temperature for contrastive loss')
    parser.add_argument('--freeze_molt5', action='store_true', help='Freeze MolT5 parameters')
    parser.add_argument('--tune_qformer', action='store_true', help='Tune QFormer parameters')

    args = parser.parse_args()

    valid_pairs = prepare_data_from_scgpt_embeddings(args.adata_dir, args.smiles_path)
    dataset = CellMolDataset(valid_pairs)

    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

    model = Blip2QformerMolT5(
        temperature=args.temperature,
        freeze_molt5=args.freeze_molt5,
        tune_qformer=args.tune_qformer,
        num_query_token=args.num_query_token,
        embed_dim=args.embed_dim
    )

    model = train_blip2_molt5(
        model=model,
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        batch_size=args.batch_size,
        num_epochs=args.num_epochs,
        learning_rate=args.learning_rate,
        weight_decay=args.weight_decay,
        checkpoint_dir=args.checkpoint_dir
    )

    torch.save(model.state_dict(), os.path.join(args.checkpoint_dir, "final_model.pth"))

if __name__ == "__main__":
    main()

Overwriting blip2_molt5_qformer.py


## data_processor.py

In [24]:
%%writefile data_processor.py
import warnings; warnings.filterwarnings("ignore")

import os
import re
import torch
import numpy as np
import pandas as pd
import scanpy as sc
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

class CellSMILESPair:
    def __init__(self, cell_embedding, smiles, drug_name=None):
        self.cell_embedding = cell_embedding
        self.smiles = smiles
        self.drug_name = drug_name

class CellMolDataset(Dataset):
    def __init__(self, pairs):
        self.pairs = pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
      pair = self.pairs[idx]

      # Check for NaN in SMILES
      if not isinstance(pair.smiles, str):
          # Return a placeholder SMILES string
          smiles_text = "C"  # Simplest valid SMILES
      else:
          smiles_text = pair.smiles

      return {
          "cell_embedding": pair.cell_embedding,
          "smiles_text": smiles_text
      }

def load_and_process_data(adata_dir, smiles_path, embedding_key="X_scGPT"):
    smiles_df = pd.read_csv(smiles_path)

    # Dictionary to collect embeddings per drug
    drug_embeddings = {}
    processed_drugs = set()

    for file in os.listdir(adata_dir):
        if not file.endswith('.h5ad'):
            continue

        try:
            match = re.search(r'([A-Za-z0-9-]+)_block', file)
            if not match:
                continue

            drug_name = match.group(1)

            # Check if we have SMILES for this drug
            drug_smiles_df = smiles_df[smiles_df['drug'] == drug_name]
            if len(drug_smiles_df) == 0:
                continue

            adata_path = os.path.join(adata_dir, file)
            adata = sc.read_h5ad(adata_path)

            if embedding_key not in adata.obsm:
                print(f"Warning: {embedding_key} not found in {file}")
                continue

            # Add embedding to collection for this drug
            embedding = torch.tensor(adata.obsm[embedding_key].mean(axis=0), dtype=torch.float32)
            if drug_name not in drug_embeddings:
                drug_embeddings[drug_name] = []
            drug_embeddings[drug_name].append(embedding)

            print(f"Processed {drug_name} from {file}")

        except Exception as e:
            print(f"Error processing {file}: {e}")

    # Create pairs using average embedding per drug
    pairs = []
    for drug_name, embeddings in drug_embeddings.items():
        if len(embeddings) > 0:
            # Average all embeddings for this drug
            avg_embedding = torch.stack(embeddings).mean(dim=0)

            # Get SMILES for this drug
            drug_smiles = smiles_df[smiles_df['drug'] == drug_name]['SMILES'].iloc[0]

            pairs.append(CellSMILESPair(
                cell_embedding=avg_embedding,
                smiles=drug_smiles,
                drug_name=drug_name
            ))

            processed_drugs.add(drug_name)

    print(f"Created a total of {len(pairs)} cell-SMILES pairs for {len(processed_drugs)} drugs")
    return pairs

def create_datasets(pairs, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15, seed=42):
    assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6

    drugs = {}
    for pair in pairs:
        if pair.drug_name not in drugs:
            drugs[pair.drug_name] = []
        drugs[pair.drug_name].append(pair)

    train_pairs, val_pairs, test_pairs = [], [], []

    drug_names = list(drugs.keys())
    np.random.seed(seed)
    np.random.shuffle(drug_names)

    n_train = int(len(drug_names) * train_ratio)
    n_val = int(len(drug_names) * val_ratio)

    train_drugs = drug_names[:n_train]
    val_drugs = drug_names[n_train:n_train+n_val]
    test_drugs = drug_names[n_train+n_val:]

    for drug in train_drugs:
        train_pairs.extend(drugs[drug])

    for drug in val_drugs:
        val_pairs.extend(drugs[drug])

    for drug in test_drugs:
        test_pairs.extend(drugs[drug])

    print(f"Split data: {len(train_pairs)} training, {len(val_pairs)} validation, {len(test_pairs)} test")
    print(f"Using drugs: {len(train_drugs)} training, {len(val_drugs)} validation, {len(test_drugs)} test")

    return CellMolDataset(train_pairs), CellMolDataset(val_pairs), CellMolDataset(test_pairs)

def create_dataloaders(train_dataset, val_dataset, test_dataset, batch_size=32, num_workers=4):
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )

    return train_loader, val_loader, test_loader

def create_negative_pairs(pairs, neg_ratio=1.0, seed=42):
    np.random.seed(seed)
    all_pairs = []

    for pair in pairs:
        all_pairs.append({
            "cell_embedding": pair.cell_embedding,
            "smiles": pair.smiles,
            "drug_name": pair.drug_name,
            "label": 1.0  # Positive pair
        })

    n_negatives = int(len(pairs) * neg_ratio)

    cell_embeddings = [p.cell_embedding for p in pairs]
    smiles_list = [p.smiles for p in pairs]
    drug_names = [p.drug_name for p in pairs]

    indices = np.arange(len(pairs))

    for _ in range(n_negatives):
        idx1, idx2 = np.random.choice(indices, 2, replace=False)

        all_pairs.append({
            "cell_embedding": cell_embeddings[idx1],
            "smiles": smiles_list[idx2],
            "drug_name": f"{drug_names[idx1]}-{drug_names[idx2]}",
            "label": 0.0  # Negative pair
        })

    return all_pairs

def augment_smiles(smiles, augmentation_factor=2):
    from rdkit import Chem
    from rdkit.Chem import AllChem

    augmented_smiles = []
    mol = Chem.MolFromSmiles(smiles)

    if mol is None:
        return [smiles]

    augmented_smiles.append(smiles)

    for _ in range(augmentation_factor - 1):
        atom_order = list(range(mol.GetNumAtoms()))
        np.random.shuffle(atom_order)
        mol_perm = Chem.RenumberAtoms(mol, atom_order)
        smiles_perm = Chem.MolToSmiles(mol_perm, isomericSmiles=True, canonical=False)
        augmented_smiles.append(smiles_perm)

    return augmented_smiles

def augment_data(pairs, augmentation_factor=2):
    augmented_pairs = []

    for pair in pairs:
        cell_embedding = pair.cell_embedding
        smiles_variants = augment_smiles(pair.smiles, augmentation_factor)

        for smiles in smiles_variants:
            augmented_pairs.append(CellSMILESPair(
                cell_embedding=cell_embedding,
                smiles=smiles,
                drug_name=pair.drug_name
            ))

    return augmented_pairs

def save_processed_data(pairs, output_path):
    import pickle
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    with open(output_path, 'wb') as f:
        pickle.dump(pairs, f)
    print(f"Saved {len(pairs)} pairs to {output_path}")

def load_processed_data(input_path):
    import pickle
    with open(input_path, 'rb') as f:
        pairs = pickle.load(f)
    print(f"Loaded {len(pairs)} pairs from {input_path}")
    return pairs

def process_and_save_data(adata_dir, smiles_path, output_dir, augmentation_factor=1):
    pairs = load_and_process_data(adata_dir, smiles_path)

    if augmentation_factor > 1:
        pairs = augment_data(pairs, augmentation_factor)

    os.makedirs(output_dir, exist_ok=True)

    train_dataset, val_dataset, test_dataset = create_datasets(pairs)

    save_processed_data(train_dataset.pairs, os.path.join(output_dir, "train_pairs.pkl"))
    save_processed_data(val_dataset.pairs, os.path.join(output_dir, "val_pairs.pkl"))
    save_processed_data(test_dataset.pairs, os.path.join(output_dir, "test_pairs.pkl"))

    return train_dataset, val_dataset, test_dataset

Overwriting data_processor.py


## training_script.py

In [25]:
%%writefile training_script.py
import warnings; warnings.filterwarnings("ignore")

import os
import torch
import argparse
import numpy as np
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from data_processor import (
    load_processed_data,
    process_and_save_data,
    CellMolDataset,
    create_dataloaders
)
from blip2_molt5_qformer import (
    Blip2QformerMolT5,
    train_blip2_molt5,
    get_similarity
)

def parse_args():
    parser = argparse.ArgumentParser(description="Train BLIP2-MolT5-QFormer model")

    parser.add_argument("--processed_data_dir", type=str, default="processed_data",
                        help="Directory containing the processed data pickle files")

    parser.add_argument("--bert_name", type=str, default="bert-base-uncased",
                        help="BERT model name for QFormer")
    parser.add_argument("--molt5_model", type=str, default="laituan245/molt5-base",
                        help="MolT5 model name")
    parser.add_argument("--embedding_dim", type=int, default=512,
                        help="Cell embedding dimension")
    parser.add_argument("--num_query_token", type=int, default=32,
                        help="Number of query tokens for QFormer")
    parser.add_argument("--cross_attention_freq", type=int, default=2,
                        help="Cross attention frequency")
    parser.add_argument("--temperature", type=float, default=0.05,
                        help="Temperature for contrastive loss")

    parser.add_argument("--freeze_molt5", action="store_true",
                        help="Freeze MolT5 parameters")
    parser.add_argument("--tune_qformer", action="store_true",
                        help="Tune QFormer parameters")

    parser.add_argument("--batch_size", type=int, default=8,
                        help="Batch size")
    parser.add_argument("--num_epochs", type=int, default=20,
                        help="Number of training epochs")
    parser.add_argument("--learning_rate", type=float, default=1e-5,
                        help="Learning rate")
    parser.add_argument("--weight_decay", type=float, default=0.02,
                        help="Weight decay")
    parser.add_argument("--warmup_steps", type=int, default=100,
                        help="Warmup steps")

    parser.add_argument("--checkpoint_dir", type=str, default="checkpoints",
                        help="Directory for checkpoints")
    parser.add_argument("--checkpoint_interval", type=int, default=1,
                        help="Checkpoint interval in epochs")

    parser.add_argument("--eval_only", action="store_true",
                        help="Run evaluation only")
    parser.add_argument("--eval_checkpoint", type=str, default=None,
                        help="Checkpoint to use for evaluation")

    parser.add_argument("--seed", type=int, default=42,
                        help="Random seed")

    parser.add_argument("--num_workers", type=int, default=4,
                        help="Number of dataloader workers")

    return parser.parse_args()

def set_seed(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def create_datasets_from_processed_data(processed_data_dir):
    train_pairs = load_processed_data(os.path.join(processed_data_dir, "train_pairs.pkl"))
    val_pairs = load_processed_data(os.path.join(processed_data_dir, "val_pairs.pkl"))
    test_pairs = load_processed_data(os.path.join(processed_data_dir, "test_pairs.pkl"))

    train_dataset = CellMolDataset(train_pairs)
    val_dataset = CellMolDataset(val_pairs)
    test_dataset = CellMolDataset(test_pairs)

    return train_dataset, val_dataset, test_dataset

def plot_training_curves(train_losses, val_losses, checkpoint_dir):
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(checkpoint_dir, 'training_curves.png'))
    plt.close()

def evaluate_model(model, test_loader, device):
    model.eval()
    total_loss = 0.0
    num_batches = 0

    with torch.no_grad():
        for batch in test_loader:
            cell_embeddings = batch["cell_embedding"].to(device)
            smiles_texts = batch["smiles_text"]

            # Filter out invalid SMILES
            valid_indices = [i for i, s in enumerate(smiles_texts) if isinstance(s, str)]
            if len(valid_indices) == 0:
                continue

            cell_embeddings = cell_embeddings[valid_indices]
            smiles_texts = [smiles_texts[i] for i in valid_indices]

            # Check for NaN values
            if torch.isnan(cell_embeddings).any():
                cell_embeddings = torch.nan_to_num(cell_embeddings, nan=0.0)

            output = model(cell_embeddings, smiles_texts)
            loss = output.loss

            if not torch.isnan(loss):
                total_loss += loss.item()
                num_batches += 1

    if num_batches == 0:
        return float('inf')

    avg_loss = total_loss / num_batches
    return avg_loss

def calculate_retrieval_metrics(model, test_dataset, k_values=(1, 5, 10), device=None):
    if device is None:
        device = next(model.parameters()).device

    model.eval()

    cell_embeddings = []
    smiles_texts = []
    drug_names = []

    for pair in test_dataset.pairs:
        if not isinstance(pair.smiles, str):
            continue
        cell_embeddings.append(pair.cell_embedding)
        smiles_texts.append(pair.smiles)
        drug_name = getattr(pair, 'drug_name', None)
        drug_names.append(drug_name if drug_name is not None else f"drug_{len(drug_names)}")

    if len(cell_embeddings) == 0:
        print("No valid test samples found for retrieval metrics")
        return {f"R@{k}": 0.0 for k in k_values}

    metrics = {f"R@{k}": 0.0 for k in k_values}
    metrics["MRR"] = 0.0

    n_samples = len(cell_embeddings)

    for i in range(n_samples):
        query_embedding = cell_embeddings[i].to(device)
        query_drug = drug_names[i]

        batch_size = 32
        all_scores = []

        for j in range(0, n_samples, batch_size):
            batch_smiles = smiles_texts[j:j+batch_size]
            batch_scores = model.predict_similarity(query_embedding, batch_smiles)
            all_scores.extend(batch_scores)

        all_scores = torch.tensor(all_scores)
        sorted_indices = torch.argsort(all_scores, descending=True)

        target_positions = []
        for idx in range(n_samples):
            if drug_names[idx] == query_drug and idx != i:
                target_positions.append(idx)

        if not target_positions:
            continue

        retrieved_positions = []
        for target_pos in target_positions:
            retrieved_idx = (sorted_indices == target_pos).nonzero(as_tuple=True)[0].item()
            retrieved_positions.append(retrieved_idx)

        min_rank = min(retrieved_positions) + 1  # +1 as ranks start from 1

        metrics["MRR"] += 1.0 / min_rank

        for k in k_values:
            if min_rank <= k:
                metrics[f"R@{k}"] += 1.0

    for metric in metrics:
        metrics[metric] /= n_samples

    return metrics

def main():
    args = parse_args()
    set_seed(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Load datasets directly from pickle files
    train_dataset, val_dataset, test_dataset = create_datasets_from_processed_data(args.processed_data_dir)

    train_loader, val_loader, test_loader = create_dataloaders(
        train_dataset, val_dataset, test_dataset,
        batch_size=args.batch_size,
        num_workers=args.num_workers
    )

    if not args.eval_only:
        model = Blip2QformerMolT5(
            bert_name=args.bert_name,
            temperature=args.temperature,
            freeze_molt5=args.freeze_molt5,
            molt5_model_name=args.molt5_model,
            tune_qformer=args.tune_qformer,
            num_query_token=args.num_query_token,
            cross_attention_freq=args.cross_attention_freq,
            embed_dim=args.embedding_dim
        ).to(device)

        optimizer = torch.optim.AdamW(
            [p for p in model.parameters() if p.requires_grad],
            lr=args.learning_rate,
            weight_decay=args.weight_decay
        )

        # Create scheduler
        total_steps = len(train_loader) * args.num_epochs
        warmup_steps = min(args.warmup_steps, total_steps // 10)

        from torch.optim.lr_scheduler import LambdaLR

        def lr_lambda(current_step):
            if current_step < warmup_steps:
                return float(current_step) / float(max(1, warmup_steps))
            return max(0.0, float(total_steps - current_step) / float(max(1, total_steps - warmup_steps)))

        scheduler = LambdaLR(optimizer, lr_lambda)

        # Make sure checkpoint directory exists
        os.makedirs(args.checkpoint_dir, exist_ok=True)

        train_losses = []
        val_losses = []
        best_val_loss = float('inf')

        # Training loop
        for epoch in range(args.num_epochs):
            print(f"Epoch {epoch+1}/{args.num_epochs}")
            model.train()
            train_loss = 0.0
            valid_batches = 0

            # Use tqdm for progress bar
            for batch_idx, batch in enumerate(tqdm(train_loader, desc=f"Training epoch {epoch+1}")):
                cell_embeddings = batch["cell_embedding"].to(device)
                smiles_texts = batch["smiles_text"]

                # Filter out invalid SMILES
                valid_indices = [i for i, s in enumerate(smiles_texts) if isinstance(s, str)]
                if len(valid_indices) < 2:  # Need at least 2 samples for contrastive loss
                    print(f"Skipping batch {batch_idx}: not enough valid SMILES")
                    continue

                cell_embeddings = cell_embeddings[valid_indices]
                smiles_texts = [smiles_texts[i] for i in valid_indices]

                # Check for NaN values
                if torch.isnan(cell_embeddings).any():
                    cell_embeddings = torch.nan_to_num(cell_embeddings, nan=0.0)

                optimizer.zero_grad()

                output = model(cell_embeddings, smiles_texts)
                loss = output.loss

                if not torch.isnan(loss):
                    loss.backward()
                    # Clip gradients
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                    optimizer.step()
                    scheduler.step()

                    train_loss += loss.item()
                    valid_batches += 1
                else:
                    print(f"Warning: NaN loss in batch {batch_idx}, skipping")

            # Calculate average training loss
            avg_train_loss = train_loss / valid_batches if valid_batches > 0 else float('inf')
            train_losses.append(avg_train_loss)
            print(f"Train Loss: {avg_train_loss:.6f}")

            # Validation
            val_loss = evaluate_model(model, val_loader, device)
            val_losses.append(val_loss)
            print(f"Validation Loss: {val_loss:.6f}")

            # Save model if it's the best so far
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save({
                    "model": model.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "epoch": epoch,
                    "train_loss": avg_train_loss,
                    "val_loss": val_loss
                }, os.path.join(args.checkpoint_dir, "best_model.pth"))
                print(f"Saved new best model with validation loss: {val_loss:.6f}")

            # Regular checkpoint saving
            if (epoch + 1) % args.checkpoint_interval == 0:
                torch.save({
                    "model": model.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "epoch": epoch,
                    "train_loss": avg_train_loss,
                    "val_loss": val_loss
                }, os.path.join(args.checkpoint_dir, f"checkpoint_epoch_{epoch+1}.pth"))

        # Plot and save learning curves
        plot_training_curves(train_losses, val_losses, args.checkpoint_dir)

        # Save final model
        torch.save({
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "epoch": args.num_epochs - 1,
            "train_loss": train_losses[-1],
            "val_loss": val_losses[-1]
        }, os.path.join(args.checkpoint_dir, "final_model.pth"))

        # Final evaluation
        print("Evaluating on test set...")
        test_loss = evaluate_model(model, test_loader, device)
        print(f"Test Loss: {test_loss:.6f}")

        # Calculate retrieval metrics
        print("Calculating retrieval metrics...")
        retrieval_metrics = calculate_retrieval_metrics(model, test_dataset, device=device)
        for metric, value in retrieval_metrics.items():
            print(f"{metric}: {value:.4f}")

    else:
        # Evaluation only mode
        if args.eval_checkpoint is None:
            args.eval_checkpoint = os.path.join(args.checkpoint_dir, "best_model.pth")
            if not os.path.exists(args.eval_checkpoint):
                raise ValueError(f"No checkpoint found at {args.eval_checkpoint}")

        print(f"Loading model from {args.eval_checkpoint} for evaluation")
        model = Blip2QformerMolT5(
            bert_name=args.bert_name,
            temperature=args.temperature,
            freeze_molt5=args.freeze_molt5,
            molt5_model_name=args.molt5_model,
            tune_qformer=args.tune_qformer,
            num_query_token=args.num_query_token,
            cross_attention_freq=args.cross_attention_freq,
            embed_dim=args.embedding_dim
        ).to(device)

        checkpoint = torch.load(args.eval_checkpoint, map_location=device)
        model.load_state_dict(checkpoint["model"])

        print("Evaluating on test set...")
        test_loss = evaluate_model(model, test_loader, device)
        print(f"Test Loss: {test_loss:.6f}")

        print("Calculating retrieval metrics...")
        retrieval_metrics = calculate_retrieval_metrics(model, test_dataset, device=device)
        for metric, value in retrieval_metrics.

if __name__ == "__main__":
    main()

Overwriting training_script.py


## inference_script.py

In [26]:
%%writefile inference_script.py

import os
import torch
import argparse
import pandas as pd
import numpy as np
import scanpy as sc
from tqdm import tqdm
from blip2_molt5_qformer import Blip2QformerMolT5

def parse_args():
    parser = argparse.ArgumentParser(description="Run inference with BLIP2-MolT5-QFormer model")

    parser.add_argument("--model_checkpoint", type=str, required=True,
                        help="Path to model checkpoint")
    parser.add_argument("--bert_name", type=str, default="bert-base-uncased",
                        help="BERT model name for QFormer")
    parser.add_argument("--molt5_model", type=str, default="laituan245/molt5-base",
                        help="MolT5 model name")
    parser.add_argument("--num_query_token", type=int, default=32,
                        help="Number of query tokens for QFormer")
    parser.add_argument("--cross_attention_freq", type=int, default=2,
                        help="Cross attention frequency")
    parser.add_argument("--embedding_dim", type=int, default=768,
                        help="Cell embedding dimension")

    parser.add_argument("--cell_embedding_file", type=str, required=True,
                        help="Path to AnnData file with cell embeddings")
    parser.add_argument("--smiles_library", type=str, required=True,
                        help="Path to CSV file with SMILES library")
    parser.add_argument("--embedding_key", type=str, default="X_scGPT",
                        help="Key in AnnData.obsm for cell embeddings")

    parser.add_argument("--output_dir", type=str, default="inference_results",
                        help="Directory for inference results")
    parser.add_argument("--top_k", type=int, default=10,
                        help="Number of top predictions to save")

    return parser.parse_args()

def load_cell_embedding(adata_path, embedding_key="X_scGPT"):
    adata = sc.read_h5ad(adata_path)

    if embedding_key not in adata.obsm:
        raise ValueError(f"{embedding_key} not found in AnnData object")

    cell_embedding = torch.tensor(adata.obsm[embedding_key].mean(axis=0), dtype=torch.float32)
    return cell_embedding

def load_smiles_library(smiles_path):
    df = pd.read_csv(smiles_path)

    smiles_list = []
    drug_names = []

    for _, row in df.iterrows():
        if 'SMILES' in df.columns and 'drug' in df.columns:
            smiles_list.append(row['SMILES'])
            drug_names.append(row['drug'])
        elif 'smiles' in df.columns and 'name' in df.columns:
            smiles_list.append(row['smiles'])
            drug_names.append(row['name'])
        else:
            raise ValueError("CSV should have columns 'SMILES'/'smiles' and 'drug'/'name'")

    return smiles_list, drug_names

def batch_predict_similarity(model, cell_embedding, smiles_list, batch_size=32, device=None):
    if device is None:
        device = next(model.parameters()).device

    if cell_embedding.dim() == 1:
        cell_embedding = cell_embedding.unsqueeze(0)

    cell_embedding = cell_embedding.to(device)

    all_scores = []

    for i in tqdm(range(0, len(smiles_list), batch_size), desc="Computing similarities"):
        batch_smiles = smiles_list[i:i+batch_size]
        batch_scores = model.predict_similarity(cell_embedding, batch_smiles)
        all_scores.extend(batch_scores)

    return all_scores

def main():
    args = parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    model = Blip2QformerMolT5(
        bert_name=args.bert_name,
        temperature=0.05,  # Not important for inference
        freeze_molt5=True,
        molt5_model_name=args.molt5_model,
        tune_qformer=False,
        num_query_token=args.num_query_token,
        cross_attention_freq=args.cross_attention_freq,
        embed_dim=args.embedding_dim
    ).to(device)

    print(f"Loading model from {args.model_checkpoint}")
    checkpoint = torch.load(args.model_checkpoint, map_location=device)

    if "model" in checkpoint:
        model.load_state_dict(checkpoint["model"])
    else:
        model.load_state_dict(checkpoint)

    model.eval()

    print(f"Loading cell embedding from {args.cell_embedding_file}")
    cell_embedding = load_cell_embedding(args.cell_embedding_file, args.embedding_key)

    print(f"Loading SMILES library from {args.smiles_library}")
    smiles_list, drug_names = load_smiles_library(args.smiles_library)

    print(f"Computing similarity scores for {len(smiles_list)} molecules")
    similarity_scores = batch_predict_similarity(model, cell_embedding, smiles_list, device=device)

    results = []
    for i, score in enumerate(similarity_scores):
        results.append({
            "drug_name": drug_names[i],
            "smiles": smiles_list[i],
            "similarity_score": score
        })

    results_df = pd.DataFrame(results)
    results_df = results_df.sort_values("similarity_score", ascending=False)

    os.makedirs(args.output_dir, exist_ok=True)
    results_path = os.path.join(args.output_dir, "similarity_scores.csv")
    results_df.to_csv(results_path, index=False)

    top_k_path = os.path.join(args.output_dir, f"top_{args.top_k}_predictions.csv")
    results_df.head(args.top_k).to_csv(top_k_path, index=False)

    print(f"Results saved to {results_path}")
    print(f"Top {args.top_k} predictions saved to {top_k_path}")

    print("\nTop 10 predictions:")
    for i, row in results_df.head(10).iterrows():
        print(f"{i+1}. {row['drug_name']} (Score: {row['similarity_score']:.4f})")

if __name__ == "__main__":
    main()

Overwriting inference_script.py


## run

In [None]:
# !ls "/content/drive/MyDrive/Colab Notebooks/esm cell state/difference counts/"

In [27]:
# 1. Process and prepare the data
!python training_script.py --adata_dir "/content/drive/MyDrive/Colab Notebooks/esm cell state/difference counts" --smiles_path "/content/drive/MyDrive/Colab Notebooks/esm cell state/smiles_df.csv" --processed_data_dir  "/content/drive/MyDrive/Colab Notebooks/esm cell state/processed_data" --reprocess_data


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Processed c-Kit-IN-1 from difference_c-Kit-IN-1_block98.h5ad
Processed c-Kit-IN-1 from difference_c-Kit-IN-1_block99.h5ad
Processed c-Kit-IN-1 from difference_c-Kit-IN-1_block100.h5ad
Processed c-Kit-IN-1 from difference_c-Kit-IN-1_block101.h5ad
Processed c-Kit-IN-1 from difference_c-Kit-IN-1_block102.h5ad
Processed c-Kit-IN-1 from difference_c-Kit-IN-1_block103.h5ad
Processed c-Kit-IN-1 from difference_c-Kit-IN-1_block104.h5ad
Processed c-Kit-IN-1 from difference_c-Kit-IN-1_block105.h5ad
Processed c-Kit-IN-1 from difference_c-Kit-IN-1_block106.h5ad
Processed c-Kit-IN-1 from difference_c-Kit-IN-1_block107.h5ad
Processed c-Kit-IN-1 from difference_c-Kit-IN-1_block108.h5ad
Processed c-Kit-IN-1 from difference_c-Kit-IN-1_block109.h5ad
Processed c-Kit-IN-1 from difference_c-Kit-IN-1_block110.h5ad
Processed c-Kit-IN-1 from difference_c-Kit-IN-1_block111.h5ad
Processed c-Kit-IN-1 from difference_c-Kit-IN-1_block112.h5ad
Process

In [21]:

# # 2. Train the model
# !python training_script.py --processed_data_dir processed_data --batch_size 32 --num_epochs 20 --checkpoint_dir checkpoints --embedding_dim 512 --learning_rate 5e-5

In [22]:
import pickle
import torch
import numpy as np

def inspect_pairs_pickle(pickle_path):
    # Load the pickle file
    with open(pickle_path, 'rb') as f:
        pairs = pickle.load(f)

    print(f"Loaded {len(pairs)} pairs from {pickle_path}")

    # Check for data type issues
    string_embeddings = []
    non_string_smiles = []
    nan_embeddings = []

    for i, pair in enumerate(pairs):
        # Check embedding type
        if isinstance(pair.cell_embedding, str):
            string_embeddings.append(i)

        # Check for NaN values in embedding
        if torch.is_tensor(pair.cell_embedding) and torch.isnan(pair.cell_embedding).any():
            nan_embeddings.append(i)
        elif isinstance(pair.cell_embedding, np.ndarray) and np.isnan(pair.cell_embedding).any():
            nan_embeddings.append(i)

        # Check SMILES type
        if not isinstance(pair.smiles, str):
            non_string_smiles.append(i)

    # Print summary
    print(f"Issues found:")
    print(f"- String embeddings: {len(string_embeddings)} indices")
    if string_embeddings:
        print(f"  First few indices: {string_embeddings[:5]}")

    print(f"- Non-string SMILES: {len(non_string_smiles)} indices")
    if non_string_smiles:
        print(f"  First few indices: {non_string_smiles[:5]}")
        print(f"  Examples: {[pairs[i].smiles for i in non_string_smiles[:3]]}")

    print(f"- NaN embeddings: {len(nan_embeddings)} indices")
    if nan_embeddings:
        print(f"  First few indices: {nan_embeddings[:5]}")

    # Show some example pairs
    print("\nSample data:")
    for i in range(min(3, len(pairs))):
        print(f"Pair {i}:")
        print(f"  Cell embedding type: {type(pairs[i].cell_embedding)}")
        if torch.is_tensor(pairs[i].cell_embedding):
            print(f"  Cell embedding shape: {pairs[i].cell_embedding.shape}")
            print(f"  Cell embedding dtype: {pairs[i].cell_embedding.dtype}")
        print(f"  SMILES type: {type(pairs[i].smiles)}")
        print(f"  SMILES: {pairs[i].smiles[:50]}..." if len(str(pairs[i].smiles)) > 50 else pairs[i].smiles)
        print()

# Example usage
inspect_pairs_pickle("processed_data/train_pairs.pkl")
inspect_pairs_pickle("processed_data/val_pairs.pkl")
inspect_pairs_pickle("processed_data/test_pairs.pkl")

Loaded 56 pairs from processed_data/train_pairs.pkl
Issues found:
- String embeddings: 0 indices
- Non-string SMILES: 1 indices
  First few indices: [47]
  Examples: [nan]
- NaN embeddings: 0 indices

Sample data:
Pair 0:
  Cell embedding type: <class 'torch.Tensor'>
  Cell embedding shape: torch.Size([512])
  Cell embedding dtype: torch.float32
  SMILES type: <class 'str'>
  SMILES: C[C@@H]1CN(C[C@@H](O1)C)C2=NC=C(C=C2)NC(=O)C3=CC=C...

Pair 1:
  Cell embedding type: <class 'torch.Tensor'>
  Cell embedding shape: torch.Size([512])
  Cell embedding dtype: torch.float32
  SMILES type: <class 'str'>
CC(C)(C#N)C1=CC(=CC(=C1)CN2C=NC=N2)C(C)(C)C#N

Pair 2:
  Cell embedding type: <class 'torch.Tensor'>
  Cell embedding shape: torch.Size([512])
  Cell embedding dtype: torch.float32
  SMILES type: <class 'str'>
  SMILES: COC1=CC(=CC(=C1)C#CC2=NN(C3=NC=NC(=C23)N)[C@H]4CCN...

Loaded 12 pairs from processed_data/val_pairs.pkl
Issues found:
- String embeddings: 0 indices
- Non-string SMILES: 0 in