In [9]:
# code for table 2 (Contextualized post perturbation one hot encoded cell line, one hot perturbation, with/without dose time)

import torch
import lightning as pl
import pandas as pd
import numpy as np
# from contextualized.easy import ContextualizedCorrelationNetworks
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

from contextualized.regression.lightning_modules import ContextualizedCorrelation
from contextualized.data import CorrelationDataModule
from lightning import seed_everything, Trainer

## Configuration and Data Loading

In [11]:
# Load and preprocess the dataset
df = pd.read_csv('data/merged_output4_head.csv')

# Filter pert_to_fit
pert_to_fit_on = ['trt_cp']

#False for context being just one hot encoded cell type and perturbation, True for adding dose and time
full_context = True


mask = df['pert_type'].isin(pert_to_fit_on)
df = df[mask]

# Condition to drop rows
condition = (
    (df['distil_cc_q75'] < 0.2) |
    (df['distil_cc_q75'] == -666) |
    (df['distil_cc_q75'].isna()) |  # Check for NaN
    (df['pct_self_rank_q25'] > 5) |
    (df['pct_self_rank_q25'] == -666) |
    (df['pct_self_rank_q25'].isna())  # Check for NaN
)
df = df[~condition]

pert_dummies = pd.get_dummies(df['pert_id'], drop_first=True)
pert_unit_dummies = pd.get_dummies(df['pert_dose_unit'], drop_first=True)

# Extract numeric columns as features
feature_cols = df.select_dtypes(include=[np.number]).columns.tolist()
columns_to_drop = ['pert_dose', 'pert_dose_unit', 'pert_time', 'distil_cc_q75', 'pct_self_rank_q25']
feature_cols = [col for col in feature_cols if col not in columns_to_drop]
feature_df = df[feature_cols]

# Scale features
scaler = StandardScaler()
scaler.fit(feature_df)
feature_df = scaler.transform(feature_df)
X = feature_df

# Create context matrix and get unique cell IDs
cell_ids = df['cell_id'].values
unique_cells = np.unique(cell_ids)
print(len(cell_ids))
print(cell_ids)

