In [None]:
import multiprocessing as mp
import os
import time
from functools import partial
from typing import Any, List, Optional

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from rdkit import Chem
from rdkit.Chem import AllChem
from torch_geometric.data import (
    Batch,
    Data,
    Dataset,
)
from torch_geometric.loader import DataLoader
from torch_geometric.nn import global_mean_pool
from tqdm import tqdm
from transformers import EsmModel, EsmTokenizer
try:
    import rdkit.Chem as Chem
    import rdkit.Chem.AllChem as AllChem
except ImportError:
    print("Warning: RDKit not found. Please install it (`pip install rdkit-pypi`)")
    exit()

import numpy as np
from rdkit import rdBase

rdBase.DisableLog("rdApp.warning")
os.environ["CUDA_VISIBLE_DEVICES"] = "1" 
os.makedirs("esm_save", exist_ok=True)
os.makedirs("process", exist_ok=True)
import torch
# 示例
print(torch.cuda.device_count())
print(torch.cuda.get_device_name(0))
if torch.cuda.is_available():
    for i in range(torch.cuda.device_count()):
        total = torch.cuda.get_device_properties(i).total_memory / 1024**3  # GB
        reserved = torch.cuda.memory_reserved(i) / 1024**3
        allocated = torch.cuda.memory_allocated(i) / 1024**3
        print(f"GPU {i}: Total: {total:.2f} GB | Reserved: {reserved:.2f} GB | Allocated: {allocated:.2f} GB")
else:
    print("No GPU available.")

In [None]:
def load_esm_model(model_name, device):
    """Loads the ESM tokenizer and model."""
    print(f"Loading ESM tokenizer for {model_name}...")
    tokenizer = EsmTokenizer.from_pretrained(model_name)
    print(f"Loading ESM model {model_name}...")
    model = EsmModel.from_pretrained(model_name).to(device)
    model.eval()
    print(f"ESM Model loaded on {device}.")
    return tokenizer, model


# --- Function to Get ESM Embeddings ---
def get_esm_embedding(sequence: str, model, tokenizer, device):
    """Generates per-residue ESM embeddings for a given sequence."""
    if not isinstance(sequence, str) or not sequence:
        # print("Warning: Skipping empty or non-string sequence for ESM.")
        return None
    try:
        inputs = tokenizer(sequence, return_tensors="pt", add_special_tokens=True)
        inputs = {key: tensor.to(device) for key, tensor in inputs.items()}
        with torch.no_grad():
            outputs = model(**inputs, output_hidden_states=True)
        last_hidden_states = outputs.hidden_states[-1]
        embeddings_all_tokens = last_hidden_states.squeeze(0)
        embeddings_residues_only = embeddings_all_tokens[1:-1, :]
        if embeddings_residues_only.shape[0] != len(sequence):
            print(
                f"Warning: ESM embedding length mismatch for sequence {sequence[:30]}..."
            )
        return embeddings_residues_only.cpu()  # Return on CPU
    except Exception as e:
        print(f"Error generating ESM embedding for sequence {sequence[:50]}...: {e}")
        # traceback.print_exc() # Uncomment for full traceback
        return None


# --- Function to Compute and Save ESM Embeddings ---
def compute_and_save_esm(csv_path, seq_col, pkl_path, model, tokenizer, device):
    """Computes ESM embeddings for unique sequences in CSV and saves to HDF5."""
    print("\n--- Starting ESM Embedding Computation ---")
    print(f"Loading CSV file: {csv_path}...")
    try:
        df = pd.read_csv(csv_path)
    except Exception as e:
        print(f"Error loading CSV {csv_path}: {e}")
        return
    if seq_col not in df.columns:
        print(f"Error: Sequence column '{seq_col}' not found in CSV.")
        return

    unique_sequences = df[seq_col].dropna().unique()
    print(f"Found {len(unique_sequences)} unique protein sequences.")
    print(f"Saving ESM embeddings to file: {pkl_path}...")

    esm_embedding_dict = {}
    for seq in tqdm(unique_sequences, desc="Computing ESM embeddings"):
        if seq in esm_embedding_dict.keys():
            pass
        else:
            embedding = get_esm_embedding(seq, model, tokenizer, device)
            if embedding is not None:
                esm_embedding_dict[seq] = embedding
            else:
                print(f"Failed to compute embedding for sequence: {seq[:50]}...")

    print("ESM Embeddings computation and saving finished.")
    torch.save(esm_embedding_dict, pkl_path)


