# Training chemCPA

## Download data

This notebook has two dependencies:

sciplex3
wget 
rdkit embeddings

This notebook is a self-contained way to train the chemCPA model.

In [1]:
from pathlib import Path

import lightning as L
import pandas as pd
import scanpy as sc
import wandb
from hydra import compose, initialize
from lightning.pytorch import seed_everything
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from omegaconf import OmegaConf
from plotnine import aes, geom_boxplot, ggplot, scale_y_continuous
from pytorch_lightning.loggers import WandbLogger

from chemCPA.data import DataModule
from chemCPA.model import ComPert
from chemCPA.train import evaluate_logfold_r2, evaluate_r2, evaluate_r2_sc

wandb.login()

In [None]:
### Load config

with initialize(version_base=None, config_path="../experiments/hydra_config"):
    config = compose(config_name="defaults", overrides=[])


In [None]:
print(OmegaConf.to_yaml(config.dataset.data_params))
assert (Path(config.dataset.data_params.dataset_path)).exists(), "Config `dataset_path` is not correct!"
assert (Path(config.dataset.data_params.drugs_embeddings)).exists(), "Config `drugs_embeddings` is not correct"

dataset_path: /nfs/homedirs/hetzell/code/chemCPA_fork/project_folder/datasets/sciplex_complete_middle_subset_lincs_genes.h5ad
drug_key: condition
dose_key: dose
knockout_key: null
covariate_keys:
- cell_type
smiles_key: SMILES
pert_category: cov_drug_dose_name
split_key: split_ood_multi_task
degs_key: lincs_DEGs
drugs_embeddings: /nfs/homedirs/hetzell/code/chemCPA_fork/project_folder/embeddings/rdkit/data/embeddings/rdkit2D_embedding_lincs_trapnell.parquet
knockouts_embeddings: null
return_dataset: true



In [None]:
### Check adata if required

# # If you want to load the adata from the dataset_path
# ADATA_PATH = Path(config.dataset.data_params.dataset_path)
# adata = sc.read(ADATA_PATH)

# adata

In [None]:
### Load data module

dm = DataModule(
    batch_size=config.model["hparams"]["batch_size"],
    full_eval_during_train=config.train["full_eval_during_train"],
    num_workers=config.train["num_workers"],
    # num_workers=19,
    **config.dataset["data_params"]
)

# # Check basic stats
# print(len(dm.datasets['training']))
# print(len(dm.datasets['training_control']))
# print(len(dm.datasets['training_treated']))
# print(len(dm.datasets['test']))
# print(len(dm.datasets['test_control']))
# print(len(dm.datasets['test_treated']))
# print(len(dm.datasets['ood']))

In [None]:
# for _item in dm.datasets['training']: 
#     print(_item)
#     break

In [None]:
print(dm.datasets['training'].num_covariates)
print(dm.datasets['training'].num_drugs)
print(dm.datasets['training'].num_knockouts)

[3]
188
0


In [None]:
###  Initialise the model


data_train = dm.datasets['training']

model = ComPert(
    data_train.num_genes,
    data_train.num_drugs,
    data_train.num_knockouts,
    data_train.num_covariates,
    config.model.hparams,
    config.train,
    config.test,
    **config.model.additional_params,
    drug_embedding_dimension=data_train.drug_embedding_dimension,
    knockout_embedding_dimension=data_train.knockout_embedding_dimension,
)


In [None]:
### Initialise the trainer

project_str = f"{config.model['model_type']}_{config.dataset['dataset_type']}"
wandb_logger = WandbLogger(project=project_str, save_dir=config.model["save_dir"])


inference_mode = (not config.train["run_eval_disentangle"]) and (not config.test["run_eval_disentangle"])
early_stop_callback = EarlyStopping("average_r2_score", patience=model.hparams.training_params["patience"], mode="max")

