In [3]:
!nvidia-smi

Wed Jul 23 08:40:39 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 565.57.01              Driver Version: 565.57.01      CUDA Version: 12.7     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A40                     Off |   00000000:48:00.0 Off |                    0 |
|  0%   28C    P8             21W /  250W |       1MiB /  46068MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [4]:
# code for table 3 (generalizing to unseen small molecule perturbations - cell expression as cell line context, fingerprint for pert context)

import torch
import lightning as pl
import numpy as np
import pandas as pd
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler, RobustScaler, QuantileTransformer
from sklearn.model_selection import train_test_split, ParameterGrid
# from contextualized.easy import ContextualizedCorrelationNetworks
from rdkit import Chem
from rdkit.Chem import rdFingerprintGenerator
import warnings

from contextualized.regression.lightning_modules import ContextualizedCorrelation
from contextualized.data import CorrelationDataModule
from lightning import seed_everything, Trainer

## Configuration and Data Loading

In [5]:
PATH_L1000   = 'data/trt_cp_smiles.csv' #file filtered with only the trt_cp perts with smiles
PATH_CTLS    = 'data/ctrls.csv'     
N_DATA_PCS   = 50   
PERTURBATION_HOLDOUT_SIZE = 0.2  
RANDOM_STATE = 42
SUBSAMPLE_FRACTION = None  # None for using full data, or decimal for percent subsample

morgan_gen = rdFingerprintGenerator.GetMorganGenerator(radius=3, fpSize=4096)  

# Function to generate Morgan fingerprints from SMILES
def smiles_to_morgan_fp(smiles, generator=morgan_gen):
    """
    Convert SMILES string to Morgan fingerprint using MorganGenerator.
    
    Args:
        smiles (str): SMILES string
        generator: RDKit MorganGenerator instance
    
    Returns:
        np.array: Binary fingerprint array, or array of zeros if invalid SMILES
    """
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            warnings.warn(f"Invalid SMILES: {smiles}")
            return np.zeros(generator.GetOptions().fpSize)
        
        fp = generator.GetFingerprint(mol)
        return np.array(fp)
        # return np.zeros(generator.GetOptions().fpSize)
    except Exception as e:
        warnings.warn(f"Error processing SMILES {smiles}: {str(e)}")
        return np.zeros(generator.GetOptions().fpSize)

#load data
df = pd.read_csv(PATH_L1000, engine='pyarrow')

# pick the perturbation to fit model on here
pert_to_fit_on = ['trt_cp']
df = df[df['pert_type'].isin(pert_to_fit_on)]

# quality filters
bad = (
    (df['distil_cc_q75'] < 0.2) | (df['distil_cc_q75'] == -666) | (df['distil_cc_q75'].isna()) |
    (df['pct_self_rank_q25'] > 5) | (df['pct_self_rank_q25'] == -666) | (df['pct_self_rank_q25'].isna())
)
df = df[~bad]

# Filter out samples with missing SMILES
df = df.dropna(subset=['canonical_smiles'])
df = df[df['canonical_smiles'] != '']

print(f"Processing {len(df)} samples with valid SMILES...")

if SUBSAMPLE_FRACTION is not None:
    df = df.sample(frac=SUBSAMPLE_FRACTION, random_state=RANDOM_STATE)
    print(f"Subsampled to {len(df)} samples ({SUBSAMPLE_FRACTION*100}% of data)")

# PERTURBATION HOLDOUT: Split unique perturbations first
unique_smiles = df['canonical_smiles'].unique()
print(f"Found {len(unique_smiles)} unique perturbations (SMILES)")

# Split unique SMILES into train and test sets
smiles_train, smiles_test = train_test_split(
    unique_smiles, 
    test_size=PERTURBATION_HOLDOUT_SIZE, 
    random_state=RANDOM_STATE
)

print(f"Perturbation split: {len(smiles_train)} train, {len(smiles_test)} test perturbations")

# Create train and test dataframes based on perturbation split
df_train = df[df['canonical_smiles'].isin(smiles_train)].copy()
df_test = df[df['canonical_smiles'].isin(smiles_test)].copy()

print(f"Sample split: {len(df_train)} train, {len(df_test)} test samples")