# --- NEW: Function to Precompute Fingerprints ---
def precompute_fingerprints(csv_path, smiles_col, radius, n_bits, output_pkl_path):
    """
    Computes Morgan fingerprints for unique SMILES in a CSV and saves them to a pickle file.

    Args:
        csv_path (str): Path to the input CSV file.
        smiles_col (str): Name of the column containing SMILES strings.
        radius (int): Morgan fingerprint radius.
        n_bits (int): Morgan fingerprint size (number of bits).
        output_pkl_path (str): Path to save the output dictionary (.pkl file).
    """
    print("\n--- Starting Fingerprint Precomputation ---")
    print(f"Loading CSV file: {csv_path}...")
    try:
        df = pd.read_csv(csv_path)
    except Exception as e:
        print(f"Error loading CSV {csv_path}: {e}")
        return

    if smiles_col not in df.columns:
        print(f"Error: SMILES column '{smiles_col}' not found in CSV.")
        return

    # Get unique, non-empty SMILES strings
    unique_smiles = df[smiles_col].dropna().unique()
    print(f"Found {len(unique_smiles)} unique SMILES strings.")
    print(f"Calculating MorganFP(radius={radius}, nBits={n_bits})")

    fingerprint_dict = {}
    computed_count = 0
    failed_count = 0

    for smiles in tqdm(unique_smiles, desc="Computing fingerprints"):
        if not isinstance(smiles, str) or not smiles:
            failed_count += 1
            continue
        try:
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                # print(f"Warning: RDKit failed to parse SMILES: {smiles[:50]}...")
                failed_count += 1
                continue

            # Calculate fingerprint
            fp_bitvect = AllChem.GetMorganFingerprintAsBitVect(
                mol, radius=radius, nBits=n_bits
            )
            # Convert to NumPy array (float32 is common for ML)
            fp_array = np.array(fp_bitvect, dtype=np.float32)

            # Store in dictionary (SMILES -> fingerprint array)
            fingerprint_dict[smiles] = fp_array
            computed_count += 1

        except Exception as e:
            print(f"Error computing fingerprint for SMILES {smiles[:50]}...: {e}")
            # traceback.print_exc() # Uncomment for full traceback
            failed_count += 1

    print("Fingerprint computation finished.")
    print(f"Successfully computed: {computed_count}, Failed: {failed_count}")

    # Save the dictionary to a pickle file
    print(f"Saving fingerprint dictionary to: {output_pkl_path}...")
    try:
        torch.save(fingerprint_dict, output_pkl_path)
        print("Fingerprint dictionary saved successfully.")
    except Exception as e:
        print(f"Error saving fingerprint dictionary: {e}")


def precompute_atom_counts(csv_path, smiles_col, output_pkl_path):
    """
    Computes the number of atoms for each unique SMILES string in a CSV
    and saves the results (SMILES -> atom_count) to a pickle file.

    Args:
        csv_path (str): Path to the input CSV file.
        smiles_col (str): Name of the column containing SMILES strings.
        output_pkl_path (str): Path to save the output dictionary (.pkl file).
    """
    print("\n--- Starting Atom Count Precomputation ---")
    print(f"Loading CSV file: {csv_path}...")
    try:
        # Load only the necessary SMILES column to save memory if the CSV is large
        df = pd.read_csv(csv_path, usecols=[smiles_col])
    except FileNotFoundError:
        print(f"Error: Input CSV file not found at {csv_path}")
        return
    except ValueError as e:
        # Handle case where smiles_col is not in the CSV
        print(f"Error reading CSV: {e}. Ensure '{smiles_col}' column exists.")
        return
    except Exception as e:
        print(f"Error loading CSV {csv_path}: {e}")
        return

    # Get unique, non-empty SMILES strings
    try:
        unique_smiles = df[smiles_col].dropna().unique()
        print(f"Found {len(unique_smiles)} unique SMILES strings.")
    except KeyError:
        print(
            f"Error: SMILES column '{smiles_col}' not found after loading. Check column name again."
        )
        return

    atom_count_dict = {}
    computed_count = 0
    failed_count = 0

    for smiles in tqdm(unique_smiles, desc="Computing atom counts"):
        if not isinstance(smiles, str) or not smiles:
            # Skip non-string or empty entries silently or add a counter
            failed_count += 1
            continue
        try:
            # Create molecule object from SMILES
            mol = Chem.MolFromSmiles(smiles)

            if mol is None:
                # print(f"Warning: RDKit failed to parse SMILES: {smiles[:50]}...")
                failed_count += 1
                continue

            # Get the number of atoms
            num_atoms = mol.GetNumAtoms()

            # Store in dictionary (SMILES -> atom_count)
            atom_count_dict[smiles] = num_atoms
            computed_count += 1

        except Exception as e:
            # Catch potential errors during RDKit processing for a specific SMILES
            print(f"Error computing atom count for SMILES {smiles[:50]}...: {e}")
            # traceback.print_exc() # Uncomment for full traceback if needed
            failed_count += 1

    print("Atom count computation finished.")
    print(
        f"Successfully computed: {computed_count}, Failed to parse/process: {failed_count}"
    )

    # Save the dictionary to a pickle file
    print(f"Saving atom count dictionary to: {output_pkl_path}...")
    try:
        # Ensure the output directory exists
        torch.save(atom_count_dict, output_pkl_path)
        print("Atom count dictionary saved successfully.")
    except Exception as e:
        print(f"Error saving atom count dictionary: {e}")


