In [1]:
import sys 
import os
import h5py
sys.path.insert(1, '/scratch/work/masooda1/mocop')

from typing import Dict
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, Subset
from rdkit import Chem
import pyarrow.parquet as pq
import pyarrow.compute as pc

from featurizer.smiles_transformation import (inchi2smiles, smiles2fp,
                                              smiles2graph)

class SupervisedGraphDataset(Dataset):
    def __init__(
        self, data_path, cmpd_col="smiles", cmpd_col_is_inchikey=False, pad_length=0
    ):
        if "parquet" in data_path:
            self.df = pd.read_parquet(data_path)
        else:
            self.df = pd.read_csv(data_path)

        self.df = self.df.set_index(cmpd_col)
        if cmpd_col_is_inchikey:
            self.df.index = [inchi2smiles(s) for s in self.df.index]
        self.df = self.df[[c for c in self.df.columns if not c.startswith("Metadata")]]
        self.unique_smiles = self.df.index
        self.pad_length = pad_length

    def __len__(self):
        return len(self.unique_smiles)

    def _pad(self, adj_mat, node_feat, atom_vec):
        p = self.pad_length - len(atom_vec)
        if p >= 0:
            adj_mat = F.pad(adj_mat, (0, p, 0, p), "constant", 0)
            node_feat = F.pad(node_feat, (0, 0, 0, p), "constant", 0)
            atom_vec = F.pad(atom_vec, (0, 0, 0, p), "constant", 0)
        return adj_mat, node_feat, atom_vec

    def __getitem__(self, index):
        smiles = self.unique_smiles[index]
        adj_mat, node_feat = smiles2graph(smiles)
        adj_mat = torch.FloatTensor(adj_mat)
        node_feat = torch.FloatTensor(node_feat)
        atom_vec = torch.ones(len(node_feat), 1)
        cmpd_feat = self._pad(adj_mat, node_feat, atom_vec)

        labels = self.df.loc[smiles]

        if len(labels.shape) > 1 and len(labels) > 1:
            labels = labels.sample(1).iloc[0]

        labels = torch.FloatTensor(labels.values)
        return {
            "inputs": {"x_a": [torch.FloatTensor(f) for f in cmpd_feat]},
            "labels": labels,
        }


class SupervisedGraphDatasetJUMP(SupervisedGraphDataset):
    def __init__(self, *args, **kwargs):
        super(SupervisedGraphDataset, self).__init__(*args, **kwargs)
        self.unique_smiles = self.df.index.unique()


class DualInputDatasetJUMP(Dataset):
    def __init__(self, data_path):
        if "parquet" in data_path:
            self.df = pd.read_parquet(data_path)
        else:
            self.df = pd.read_csv(data_path)

        self.smiles_col = "Metadata_SMILES"
        if self.smiles_col not in self.df.columns:
            self.df[self.smiles_col] = [
                inchi2smiles(s) if s is not None else None
                for s in self.df["Metadata_InChI"]
            ]

        self.unique_smiles = [
            s for s in self.df[self.smiles_col].unique() if s is not None
        ]

        self.morph_col = [c for c in self.df.columns if not c.startswith("Metadata_")]
        self.smiles2mask = {}

    def _create_index(self):
        smiles = self.df[self.smiles_col].values
        return {s: np.argwhere(smiles == s).reshape(-1) for s in smiles}

    def __len__(self):
        return len(self.unique_smiles)

    def __getitem__(self, index):
        smiles = self.unique_smiles[index]
        cmpd_feat = smiles2fp(smiles)

        df = self.df[self.df[self.smiles_col] == smiles]
        morph_feat = df.sample(1)[self.morph_col].values.astype(float).flatten()

        labels = torch.Tensor([-1])
        return {
            "inputs": {
                "x_a": torch.FloatTensor(cmpd_feat),
                "x_b": torch.FloatTensor(morph_feat),
            },
            "labels": labels,
        }


class DualInputGraphDatasetJUMP(DualInputDatasetJUMP):
    def __init__(self, pad_length, *args, **kwargs):
        super(DualInputGraphDatasetJUMP, self).__init__(*args, **kwargs)
        self.pad_length = pad_length

    def _pad(self, adj_mat, node_feat, atom_vec):
        p = self.pad_length - len(atom_vec)
        if p >= 0:
            adj_mat = F.pad(adj_mat, (0, p, 0, p), "constant", 0)
            node_feat = F.pad(node_feat, (0, 0, 0, p), "constant", 0)
            atom_vec = F.pad(atom_vec, (0, 0, 0, p), "constant", 0)
        return adj_mat, node_feat, atom_vec

    def __getitem__(self, index):
        smiles = self.unique_smiles[index]
        adj_mat, node_feat = smiles2graph(smiles)
        adj_mat = torch.FloatTensor(adj_mat)
        node_feat = torch.FloatTensor(node_feat)
        atom_vec = torch.ones(len(node_feat), 1)
        cmpd_feat = self._pad(adj_mat, node_feat, atom_vec)

        try:
            mask = self.smiles2mask[smiles]
        except KeyError:
            mask = self.df[self.smiles_col] == smiles
            self.smiles2mask[smiles] = mask
        df = self.df[mask]
        morph_feat = df.sample(1)[self.morph_col].values.astype(float).flatten()
        labels = torch.Tensor([-1])
        return {
            "inputs": {
                "x_a": [torch.FloatTensor(f) for f in cmpd_feat],
                "x_b": torch.FloatTensor(morph_feat),
            },
            "labels": labels,
        }


class TripleInputGraphDatasetJUMP(DualInputGraphDatasetJUMP):
    def __init__(self, data_path, genomic_data_path, pad_length, *args, **kwargs):
        # Define SMILES column name first
        self.smiles_col = "Metadata_SMILES"
        
        # Load cell data
        if "parquet" in data_path:
            self.df = pd.read_parquet(data_path)
        else:
            self.df = pd.read_csv(data_path)

        # Load genomic data
        self.genomic_df = pd.read_parquet(genomic_data_path)

        # Ensure SMILES column exists in both datasets
        for df in [self.df, self.genomic_df]:
            if self.smiles_col not in df.columns:
                df[self.smiles_col] = [
                    inchi2smiles(s) if s is not None else None
                    for s in df["Metadata_InChI"]
                ]
        
        # Get unique SMILES from each dataset (excluding None)
        cell_smiles = set([s for s in self.df[self.smiles_col].unique() if s is not None])
        genomic_smiles = set([s for s in self.genomic_df[self.smiles_col].unique() if s is not None])
        
        # Print initial statistics
        print("\nSMILES Statistics:")
        print(f"Cell data unique SMILES: {len(cell_smiles)}")
        print(f"Genomic data unique SMILES: {len(genomic_smiles)}")
        print(f"Common SMILES (intersection): {len(cell_smiles.intersection(genomic_smiles))}")
        
        # Combine all unique SMILES (excluding None)
        self.unique_smiles = list(cell_smiles.union(genomic_smiles))
        
        # Print final dataset composition
        print(f"\nFinal Dataset:")
        print(f"Total unique SMILES: {len(self.unique_smiles)}")
        
        # Set remaining attributes
        self.pad_length = pad_length
        self.genomic_cols = [c for c in self.genomic_df.columns if not c.startswith("Metadata_")]
        self.morph_cols = [c for c in self.df.columns if not c.startswith("Metadata_")]

    def _pad(self, adj_mat, node_feat, atom_vec):
        p = self.pad_length - len(atom_vec)
        if p >= 0:
            adj_mat = F.pad(adj_mat, (0, p, 0, p), "constant", 0)
            node_feat = F.pad(node_feat, (0, 0, 0, p), "constant", 0)
            atom_vec = F.pad(atom_vec, (0, 0, 0, p), "constant", 0)
        return adj_mat, node_feat, atom_vec

    def __len__(self):
        return len(self.unique_smiles)

    def __getitem__(self, index):
        smiles = self.unique_smiles[index]
        # Get graph features (x_a)
        adj_mat, node_feat = smiles2graph(smiles)
        adj_mat = torch.FloatTensor(adj_mat)
        node_feat = torch.FloatTensor(node_feat)
        atom_vec = torch.ones(len(node_feat), 1)
        cmpd_feat = self._pad(adj_mat, node_feat, atom_vec)
        
        # Get morphological features (x_b)
        morph_mask = self.df[self.smiles_col] == smiles
        if morph_mask.any():
            morph_feat = self.df[morph_mask].sample(1)[self.morph_cols].values.astype(float).flatten()
        else:
            morph_feat = -1 * np.ones(len(self.morph_cols))
            
        # Get genomic features (x_c)
        genomic_mask = self.genomic_df[self.smiles_col] == smiles
        if genomic_mask.any():
            genomic_feat = self.genomic_df[genomic_mask].sample(1)[self.genomic_cols].values.astype(float).flatten()
        else:
            genomic_feat = -1 * np.ones(len(self.genomic_cols))
        
        return {
            "inputs": {
                "x_a": [torch.FloatTensor(f) for f in cmpd_feat],
                "x_b": torch.FloatTensor(morph_feat),
                "x_c": torch.FloatTensor(genomic_feat)
            },
            "labels": torch.Tensor([-1])
        }