Processing 242514 samples with valid SMILES...
Found 12976 unique perturbations (SMILES)
Perturbation split: 10380 train, 2596 test perturbations
Sample split: 195733 train, 46781 test samples


## Data Preprocessing and Feature Engineering

In [6]:
# Process train/test sets - fit preprocessing on training data only
pert_time_mean = None
pert_dose_mean = None

for df_split, split_name in [(df_train, 'train'), (df_test, 'test')]:
    # ignore-flag columns for missing meta-data
    df_split['ignore_flag_pert_time'] = (df_split['pert_time'] == -666).astype(int)
    df_split['ignore_flag_pert_dose'] = (df_split['pert_dose'] == -666).astype(int)

    # replace –666 with column mean (computed from training set only)
    for col in ['pert_time', 'pert_dose']:
        if split_name == 'train':
            mean_val = df_split.loc[df_split[col] != -666, col].mean()
            # Store the mean for use with val/test sets
            if col == 'pert_time':
                pert_time_mean = mean_val
            else:
                pert_dose_mean = mean_val
        else:
            # Use training set means for test set
            mean_val = pert_time_mean if col == 'pert_time' else pert_dose_mean
        
        df_split[col] = df_split[col].replace(-666, mean_val)

# Function to process data split
def process_data_split(df_split, split_name):
    # Getting X (gene expression features)
    numeric_cols   = df_split.select_dtypes(include=[np.number]).columns
    drop_cols      = ['pert_dose', 'pert_dose_unit', 'pert_time',
                      'distil_cc_q75', 'pct_self_rank_q25']
    feature_cols   = [c for c in numeric_cols if c not in drop_cols]
    X_raw          = df_split[feature_cols].values

    # Generate Morgan fingerprints
    print(f"Generating Morgan fingerprints for {split_name} set...")
    morgan_fps = []
    for smiles in df_split['canonical_smiles']:
        fp = smiles_to_morgan_fp(smiles)
        morgan_fps.append(fp)

    morgan_fps = np.array(morgan_fps)
    print(f"Generated Morgan fingerprints for {split_name}: shape {morgan_fps.shape}")

    # Keep other context features
    pert_unit_dummies  = pd.get_dummies(df_split['pert_dose_unit'], drop_first=True)

    pert_time   = df_split['pert_time'  ].to_numpy().reshape(-1, 1)
    pert_dose   = df_split['pert_dose'  ].to_numpy().reshape(-1, 1)
    ignore_time = df_split['ignore_flag_pert_time'].to_numpy().reshape(-1, 1)
    ignore_dose = df_split['ignore_flag_pert_dose'].to_numpy().reshape(-1, 1)

    return X_raw, morgan_fps, pert_unit_dummies, pert_time, pert_dose, ignore_time, ignore_dose

# Process both splits
X_raw_train, morgan_fps_train, pert_unit_dummies_train, pert_time_train, pert_dose_train, ignore_time_train, ignore_dose_train = process_data_split(df_train, 'train')
X_raw_test, morgan_fps_test, pert_unit_dummies_test, pert_time_test, pert_dose_test, ignore_time_test, ignore_dose_test = process_data_split(df_test, 'test')

print("Applying improved scaling strategy...")

# scaling
scaler_genes = StandardScaler()
X_train_scaled = scaler_genes.fit_transform(X_raw_train)
X_test_scaled = scaler_genes.transform(X_raw_test)
print(f"Gene expression scaled: train {X_train_scaled.shape}, test {X_test_scaled.shape}")

scaler_morgan = StandardScaler()
morgan_train_scaled = morgan_fps_train.astype(float)
morgan_test_scaled = morgan_fps_test.astype(float)
print(f"Morgan fingerprints scaled: train {morgan_train_scaled.shape}, test {morgan_test_scaled.shape}")

# Load and process control data
ctrls_df = pd.read_csv(PATH_CTLS, index_col=0)          # index = cell_id
unique_cells_train = np.sort(df_train['cell_id'].unique())
unique_cells_test = np.sort(df_test['cell_id'].unique())
unique_cells_all = np.sort(np.union1d(unique_cells_train, unique_cells_test))

ctrls_df = ctrls_df.loc[ctrls_df.index.intersection(unique_cells_all)]