def precompute_atom_types(csv_path, smiles_col, output_pkl_path):
    """
    Computes the list of atom types (atomic numbers) for each unique SMILES string
    in a CSV and saves the results (SMILES -> list_of_atom_types) to a pickle file.

    Args:
        csv_path (str): Path to the input CSV file.
        smiles_col (str): Name of the column containing SMILES strings.
        output_pkl_path (str): Path to save the output dictionary (.pkl file).
    """
    print("\n--- Starting Atom Types Precomputation ---")
    print(f"Loading CSV file: {csv_path}...")
    try:
        # Load only the necessary SMILES column to save memory if the CSV is large
        df = pd.read_csv(csv_path, usecols=[smiles_col])
    except FileNotFoundError:
        print(f"Error: Input CSV file not found at {csv_path}")
        return
    except ValueError as e:
        # Handle case where smiles_col is not in the CSV
        print(f"Error reading CSV: {e}. Ensure '{smiles_col}' column exists.")
        return
    except Exception as e:
        print(f"Error loading CSV {csv_path}: {e}")
        return

    # Get unique, non-empty SMILES strings
    try:
        unique_smiles = df[smiles_col].dropna().unique()
        print(f"Found {len(unique_smiles)} unique SMILES strings.")
    except KeyError:
        print(
            f"Error: SMILES column '{smiles_col}' not found after loading. Check column name again."
        )
        return

    atom_types_dict = {}  # Renamed dictionary
    computed_count = 0
    failed_count = 0

    for smiles in tqdm(unique_smiles, desc="Computing atom types"):
        if not isinstance(smiles, str) or not smiles:
            # Skip non-string or empty entries silently or add a counter
            failed_count += 1
            continue
        try:
            # Create molecule object from SMILES
            mol = Chem.MolFromSmiles(smiles)

            if mol is None:
                # print(f"Warning: RDKit failed to parse SMILES: {smiles[:50]}...")
                failed_count += 1
                continue

            # --- MODIFIED PART: Get list of atomic numbers ---
            atom_types = [atom.GetAtomicNum() for atom in mol.GetAtoms()]
            # --- End of MODIFIED PART ---

            # Store in dictionary (SMILES -> list of atom types)
            atom_types_dict[smiles] = atom_types
            computed_count += 1

        except Exception as e:
            # Catch potential errors during RDKit processing for a specific SMILES
            print(f"Error computing atom types for SMILES {smiles[:50]}...: {e}")
            # traceback.print_exc() # Uncomment for full traceback if needed
            failed_count += 1

    print("Atom type computation finished.")
    print(
        f"Successfully computed: {computed_count}, Failed to parse/process: {failed_count}"
    )

    # Save the dictionary to a pickle file
    print(f"Saving atom types dictionary to: {output_pkl_path}...")
    try:
        # Ensure the output directory exists
        torch.save(atom_types_dict, output_pkl_path)
        print("Atom types dictionary saved successfully.")
    except Exception as e:
        print(f"Error saving atom types dictionary: {e}")