class dataset_used_for_hdf5_conversion(DualInputGraphDatasetJUMP):
    """Dataset class for handling molecular data with cell line-specific genomic features.
    
    This class processes three types of data:
    1. Molecular graph features (x_a): Structural information about compounds
    2. Morphological features (x_b): Cell morphology measurements
    3. Genomic features (x_c): Gene expression data for different cell lines
    
    Each compound (SMILES) can have:
    - Morphological data only
    - Genomic data only
    - Both morphological and genomic data
    - Genomic data for multiple cell lines
    """

    def __init__(self, data_path, genomic_data_path, pad_length, *args, **kwargs):
        """Initialize the dataset without loading full data into memory."""
        
        # Define column names
        self.smiles_col = "Metadata_SMILES"
        self.cell_line_col = "Metadata_cell_iname"
        self.num_input_features_morph = 3479
        
        # Load data using pandas
        self.morph_df = pd.read_parquet(data_path)
        self.genomic_df = pd.read_parquet(genomic_data_path)
        
        # Get unique valid SMILES
        cell_smiles = set(self.morph_df.dropna(subset=[self.smiles_col])[self.smiles_col].unique())
        genomic_smiles = set(self.genomic_df.dropna(subset=[self.smiles_col])[self.smiles_col].unique())
        
        # Combine SMILES while preserving genomic data priority
        self.unique_smiles = (
            list(genomic_smiles) +  
            list(cell_smiles - genomic_smiles)
        )
        
        # Get feature columns
        self.morph_cols = [c for c in self.morph_df.columns if not c.startswith("Metadata_")]
        self.genomic_cols = [c for c in self.genomic_df.columns if not c.startswith("Metadata_")]
        
        # Create mappings for cell lines, doses, and times
        self.unique_cell_lines = sorted(self.genomic_df[self.cell_line_col].unique())
        self.cell_line_to_idx = {cell: idx + 1 for idx, cell in enumerate(self.unique_cell_lines)}
        
        self.unique_doses = sorted(self.genomic_df['Metadata_Dose_Level'].unique())
        self.dose_to_idx = {dose: idx + 1 for idx, dose in enumerate(self.unique_doses)}
        
        self.unique_times = sorted(self.genomic_df['Metadata_pert_time'].unique())
        self.time_to_idx = {time: idx + 1 for idx, time in enumerate(self.unique_times)}

        print(self.cell_line_to_idx)
        print("dose",self.dose_to_idx)
        print("time", self.time_to_idx)

        self.pad_length = pad_length

    def __getitem__(self, index):
        smiles = self.unique_smiles[index]
        
        # Get graph features
        adj_mat, node_feat = smiles2graph(smiles)
        adj_mat = torch.FloatTensor(adj_mat)
        node_feat = torch.FloatTensor(node_feat)
        atom_vec = torch.ones(len(node_feat), 1)
        cmpd_feat = self._pad(adj_mat, node_feat, atom_vec)
        
        # Get morphological features
        morph_data = self.morph_df[self.morph_df[self.smiles_col] == smiles][self.morph_cols]
        morph_feat = (
            morph_data.values.astype(float).reshape(-1, self.num_input_features_morph)
            if not morph_data.empty
            else -1 * np.ones((1, self.num_input_features_morph))
        )

        # Get genomic features
        genomic_data = self.genomic_df[self.genomic_df[self.smiles_col] == smiles]
    
        # Calculate genomic output array size
        n_cell_lines = len(self.cell_line_to_idx)
        n_doses = len(self.dose_to_idx)
        n_times = len(self.time_to_idx)
        n_features = len(self.genomic_cols)

        # Calculate max_replicates
        if not genomic_data.empty:
            max_replicates = genomic_data.groupby([self.cell_line_col, 'Metadata_Dose_Level', 'Metadata_pert_time']).size().max()
        else:
            max_replicates = 1
                            
        # Initialize arrays with padding values
        cell_features = -1 * np.ones((n_cell_lines, n_doses, n_times, max_replicates, n_features))
        cell_indices = np.zeros((n_cell_lines, n_doses, n_times, max_replicates))
        doses = np.zeros((n_cell_lines, n_doses, n_times, max_replicates))
        times = np.zeros((n_cell_lines, n_doses, n_times, max_replicates))
        
        if not genomic_data.empty:
            # Group data by cell line, dose, and time
            grouped_data = genomic_data.groupby([
                self.cell_line_col, 
                'Metadata_Dose_Level', 
                'Metadata_pert_time'
            ])
            
            for (cell_line, dose, time), group in grouped_data:
                # Convert 1-based indices to 0-based for array indexing
                cell_idx = self.cell_line_to_idx[cell_line] - 1
                dose_idx = self.dose_to_idx[dose] - 1
                time_idx = self.time_to_idx[time] - 1
                
                # Store each replicate
                for replicate_idx, (_, row) in enumerate(group.iterrows()):
                    cell_features[cell_idx, dose_idx, time_idx, replicate_idx] = row[self.genomic_cols].values
                    cell_indices[cell_idx, dose_idx, time_idx, replicate_idx] = self.cell_line_to_idx[cell_line]  # Keep 1-based for embedding
                    doses[cell_idx, dose_idx, time_idx, replicate_idx] = self.dose_to_idx[dose]  # Keep 1-based for embedding
                    times[cell_idx, dose_idx, time_idx, replicate_idx] = self.time_to_idx[time]  # Keep 1-based for embedding
        
        return {
            "inputs": {
                "x_a": [torch.FloatTensor(f) for f in cmpd_feat],
                "x_b": torch.FloatTensor(morph_feat),
                "x_c": torch.FloatTensor(cell_features),
                "cell_indices": torch.LongTensor(cell_indices),
                "doses": torch.FloatTensor(doses),
                "times": torch.FloatTensor(times)
            },
            "labels": torch.Tensor([-1]),
            "SMILES": smiles
        }
    '''
    def collate_fn(self, batch):
        """Collate function for DataLoader.
        
        Handles batching of variable-length cell line data by:
        1. Padding to maximum number of possible cell lines
        2. Using 0 to indicate padding/missing data
        
        Args:
            batch (list): List of items from __getitem__
            
        Returns:
            dict: Contains:
                - inputs: Dict with batched x_a, x_b, x_c features, cell_indices, doses, and times
                - labels: Batched labels
        """
        # Stack molecular features
        adj_mats = torch.stack([item["inputs"]["x_a"][0] for item in batch])
        node_feats = torch.stack([item["inputs"]["x_a"][1] for item in batch])
        atom_vecs = torch.stack([item["inputs"]["x_a"][2] for item in batch])
        x_a_batch = [adj_mats, node_feats, atom_vecs]
        
        # Stack morphological features and labels
        x_b_batch = torch.stack([item["inputs"]["x_b"] for item in batch])
        labels_batch = torch.stack([item["labels"] for item in batch])
        
        # Stack genomic features and cell indices
        x_c_batch = torch.stack([item["inputs"]["x_c"] for item in batch])
        cell_indices_batch = torch.stack([item["inputs"]["cell_indices"] for item in batch])
        
        # Stack doses and times
        doses_batch = torch.stack([item["inputs"]["doses"] for item in batch])
        times_batch = torch.stack([item["inputs"]["times"] for item in batch])

        # stack SMILES
        SMILES_batch = [item["SMILES"] for item in batch]
        
        return {
            "inputs": {
                "x_a": x_a_batch,
                "x_b": x_b_batch,
                "x_c": x_c_batch,
                "cell_indices": cell_indices_batch,
                "doses": doses_batch,  # Add batched doses
                "times": times_batch   # Add batched times
            },
            "labels": labels_batch,
            "SMILES": SMILES_batch
        }
    '''

