# Training chemCPA

Either use the provided env.yml, or if using pip install the following:

pip install pandas scanpy torch anndata numpy torchmetrics seaborn matplotlib wandb lightning rdkit hydra-core plotnine pyarrow ipywidgets



## Imports

In [12]:
# optional autoreload
%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [13]:
from pathlib import Path

import lightning as L
import pandas as pd
import scanpy as sc
import wandb
from hydra import compose, initialize
from lightning.pytorch import seed_everything
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from omegaconf import OmegaConf
from plotnine import aes, geom_boxplot, ggplot, scale_y_continuous
from pytorch_lightning.loggers import WandbLogger

from chemCPA.data import DataModule
from chemCPA.model import ComPert
from chemCPA.train import evaluate_logfold_r2, evaluate_r2, evaluate_r2_sc

wandb.login()



True

## Load configs

Before you load the config, you should probably double check that you are happy with the paths where the datasets that will be used are located. 
These are configured <a href="../../experiments/hydra_config/dataset/sciplex_lincs.yaml">here</a>.
You may also want to change the model output path <a href="../../experiments/hydra_config/model/defaults.yaml">here</a>



In [3]:
### Load config

with initialize(version_base=None, config_path="../../experiments/hydra_config"):
    config = compose(config_name="defaults", overrides=[])
    


## (Down)Load data

In [4]:
dataset_path = config.dataset.data_params.dataset_path
dataset_dir = Path(dataset_path).parent

print(f"Dataset path: {config.dataset.data_params.dataset_path}")
print(f"Embeddings path: {config.dataset.data_params.drugs_embeddings}")




Dataset path: /root/chemCPA_fork/sciplex_complete_middle_subset_lincs_genes.h5ad
Embeddings path: /root/chemCPA_fork/rdkit2D_embedding_lincs_trapnell.parquet


The following checks that the files exist at the expected paths, if they don't the files are downloaded. Feel free to run this cell multiple times.


In [5]:
import sys
sys.path.append('../../data')

from datasets.datasets import sciplex_complete_middle_subset_lincs_genes, rdkit2D_embedding_lincs_trapnell

sciplex_complete_middle_subset_lincs_genes(config.dataset.data_params.dataset_path)
rdkit2D_embedding_lincs_trapnell(config.dataset.data_params.drugs_embeddings)
print(OmegaConf.to_yaml(config.dataset.data_params))
assert (Path(config.dataset.data_params.dataset_path)).exists(), "Config `dataset_path` is not correct!"
assert (Path(config.dataset.data_params.drugs_embeddings)).exists(), "Config `drugs_embeddings` is not correct"

dataset_path: /root/chemCPA_fork/sciplex_complete_middle_subset_lincs_genes.h5ad
drug_key: condition
dose_key: dose
knockout_key: null
covariate_keys:
- cell_type
smiles_key: SMILES
pert_category: cov_drug_dose_name
split_key: split_ood_multi_task
degs_key: lincs_DEGs
drugs_embeddings: /root/chemCPA_fork/rdkit2D_embedding_lincs_trapnell.parquet
knockouts_embeddings: null
return_dataset: true



## Load data

In [6]:
### Load data module

dm = DataModule(
    batch_size=config.model["hparams"]["batch_size"],
    full_eval_during_train=config.train["full_eval_during_train"],
    num_workers=config.train["num_workers"],
    # num_workers=19,
    **config.dataset["data_params"]
)

print(dm.datasets['training'].num_covariates)
print(dm.datasets['training'].num_drugs)
print(dm.datasets['training'].num_knockouts)

[3]
188
0


## Initialize Model

In [14]:
###  Initialise the model


data_train = dm.datasets['training']

model = ComPert(
    data_train.num_genes,
    data_train.num_drugs,
    data_train.num_knockouts,
    data_train.num_covariates,
    config.model.hparams,
    config.train,
    config.test,
    **config.model.additional_params,
    drug_embedding_dimension=data_train.drug_embedding_dimension,
    knockout_embedding_dimension=data_train.knockout_embedding_dimension,
)


## Train

In [15]:
config.model["save_dir"]


'/root/chemCPA_fork/model_outputs'

In [16]:
### Initialise the trainer

project_str = f"{config.model['model_type']}_{config.dataset['dataset_type']}"
wandb_logger = WandbLogger(project=project_str, save_dir=config.model["save_dir"])


inference_mode = (not config.train["run_eval_disentangle"]) and (not config.test["run_eval_disentangle"])
early_stop_callback = EarlyStopping("average_r2_score", patience=model.hparams.training_params["patience"], mode="max")

trainer = L.Trainer(
    logger=wandb_logger,
    max_epochs=config.train["num_epochs"],
    max_time=config.train["max_time"],
    check_val_every_n_epoch=config.train["checkpoint_freq"],
    default_root_dir=config.model["save_dir"],
    # profiler="advanced",
    callbacks=[early_stop_callback],
    inference_mode=inference_mode,
    num_sanity_val_steps=0,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [18]:
trainer.fit(model, datamodule=dm)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                   | Type            | Params
-----------------------------------------------------------
0 | loss_autoencoder       | GaussianNLLLoss | 0     
1 | encoder                | MLP             | 259 K 
2 | decoder                | MLP             | 511 K 
3 | drug_embedding_encoder | MLP             | 644 K 
4 | loss_adversary_drugs   | CELoss          | 0     
5 | dosers                 | MLP             | 50.9 K
-----------------------------------------------------------
1.5 M     Trainable params
0         Non-trainable params
1.5 M     Total params
5.863     Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]