# --- Worker Function (must be defined at top-level or be picklable) ---
def _compute_single_fingerprint(smiles, radius, n_bits):
    """
    Computes the Morgan fingerprint for a single SMILES string.

    Args:
        smiles (str): The SMILES string.
        radius (int): Morgan fingerprint radius.
        n_bits (int): Morgan fingerprint size.

    Returns:
        tuple: (smiles, np.array or None) - The original SMILES and its
               computed fingerprint as a float32 numpy array, or None if failed.
    """
    if not isinstance(smiles, str) or not smiles:
        return smiles, None  # Return smiles to track failures
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            # Optionally log or handle SMILES parsing errors more explicitly here
            # print(f"Warning: RDKit failed to parse SMILES: {smiles[:50]}...")
            return smiles, None

        # Calculate fingerprint
        fp_bitvect = AllChem.GetMorganFingerprintAsBitVect(
            mol, radius=radius, nBits=n_bits
        )
        # Convert to NumPy array (float32 is common for ML)
        fp_array = np.array(fp_bitvect, dtype=np.float32)
        return smiles, fp_array

    except Exception:
        # Catch errors during fingerprint calculation for a specific SMILES
        # print(f"Error computing fingerprint for SMILES {smiles[:50]}...: {e}")
        # traceback.print_exc() # Uncomment for detailed traceback from worker
        return smiles, None


# --- Main Function using Multiprocessing ---
def precompute_fingerprints_mp(
    csv_path,
    smiles_col,
    radius,
    n_bits,
    output_pkl_path,
    num_workers=None,
    chunksize=100,
):
    """
    Computes Morgan fingerprints in parallel for unique SMILES in a CSV
    and saves them to a pickle file using multiprocessing.

    Args:
        csv_path (str): Path to the input CSV file.
        smiles_col (str): Name of the column containing SMILES strings.
        radius (int): Morgan fingerprint radius.
        n_bits (int): Morgan fingerprint size (number of bits).
        output_pkl_path (str): Path to save the output dictionary (.pkl file).
        num_workers (int, optional): Number of worker processes.
                                     Defaults to os.cpu_count() - 1.
        chunksize (int, optional): Number of tasks to send to each worker at once.
                                   Adjusting might impact performance. Defaults to 100.
    """
    print("\n--- Starting Fingerprint Precomputation (Multiprocessing) ---")
    start_time = time.time()

    print(f"Loading CSV file: {csv_path}...")
    try:
        df = pd.read_csv(csv_path)  # , usecols=[smiles_col])
    except FileNotFoundError:
        print(f"Error: Input CSV file not found at {csv_path}")
        return
    except Exception as e:
        print(f"Error loading CSV {csv_path}: {e}")
        return

    if smiles_col not in df.columns:
        print(f"Error: SMILES column '{smiles_col}' not found in CSV.")
        return

    # Get unique, non-empty SMILES strings
    unique_smiles = df[smiles_col].dropna().unique()
    total_unique = len(unique_smiles)
    print(f"Found {total_unique} unique SMILES strings.")

    if total_unique == 0:
        print("No unique SMILES strings found to process.")
        return

    print(f"Calculating MorganFP(radius={radius}, nBits={n_bits})")

    # Determine number of workers
    if num_workers is None:
        num_workers = os.cpu_count()
        if num_workers > 1:
            num_workers -= 1  # Leave one core for the main process and OS
        print(f"Defaulting to {num_workers} worker processes.")
    else:
        num_workers = max(1, num_workers)
        print(f"Using specified {num_workers} worker processes.")

    worker_func = partial(_compute_single_fingerprint, radius=radius, n_bits=n_bits)

    fingerprint_dict = {}
    computed_count = 0
    failed_count = 0

    print(f"Starting parallel computation with chunksize={chunksize}...")
    try:
        with mp.Pool(processes=num_workers) as pool:
            results_iterator = pool.imap_unordered(
                worker_func, unique_smiles, chunksize=chunksize
            )
            for smiles_key, fp_result in tqdm(
                results_iterator, total=total_unique, desc="Computing fingerprints"
            ):
                if fp_result is not None:
                    fingerprint_dict[smiles_key] = fp_result
                    computed_count += 1
                else:
                    failed_count += 1
    except Exception as e:
        print(f"\nAn error occurred during multiprocessing: {e}")
        return

    computation_time = time.time() - start_time
    print(f"\nFingerprint computation finished in {computation_time:.2f} seconds.")
    print(f"Successfully computed: {computed_count}, Failed: {failed_count}")

    # Save the dictionary to a pickle file using torch.save
    if computed_count > 0:
        print(
            f"Saving fingerprint dictionary ({computed_count} entries) to: {output_pkl_path}..."
        )
        save_start_time = time.time()
        try:
            # Using torch.save as in the original function
            torch.save(fingerprint_dict, output_pkl_path)
            save_time = time.time() - save_start_time
            print(
                f"Fingerprint dictionary saved successfully in {save_time:.2f} seconds."
            )
        except Exception as e:
            print(f"Error saving fingerprint dictionary: {e}")
    else:
        print("No fingerprints were successfully computed, skipping save.")

    total_time = time.time() - start_time
    print(f"--- Total precomputation time: {total_time:.2f} seconds ---")


