# training and evaluating the CPA

In [1]:
import sys
sys.path.append('..')

import wandb
wandb.login()

from data import DataModule
from model import ComPert
import torch
import numpy as np
import time
import yaml
import lightning as L
from lightning.pytorch import seed_everything
from pytorch_lightning.loggers import WandbLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmanuelgander1[0m ([33mi_selbr[0m). Use [1m`wandb login --relogin`[0m to force relogin


## load configuration

In [2]:
with open('config_hparam.yaml') as file:
    config_data = yaml.safe_load(file)

## initialize the dataset module

In [3]:
dm = DataModule(config_data['model']['hparams']['batch_size'],
                        config_data['model']['training_hparams']['full_eval_during_train'],
                        **config_data['dataset']['data_params'])

In [4]:
print(len(dm.datasets['training']))
print(len(dm.datasets['training_control']))
print(len(dm.datasets['training_treated']))
print(len(dm.datasets['test']))
print(len(dm.datasets['test_control']))
print(len(dm.datasets['test_treated']))
print(len(dm.datasets['ood']))
print(dm.datasets['training'].num_covariates)
print(dm.datasets['training'].num_drugs)
print(dm.datasets['training'].num_knockouts)

75503
10072
65431
13324
1763
11561
22428
[1]
0
106


## initialize the model

In [6]:
# initialize the model
# the knockouts embeddings are initialized as random 
model = ComPert(
    dm.datasets['training'].num_genes,
    dm.datasets['training'].num_drugs,
    dm.datasets['training'].num_knockouts,
    dm.datasets['training'].num_covariates,
    config_data['model']['hparams'],
    config_data['model']['training_hparams'],
    config_data['model']['test_hparams'],
    **config_data['model']['additional_params'],
    drug_embedding_dimension=dm.datasets['training'].drug_embedding_dimension,
    knockout_embedding_dimension=dm.datasets['training'].knockout_embedding_dimension)

## Initialize the trainer

In [7]:
early_stop_callback = EarlyStopping('average_r2_score', 
                                    patience=model.hparams.training_hparams['patience'], 
                                    mode='max')

if (not model.hparams.training_hparams['run_eval_disentangle']) and (not model.hparams.test_hparams['run_eval_disentangle']):
    trainer = L.Trainer(
        logger=WandbLogger(log_model="all"),
        max_epochs=config_data['model']['training_hparams']['num_epochs'],
        max_time=config_data['model']['training_hparams']['max_time'],
        check_val_every_n_epoch= config_data['model']['training_hparams']['checkpoint_freq'],
        default_root_dir=config_data['model']['training_hparams']['save_dir'],
        profiler="advanced",
        callbacks=[early_stop_callback],
        #inference_mode=False
    )
else: 
    trainer = L.Trainer(
        logger=WandbLogger(log_model="all"),
        max_epochs=config_data['model']['training_hparams']['num_epochs'],
        max_time=config_data['model']['training_hparams']['max_time'],
        check_val_every_n_epoch= config_data['model']['training_hparams']['checkpoint_freq'],
        default_root_dir=config_data['model']['training_hparams']['save_dir'],
        profiler="advanced",
        callbacks=[early_stop_callback],
        inference_mode=False
    )

/home/manu/miniconda3/envs/g/lib/python3.10/site-packages/lightning/fabric/accelerators/cuda.py:239: Can't initialize NVML
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


## train the model

In [8]:
trainer.fit(model, datamodule=dm)


  | Name                       | Type               | Params
------------------------------------------------------------------
0 | loss_autoencoder           | GaussianNLLLoss    | 0     
1 | encoder                    | MLP                | 1.9 M 
2 | decoder                    | MLP                | 2.9 M 
3 | knockout_embedding_encoder | MLP                | 1.9 M 
4 | knockout_effects           | GeneralizedSigmoid | 0     
5 | adversary_knockouts        | MLP                | 64.0 K
6 | loss_adversary_knockout    | CELoss             | 0     
------------------------------------------------------------------
6.8 M     Trainable params
0         Non-trainable params
6.8 M     Total params
27.046    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

/home/manu/miniconda3/envs/g/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:442: The dataloader, val_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.


0 combinations had '-inf' R2 scores:
	 set()


/home/manu/miniconda3/envs/g/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:442: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
/home/manu/miniconda3/envs/g/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:281: The number of training batches (8) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

0 combinations had '-inf' R2 scores:
	 set()


Validation: 0it [00:00, ?it/s]

0 combinations had '-inf' R2 scores:
	 set()


Validation: 0it [00:00, ?it/s]

0 combinations had '-inf' R2 scores:
	 set()


Validation: 0it [00:00, ?it/s]

0 combinations had '-inf' R2 scores:
	 set()


`Trainer.fit` stopped: `max_epochs=60` reached.
FIT Profiler Report
Profile stats for: [LightningModule]ComPert.configure_callbacks
         7 function calls in 0.000 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.000    0.000 contextlib.py:139(__exit__)
        1    0.000    0.000    0.000    0.000 {built-in method builtins.next}
        1    0.000    0.000    0.000    0.000 profiler.py:54(profile)
        1    0.000    0.000    0.000    0.000 advanced.py:67(stop)
        1    0.000    0.000    0.000    0.000 module.py:899(configure_callbacks)
        1    0.000    0.000    0.000    0.000 {method 'get' of 'dict' objects}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}



Profile stats for: [LightningDataModule]DataModule.prepare_data
         7 function calls in 0.000 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumt

## test the model

In [9]:
result=trainer.test(model, datamodule=dm)[0]

/home/manu/miniconda3/envs/g/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:442: The dataloader, test_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.


Testing: 0it [00:00, ?it/s]

0 combinations had '-inf' R2 scores:
	 set()
0 combinations had '-inf' R2 scores:
	 set()
0 combinations had '-inf' R2 scores:
	 set()


TEST Profiler Report
Profile stats for: [LightningModule]ComPert.configure_callbacks
         7 function calls in 0.000 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.000    0.000 contextlib.py:139(__exit__)
        1    0.000    0.000    0.000    0.000 {built-in method builtins.next}
        1    0.000    0.000    0.000    0.000 profiler.py:54(profile)
        1    0.000    0.000    0.000    0.000 advanced.py:67(stop)
        1    0.000    0.000    0.000    0.000 module.py:899(configure_callbacks)
        1    0.000    0.000    0.000    0.000 {method 'get' of 'dict' objects}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}



Profile stats for: [LightningDataModule]DataModule.prepare_data
         7 function calls in 0.000 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        

In [10]:
result

{'stats_disent_knockout': 0.1579105406999588,
 'optimal_disent_score_knockout': 0.13231762233563493,
 'stats_disent_cov_0': 1.0,
 'optimal_disent_cov_0': 1.0,
 'training_mean_score': 0.9733860431430496,
 'training_mean_score_de': 0.8492184768770342,
 'training_var_score': 0.7704700540716403,
 'training_var_score_de': 0.31402403004815643,
 'test_mean_score': 0.9702203083486064,
 'test_mean_score_de': 0.8541510997803559,
 'test_var_score': 0.7337480889799449,
 'test_var_score_de': 0.3038047126761065,
 'ood_mean_score': 0.9854941015893762,
 'ood_mean_score_de': 0.8849404102021997,
 'ood_var_score': 0.7795604115182703,
 'ood_var_score_de': 0.1886568394574252,
 'training_sc_mean_score': 0.8272536335704482,
 'training_sc_mean_score_de': 0.7371142606172606,
 'training_sc_var_score': 0.7988694574788352,
 'training_sc_var_score_de': 0.06766487059192122,
 'test_sc_mean_score': 0.8289370662729505,
 'test_sc_mean_score_de': 0.7398652739717927,
 'test_sc_var_score': 0.7575349035397382,
 'test_sc_va