In [2]:
def generate_hdf5_data(path, dl, file_name):
    num_images = len(dl)
    dataloader = dl
    
    if os.path.exists(path + f"/{file_name}.h5"):
        os.remove(path + f"/{file_name}.h5")
        
    file = h5py.File(path + f"/{file_name}.h5", "w")
    
    # Get sample batch
    sample_batch = next(iter(dataloader))
    inputs = sample_batch["inputs"]
    labels = sample_batch["labels"]

    # Create a group for inputs
    inputs_group = file.create_group("inputs")

    # Store dimensions for x_c reconstruction
    inputs_group.create_dataset("x_c_dimensions", (4,), dtype=np.int32)  # [n_cell_lines, n_doses, n_times, n_features]
    inputs_group["x_c_dimensions"][()] = np.array([
        inputs["x_c"].shape[1],  # n_cell_lines
        inputs["x_c"].shape[2],  # n_doses
        inputs["x_c"].shape[3],  # n_times
        inputs["x_c"].shape[-1]  # n_features
    ])

    # Molecular features (x_a)
    inputs_group.create_dataset(
        "x_a/adj_mats", (num_images, *inputs["x_a"][0].shape[1:]), dtype=np.float32
    )
    inputs_group.create_dataset(
        "x_a/node_feats", (num_images, *inputs["x_a"][1].shape[1:]), dtype=np.float32
    )
    inputs_group.create_dataset(
        "x_a/atom_vecs", (num_images, *inputs["x_a"][2].shape[1:]), dtype=np.float32
    )

    # Variable length for x_b
    inputs_group.create_dataset("x_b_data", (num_images,), dtype=h5py.vlen_dtype(np.dtype('float32')))
    inputs_group.create_dataset("x_b_shapes", (num_images, 2), dtype=np.int32)

    # For x_c and related features
    inputs_group.create_dataset("x_c_features", (num_images,), dtype=h5py.vlen_dtype(np.dtype('float32')))
    inputs_group.create_dataset("x_c_cell_indices", (num_images,), dtype=h5py.vlen_dtype(np.dtype('int32')))
    inputs_group.create_dataset("x_c_dose_indices", (num_images,), dtype=h5py.vlen_dtype(np.dtype('int32')))
    inputs_group.create_dataset("x_c_time_indices", (num_images,), dtype=h5py.vlen_dtype(np.dtype('int32')))
    inputs_group.create_dataset("x_c_replicate_indices", (num_images,), dtype=h5py.vlen_dtype(np.dtype('int32')))
    inputs_group.create_dataset("x_c_max_replicates", (num_images,), dtype=np.int32)

    file.create_dataset("labels", (num_images, *labels.shape[1:]), dtype=np.float32)
    file.create_dataset("SMILES", (num_images,), dtype=h5py.string_dtype())

    # Save all data
    for i, batch in enumerate(dataloader):
        inputs = batch["inputs"]
        labels = batch["labels"]
        smiles = batch["SMILES"]

        # Save molecular features (x_a)
        inputs_group["x_a/adj_mats"][i] = inputs["x_a"][0].numpy()
        inputs_group["x_a/node_feats"][i] = inputs["x_a"][1].numpy()
        inputs_group["x_a/atom_vecs"][i] = inputs["x_a"][2].numpy()

        # Handle x_b data
        x_b = inputs["x_b"].squeeze(0).numpy()
        if np.all(x_b == -1):
            inputs_group["x_b_data"][i] = np.array([], dtype=np.float32)
            inputs_group["x_b_shapes"][i] = np.array([0, 0])
        else:
            x_b_flat = x_b.reshape(-1)
            inputs_group["x_b_data"][i] = x_b_flat
            inputs_group["x_b_shapes"][i] = np.array(x_b.shape)

        # Handle x_c data efficiently
        x_c = inputs["x_c"].squeeze(0).numpy()
        cell_indices = inputs["cell_indices"].squeeze(0).numpy()
        
        # Find non-empty entries (where cell_indices > 0)
        valid_mask = cell_indices > 0
        inputs_group["x_c_max_replicates"][i] = x_c.shape[3]  # Save max_replicates for this sample
        
        if np.any(valid_mask):
            # Get indices where data exists
            idx_arrays = np.where(valid_mask)
            
            # Store features and corresponding indices
            inputs_group["x_c_features"][i] = x_c[valid_mask].reshape(-1)
            inputs_group["x_c_cell_indices"][i] = idx_arrays[0]
            inputs_group["x_c_dose_indices"][i] = idx_arrays[1]
            inputs_group["x_c_time_indices"][i] = idx_arrays[2]
            inputs_group["x_c_replicate_indices"][i] = idx_arrays[3]
        else:
            # Store empty arrays if no data
            inputs_group["x_c_features"][i] = np.array([], dtype=np.float32)
            inputs_group["x_c_cell_indices"][i] = np.array([], dtype=np.int32)
            inputs_group["x_c_dose_indices"][i] = np.array([], dtype=np.int32)
            inputs_group["x_c_time_indices"][i] = np.array([], dtype=np.int32)
            inputs_group["x_c_replicate_indices"][i] = np.array([], dtype=np.int32)

        file["labels"][i] = labels.numpy()
        file["SMILES"][i] = smiles[0]

    file.close()
    return

In [3]:
def sanitize_smiles(smiles):
    """Simple SMILES sanitization to handle valence errors."""
    if not smiles or pd.isna(smiles):
        return None
        
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return None
            
        # Try to sanitize with just the valence check
        Chem.SanitizeMol(mol, sanitizeOps=Chem.SANITIZE_PROPERTIES)
        return Chem.MolToSmiles(mol, canonical=True)
    except Exception:
        return None

In [4]:
def _split_data(dataset: Dataset, splits: Dict[str, str]) -> Dict[str, Dataset]:
    print(f"genomic and cell unique SMILES, {len(dataset.unique_smiles)}")
    if splits is None:
        unique_smiles = dataset.unique_smiles
        total_smiles = len(unique_smiles)
        train_idx = np.random.choice(
            total_smiles, size=int(0.9 * total_smiles), replace=False
        )
        val_idx = [i for i in range(total_smiles) if i not in train_idx]
        return {
            "train": Subset(dataset, train_idx),
            "val": Subset(dataset, val_idx),
            "test": Subset(dataset, val_idx),
        }

    assert "train" in splits and "val" in splits
    split_dataset = {}
    for k, v in splits.items():
        print(f"Split {k}: {v}")
        df_split = pd.read_csv(v)
        if "index" in df_split.columns:
            indices = df_split["index"].values
            # Map indices to SMILES
            smiles_list = [dataset.unique_smiles[i] for i in indices if i < len(dataset.unique_smiles)]

            # Sanitize SMILES
            sanitized_smiles = [sanitize_smiles(s) for s in smiles_list]

            # Filter out invalid SMILES
            valid_indices = [i for i, s in zip(indices, sanitized_smiles) if s is not None]

            # Print sanitization statistics
            print(f"Original indices count: {len(indices)}")
            print(f"Valid indices count: {len(valid_indices)}")
            print(f"Removed indices count: {len(indices) - len(valid_indices)}")

        else:
            split_smiles = df_split["SMILES"].unique()
            sanitized_smiles = [sanitize_smiles(s) for s in split_smiles]
            valid_smiles = [s for s in sanitized_smiles if s is not None]

            # Print sanitization statistics
            print("Using SMILE index")
            print(f"Original SMILES count: {len(split_smiles)}")
            print(f"Valid SMILES count: {len(valid_smiles)}")
            print(f"Removed SMILES count: {len(split_smiles) - len(valid_smiles)}")

            valid_indices = [
                i
                for i, smiles in enumerate(dataset.unique_smiles)
                if smiles in valid_smiles
            ]
        split_dataset[k] = Subset(dataset, valid_indices)
    return split_dataset

In [5]:
from mocop.training import build_dataloaders
from mocop.dataset import CellLineTripleInputGraphDatasetJUMP
from torch.utils.data import DataLoader, Dataset

data_path = "/scratch/work/masooda1/mocop/data/jump_data/cell_fetures_with_smiles_1000.parquet"
#data_path = "/scratch/work/masooda1/mocop/data/jump_data/cell_fetures_with_smiles.parquet"

#genomic_data_path = "/scratch/cs/pml/AI_drug/molecular_representation_learning/LINCS/landmark_cmp_data_min1000compounds_all_measurements.parquet"
genomic_data_path = "/scratch/cs/pml/AI_drug/molecular_representation_learning/LINCS/landmark_cmp_data_min1000compounds_all_measurements_test.parquet"


#JUMP-LINCS
train = "/scratch/work/masooda1/mocop/data/LINCS_All_cell_lines/JUMP-LINCS-compound-split-0-train.csv"
val = "/scratch/work/masooda1/mocop/data/LINCS_All_cell_lines/JUMP-LINCS-compound-split-0-val.csv"
test = "/scratch/work/masooda1/mocop/data/LINCS_All_cell_lines/JUMP-LINCS-compound-split-0-val.csv"

#JUMP
#train = "/scratch/work/masooda1/mocop/data/jump/jump-compound-split-0-train.csv"
#val = "/scratch/work/masooda1/mocop/data/jump/jump-compound-split-0-val.csv"
#test = "/scratch/work/masooda1/mocop/data/jump/jump-compound-split-0-val.csv"