def compute_processed_trunk_file(
    original_trunk_file, trunk_file_dir, saved_processed_trunk_file
):
    df = pd.read_csv(original_trunk_file)
    feats_dict = {}
    for path in os.listdir(trunk_file_dir):
        idx = path.split(".")[0].split("_")[-1]
        trunk = torch.load(
            os.path.join(trunk_file_dir, path), weights_only=False, map_location="cpu"
        )
        proseq = df[df.ID == int(idx)].PROTEIN_SEQUENCE.item()
        feats_dict[proseq] = {"s": trunk["s"], "z": trunk["z"].mean(dim=1)}
    torch.save(feats_dict, saved_processed_trunk_file)

In [None]:
compute_esm = True
compute_fingerprint = True
compute_atom_type = True
compute_trunk_file = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


esm_csv_file = fingerprint_csv_file = atomtype_csv_file = "demo_input.csv"

esm_output_file = "esm_save/screen_esm_embeddings.pkl"
fingerprint_output_file = "demo_input_fingerprint.pkl"
atomtype_output_file = "demo_input_atomtype.pkl"

original_trunk_file = "wait2pred.csv"
trunk_file_dir = "feats"
saved_processed_trunk_file = "process/trunk_dict_withseq.pkl"

In [None]:
compute_esm = True
compute_fingerprint = True
compute_atom_type = True
compute_trunk_file = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device", device)


esm_csv_file = fingerprint_csv_file = atomtype_csv_file = "demo_input.csv"

esm_output_file = "esm_save/screen_esm_embeddings.pkl"
fingerprint_output_file = "demo_input_fingerprint.pkl"
atomtype_output_file = "demo_input_atomtype.pkl"

original_trunk_file = "wait2pred.csv"
trunk_file_dir = "feats"
saved_processed_trunk_file = "process/trunk_dict_withseq.pkl"

if compute_esm:
    # --- Option 1: Compute ESM Embeddings ---
    # Uncomment the block below if you need to compute ESM embeddings
    print("\nStarting ESM computation...")
    esm_model_name = "facebook/esm2_t36_3B_UR50D"
    sequence_column = "Protein_Sequence"
    esm_tokenizer, esm_model = load_esm_model(esm_model_name, device)

    compute_and_save_esm(
        csv_path=esm_csv_file,
        seq_col=sequence_column,
        pkl_path=esm_output_file,
        model=esm_model,
        tokenizer=esm_tokenizer,
        device=device,
    )

    print("ESM computation finished.")


if compute_fingerprint:
    # --- Option 2: Precompute Fingerprints ---
    print("\nStarting Fingerprint precomputation...")
   
    smiles_column = "Canonical_SMILES"
    fp_radius = 2
    fp_nbits = 1024
    
    precompute_fingerprints_mp(
        csv_path=fingerprint_csv_file,
        smiles_col=smiles_column,
        radius=fp_radius,
        n_bits=fp_nbits,
        output_pkl_path=fingerprint_output_file,
        num_workers=400,
        chunksize=500,
    )

    print("Fingerprint precomputation finished.")

if compute_atom_type:
    # --- Option 2: Precompute Fingerprints ---
    print("\nStarting Fingerprint precomputation...")
   
    smiles_column = "Canonical_SMILES"
    
    precompute_atom_types(
        csv_path=atomtype_csv_file,
        smiles_col=smiles_column,
        output_pkl_path=atomtype_output_file,
    )

if compute_processed_trunk_file:
    print("\nStarting Trunk File precomputation...")
    compute_processed_trunk_file(
        original_trunk_file, trunk_file_dir, saved_processed_trunk_file
    )