trainer = L.Trainer(
    logger=wandb_logger,
    max_epochs=config.train["num_epochs"],
    max_time=config.train["max_time"],
    check_val_every_n_epoch=config.train["checkpoint_freq"],
    default_root_dir=config.model["save_dir"],
    # profiler="advanced",
    callbacks=[early_stop_callback],
    inference_mode=inference_mode,
    num_sanity_val_steps=0,
)

/nfs/staff-hdd/hetzell/miniconda3/envs/chemCPA-test-env-new/lib/python3.10/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /nfs/staff-hdd/hetzell/miniconda3/envs/chemCPA-test- ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [12]:
trainer.fit(model, datamodule=dm)

/nfs/staff-hdd/hetzell/miniconda3/envs/chemCPA-test-env-new/lib/python3.10/site-packages/pytorch_lightning/loggers/wandb.py:391: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1]

  | Name                   | Type            | Params
-----------------------------------------------------------
0 | loss_autoencoder       | GaussianNLLLoss | 0     
1 | encoder                | MLP             | 1.3 M 
2 | decoder                | MLP             | 1.8 M 
3 | adversary_drugs        | MLP             | 62.3 K
4 | drug_embedding_encoder | MLP             | 644 K 
5 | loss_adversary_drugs   | CELoss          | 0     
6 | dosers                 | MLP             | 12.7 K
-----------------------------------------------------------
3.8 M     Trainable params
0         Non-trainable params
3.8 M     Total params
15.361

Training: |          | 0/? [00:00<?, ?it/s]

In [27]:
### Test the model

trainer.test(model, datamodule=dm)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.


Testing: |                                                                                                        | 0/? [00:00<?, ?it/s]

In [None]:
# Load model from checkpoint

#model = ComPert.load_from_checkpoint('train_data/CPA/3gm2eppz/checkpoints/epoch=14-step=10560.ckpt')

In [None]:
### Perform evaluation


In [None]:
# #draw the logfold r2
# def draw_logfold_r2(autoencoder, ds_treated, ds_ctrl):
#     logfold_score, signs_score = evaluate_logfold_r2(autoencoder, ds_treated, ds_ctrl, return_mean=False)
#     df = pd.DataFrame(
#         data = {'logfold_score': logfold_score, 'signs_score': signs_score}
#     )
#     df = pd.melt(df, value_vars=['logfold_score', 'signs_score'], var_name='score_type', value_name='score')
#     p = ggplot(df, aes(x='factor(score_type)', y='score', fill='factor(score_type)')) + geom_boxplot() + scale_y_continuous(limits=(-1,1))
#     return p

In [None]:
# R2
def draw_r2(autoencoder, dataset, genes_control):
    mean_score, mean_score_de, var_score, var_score_de = evaluate_r2(autoencoder, dataset, genes_control, return_mean=False)
    df = pd.DataFrame(
        data = {'mean_score': mean_score, 
                'mean_score_de': mean_score_de,
                'var_score': var_score,
                'var_score_de':var_score_de
                }
    )
    df = pd.melt(df, value_vars=['mean_score', 'mean_score_de', 'var_score', 'var_score_de'], 
                 var_name='score_type', value_name='score')
    p = ggplot(df, aes(x='factor(score_type)', y='score', fill='factor(score_type)')) + geom_boxplot()+ scale_y_continuous(limits=(0,1))
    return p

In [None]:
# #draw the r2 sc
# def draw_r2_sc(autoencoder, dataset):
#     mean_score, mean_score_de, var_score, var_score_de = evaluate_r2_sc(autoencoder, dataset, return_mean=False)
#     df = pd.DataFrame(
#         data = {'mean_score': mean_score, 
#                 'mean_score_de': mean_score_de,
#                 'var_score': var_score,
#                 'var_score_de':var_score_de
#                 }
#     )
#     df = pd.melt(df, value_vars=['mean_score', 'mean_score_de', 'var_score', 'var_score_de'], 
#                  var_name='score_type', value_name='score')
#     p = ggplot(df, aes(x='factor(score_type)', y='score', fill='factor(score_type)')) + geom_boxplot()+ scale_y_continuous(limits=(0,1))
#     return p