# LINCS
#train = "/scratch/work/masooda1/mocop/data/LINCS_All_cell_lines/LINCS-compound-split-0-train.csv"
#val = "/scratch/work/masooda1/mocop/data/LINCS_All_cell_lines/LINCS-compound-split-0-val.csv"
#test = "/scratch/work/masooda1/mocop/data/LINCS_All_cell_lines/LINCS-compound-split-0-val.csv"

In [6]:
ds = dataset_used_for_hdf5_conversion(
                                data_path = data_path, 
                                genomic_data_path = genomic_data_path,
                                pad_length = 250,
                                splits = [train, val, test]
)

{'A375': 1, 'A549': 2, 'ASC': 3, 'HA1E': 4, 'HCC515': 5, 'HEC108': 6, 'HEK293': 7, 'HELA': 8, 'HEPG2': 9, 'HT29': 10, 'JURKAT': 11, 'MCF10A': 12, 'MCF7': 13, 'MDAMB231': 14, 'NEU': 15, 'NPC': 16, 'PC3': 17, 'PHH': 18, 'SKB': 19, 'THP1': 20, 'U2OS': 21, 'VCAP': 22, 'XC.L10': 23, 'YAPC': 24}
dose {3: 1, 4: 2, 5: 3, 6: 4, 7: 5}
time {6.0: 1, 24.0: 2}


In [7]:
dl = build_dataloaders(
                dataset = ds,
                batch_size = 1,
                num_workers = 1,
                splits = {"train": train, "val": val, "test":test})

Split train: /scratch/work/masooda1/mocop/data/LINCS_All_cell_lines/JUMP-LINCS-compound-split-0-train.csv
Split val: /scratch/work/masooda1/mocop/data/LINCS_All_cell_lines/JUMP-LINCS-compound-split-0-val.csv
Split test: /scratch/work/masooda1/mocop/data/LINCS_All_cell_lines/JUMP-LINCS-compound-split-0-val.csv

Dataloader Sizes:
--------------------------------------------------
TRAIN:
  • Number of samples: 754
  • Batch size: 1
  • Number of batches: 754
  • Drop last: True
--------------------------------------------------
VAL:
  • Number of samples: 35
  • Batch size: 1
  • Number of batches: 35
  • Drop last: True
--------------------------------------------------
TEST:
  • Number of samples: 35
  • Batch size: 1
  • Number of batches: 35
  • Drop last: False
--------------------------------------------------


In [8]:
path = "/scratch/work/masooda1/mocop/data/dummy_data/hd5_data"
file_name = "train"
#generate_hdf5_data(path, dl["train"], file_name)

In [36]:
import torch
import numpy as np
import h5py

def read_hdf5_data(file_path, index):
    """
    Reads and processes data from an HDF5 file for a given index.
    This function performs the random sampling of x_b rows and random replicate
    selection for x_c (and its associated arrays) directly.
    """
    with h5py.File(file_path, 'r') as file:
        # Initialize the data dictionary.
        data = {
            'inputs': {
                'x_a': [],
                'x_b': None,
                'x_c': None,
                'cell_indices': None,
                'doses': None,
                'times': None
            },
            'labels': None,
            'SMILES': None
        }

        # ------------------------------
        # Load and process x_a
        # ------------------------------
        data['inputs']['x_a'] = [
            torch.from_numpy(file['inputs']['x_a/adj_mats'][index]),
            torch.from_numpy(file['inputs']['x_a/node_feats'][index]),
            torch.from_numpy(file['inputs']['x_a/atom_vecs'][index])
        ]

        # ------------------------------
        # Load, process, and randomly sample x_b
        # ------------------------------
        x_b_data = file['inputs']['x_b_data'][index]
        x_b_shape = file['inputs']['x_b_shapes'][index]
        if len(x_b_data) == 0:
            # If no data is available, create a default tensor.
            x_b_tensor = torch.full((1, 3479), -1, dtype=torch.float32)
        else:
            x_b_tensor = torch.from_numpy(np.array(x_b_data).reshape(x_b_shape))

        # Randomly sample one row from x_b_tensor.
        rand_idx = torch.randint(x_b_tensor.shape[0], (1,))
        data['inputs']['x_b'] = x_b_tensor[rand_idx].squeeze()

        # ------------------------------
        # Load and process x_c and associated arrays
        # ------------------------------
        # Get the full dimensions for x_c.
        dims = file['inputs']['x_c_dimensions'][()]
        n_cell_lines, n_doses, n_times, n_features = dims
        max_replicates = file['inputs']['x_c_max_replicates'][index]

        # Initialize the numpy arrays with default values.
        x_c = -1 * np.ones((n_cell_lines, n_doses, n_times, max_replicates, n_features))
        cell_indices_arr = np.zeros((n_cell_lines, n_doses, n_times, max_replicates))
        doses_arr = np.zeros((n_cell_lines, n_doses, n_times, max_replicates))
        times_arr = np.zeros((n_cell_lines, n_doses, n_times, max_replicates))

        # Retrieve the available x_c features (and related indices) from the file.
        features = file['inputs']['x_c_features'][index]
        if len(features) > 0:
            cell_idx = file['inputs']['x_c_cell_indices'][index]
            dose_idx = file['inputs']['x_c_dose_indices'][index]
            time_idx = file['inputs']['x_c_time_indices'][index]
            rep_idx = file['inputs']['x_c_replicate_indices'][index]

            # Fill in available data.
            x_c[cell_idx, dose_idx, time_idx, rep_idx] = features.reshape(-1, n_features)
            cell_indices_arr[cell_idx, dose_idx, time_idx, rep_idx] = cell_idx + 1
            doses_arr[cell_idx, dose_idx, time_idx, rep_idx] = dose_idx + 1
            times_arr[cell_idx, dose_idx, time_idx, rep_idx] = time_idx + 1

        # Convert the numpy arrays to torch tensors.
        x_c_tensor = torch.from_numpy(x_c)
        cell_indices_tensor = torch.from_numpy(cell_indices_arr)
        doses_tensor = torch.from_numpy(doses_arr)
        times_tensor = torch.from_numpy(times_arr)

        # ------------------------------
        # Random replicate selection for x_c
        # ------------------------------
        # x_c_tensor shape: [n_cell_lines, n_doses, n_times, max_replicates, n_features]
        n_cell_lines, n_doses, n_times, n_replicates, n_features = x_c_tensor.shape

        # Create a mask of valid replicates (assuming a replicate is valid if its first feature != -1).
        valid_mask = x_c_tensor[..., 0] != -1

        # For each (cell line, dose, time), determine if any replicate is valid.
        valid_positions = valid_mask.any(dim=-1)

        # Randomly choose a replicate index for each (cell line, dose, time).
        random_replicate_idx = torch.randint(0, n_replicates, (n_cell_lines, n_doses, n_times))

        # Check whether the randomly chosen replicate is valid.
        chosen_valid = valid_mask.gather(dim=-1, index=random_replicate_idx.unsqueeze(-1)).squeeze(-1)

        # For positions where the random choice is not valid, select the first valid replicate.
        first_valid_idx = valid_mask.float().argmax(dim=-1)
        chosen_idx = torch.where(chosen_valid, random_replicate_idx, first_valid_idx)

        # Build index tensors for advanced indexing.
        cell_idx_range = torch.arange(n_cell_lines).view(n_cell_lines, 1, 1).expand(n_cell_lines, n_doses, n_times)
        dose_idx_range = torch.arange(n_doses).view(1, n_doses, 1).expand(n_cell_lines, n_doses, n_times)
        time_idx_range = torch.arange(n_times).view(1, 1, n_times).expand(n_cell_lines, n_doses, n_times)

        # Select the corresponding replicate from x_c and the associated arrays.
        sampled_x_c = x_c_tensor[cell_idx_range, dose_idx_range, time_idx_range, chosen_idx, :]
        sampled_x_c = torch.where(
            valid_positions.unsqueeze(-1),
            sampled_x_c,
            torch.full_like(sampled_x_c, -1)
        )

        sampled_cell_indices = cell_indices_tensor[cell_idx_range, dose_idx_range, time_idx_range, chosen_idx]
        sampled_cell_indices = torch.where(
            valid_positions,
            sampled_cell_indices,
            torch.full_like(sampled_cell_indices, 0)
        )

        sampled_doses = doses_tensor[cell_idx_range, dose_idx_range, time_idx_range, chosen_idx]
        sampled_doses = torch.where(
            valid_positions,
            sampled_doses,
            torch.full_like(sampled_doses, 0)
        )

        sampled_times = times_tensor[cell_idx_range, dose_idx_range, time_idx_range, chosen_idx]
        sampled_times = torch.where(
            valid_positions,
            sampled_times,
            torch.full_like(sampled_times, 0)
        )

        # Update the data dictionary with the processed x_c and associated arrays.
        data['inputs']['x_c'] = sampled_x_c
        data['inputs']['cell_indices'] = sampled_cell_indices
        data['inputs']['doses'] = sampled_doses
        data['inputs']['times'] = sampled_times

        # ------------------------------
        # Process labels and SMILES
        # ------------------------------
        data['labels'] = torch.from_numpy(file['labels'][index])
        data['SMILES'] = file['SMILES'][index].decode('utf-8')

        # (Optional) Print some debugging info.
        #print(f"Valid replicates count: {valid_mask.sum().item()} for SMILES: {data['SMILES']}")

    return data