# Standardize controls and do PCA
scaler_ctrls = StandardScaler()
ctrls_scaled = scaler_ctrls.fit_transform(ctrls_df.values)

n_cells = ctrls_scaled.shape[0]
n_ctrl_pcs = min(50, n_cells)

pca_ctrls = PCA(n_components=n_ctrl_pcs, random_state=RANDOM_STATE)
ctrls_pcs = pca_ctrls.fit_transform(ctrls_scaled)        # shape (n_cells, n_ctrl_pcs)

# Build mapping from cell_id → compressed control vector
cell2vec = dict(zip(ctrls_df.index, ctrls_pcs))

if not cell2vec:
    raise ValueError(
        "No common cell IDs found between lincs1000.csv and embeddings/ctrls.csv. "
        "Cannot proceed. Please check your data files."
    )

print(f"Loaded and processed control embeddings for {len(cell2vec)} unique cells.")

Generating Morgan fingerprints for train set...
Generated Morgan fingerprints for train: shape (195733, 4096)
Generating Morgan fingerprints for test set...
Generated Morgan fingerprints for test: shape (46781, 4096)
Applying improved scaling strategy...
Gene expression scaled: train (195733, 983), test (46781, 983)
Morgan fingerprints scaled: train (195733, 4096), test (46781, 4096)
Loaded and processed control embeddings for 70 unique cells.


## Build Context Matrices

In [7]:
def build_context_matrix_improved(df_split, morgan_fps_scaled, pert_time, pert_dose, 
                                 ignore_time, ignore_dose, split_name, scaler_context=None, is_train=False):
    """
    Build context matrix with globally consistent scaling
    """
    cell_ids = df_split['cell_id'].to_numpy()
    unique_cells_split = np.sort(df_split['cell_id'].unique())
    
    all_continuous_context = []
    valid_cells = []
    
    for cell_id in unique_cells_split:
        if cell_id not in cell2vec:
            print(f"Warning: Cell {cell_id} not found in control embeddings, skipping...")
            continue
            
        mask = cell_ids == cell_id
        if mask.sum() == 0:
            continue
            
        valid_cells.append(cell_id)
        
        # Build continuous context matrix (cell embeddings + time + dose)
        C_continuous = np.hstack([
            np.tile(cell2vec[cell_id], (mask.sum(), 1)),  # Cell embeddings
            pert_time[mask],                              # Perturbation time
            pert_dose[mask],                              # Perturbation dose
        ])
        all_continuous_context.append(C_continuous)
    
    # Fit scaler on all continuous context (training data only)
    if is_train:
        all_continuous_combined = np.vstack(all_continuous_context)
        scaler_context = StandardScaler()
        scaler_context.fit(all_continuous_combined)
        print(f"Fitted context scaler on {all_continuous_combined.shape} continuous context features")
    
    if scaler_context is None:
        raise ValueError("scaler_context must be provided for non-training data")
    
    X_lst, C_lst, cell_lst = [], [], []
    
    for i, cell_id in enumerate(valid_cells):
        mask = cell_ids == cell_id
        X_cell = X_train_scaled[mask] if split_name == 'train' else X_test_scaled[mask]
        
        # Scale continuous context consistently
        C_continuous_scaled = scaler_context.transform(all_continuous_context[i])
        
        n_samples = mask.sum()
        
        # Combine all context features
        C_cell = np.hstack([
            C_continuous_scaled,                    # Scaled continuous features
            morgan_fps_scaled[mask],               # Pre-scaled molecular features  
            ignore_time[mask],                     # Binary flags (unscaled)
            ignore_dose[mask],
        ])

        X_lst.append(X_cell)
        C_lst.append(C_cell)
        cell_lst.append(cell_ids[mask])

    if not X_lst:
        raise RuntimeError(f"No data collected for {split_name} set.")
    
    X_final = np.vstack(X_lst)
    C_final = np.vstack(C_lst)
    cell_ids_final = np.concatenate(cell_lst)
    
    return X_final, C_final, cell_ids_final, scaler_context

# Build context matrices for both splits with improved scaling
print("Building context matrices with improved scaling...")

