# training and evaluating the CPA

In [None]:
import sys
sys.path.append('..')

import wandb
wandb.login()

from chemCPA.data import DataModule
from chemCPA.model import ComPert
import torch
import numpy as np
import time
import yaml
import lightning as L
from lightning.pytorch import seed_everything
from pytorch_lightning.loggers import WandbLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

## load configuration

In [None]:
with open('../chemCPA/config/dataset/drug.yaml') as file:
    config_dataset = yaml.safe_load(file)
with open('../chemCPA/config/model/cpa.yaml') as file:
    config_model = yaml.safe_load(file)
with open('../chemCPA/config/test/test.yaml') as file:
    config_test = yaml.safe_load(file)
with open('../chemCPA/config/train/train.yaml') as file:
    config_train = yaml.safe_load(file)

#seed_everything()

## initialize the dataset module

In [None]:
dm = DataModule(config_model['hparams']['batch_size'],
                config_train['full_eval_during_train'],
                **config_dataset['data_params'])

In [None]:
print(len(dm.datasets['training']))
print(len(dm.datasets['training_control']))
print(len(dm.datasets['training_treated']))
print(len(dm.datasets['test']))
print(len(dm.datasets['test_control']))
print(len(dm.datasets['test_treated']))
print(len(dm.datasets['ood']))

In [None]:
print(dm.datasets['training'].num_covariates)
print(dm.datasets['training'].num_drugs)
print(dm.datasets['training'].num_knockouts)

## initialize the model

In [None]:
# initialize the model
# the knockouts embeddings are initialized as random 
model = ComPert(
    dm.datasets['training'].num_genes,
    dm.datasets['training'].num_drugs,
    dm.datasets['training'].num_knockouts,
    dm.datasets['training'].num_covariates,
    config_model['hparams'],
    config_train,
    config_test,
    **config_model['additional_params'],
    drug_embedding_dimension=dm.datasets['training'].drug_embedding_dimension,
    knockout_embedding_dimension=dm.datasets['training'].knockout_embedding_dimension
    )
    

## Initialize the trainer

In [None]:
early_stop_callback = EarlyStopping('average_r2_score', 
                                    patience=model.hparams.training_params['patience'], 
                                    mode='max')
wandb_logger = WandbLogger(
        project = config_model['model_type'] + "_" + config_dataset['dataset_type'],
        save_dir = config_model['save_dir']     
    )
inference_mode = ((not config_train['run_eval_disentangle']) and (not config_test['run_eval_disentangle']))
trainer = L.Trainer(
    logger=wandb_logger,
    max_epochs=config_train['num_epochs'],
    max_time=config_train['max_time'],
    check_val_every_n_epoch= config_train['checkpoint_freq'],
    default_root_dir=config_model['save_dir'],
    profiler="advanced",
    callbacks=[early_stop_callback],
    inference_mode=inference_mode
)


## train the model

In [None]:
trainer.fit(model, datamodule=dm)

## test the model

In [None]:
trainer.test(model, datamodule=dm)

## reload the model from check point

In [None]:
#reload the model from check point
#model = ComPert.load_from_checkpoint('train_data/CPA/3gm2eppz/checkpoints/epoch=14-step=10560.ckpt')

## drawing the evaluation results

In [None]:
from plotnine import *
from chemCPA.train import evaluate_logfold_r2, evaluate_r2, evaluate_r2_sc
import pandas as pd

In [None]:
#draw the logfold r2
def draw_logfold_r2(autoencoder, ds_treated, ds_ctrl):
    logfold_score, signs_score = evaluate_logfold_r2(autoencoder, ds_treated, ds_ctrl, return_mean=False)
    df = pd.DataFrame(
        data = {'logfold_score': logfold_score, 'signs_score': signs_score}
    )
    df = pd.melt(df, value_vars=['logfold_score', 'signs_score'], var_name='score_type', value_name='score')
    p = ggplot(df, aes(x='factor(score_type)', y='score', fill='factor(score_type)')) + geom_boxplot() + scale_y_continuous(limits=(-1,1))
    return p

In [None]:
#draw the r2
def draw_r2(autoencoder, dataset, genes_control):
    mean_score, mean_score_de, var_score, var_score_de = evaluate_r2(autoencoder, dataset, genes_control, return_mean=False)
    df = pd.DataFrame(
        data = {'mean_score': mean_score, 
                'mean_score_de': mean_score_de,
                'var_score': var_score,
                'var_score_de':var_score_de
                }
    )
    df = pd.melt(df, value_vars=['mean_score', 'mean_score_de', 'var_score', 'var_score_de'], 
                 var_name='score_type', value_name='score')
    p = ggplot(df, aes(x='factor(score_type)', y='score', fill='factor(score_type)')) + geom_boxplot()+ scale_y_continuous(limits=(0,1))
    return p

In [None]:
#draw the r2 sc
def draw_r2_sc(autoencoder, dataset):
    mean_score, mean_score_de, var_score, var_score_de = evaluate_r2_sc(autoencoder, dataset, return_mean=False)
    df = pd.DataFrame(
        data = {'mean_score': mean_score, 
                'mean_score_de': mean_score_de,
                'var_score': var_score,
                'var_score_de':var_score_de
                }
    )
    df = pd.melt(df, value_vars=['mean_score', 'mean_score_de', 'var_score', 'var_score_de'], 
                 var_name='score_type', value_name='score')
    p = ggplot(df, aes(x='factor(score_type)', y='score', fill='factor(score_type)')) + geom_boxplot()+ scale_y_continuous(limits=(0,1))
    return p