class HDF5Dataset(torch.utils.data.Dataset):
    """
    A PyTorch Dataset for reading data from an HDF5 file.
    The __getitem__ method simply calls read_hdf5_data which already performs
    all the necessary processing and random sampling.
    """
    def __init__(self, file_path):
        self.file_path = file_path
        with h5py.File(file_path, 'r') as file:
            self.length = len(file['SMILES'])

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        # All processing is handled in read_hdf5_data.
        return read_hdf5_data(self.file_path, index=idx)


    def collate_fn(self, batch):
        """
        Custom collate function to efficiently combine samples into batches.
        """
        batch_data = {
            'inputs': {
                'x_a': [
                    torch.stack([item['inputs']['x_a'][i].squeeze() for item in batch])
                    for i in range(3)
                ],
                'x_b': torch.stack([item['inputs']['x_b'] for item in batch]),
                'x_c': torch.stack([item['inputs']['x_c'] for item in batch]),
                'cell_indices': torch.stack([item['inputs']['cell_indices'] for item in batch]),
                'doses': torch.stack([item['inputs']['doses'] for item in batch]),
                'times': torch.stack([item['inputs']['times'] for item in batch])
            },
            'labels': torch.stack([item['labels'] for item in batch]),
            'SMILES': np.array([item['SMILES'] for item in batch])
        }
        return batch_data


def create_dataloader(file_path, batch_size, num_workers, shuffle):
    """
    Create a DataLoader for the HDF5 dataset.

    Args:
        file_path (str): Path to the HDF5 file.
        batch_size (int): Batch size.
        num_workers (int): Number of worker processes.
        shuffle (bool): Whether to shuffle the data.

    Returns:
        torch.utils.data.DataLoader: DataLoader for the dataset.
    """
    dataset = HDF5Dataset(file_path)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        collate_fn=collate_fn,
        pin_memory=True
    )
    return dataloader


In [46]:
import time
train_loader = create_dataloader(
    file_path='/scratch/work/masooda1/mocop/data/dummy_data/hd5_data/train.h5',
    batch_size=7,
    num_workers = 0,
    shuffle=False
)

start = time.time()
batch = next(iter(train_loader))
end = time.time()
print(end - start)

x_a = batch['inputs']['x_a']  # List of 3 tensors
x_b = batch['inputs']['x_b']  # [batch_size, ...]
x_c = batch['inputs']['x_c']  # [batch_size, n_cell_lines, n_doses, n_times, max_replicates, n_features]
cell_indices = batch['inputs']['cell_indices']  # [batch_size, n_cell_lines, n_doses, n_times, max_replicates]
doses = batch['inputs']['doses']  # Same shape as cell_indices
times = batch['inputs']['times']  # Same shape as cell_indices
labels = batch['labels']
smiles = batch['SMILES']
print(x_b.shape, x_c.shape, x_a[0].shape, x_a[1].shape, x_a[2].shape)

0.04106926918029785
torch.Size([7, 3479]) torch.Size([7, 24, 5, 2, 978]) torch.Size([7, 250, 250]) torch.Size([7, 250, 75]) torch.Size([7, 250])


In [47]:
print(cell_indices.shape, doses.shape, times.shape)

torch.Size([7, 24, 5, 2]) torch.Size([7, 24, 5, 2]) torch.Size([7, 24, 5, 2])


In [48]:
selected_SMILES = "COc1cccc(OC)c1-c1cc(nn1-c1ccnc2cc(Cl)ccc12)C(=O)NC1(C2CC3CC(C2)CC1C3)C(O)=O"

In [49]:
valid_positions = torch.any(x_c != -1, dim=-1)  # Remove feature dimension
count = valid_positions.sum().item()
print(f"Number of positions with actual data: {count}")

Number of positions with actual data: 4


In [50]:
cell_indices[cell_indices != 0]

tensor([15., 15., 16., 16.], dtype=torch.float64)

In [51]:
doses[doses != 0]

tensor([4., 4., 4., 4.], dtype=torch.float64)

In [52]:
times[times != 0]

tensor([1., 2., 1., 2.], dtype=torch.float64)

In [53]:
x_c.shape

torch.Size([7, 24, 5, 2, 978])

In [54]:
x_c[6,15,3,0,:6]

tensor([ 0.3877,  0.8424, -0.5394, -0.6000,  0.7487, -0.1743],
       dtype=torch.float64)

In [32]:
morph_df = pd.read_parquet(data_path)
genomic_df = pd.read_parquet(genomic_data_path)

In [33]:
morph_df[morph_df.Metadata_SMILES == selected_SMILES]

Unnamed: 0,Metadata_SMILES,Metadata_JCP2022,Metadata_InChIKey,Metadata_InChI,Metadata_Source,Metadata_Plate,Metadata_Well,Cells_AreaShape_BoundingBoxMaximum_X,Cells_AreaShape_BoundingBoxMaximum_Y,Cells_AreaShape_BoundingBoxMinimum_X,...,Nuclei_Texture_SumEntropy_RNA_10_02_256,Nuclei_Texture_SumEntropy_RNA_10_03_256,Nuclei_Texture_SumEntropy_RNA_3_00_256,Nuclei_Texture_SumEntropy_RNA_3_01_256,Nuclei_Texture_SumEntropy_RNA_3_02_256,Nuclei_Texture_SumEntropy_RNA_3_03_256,Nuclei_Texture_SumEntropy_RNA_5_00_256,Nuclei_Texture_SumEntropy_RNA_5_01_256,Nuclei_Texture_SumEntropy_RNA_5_02_256,Nuclei_Texture_SumEntropy_RNA_5_03_256


In [34]:
genomic_df[genomic_df.Metadata_SMILES == selected_SMILES].groupby(["Metadata_cell_iname", "Metadata_Dose_Level", "Metadata_pert_time"]).count()

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Metadata_cid,10007,1001,10013,10038,10046,10049,10051,10057,10058,...,Metadata_det_wells,Metadata_det_plates,Metadata_distil_ids,Metadata_build_name,Metadata_project_code,Metadata_cmap_name_y,Metadata_is_exemplar_sig,Metadata_is_ncs_sig,Metadata_is_null_sig,Metadata_Dose_Bins
Metadata_cell_iname,Metadata_Dose_Level,Metadata_pert_time,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1
NEU,6,6.0,1,1,1,1,1,1,1,1,1,1,...,1,1,1,0,1,1,1,1,1,1
NEU,6,24.0,1,1,1,1,1,1,1,1,1,1,...,1,1,1,0,1,1,1,1,1,1
NPC,6,6.0,2,2,2,2,2,2,2,2,2,2,...,2,2,2,0,2,2,2,2,2,2
NPC,6,24.0,2,2,2,2,2,2,2,2,2,2,...,2,2,2,0,2,2,2,2,2,2


In [35]:
genomic_df[(genomic_df.Metadata_SMILES == selected_SMILES)
            & (genomic_df.Metadata_cell_iname == "NPC")
            & (genomic_df.Metadata_pert_time == 6.0)]

Unnamed: 0,Metadata_cid,10007,1001,10013,10038,10046,10049,10051,10057,10058,...,Metadata_det_plates,Metadata_distil_ids,Metadata_build_name,Metadata_project_code,Metadata_cmap_name_y,Metadata_is_exemplar_sig,Metadata_is_ncs_sig,Metadata_is_null_sig,Metadata_Dose_Bins,Metadata_Dose_Level
173,NMH001_FIBRNPC_6H:BRD-K25075681-001-01-2:10,-0.42735,1.43825,-0.2945,-2.0058,-0.43305,-0.0918,0.98695,-0.13975,-0.6863,...,NMH001_FIBRNPC_6H_X1_B6_DUO52HI53LO|NMH001_FIB...,NMH001_FIBRNPC_6H_X1_B6_DUO52HI53LO:P11|NMH001...,,NMH,SR-48692,0,0.0,0.0,"(1.0, 10.0]",6
189,NMH001_NPC_6H:BRD-K25075681-001-01-2:10,0.387717,0.842399,-0.539441,-0.599967,0.748698,-0.174282,-0.028204,0.925945,-0.877023,...,NMH001_NPC_6H_X1_B6_DUO52HI53LO|NMH001_NPC_6H_...,NMH001_NPC_6H_X1_B6_DUO52HI53LO:P11|NMH001_NPC...,,NMH,SR-48692,0,1.0,0.0,"(1.0, 10.0]",6