X_train, C_train, cell_ids_train, scaler_context = build_context_matrix_improved(
    df_train, morgan_train_scaled, pert_time_train, pert_dose_train,
    ignore_time_train, ignore_dose_train, 'train', is_train=True
)

X_test, C_test, cell_ids_test, _ = build_context_matrix_improved(
    df_test, morgan_test_scaled, pert_time_test, pert_dose_test,
    ignore_time_test, ignore_dose_test, 'test', scaler_context=scaler_context
)

print(f'Context matrix:   train {C_train.shape}   test {C_test.shape}')
print(f'Feature matrix:   train {X_train.shape}   test {X_test.shape}')

# IMPROVED PCA WITH BETTER SCALING
print("Applying PCA with improved scaling...")

# PCA on features (fit on training data only)
pca_data = PCA(n_components=N_DATA_PCS, random_state=RANDOM_STATE)
X_train_pca = pca_data.fit_transform(X_train)
X_test_pca = pca_data.transform(X_test)

# Improved scaling in PCA space
pca_scaler = StandardScaler()
X_train_norm = pca_scaler.fit_transform(X_train_pca)
X_test_norm = pca_scaler.transform(X_test_pca)

print(f'Normalized PCA features: train {X_train_norm.shape}   test {X_test_norm.shape}')

# Set useful variables
train_group_ids = cell_ids_train
test_group_ids = cell_ids_test
X_train = X_train_norm
X_test = X_test_norm

Building context matrices with improved scaling...
Fitted context scaler on (195733, 52) continuous context features
Context matrix:   train (195733, 4150)   test (46781, 4150)
Feature matrix:   train (195733, 983)   test (46781, 983)
Applying PCA with improved scaling...
Normalized PCA features: train (195733, 50)   test (46781, 50)


## Fit Population Baseline

In [8]:
from contextualized.baselines.networks import CorrelationNetwork
pop_model = CorrelationNetwork()
pop_model.fit(X_train)
print(f"Train MSE: {pop_model.measure_mses(X_train).mean()}")
print(f"Test MSE: {pop_model.measure_mses(X_test).mean()}")

Train MSE: 0.9800000000000005
Test MSE: 0.9598707173826789


## Fit Grouped Baseline

In [None]:
from contextualized.baselines.networks import GroupedNetworks
grouped_model = GroupedNetworks(CorrelationNetwork)
grouped_model.fit(X_train, train_group_ids)
print(f"Grouped Train MSE: {grouped_model.measure_mses(X_train, train_group_ids).mean()}")
print(f"Grouped Test MSE: {grouped_model.measure_mses(X_test, test_group_ids).mean()}")

Grouped Train MSE: 0.5788541778716761


## Fit Contextualized Model

In [2]:
import wandb
wandb.login(key='443a02df5197cc2c2c579a9ff78179ff45e47824')  # Add your WandB API key here

contextualized_model = ContextualizedCorrelation(
    context_dim=C_train.shape[1],
    x_dim=X_train.shape[1],
    encoder_type='mlp',
    num_archetypes=30,
)
# Random val split
C_val = train_test_split(C_train, test_size=0.2, random_state=RANDOM_STATE)[0]
X_val = train_test_split(X_train, test_size=0.2, random_state=RANDOM_STATE)[0]
datamodule = CorrelationDataModule(
    C_train=C_train,
    X_train=X_train,
    C_val=C_val,
    X_val=X_val,
    C_test=C_test,
    X_test=X_test,
    C_predict=np.concatenate((C_train, C_test), axis=0),
    X_predict=np.concatenate((X_train, X_test), axis=0),
    batch_size=32,
)
checkpoint_callback = pl.pytorch.callbacks.ModelCheckpoint(
    monitor='val_loss',
    mode='min',
    save_top_k=1,
    filename='best_model',
)
logger = pl.pytorch.loggers.WandbLogger(
    project='contextpert',
    name='unseen_perturbations',
    log_model=True,
    save_dir='logs/',
)
trainer = Trainer(
    max_epochs=10,
    accelerator='auto',
    devices='auto',
    callbacks=[checkpoint_callback],
    logger=logger,
)
trainer.fit(contextualized_model, datamodule=datamodule)