print("\nPrecomputation script finished.")


In [None]:

class LigandProteinInteractionModel(nn.Module):
    def __init__(
        self,
        hidden_size=1024,
        esm_dim=2560,
        # Max atomic num + 1
        num_atom_types=118,
        fp_dim=1024,
        graph_output_dim=1,
        combiner_dropout=0.2,
        predictor_dropout=0.2,
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.scaling_factor = 1e-2

        # --- Encoders (Same as before) ---
        self.atom_encoder = nn.Embedding(num_atom_types, hidden_size)

        self.fp_encoder = nn.Sequential(
            nn.Linear(fp_dim, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
        )

        self.esm_proj = nn.Linear(esm_dim, hidden_size)

        # --- Graph Pooling Method ---
        self.global_combiner = nn.Sequential(
            # Example: Project down slightly
            nn.Linear(hidden_size * 3, hidden_size * 2),
            nn.ReLU(),
            nn.Dropout(p=combiner_dropout),
            # Output unified embedding
            nn.Linear(hidden_size * 2, hidden_size),
        )

        self.trunk_encoder = nn.Linear(hidden_size, hidden_size, bias=False)

        # --- Final Graph Prediction Head ---
        self.graph_predictor = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(p=predictor_dropout),
            nn.Linear(hidden_size // 2, graph_output_dim),
        )

    def forward(self, data: Batch):
        """
        Forward pass optimized for graph-level prediction.
        Includes atom's own embedding in the feature fusion before context layer.
        Uses data.atom_types_batch for ligand atom indexing.
        """
        B = data.num_graphs
        # 1. Encode Ligand Atoms (Node features)
        atom_emb = self.atom_encoder(data.atom_types)  # [N_total_lig, H]

        trunk_s = data.trunk_s
        trunk_s = global_mean_pool(
            trunk_s.mean(1), batch=data["trunk_s_batch"]
        ).unsqueeze(1)
        trunk_s = trunk_s.repeat(1, self.hidden_size)
        trunk_z = data.trunk_z
        trunk_z = global_mean_pool(
            trunk_z.mean((1)), batch=data["trunk_z_batch"]
        ).unsqueeze(1)
        trunk_z = trunk_z.repeat(1, self.hidden_size)
        trunk = trunk_s + trunk_z
        trunk = self.trunk_encoder(trunk)
        trunk = F.sigmoid(trunk) * trunk

        # --- Use atom_types_batch ---
        # 1.5 Pool initial atoms for global context
        ligand_atom_pooled = global_mean_pool(atom_emb, data.atom_types_batch) 

        # 2. Encode Fingerprint (Graph-level feature)
        fp_input = (
            data.fp.squeeze(1)
            if data.fp.dim() == 3 and data.fp.size(1) == 1
            else data.fp
        )
        fp_emb = self.fp_encoder(fp_input)  # [B, H]

        # 3. Encode Protein ESM (Node features pooled to graph level)
        # --- Uses esm_embedding_batch (created because it's in follow_batch) ---
        if hasattr(data, "esm_embedding") and data.esm_embedding.shape[0] > 0: # type: ignore
            if not hasattr(data, "esm_embedding_batch"):
                # This is critical, should be guaranteed by DataLoader's follow_batch
                raise AttributeError(
                    "ESM embedding exists but 'esm_embedding_batch' is missing. Check DataLoader setup."
                )
            esm_pooled = global_mean_pool(data.esm_embedding, data.esm_embedding_batch) # type: ignore
            esm_ctx = self.esm_proj(esm_pooled)  # [B, H]
        else:
            # Default to zeros if no protein info
            # Ensure device/dtype match other tensors for concatenation
            esm_ctx = torch.zeros(
                B,
                self.hidden_size,
                device=fp_emb.device,
                dtype=fp_emb.dtype,
            )

        # esm_ctx = esm_ctx + trunk * self.scaling_factor
        esm_ctx = esm_ctx + trunk * self.scaling_factor
        # 6. Fuse Features (Dense Format)
        combined_global = torch.cat(
            [fp_emb, esm_ctx, ligand_atom_pooled], dim=-1
        )  # Shape: [B, H * 3]
        # Shape: [B, graph_output_dim]
        combined_features = self.global_combiner(combined_global)
        graph_prediction = self.graph_predictor(combined_features)
        return graph_prediction



In [None]:
def collate_skip_none(batch: List[Optional[Any]]) -> Batch:
    """Collate function that filters out None items."""
    batch = [item for item in batch if item is not None]
    if not batch:
        return Batch() 
    try:
        return Batch.from_data_list(batch)
    except Exception as e:
        print(f"Error during collation: {e}. Skipping batch.")
        return Batch()



class InferenceDataset(Dataset):
    def __init__(
        self,
        samples_df: pd.DataFrame,
        esm_embeddings_dict: dict,
        fingerprint_dict: dict,
        atom_type_dict: dict,
        trunk_dict: dict,
        esm_key_col="Protein_Sequence",
    ):
        """
        Dataset for inference using pre-loaded dictionaries.
        """
        super().__init__(root=None)
        self.samples_df = samples_df.reset_index(drop=True)
        self.esm_embeddings_dict = esm_embeddings_dict
        self.fingerprint_dict = fingerprint_dict
        self.atom_type_dict = atom_type_dict
        self.esm_key_col = esm_key_col
        self.trunk_dict = trunk_dict
        # Column names
        self.smiles_col = "Canonical_SMILES"
        self.id_col = "ID"

        # Check for required columns
        required_cols = [
            self.smiles_col,
            self.esm_key_col,
            self.id_col,
        ]
        missing = [c for c in required_cols if c not in self.samples_df.columns]
        if missing:
            raise ValueError(f"Missing required columns in DataFrame: {missing}")

    def len(self):
        return len(self.samples_df)

    def get(self, idx):
        """Gets and processes the idx-th sample using pre-loaded dictionaries."""
        # This method remains largely the same, processing one sample at a time.
        if idx >= len(self.samples_df):
            # This should ideally not happen if sampler is used correctly
            raise IndexError(
                f"Index {idx} out of bounds for dataset length {len(self.samples_df)}"
            )

        row = self.samples_df.iloc[idx]
        smiles = row.get(self.smiles_col, None)
        protein_key = row.get(self.esm_key_col, None)
        sample_id = row.get(self.id_col, f"sample_{idx}")
        # trunk_idx = int(row.get("ID_A", 0))
        # print(trunk_idx)

        # 1. Validate SMILES, Protein Key, Target
        if not isinstance(smiles, str) or not smiles:
            return None
        if not isinstance(protein_key, str) or not protein_key:
            return None

        # 2. Lookup pre-computed Fingerprint
        if smiles not in self.fingerprint_dict:
            return None
        fp_array = self.fingerprint_dict[smiles]
        fp_tensor = torch.from_numpy(fp_array).float().unsqueeze(0)

        # 3. Lookup pre-computed Atom Types
        if smiles not in self.atom_type_dict:
            return None
        atom_types = self.atom_type_dict[smiles]
        atom_types_tensor = torch.tensor(atom_types, dtype=torch.long)
        num_lig_atoms = len(atom_types)
        if num_lig_atoms == 0:
            return None

        # 4. Lookup pre-loaded ESM embedding
        if protein_key not in self.esm_embeddings_dict:
            return None
        esm_embedding = self.esm_embeddings_dict[protein_key]
        if not isinstance(esm_embedding, torch.Tensor) or esm_embedding.shape[0] == 0:
            return None
        num_prot_res = esm_embedding.shape[0]

        trunk_file = self.trunk_dict[protein_key]

        # 6. Create Data object
        data = Data(
            atom_types=atom_types_tensor,
            fp=fp_tensor,
            esm_embedding=esm_embedding.clone(),
            smiles=smiles,
            trunk_s=trunk_file["s"],
            trunk_z=trunk_file["z"],
            num_lig_atoms=num_lig_atoms,
            num_prot_res=num_prot_res,
            sample_name=sample_id,
        )
        return data



In [None]:
BATCH_SIZE = 2048
HIDDEN_SIZE = 1024
ESM_DIM = 2560
FP_DIM = 1024
MAX_ATOMIC_NUM = 118
GRAPH_OUTPUT_DIM = 1
checkpoint_path = "model_load/model_aurofast.pt"
# device = "cuda:1"
result_csv_path = "inference_result.csv"
USE_AMP = True


esm_embeddings_ram = torch.load(
        esm_output_file,
        map_location="cpu",
        weights_only=False,
    )

fingerprint_dict_ram = torch.load(
   fingerprint_output_file,
    map_location="cpu",
    weights_only=False,
)

atom_type_dict_ram = torch.load(
    atomtype_output_file,
    map_location="cpu",
    weights_only=False,
)

trunk_dict = torch.load(
    saved_processed_trunk_file ,
    map_location="cpu",
    weights_only=False,
)

inference_df = pd.read_csv("demo_input.csv")
esm_key_col = "Protein_Sequence"
smiles_col = "Canonical_SMILES" 





In [None]:
inference_dataset = InferenceDataset(
    samples_df=inference_df,
    esm_embeddings_dict=esm_embeddings_ram,
    fingerprint_dict=fingerprint_dict_ram,
    atom_type_dict=atom_type_dict_ram,
    esm_key_col=esm_key_col,
    trunk_dict=trunk_dict,
)

print(f"Dataset initialized with {len(inference_dataset)} samples.")

loader = DataLoader(
    inference_dataset,
    batch_size=BATCH_SIZE,  # This is now per-GPU batch size
    collate_fn=collate_skip_none,
    follow_batch=["esm_embedding", "atom_types", "trunk_s", "trunk_z"],
)

print(f"DataLoader setup complete with Batch Size: {BATCH_SIZE}.")

model = LigandProteinInteractionModel(
    hidden_size=HIDDEN_SIZE,
    esm_dim=ESM_DIM,
    fp_dim=FP_DIM,
    num_atom_types=MAX_ATOMIC_NUM + 1,
    graph_output_dim=GRAPH_OUTPUT_DIM,
)

model = model.to(device)
print("Loading mock pretrained state_dict for 'existing_layers'...")
pretrained_state_dict = torch.load(checkpoint_path, weights_only=False)[
    "model_state_dict"
]
model.load_state_dict(pretrained_state_dict, strict=False)
model.eval()

local_predictions = []
local_sample_names = []
local_smiles = []

progress_bar = tqdm(loader, desc="Inferring", unit="batch", leave=True)

with torch.no_grad():
    for batch in progress_bar:
        if (
            not isinstance(batch, Batch)
            or not hasattr(batch, "num_graphs")
            or batch.num_graphs == 0
        ):
            print("Skipping empty or invalid batch.")
            continue

        try:
            batch = batch.to(device)
        except Exception as e:
            print(f"Error moving batch to {device}: {e}. Skipping batch.")
            continue

        try:
            with torch.autocast(
                device_type="cuda",
                dtype=torch.float16 if USE_AMP else torch.float32,
                enabled=USE_AMP,
            ):
                preds = model(batch)
            
            current_preds = preds.squeeze().cpu().numpy()

            # Ensure current_preds is iterable (handle single prediction case)
            if current_preds.ndim == 0:
                current_preds = [current_preds.item()]
            else:
                current_preds = current_preds.tolist()
            local_predictions.extend(current_preds)

            if hasattr(batch, "sample_name") and batch.sample_name is not None:
                local_sample_names.extend(batch.sample_name)
            else:
                print("Warning: Missing original_id in batch.")
                local_sample_names.extend(["unknown"] * len(current_preds))

            # --- Collect new fields ---
            if hasattr(batch, "smiles") and batch.smiles is not None:
                local_smiles.extend(batch.smiles)
            else:
                print("Warning: Missing smiles in batch.")
                local_smiles.extend([""] * len(current_preds))

        except Exception as e:
            print(f"\nError during inference on batch: {e}")
            # Add placeholders for this batch if lengths might mismatch
            num_in_batch = batch.num_graphs
            local_predictions.extend([np.nan] * num_in_batch)
            local_sample_names.extend(["error"] * num_in_batch)
            local_smiles.extend(["error"] * num_in_batch)
            continue

# Final length check (optional but good practice)
final_len = len(local_sample_names)
if not (len(local_predictions) == final_len and len(local_smiles) == final_len):
    print(
        "CRITICAL WARNING: Length mismatch between collected lists after inference loop. Results may be corrupted."
    )
    # Consider raising an error or attempting recovery/padding

print(f"Inference loop finished. Collected {final_len} results.")
# Create DataFrame with all columns

results_df = pd.DataFrame(
    {
        "ID": local_sample_names,
        "smiles": local_smiles,
        "Prediction": local_predictions,
    }
)

results_df.to_csv(result_csv_path, index=False, float_format="%.8f")
print("Predictions saved successfully.")