In [172]:
ds = dataset_used_for_hdf5_conversion(
                                data_path = data_path, 
                                genomic_data_path = genomic_data_path,
                                pad_length = 250,
                                splits = [train, val, test]
)

{'A375': 1, 'A549': 2, 'ASC': 3, 'HA1E': 4, 'HCC515': 5, 'HEC108': 6, 'HEK293': 7, 'HELA': 8, 'HEPG2': 9, 'HT29': 10, 'JURKAT': 11, 'MCF10A': 12, 'MCF7': 13, 'MDAMB231': 14, 'NEU': 15, 'NPC': 16, 'PC3': 17, 'PHH': 18, 'SKB': 19, 'THP1': 20, 'U2OS': 21, 'VCAP': 22, 'XC.L10': 23, 'YAPC': 24}
dose {3: 1, 4: 2, 5: 3, 6: 4, 7: 5}
time {6.0: 1, 24.0: 2}


In [71]:
import time
train_loader = create_dataloader(
    file_path='/scratch/work/masooda1/mocop/data/dummy_data/hd5_data/train.h5',
    batch_size=128,
    num_workers = 0,
    shuffle=True
)

start = time.time()
batch = next(iter(train_loader))
end = time.time()
print(end - start)

x_a = batch['inputs']['x_a']  # List of 3 tensors
x_b = batch['inputs']['x_b']  # [batch_size, ...]
x_c = batch['inputs']['x_c']  # [batch_size, n_cell_lines, n_doses, n_times, max_replicates, n_features]
cell_indices = batch['inputs']['cell_indices']  # [batch_size, n_cell_lines, n_doses, n_times, max_replicates]
doses = batch['inputs']['doses']  # Same shape as cell_indices
times = batch['inputs']['times']  # Same shape as cell_indices
labels = batch['labels']
smiles = batch['SMILES']
print(x_b.shape, x_c.shape, x_a[0].shape, x_a[1].shape, x_a[2].shape)

1.0559806823730469
torch.Size([128, 3479]) torch.Size([128, 24, 5, 2, 978]) torch.Size([128, 250, 250]) torch.Size([128, 250, 75]) torch.Size([128, 250])


In [170]:
batch = read_hdf5_data('/scratch/work/masooda1/mocop/data/dummy_data/hd5_data/train.h5', indices = 0)
print(batch['inputs']['x_c'].shape)
print(batch['inputs']['x_b'][0].shape)
#torch.all(batch['inputs']['x_c'] != -1, dim = -1).item()

torch.Size([24, 5, 2, 1, 978])
torch.Size([1, 3479])


In [171]:
batch['inputs']['x_b'][0].shape

torch.Size([1, 3479])

In [190]:
batch['inputs']['x_b']

tensor([[[ 0.5313,  0.8455,  0.5337,  ..., -1.5320, -1.5451, -1.4384]]])

In [191]:
for i in range(700):
    batch = read_hdf5_data('/scratch/work/masooda1/mocop/data/dummy_data/hd5_data/train.h5', indices = i)
    if torch.all(batch['inputs']['x_b'] != -1).item():
        print(batch['inputs']['x_b'].shape, batch['SMILES'])
        print(batch['inputs']['x_b'])

torch.Size([1, 1, 3479]) ['COC(=O)c1cc(OC)c(OC)cc1N=C(O)c1c(-c2ccccc2Cl)noc1C']
tensor([[[ 0.4693, -0.1951, -0.1710,  ...,  1.9172,  1.9902,  2.1981]]])
torch.Size([1, 1, 3479]) ['COc1cc(OC)cc(C(O)=NC(C(=O)N2CCc3ccccc32)C(C)C)c1']
tensor([[[ 1.2894,  1.1799,  1.4107,  ..., -0.9646, -0.9922, -0.9553]]])
torch.Size([1, 1, 3479]) ['COc1ccc(-c2ncccc2O)cc1']
tensor([[[-1.0919, -0.2693, -1.1531,  ..., -1.1063, -1.1618, -1.0558]]])
torch.Size([1, 1, 3479]) ['CCOc1ccc(NS(=O)(=O)c2ccc(C(=O)N3CCCC3)s2)cc1']
tensor([[[0.0425, 0.5249, 0.0443,  ..., 0.7980, 0.8656, 0.8351]]])
torch.Size([1, 1, 3479]) ['Cc1cc(F)ccc1NC(=O)C1(C)Oc2cccnc2N=C1O']
tensor([[[ 1.2437,  0.3285,  1.3114,  ..., -0.3301, -0.4357, -0.3865]]])
torch.Size([1, 1, 3479]) ['CCOc1ccc2cc(C3CC(c4ccc(C)cc4)=NN3S(C)(=O)=O)c(Cl)nc2c1']
tensor([[[ 0.1773, -2.6001,  0.2492,  ...,  0.1324,  0.1061,  0.1682]]])
torch.Size([1, 1, 3479]) ['Cc1c[nH]nc1C1CCCN(C)C1']
tensor([[[-1.6859, -1.1257, -1.8251,  ..., -0.1098, -0.0721, -0.1000]]])
torch.Si

In [173]:
phenomic_data = pd.read_parquet(data_path)
SMILE = "c1ccc(-c2n[nH]cc2-c2ccnc3ccccc23)nc1"
selected_mol = phenomic_data[(phenomic_data.Metadata_SMILES == SMILE)]
selected_mol