[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /mmfs1/home/jiaqiw18/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mjiaqiw[0m ([33mcontextualized[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


NameError: name 'ContextualizedCorrelation' is not defined

In [None]:
print(f"Testing model on training data...")
trainer.test(contextualized_model, datamodule.train_dataloader())
print(f"Testing model on test data...")
trainer.test(contextualized_model, datamodule.test_dataloader())

In [None]:
print(checkpoint_callback.best_model_path)

## Predict Networks

In [None]:
# Necessary to save predictions from multiple devices in parallel
from contextualized.callbacks import PredictionWriter
from pathlib import Path

output_dir = Path(checkpoint_callback.best_model_path).parent / 'predictions'
writer_callback = PredictionWriter(
    output_dir=output_dir,
    write_interval='batch',
)
trainer = Trainer(
    accelerator='auto',
    devices='auto',
    callbacks=[checkpoint_callback, writer_callback],
)
_ = trainer.predict(contextualized_model, datamodule=datamodule)

In [None]:
# Compile distributed predictions and put into order
"""
import torch
import glob

# Convert context to hashable type for lookup
C_train_hashable = [tuple(row) for row in C_train]
C_test_hashable = [tuple(row) for row in C_test]

# Gather preds and move to CPU
all_correlations = {}
all_betas = {}
all_mus = {}
pred_files = glob.glob(str(output_dir / 'predictions_*.pt'))
for file in pred_files:
    preds = torch.load(file)
    for context, correlation, beta, mu in zip(preds['contexts'], preds['correlations'], preds['betas'], preds['mus']):
        context_tuple = tuple(context.tolist())
        all_correlations[context_tuple] = correlation.cpu().numpy()
        all_betas[context_tuple] = beta.cpu().numpy()
        all_mus[context_tuple] = mu.cpu().numpy()

# Remake preds in order of C_train and C_test
correlations_train = np.array([all_correlations[c] for c in C_train_hashable])
correlations_test = np.array([all_correlations[c] for c in C_test_hashable])
betas_train = np.array([all_betas[c] for c in C_train_hashable])
betas_test = np.array([all_betas[c] for c in C_test_hashable])
mus_train = np.array([all_mus[c] for c in C_train_hashable])
mus_test = np.array([all_mus[c] for c in C_test_hashable])
"""

In [None]:
# Compile distributed predictions and put into order
import torch
import glob

C_train = np.round(C_train.astype(np.float32), 6)
C_test  = np.round(C_test.astype(np.float32), 6)

# Convert context to hashable type for lookup
C_train_hashable = [tuple(row) for row in C_train]
C_test_hashable = [tuple(row) for row in C_test]

# Gather preds and move to CPU
all_correlations = {}
all_betas = {}
all_mus = {}
pred_files = glob.glob(str(output_dir / 'predictions_*.pt'))
for file in pred_files:
    preds = torch.load(file)
    for context, correlation, beta, mu in zip(preds['contexts'], preds['correlations'], preds['betas'], preds['mus']):
        context_tuple = tuple(np.round(context.cpu().numpy(), 6))
        # context_tuple = tuple(context.tolist())
        all_correlations[context_tuple] = correlation.cpu().numpy()
        all_betas[context_tuple] = beta.cpu().numpy()
        all_mus[context_tuple] = mu.cpu().numpy()

# Remake preds in order of C_train and C_test
correlations_train = np.array([all_correlations[c] for c in C_train_hashable])
correlations_test = np.array([all_correlations[c] for c in C_test_hashable])
betas_train = np.array([all_betas[c] for c in C_train_hashable])
betas_test = np.array([all_betas[c] for c in C_test_hashable])
mus_train = np.array([all_mus[c] for c in C_train_hashable])
mus_test = np.array([all_mus[c] for c in C_test_hashable])

In [None]:
# Get individual MSEs by sample
# Sanity check: These should closely match the trainer.test() outputs from earlier
def measure_mses(betas, mus, X):
    mses = np.zeros(len(X))
    for i in range(len(X)):
        sample_mse = 0
        for j in range(X.shape[-1]):
            for k in range(X.shape[-1]):
                residual = X[i, j] - betas[i, j, k] * X[i, k] - mus[i, j, k]
                sample_mse += residual**2 / (X.shape[-1] ** 2)
        mses += sample_mse / len(X)
    return mses

mse_train = measure_mses(betas_train, mus_train, X_train)
mse_test = measure_mses(betas_test, mus_test, X_test)
print(f"Train MSEs: {mse_train.mean()}")
print(f"Test MSEs: {mse_test.mean()}")

# Per-cell performance breakdown
print(f"\nPer-cell performance breakdown:")
print("Cell ID          Train MSE    Test MSE     Train N  Test N")
print("─" * 60)

all_unique_cells = np.union1d(cell_ids_train, cell_ids_test)

for cell_id in sorted(all_unique_cells):
    tr_mask = cell_ids_train == cell_id
    te_mask = cell_ids_test == cell_id
    
    tr_mse = mse_train[tr_mask].mean() if tr_mask.any() else np.nan
    te_mse = mse_test[te_mask].mean() if te_mask.any() else np.nan
    tr_n = tr_mask.sum()
    te_n = te_mask.sum()
    
    if tr_n > 0 or te_n > 0:
        print(f'{cell_id:<15}  {tr_mse:8.4f}   {te_mse:8.4f}   {tr_n:6d}   {te_n:6d}')

# Summary statistics about the perturbation split
print(f"\n" + "="*80)
print("PERTURBATION HOLDOUT SUMMARY:")
print(f"  Total unique SMILES: {len(unique_smiles)}")
print(f"  Training SMILES: {len(smiles_train)} ({len(smiles_train)/len(unique_smiles)*100:.1f}%)")
print(f"  Test SMILES: {len(smiles_test)} ({len(smiles_test)/len(unique_smiles)*100:.1f}%)")
print(f"  Training samples: {len(df_train)}")
print(f"  Test samples: {len(df_test)}")
print("="*80)

In [None]:
import os, pathlib, warnings, random
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.metrics import pairwise_distances
from scipy.cluster import hierarchy
from scipy.spatial.distance import squareform

OUT_DIR   = pathlib.Path("figs/pert_similarity_clustermaps")
OUT_DIR.mkdir(parents=True, exist_ok=True)
GROUP_KEY = "canonical_smiles"
MAX_PERTS = 10_000
RNG       = np.random.default_rng(42)

In [None]:
def cdist(mat: np.ndarray):
    xy = mat @ mat.T
    x2 = y2 = (mat * mat).sum(1)
    d2 = np.add.outer(x2, y2) - 2 * xy
    d2.flat[::len(mat)+1] = 0
    d2[d2 < 0] = 0
    return squareform(np.sqrt(d2))

In [None]:
numeric_cols = df_train.select_dtypes(include=[np.number]).columns
drop_cols    = ['pert_dose', 'pert_dose_unit', 'pert_time',
                'distil_cc_q75', 'pct_self_rank_q25']
gene_cols    = [c for c in numeric_cols if c not in drop_cols]

expr_all   = pd.concat([df_train, df_test], axis=0, ignore_index=True)
expr_mat   = expr_all[gene_cols].to_numpy(dtype=np.float32)
smiles_all = expr_all[GROUP_KEY].to_numpy()

print(f"Gene-expression raw matrix: {expr_mat.shape}")

# standardize then PCA
scaler_expr = StandardScaler()
expr_scaled = scaler_expr.fit_transform(expr_mat)

pca_expr    = PCA(n_components=50, random_state=42)
expr_pcs    = pca_expr.fit_transform(expr_scaled)      # (n_samples, 50)

# aggregate
expr_df = pd.DataFrame(expr_pcs)
expr_df[GROUP_KEY] = smiles_all
expr_repr_df = expr_df.groupby(GROUP_KEY).mean()       # (perts × 50)

expr_repr   = expr_repr_df.to_numpy(dtype=np.float32)
smiles_order = expr_repr_df.index.tolist()

print(f"Expression representation: {expr_repr.shape}")


In [None]:
from rdkit import Chem
from rdkit.Chem import rdFingerprintGenerator

morgan_gen = rdFingerprintGenerator.GetMorganGenerator(radius=3, fpSize=4096)

def smiles_to_fp(smiles, gen=morgan_gen):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        warnings.warn(f"Invalid SMILES: {smiles}")
        return np.zeros(gen.GetOptions().fpSize, dtype=np.float32)
    return np.asarray(gen.GetFingerprint(mol), dtype=np.float32)

fp_repr = np.vstack([smiles_to_fp(sm) for sm in smiles_order])
print(f"Fingerprint representation: {fp_repr.shape}")


In [None]:
def _aligned_smiles(df_split, C_split, key=GROUP_KEY):
    blocks = []
    for cell_id in df_split['cell_id'].unique():
        mask = df_split['cell_id'] == cell_id
        blocks.append(df_split.loc[mask, key].to_numpy())
    return np.concatenate(blocks)

smiles_train_aligned = _aligned_smiles(df_train, C_train)
smiles_test_aligned  = _aligned_smiles(df_test,  C_test)

assert len(smiles_train_aligned) == C_train.shape[0]
assert len(smiles_test_aligned)  == C_test.shape[0]

print(f"Aligned SMILES arrays built: train {smiles_train_aligned.shape}, "
      f"test {smiles_test_aligned.shape}")

In [None]:
def flat_upper(mat3d: np.ndarray):
    p = mat3d.shape[1]
    iu = np.triu_indices(p, k=1)
    return mat3d[:, iu[0], iu[1]].astype(np.float32)   # (n_samples, p*(p-1)/2)

# concatenate train+test, flatten
corr_flat_all = flat_upper(np.concatenate([correlations_train,
                                           correlations_test], axis=0))
smiles_ctx_all = np.concatenate([smiles_train_aligned,
                                 smiles_test_aligned])

# mean per SMILES
corr_df = pd.DataFrame(corr_flat_all)
corr_df[GROUP_KEY] = smiles_ctx_all
corr_df = corr_df.groupby(GROUP_KEY).mean().reindex(smiles_order)

if corr_df.isna().any().any():
    dropped = corr_df.index[corr_df.isna().all(axis=1)]
    print(f"[ctx_corr] dropped {len(dropped)} perturbations lacking predictions.")
    corr_df = corr_df.dropna(how='all')

ctx_repr_corr = corr_df.to_numpy(dtype=np.float32)
print(f"[ctx_corr] representation shape: {ctx_repr_corr.shape}")


In [None]:
import seaborn as sns, matplotlib.pyplot as plt, pathlib, numpy as np
from scipy.cluster import hierarchy
from scipy.spatial.distance import squareform

OUT_DIR   = pathlib.Path("figs/pert_similarity_clustermaps")
OUT_DIR.mkdir(parents=True, exist_ok=True)
MAX_PERTS = 10_000
_rng      = np.random.default_rng(42)

def cdist_fast(mat: np.ndarray):
    xy = mat @ mat.T
    norms = (mat * mat).sum(1)
    d2 = np.add.outer(norms, norms) - 2 * xy
    d2.flat[::len(mat)+1] = 0
    d2[d2 < 0] = 0
    return squareform(np.sqrt(d2, dtype=np.float32))

def make_clustermap(X: np.ndarray, label: str):
    n = X.shape[0]
    if MAX_PERTS and n > MAX_PERTS:
        idx = _rng.choice(n, MAX_PERTS, replace=False)
        X = X[idx]
        print(f"[{label}] sub-sampled {n} - {len(idx)} perturbations.")
    dist_cond = cdist_fast(X.astype(np.float32, copy=False))
    linkage   = hierarchy.linkage(dist_cond, method="average")
    dist_sq   = squareform(dist_cond)

    g = sns.clustermap(dist_sq,
                       row_linkage=linkage,
                       col_linkage=linkage,
                       cmap="vlag",
                       figsize=(12, 12),
                       xticklabels=False,
                       yticklabels=False,
                       cbar_kws={'label': 'Euclidean distance'})
    g.fig.suptitle(f"{label.upper()} - perturbation similarity", y=1.02)

    png = OUT_DIR / f"clustermap_{label}.png"
    npy = OUT_DIR / f"dist_square_{label}.npy"
    g.savefig(png, bbox_inches="tight")
    np.save(npy, dist_sq)
    plt.show()
    print(f"[{label}] saved ➜ {png.name}  |  {npy.name}")


In [None]:
make_clustermap(expr_repr,     "expr")
make_clustermap(fp_repr,       "fp")
make_clustermap(ctx_repr_corr, "ctx_corr")