404
['A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375'
 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375'
 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375'
 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375'
 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375'
 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375'
 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375'
 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375'
 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375'
 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375'
 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375'
 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375'
 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375'
 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375' 'A375'
 '

## Handle Missing Data and Create Context Features

In [12]:
# Add separate ignore flag columns for 'pert_time' and 'pert_dose'
df['ignore_flag_pert_time'] = np.where(df['pert_time'] == -666, 1, 0)
df['ignore_flag_pert_dose'] = np.where(df['pert_dose'] == -666, 1, 0)

# Replace -666 in 'pert_time' and 'pert_dose' with column mean for those rows
for col in ['pert_time', 'pert_dose']:
    mean_value = df[df[col] != -666][col].mean()  # Calculate mean excluding -666
    df[col] = df[col].replace(-666, mean_value)

print(df[['pert_type','pert_time', 'pert_dose']])
print(df[['ignore_flag_pert_time', 'ignore_flag_pert_dose']])

# Generate perturbation dummies and time/dose values
pert_time = df['pert_time'].values
pert_dose = df['pert_dose'].values
ignore_time = df['ignore_flag_pert_time'].values
ignore_dose = df['ignore_flag_pert_dose'].values
print(pert_dummies.shape)
print(pert_unit_dummies)

# Initialize lists to store split data
X_train_list = []
X_test_list = []
C_train_list = []
C_test_list = []
cell_ids_train_list = []
cell_ids_test_list = []

only1 = 0
# Split data within each context group
for cell_id in unique_cells:
    # Get indices for current cell type
    cell_mask = cell_ids == cell_id
    X_cell = X[cell_mask]
    cell_ids_cell = cell_ids[cell_mask]
    
    # Get corresponding perturbation information
    pert_dummies_cell = pert_dummies.loc[cell_mask].values
    pert_unit_dummies_cell = pert_unit_dummies.loc[cell_mask].values
    pert_time_cell = pert_time[cell_mask].reshape(-1, 1)
    pert_dose_cell = pert_dose[cell_mask].reshape(-1, 1)
    ignore_time_cell = ignore_time[cell_mask].reshape(-1, 1)
    ignore_dose_cell = ignore_dose[cell_mask].reshape(-1, 1)
    
    # Create one-hot encoding for current cell type
    C_cell = np.zeros((X_cell.shape[0], len(unique_cells)))
    C_cell[:, np.where(unique_cells == cell_id)[0]] = 1
    
    # Concatenate all context information
    if full_context:
        C_cell = np.hstack([C_cell, pert_dummies_cell, pert_unit_dummies_cell, pert_time_cell, pert_dose_cell, ignore_time_cell, ignore_dose_cell])
    else:
        C_cell = np.hstack([C_cell, pert_dummies_cell])

    # Split data for current cell type
    if X_cell.shape[0] > 0:  
        X_train_cell, X_test_cell, C_train_cell, C_test_cell, ids_train_cell, ids_test_cell = train_test_split(
            X_cell, C_cell, cell_ids_cell, test_size=0.33, random_state=42
        )
        
        X_train_list.append(X_train_cell)
        X_test_list.append(X_test_cell)
        C_train_list.append(C_train_cell)
        C_test_list.append(C_test_cell)
        cell_ids_train_list.append(ids_train_cell)
        cell_ids_test_list.append(ids_test_cell)
    else:
        only1 = only1 + 1
print('how many cell types removed killed: ', only1)

    pert_type  pert_time  pert_dose
3      trt_cp          6       10.0
4      trt_cp          6       10.0
5      trt_cp          6       10.0
15     trt_cp          6       10.0
16     trt_cp          6       10.0
..        ...        ...        ...
978    trt_cp          6       10.0
979    trt_cp          6       10.0
983    trt_cp          6       10.0
984    trt_cp          6       10.0
985    trt_cp          6       10.0

[404 rows x 3 columns]
     ignore_flag_pert_time  ignore_flag_pert_dose
3                        0                      0
4                        0                      0
5                        0                      0
15                       0                      0
16                       0                      0
..                     ...                    ...
978                      0                      0
979                      0                      0
983                      0                      0
984                      0                  

## Combine Data and Apply PCA

In [13]:
# Combine split data
# np.set_printoptions(threshold=np.inf)
X_train = np.vstack(X_train_list)
X_test = np.vstack(X_test_list)
C_train = np.vstack(C_train_list)
C_test = np.vstack(C_test_list)
cell_ids_train = np.concatenate(cell_ids_train_list)
cell_ids_test = np.concatenate(cell_ids_test_list)
print(C_train.shape)
print(C_test.shape)
print(C_train)

# Apply PCA
pca = PCA(n_components=50)
pca.fit(X_train)
X_train_pca = pca.transform(X_train)
X_test_pca = pca.transform(X_test)

# Normalize train and test PCA data
X_mean = X_train_pca.mean(axis=0)
X_std = X_train_pca.std(axis=0)
X_train_norm = (X_train_pca - X_mean) / X_std
X_test_norm = (X_test_pca - X_mean) / X_std

# Set useful variables
train_group_ids = cell_ids_train
test_group_ids = cell_ids_test
X_train = X_train_norm
X_test = X_test_norm

(270, 140)
(134, 140)
[[ 1.  0.  0. ... 10.  0.  0.]
 [ 1.  0.  0. ... 10.  0.  0.]
 [ 1.  0.  0. ... 10.  0.  0.]
 ...
 [ 1.  0.  0. ... 10.  0.  0.]
 [ 1.  0.  0. ... 10.  0.  0.]
 [ 1.  0.  0. ... 10.  0.  0.]]


## Fit Population Baseline

In [14]:
from contextualized.baselines.networks import CorrelationNetwork
pop_model = CorrelationNetwork()
pop_model.fit(X_train)
print(f"Train MSE: {pop_model.measure_mses(X_train).mean()}")
print(f"Test MSE: {pop_model.measure_mses(X_test).mean()}")

Train MSE: 0.9800000000000001
Test MSE: 0.3981750757558998


## Fit Grouped Baseline

In [15]:
from contextualized.baselines.networks import GroupedNetworks
grouped_model = GroupedNetworks(CorrelationNetwork)
grouped_model.fit(X_train, train_group_ids)
print(f"Grouped Train MSE: {grouped_model.measure_mses(X_train, train_group_ids).mean()}")
print(f"Grouped Test MSE: {grouped_model.measure_mses(X_test, test_group_ids).mean()}")

Grouped Train MSE: 0.9800000000000001
Grouped Test MSE: 0.3981750757558998


## Fit Contextualized Model

In [None]:
import wandb
wandb.login(key='add-your-key-here')  # Add your WandB API key here

contextualized_model = ContextualizedCorrelation(
    context_dim=C_train.shape[1],
    x_dim=X_train.shape[1],
    encoder_type='mlp',
    num_archetypes=50,
)
# Random val split
C_val = train_test_split(C_train, test_size=0.2, random_state=42)[0]
X_val = train_test_split(X_train, test_size=0.2, random_state=42)[0]
datamodule = CorrelationDataModule(
    C_train=C_train,
    X_train=X_train,
    C_val=C_val,
    X_val=X_val,
    C_test=C_test,
    X_test=X_test,
    C_predict=np.concatenate((C_train, C_test), axis=0),
    X_predict=np.concatenate((X_train, X_test), axis=0),
    batch_size=32,
)
checkpoint_callback = pl.pytorch.callbacks.ModelCheckpoint(
    monitor='val_loss',
    mode='min',
    save_top_k=1,
    filename='best_model',
)
logger = pl.pytorch.loggers.WandbLogger(
    project='contextpert',
    name='one_hot_context',
    log_model=True,
    save_dir='logs/',
)
trainer = Trainer(
    max_epochs=10,
    accelerator='auto',
    devices='auto',
    callbacks=[checkpoint_callback],
    logger=logger,
)
trainer.fit(contextualized_model, datamodule=datamodule)

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /mmfs1/home/jiaqiw18/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mjiaqiw[0m ([33mcontextualized[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
/mmfs1/gscratch/ark/jiaqi/miniconda3/envs/cml/lib/python3.10/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /mmfs1/gscratch/ark/jiaqi/miniconda3/envs/cml/lib/py ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100 80GB PCIe') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read http

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params | Mode 
-------------------------------------------------------
0 | metamodel | SubtypeMetamodel | 254 K  | train
-------------------------------------------------------
254 K     Trainable params
0         Non-trainable params
254 K     Total params
1.019     Total estimated model params size (MB)
7         Modules in train mode
0         Modules in eval mode


                                                                                                                                                                                                   

/mmfs1/gscratch/ark/jiaqi/miniconda3/envs/cml/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:310: The number of training batches (9) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 69.67it/s, v_num=2jlq]
Validation: |                                                                                                                                                                | 0/? [00:00<?, ?it/s][A
Validation:   0%|                                                                                                                                                            | 0/7 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                                                                                                               | 0/7 [00:00<?, ?it/s][A
Validation DataLoader 0:  14%|███████████████████▏                                                                                                                  | 1/7 [00:00<00:00, 125.03it/s][A
Validati

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 70.88it/s, v_num=2jlq]


In [18]:
print(f"Testing model on training data...")
trainer.test(contextualized_model, datamodule.train_dataloader())
print(f"Testing model on test data...")
trainer.test(contextualized_model, datamodule.test_dataloader())

/mmfs1/gscratch/ark/jiaqi/miniconda3/envs/cml/lib/python3.10/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /mmfs1/gscratch/ark/jiaqi/miniconda3/envs/cml/lib/py ...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/mmfs1/gscratch/ark/jiaqi/miniconda3/envs/cml/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:476: Your `test_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.


Testing model on training data...
Testing DataLoader 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 194.52it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]



────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss           0.9771710634231567
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Testing model on test data...
Testing DataLoader 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 261.79it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss           0.398

[{'test_loss': 0.39830678701400757}]

In [19]:
print(checkpoint_callback.best_model_path)

logs/contextpert/35f72jlq/checkpoints/best_model.ckpt


## Predict Networks

In [20]:
# Necessary to save predictions from multiple devices in parallel
from contextualized.callbacks import PredictionWriter
from pathlib import Path

output_dir = Path(checkpoint_callback.best_model_path).parent / 'predictions'
writer_callback = PredictionWriter(
    output_dir=output_dir,
    write_interval='batch',
)
trainer = Trainer(
    accelerator='auto',
    devices='auto',
    callbacks=[checkpoint_callback, writer_callback],
)
_ = trainer.predict(contextualized_model, datamodule=datamodule)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/mmfs1/gscratch/ark/jiaqi/miniconda3/envs/cml/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting DataLoader 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 59.70it/s]


In [21]:
# Compile distributed predictions and put into order
import torch
import glob

# Convert context to hashable type for lookup
C_train_hashable = [tuple(row) for row in C_train]
C_test_hashable = [tuple(row) for row in C_test]

# Gather preds and move to CPU
all_correlations = {}
all_betas = {}
all_mus = {}
pred_files = glob.glob(str(output_dir / 'predictions_*.pt'))
for file in pred_files:
    preds = torch.load(file)
    for context, correlation, beta, mu in zip(preds['contexts'], preds['correlations'], preds['betas'], preds['mus']):
        context_tuple = tuple(context.tolist())
        all_correlations[context_tuple] = correlation.cpu().numpy()
        all_betas[context_tuple] = beta.cpu().numpy()
        all_mus[context_tuple] = mu.cpu().numpy()

# Remake preds in order of C_train and C_test
correlations_train = np.array([all_correlations[c] for c in C_train_hashable])
correlations_test = np.array([all_correlations[c] for c in C_test_hashable])
betas_train = np.array([all_betas[c] for c in C_train_hashable])
betas_test = np.array([all_betas[c] for c in C_test_hashable])
mus_train = np.array([all_mus[c] for c in C_train_hashable])
mus_test = np.array([all_mus[c] for c in C_test_hashable])

In [22]:
# Get individual MSEs by sample
# Sanity check: These should closely match the trainer.test() outputs from earlier
def measure_mses(betas, mus, X):
    mses = np.zeros(len(X))
    for i in range(len(X)):
        sample_mse = 0
        for j in range(X.shape[-1]):
            for k in range(X.shape[-1]):
                residual = X[i, j] - betas[i, j, k] * X[i, k] - mus[i, j, k]
                sample_mse += residual**2 / (X.shape[-1] ** 2)
        mses += sample_mse / len(X)
    return mses

mse_train = measure_mses(betas_train, mus_train, X_train)
mse_test = measure_mses(betas_test, mus_test, X_test)
print(f"Train MSEs: {mse_train.mean()}")
print(f"Test MSEs: {mse_test.mean()}")

# Iterate over the unique cells that were included in the splits for per-cell MSE
print("Per-cell MSE:")
for cell_id in unique_cells:
    tr_mask = cell_ids_train == cell_id
    te_mask = cell_ids_test == cell_id

    if tr_mask.sum() == 0 and te_mask.sum() == 0:
        continue

    tr_mse = mse_train[tr_mask].mean() if tr_mask.any() else np.nan
    te_mse = mse_test[te_mask].mean() if te_mask.any() else np.nan
    print(f'Cell {cell_id:<15}:  train MSE = {tr_mse:7.4f}   '
          f'test MSE = {te_mse:7.4f}   (n={tr_mask.sum():3d}/{te_mask.sum():3d})')

Train MSEs: 0.9771710609955975
Test MSEs: 0.39830678965319255
Per-cell MSE:
Cell A375           :  train MSE =  0.9772   test MSE =  0.3983   (n=270/134)