Unnamed: 0,Metadata_SMILES,Metadata_JCP2022,Metadata_InChIKey,Metadata_InChI,Metadata_Source,Metadata_Plate,Metadata_Well,Cells_AreaShape_BoundingBoxMaximum_X,Cells_AreaShape_BoundingBoxMaximum_Y,Cells_AreaShape_BoundingBoxMinimum_X,...,Nuclei_Texture_SumEntropy_RNA_10_02_256,Nuclei_Texture_SumEntropy_RNA_10_03_256,Nuclei_Texture_SumEntropy_RNA_3_00_256,Nuclei_Texture_SumEntropy_RNA_3_01_256,Nuclei_Texture_SumEntropy_RNA_3_02_256,Nuclei_Texture_SumEntropy_RNA_3_03_256,Nuclei_Texture_SumEntropy_RNA_5_00_256,Nuclei_Texture_SumEntropy_RNA_5_01_256,Nuclei_Texture_SumEntropy_RNA_5_02_256,Nuclei_Texture_SumEntropy_RNA_5_03_256
558,c1ccc(-c2n[nH]cc2-c2ccnc3ccccc23)nc1,JCP2022_033954,IBCXZJCWDGCXQT-UHFFFAOYSA-N,InChI=1S/C17H12N4/c1-2-6-15-13(5-1)12(8-10-19-...,source_9,GR00003332,M01,1.2493,0.004344,1.633361,...,-1.872513,-1.539756,-1.725201,-1.776512,-1.693937,-1.688993,-1.744569,-1.861845,-1.789947,-1.800001
522,c1ccc(-c2n[nH]cc2-c2ccnc3ccccc23)nc1,JCP2022_033954,IBCXZJCWDGCXQT-UHFFFAOYSA-N,InChI=1S/C17H12N4/c1-2-6-15-13(5-1)12(8-10-19-...,source_9,GR00004367,U24,0.614394,-0.32042,1.049909,...,-0.512075,-0.803961,-0.140347,-0.305482,-0.106452,-0.290548,-0.412683,-0.468575,-0.303284,-0.489874
520,c1ccc(-c2n[nH]cc2-c2ccnc3ccccc23)nc1,JCP2022_033954,IBCXZJCWDGCXQT-UHFFFAOYSA-N,InChI=1S/C17H12N4/c1-2-6-15-13(5-1)12(8-10-19-...,source_9,GR00004368,M48,-0.495308,-1.044813,-0.111481,...,-2.638085,-2.782815,-2.127806,-2.218116,-2.184556,-2.189832,-2.192579,-2.328746,-2.353566,-2.384934
567,c1ccc(-c2n[nH]cc2-c2ccnc3ccccc23)nc1,JCP2022_033954,IBCXZJCWDGCXQT-UHFFFAOYSA-N,InChI=1S/C17H12N4/c1-2-6-15-13(5-1)12(8-10-19-...,source_9,GR00004394,AC48,0.083219,0.17792,0.479504,...,-1.098045,-1.207331,-0.980516,-1.070063,-0.993075,-1.133956,-1.126388,-1.183341,-1.142122,-1.249668


In [132]:
for i in range(700):
    batch = read_hdf5_data('/scratch/work/masooda1/mocop/data/dummy_data/hd5_data/train.h5', indices = i)
    if torch.all(batch['inputs']['x_c'] != -1, dim = -1).any():
        #print("non zero dims", torch.all(batch['inputs']['x_c'] != -1, dim = -1).sum())
        print(batch['SMILES'])
        # To match DataFrame counting, we should count unique combinations of (cell, dose, time)
        unique_combos = set()
        cell_indices = batch['inputs']['cell_indices'].squeeze(0)  # Remove batch dimension
        for i in range(cell_indices.shape[0]):    # cell lines
            for j in range(cell_indices.shape[1]): # doses
                for k in range(cell_indices.shape[2]): # times
                    if cell_indices[i,j,k].any() > 0:  # if any replicate has data
                        unique_combos.add((i,j,k))
        print("Number of unique (cell,dose,time) combinations:", len(unique_combos))

['CCS(=O)(=O)N1CCN(CC1)c1ccc(Nc2ncc(C(N)=O)c(NC3CC3)n2)cc1']
Number of unique (cell,dose,time) combinations: 14
['CCCN(CCC)C1CCc2ccc(O)cc2C1']
Number of unique (cell,dose,time) combinations: 5
['NC(=O)OCC(O)COc1ccc(Cl)cc1']
Number of unique (cell,dose,time) combinations: 16
['CC(C)OC(=O)[C@H](C)N[P@@](=O)(OC[C@H]1O[C@@H](n2ccc(=O)[nH]c2=O)[C@](C)(F)[C@@H]1O)Oc1ccccc1']
Number of unique (cell,dose,time) combinations: 51
['C[N+](C)(C)[C@@H](Cc1c[nH]c(=S)[nH]1)C(O)=O']
Number of unique (cell,dose,time) combinations: 38
['OC(=O)CCCC=C/C[C@@H]1CO[C@@H](O[C@@H]1c1ccccc1O)C(F)(F)F']
Number of unique (cell,dose,time) combinations: 60
['CN1CCc2cccc-3c2C1Cc1ccc(O)c(O)c-31']
Number of unique (cell,dose,time) combinations: 1
['C[C@H](N)Cc1c[nH]cn1']
Number of unique (cell,dose,time) combinations: 4
['CCNC(=O)CCC(N)C(O)=O']
Number of unique (cell,dose,time) combinations: 1
['NC(CCCNC(N)=O)C(O)=O']
Number of unique (cell,dose,time) combinations: 52
['C[C@@H](N)[C@@H](O)c1ccccc1']
Number of unique (c

In [134]:
genomic_data = pd.read_parquet(genomic_data_path)
SMILE = "CCS(=O)(=O)N1CCN(CC1)c1ccc(Nc2ncc(C(N)=O)c(NC3CC3)n2)cc1"
selected_mol = genomic_data[(genomic_data.Metadata_SMILES == SMILE)]
count_data = selected_mol.groupby(["Metadata_SMILES", "Metadata_cell_iname", "Metadata_Dose_Level", "Metadata_pert_time"]).count().reset_index()
count_data.shape

(14, 1024)

In [3]:
import pandas as pd
complete_genomic_Data = "/scratch/cs/pml/AI_drug/molecular_representation_learning/LINCS/landmark_cmp_data_min1000compounds_all_measurements.parquet"
genomic_data = pd.read_parquet(complete_genomic_Data)

In [10]:
count_data = genomic_data.groupby(["Metadata_SMILES", "Metadata_cell_iname", "Metadata_Dose_Level", "Metadata_pert_time"]).count()["Metadata_cid"].reset_index()
count_data = count_data[count_data.Metadata_SMILES != "restricted"]

In [11]:
count_data.sort_values(by = ["Metadata_SMILES","Metadata_cid"], ascending = False)

Unnamed: 0,Metadata_SMILES,Metadata_cell_iname,Metadata_Dose_Level,Metadata_pert_time,Metadata_cid
269093,c1nc(cs1)-c1nc2ccccc2[nH]1,A549,6,24.0,7
269134,c1nc(cs1)-c1nc2ccccc2[nH]1,MCF7,6,24.0,7
269144,c1nc(cs1)-c1nc2ccccc2[nH]1,PC3,6,24.0,7
269091,c1nc(cs1)-c1nc2ccccc2[nH]1,A549,5,24.0,6
269100,c1nc(cs1)-c1nc2ccccc2[nH]1,HA1E,6,24.0,5
...,...,...,...,...,...
6,BrC1C(Br)C(Br)C(Br)C(Br)C1Br,MCF7,6,6.0,1
7,BrC1C(Br)C(Br)C(Br)C(Br)C1Br,MCF7,6,24.0,1
8,BrC1C(Br)C(Br)C(Br)C(Br)C1Br,NPC,6,24.0,1
9,BrC1C(Br)C(Br)C(Br)C(Br)C1Br,PHH,6,24.0,1


In [13]:
count_data.groupby(["Metadata_SMILES", "Metadata_cell_iname"]).count()

Unnamed: 0_level_0,Unnamed: 1_level_0,Metadata_Dose_Level,Metadata_pert_time,Metadata_cid
Metadata_SMILES,Metadata_cell_iname,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
BrC1C(Br)C(Br)C(Br)C(Br)C1Br,A375,1,1,1
BrC1C(Br)C(Br)C(Br)C(Br)C1Br,A549,2,2,2
BrC1C(Br)C(Br)C(Br)C(Br)C1Br,ASC,1,1,1
BrC1C(Br)C(Br)C(Br)C(Br)C1Br,HEPG2,1,1,1
BrC1C(Br)C(Br)C(Br)C(Br)C1Br,HT29,1,1,1
...,...,...,...,...
c1nc(cs1)-c1nc2ccccc2[nH]1,SKB,1,1,1
c1nc(cs1)-c1nc2ccccc2[nH]1,THP1,4,4,4
c1nc(cs1)-c1nc2ccccc2[nH]1,U2OS,1,1,1
c1nc(cs1)-c1nc2ccccc2[nH]1,VCAP,2,2,2


In [16]:
genomic_data.Metadata_pert_time.unique()

array([24.,  6.])

In [14]:

genomic_data[(genomic_data.Metadata_SMILES == "c1nc(cs1)-c1nc2ccccc2[nH]1") & (genomic_data.Metadata_cell_iname == "THP1")]

Unnamed: 0,Metadata_cid,10007,1001,10013,10038,10046,10049,10051,10057,10058,...,Metadata_det_plates,Metadata_distil_ids,Metadata_build_name,Metadata_project_code,Metadata_cmap_name_y,Metadata_is_exemplar_sig,Metadata_is_ncs_sig,Metadata_is_null_sig,Metadata_Dose_Bins,Metadata_Dose_Level
399541,REP.A024_THP1_24H:K01,-0.8072,-1.1524,0.4497,0.62905,0.4032,-1.33085,0.1963,0.9096,0.5157,...,REP.A024_THP1_24H_X1_B32|REP.A024_THP1_24H_X3_B32,REP.A024_THP1_24H_X1_B32:K01|REP.A024_THP1_24H...,,REP,tiabendazole,0,0.0,0.0,"(1.0, 10.0]",6
399542,REP.A024_THP1_24H:K02,-0.53715,0.77805,-0.63515,-0.4853,-0.03285,0.1237,0.09375,0.18225,0.0176,...,REP.A024_THP1_24H_X1_B32|REP.A024_THP1_24H_X3_B32,REP.A024_THP1_24H_X1_B32:K02|REP.A024_THP1_24H...,,REP,tiabendazole,0,1.0,0.0,"(1.0, 10.0]",6
399543,REP.A024_THP1_24H:K03,-1.19995,-0.6346,-1.8495,-0.1522,-0.30995,-0.9461,-0.76735,0.4861,-0.44495,...,REP.A024_THP1_24H_X1_B32|REP.A024_THP1_24H_X3_B32,REP.A024_THP1_24H_X1_B32:K03|REP.A024_THP1_24H...,,REP,tiabendazole,0,0.0,0.0,"(1.0, 10.0]",6
399544,REP.A024_THP1_24H:K04,-0.69015,-1.04185,2.86895,-0.1272,-0.2975,-5.2138,0.2362,-0.6681,0.4141,...,REP.A024_THP1_24H_X1_B32|REP.A024_THP1_24H_X3_B32,REP.A024_THP1_24H_X1_B32:K04|REP.A024_THP1_24H...,,REP,tiabendazole,0,1.0,0.0,"(0.1, 1.0]",5
399545,REP.A024_THP1_24H:K05,-0.7047,-0.71205,-0.2883,0.33185,-0.49285,-0.9465,-0.6697,0.16545,-0.03315,...,REP.A024_THP1_24H_X1_B32|REP.A024_THP1_24H_X3_B32,REP.A024_THP1_24H_X1_B32:K05|REP.A024_THP1_24H...,,REP,tiabendazole,0,1.0,0.0,"(0.1, 1.0]",5
399546,REP.A024_THP1_24H:K06,-0.2526,4.7908,1.65675,-0.8838,0.38155,-0.50505,0.03915,0.27875,0.1435,...,REP.A024_THP1_24H_X1_B32|REP.A024_THP1_24H_X3_B32,REP.A024_THP1_24H_X1_B32:K06|REP.A024_THP1_24H...,,REP,tiabendazole,0,0.0,0.0,"(0.01, 0.1]",4
510480,REP.B024_THP1_24H:K01,0.10715,-0.43025,0.08775,0.05835,-0.29045,-0.38385,-1.46895,0.52275,-0.20055,...,REP.B024_THP1_24H_X1_B32|REP.B024_THP1_24H_X2_B32,REP.B024_THP1_24H_X1_B32:K01|REP.B024_THP1_24H...,,REP,tiabendazole,0,0.0,0.0,"(1.0, 10.0]",6
510481,REP.B024_THP1_24H:K02,0.38135,-1.15935,0.33045,-0.647,0.2222,-0.4522,-0.11155,-0.20315,0.61915,...,REP.B024_THP1_24H_X1_B32|REP.B024_THP1_24H_X2_B32,REP.B024_THP1_24H_X1_B32:K02|REP.B024_THP1_24H...,,REP,tiabendazole,0,0.0,0.0,"(0.1, 1.0]",5
510482,REP.B024_THP1_24H:K03,-0.4024,-0.4416,-0.6835,-1.2512,-0.43625,0.93205,-0.45825,4.7688,0.1771,...,REP.B024_THP1_24H_X1_B32|REP.B024_THP1_24H_X2_B32,REP.B024_THP1_24H_X1_B32:K03|REP.B024_THP1_24H...,,REP,tiabendazole,0,1.0,0.0,"(0.1, 1.0]",5
510483,REP.B024_THP1_24H:K04,0.6995,-0.10525,1.1241,-2.9821,1.1362,0.62405,0.07815,0.7321,0.68365,...,REP.B024_THP1_24H_X1_B32|REP.B024_THP1_24H_X2_B32,REP.B024_THP1_24H_X1_B32:K04|REP.B024_THP1_24H...,,REP,tiabendazole,0,0.0,0.0,"(0.01, 0.1]",4


In [2]:
genomic_data[(genomic_data.Metadata_SMILES == "restricted")]

Unnamed: 0,Metadata_cid,10007,1001,10013,10038,10046,10049,10051,10057,10058,...,Metadata_det_plates,Metadata_distil_ids,Metadata_build_name,Metadata_project_code,Metadata_cmap_name_y,Metadata_is_exemplar_sig,Metadata_is_ncs_sig,Metadata_is_null_sig,Metadata_Dose_Bins,Metadata_Dose_Level
62701,CPC013_A375_6H:BRD-U37049823-000-01-2:10,-0.770983,2.439857,0.982420,-0.187816,0.869009,1.068488,-0.560734,-1.376168,0.015412,...,CPC013_A375_6H_X1_B4_DUO52HI53LO|CPC013_A375_6...,CPC013_A375_6H_X1_B4_DUO52HI53LO:O06|CPC013_A3...,,CPC,HG-6-64-01,1,1.0,0.0,"(1.0, 10.0]",6
62702,CPC013_A375_6H:BRD-U44700465-000-01-4:10,-1.249120,1.460052,0.492889,0.376451,-0.354190,-0.373150,0.974565,-0.675375,0.121491,...,CPC013_A375_6H_X1_B4_DUO52HI53LO|CPC013_A375_6...,CPC013_A375_6H_X1_B4_DUO52HI53LO:O04|CPC013_A3...,,CPC,HG-5-88-01,1,1.0,0.0,"(1.0, 10.0]",6
62703,CPC013_A375_6H:BRD-U68942961-000-01-2:10,0.216492,1.020362,-1.210737,0.298784,0.183157,0.476728,0.909482,0.956407,0.329456,...,CPC013_A375_6H_X1_B4_DUO52HI53LO|CPC013_A375_6...,CPC013_A375_6H_X1_B4_DUO52HI53LO:O10|CPC013_A3...,,CPC,JW-7-24-1,1,1.0,0.0,"(1.0, 10.0]",6
62704,CPC013_A375_6H:BRD-U82589721-000-01-4:10,-3.979200,1.565157,0.356323,-0.902218,0.486713,-0.800978,-1.255811,3.063774,-0.451412,...,CPC013_A375_6H_X1_B4_DUO52HI53LO|CPC013_A375_6...,CPC013_A375_6H_X1_B4_DUO52HI53LO:O02|CPC013_A3...,,CPC,HG-5-113-01,1,1.0,0.0,"(1.0, 10.0]",6
63061,CPC013_A549_24H:BRD-U37049823-000-01-2:10,3.227350,1.303650,0.654050,-0.445100,0.518500,-1.683050,-2.310250,-0.645400,0.549050,...,CPC013_A549_24H_X1_F1B6_DUO52HI53LO|CPC013_A54...,CPC013_A549_24H_X1_F1B6_DUO52HI53LO:O06|CPC013...,,CPC,HG-6-64-01,0,1.0,0.0,"(1.0, 10.0]",6
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
205008,LJP006_PC3_24H:L20,0.353754,0.229798,0.675802,-1.505531,0.129991,-0.228289,0.307487,4.210205,0.582959,...,LJP006_PC3_24H_X1_B19|LJP006_PC3_24H_X2_B19|LJ...,LJP006_PC3_24H_X1_B19:L20|LJP006_PC3_24H_X2_B1...,,LJP,THZ-2-98-01,0,1.0,0.0,"(1.0, 10.0]",6
205009,LJP006_PC3_24H:L21,-0.225144,-1.706290,-0.366275,-0.177664,1.103752,0.361853,0.444221,5.018650,-0.316004,...,LJP006_PC3_24H_X1_B19|LJP006_PC3_24H_X2_B19|LJ...,LJP006_PC3_24H_X1_B19:L21|LJP006_PC3_24H_X2_B1...,,LJP,THZ-2-98-01,0,1.0,0.0,"(1.0, 10.0]",6
205010,LJP006_PC3_24H:L22,0.125931,0.501796,-0.337844,-0.551284,1.937531,-0.460242,-0.357103,7.475909,0.390728,...,LJP006_PC3_24H_X1_B19|LJP006_PC3_24H_X2_B19|LJ...,LJP006_PC3_24H_X1_B19:L22|LJP006_PC3_24H_X2_B1...,,LJP,THZ-2-98-01,0,1.0,0.0,"(0.1, 1.0]",5
205011,LJP006_PC3_24H:L23,-0.148295,0.344115,-0.351625,-0.219762,-0.439817,-0.172699,-0.434149,-0.014210,-0.743227,...,LJP006_PC3_24H_X1_B19|LJP006_PC3_24H_X2_B19|LJ...,LJP006_PC3_24H_X1_B19:L23|LJP006_PC3_24H_X2_B1...,,LJP,THZ-2-98-01,0,1.0,0.0,"(0.1, 1.0]",5


In [59]:
import time
train_dl = dl["train"]
epochs = 1
for epoch in range(epochs):
    timer_per_epoch = time.time()
    for i, batch in enumerate(train_dl):
        #print(batch["inputs"]["x_a"][0].shape)
        #print(batch["inputs"]["x_b"].shape, batch["SMILES"])
print(f"Epoch {epoch} finished in {time.time() - timer_per_epoch}")

IndentationError: expected an indented block (3936786293.py, line 9)

In [60]:
batch

{'inputs': {'x_a': [tensor([[[0.5000, 0.4082, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
            [0.4082, 0.3333, 0.3333,  ..., 0.0000, 0.0000, 0.0000],
            [0.0000, 0.3333, 0.3333,  ..., 0.0000, 0.0000, 0.0000],
            ...,
            [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
            [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
            [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]]]),
   tensor([[[1., 0., 0.,  ..., 0., 0., 0.],
            [1., 0., 0.,  ..., 0., 0., 0.],
            [1., 0., 0.,  ..., 0., 0., 0.],
            ...,
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.]]]),
   tensor([[[1.],
            [1.],
            [1.],
            [1.],
            [1.],
            [1.],
            [1.],
            [1.],
            [1.],
            [1.],
            [1.],
            [1.],
            [1.],
            [1.],
            [1.],
 

In [68]:
batch["inputs"]["x_b"].shape

torch.Size([1, 1, 3479])