# training and evaluating the CPA

In [None]:
import sys
sys.path.append('..')

import wandb
wandb.login()

from chemCPA.data import DataModule
from chemCPA.model import ComPert
import torch
import numpy as np
import time
import yaml
import lightning as L
from lightning.pytorch import seed_everything
from pytorch_lightning.loggers import WandbLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

## load configuration

In [3]:
with open('../chemCPA/config/config.yaml') as file:
    config_data = yaml.safe_load(file)

#seed_everything()

## initialize the dataset module

In [4]:
dm = DataModule(config_data['model']['hparams']['batch_size'],
                        config_data['model']['training_params']['full_eval_during_train'],
                        **config_data['dataset']['data_params'])

In [None]:
print(len(dm.datasets['training']))
print(len(dm.datasets['training_control']))
print(len(dm.datasets['training_treated']))
print(len(dm.datasets['test']))
print(len(dm.datasets['test_control']))
print(len(dm.datasets['test_treated']))
print(len(dm.datasets['ood']))

In [None]:
print(dm.datasets['training'].num_covariates)
print(dm.datasets['training'].num_drugs)
print(dm.datasets['training'].num_knockouts)

In [None]:
next(iter(dm.train_dataloader()))

## initialize the model

In [8]:
# initialize the model
# the knockouts embeddings are initialized as random 
model = ComPert(
    dm.datasets['training'].num_genes,
    dm.datasets['training'].num_drugs,
    dm.datasets['training'].num_knockouts,
    dm.datasets['training'].num_covariates,
    config_data['model']['hparams'],
    config_data['model']['training_params'],
    config_data['model']['test_params'],
    **config_data['model']['additional_params'],
    drug_embedding_dimension=dm.datasets['training'].drug_embedding_dimension,
    knockout_embedding_dimension=dm.datasets['training'].knockout_embedding_dimension
    )
    

## Initialize the trainer

In [None]:
early_stop_callback = EarlyStopping('average_r2_score', 
                                    patience=model.hparams.training_params['patience'], 
                                    mode='max')

if (not model.hparams.training_params['run_eval_disentangle']) and (not model.hparams.test_params['run_eval_disentangle']):
    trainer = L.Trainer(
        logger=WandbLogger(log_model="all"),
        max_epochs=config_data['model']['training_params']['num_epochs'],
        max_time=config_data['model']['training_params']['max_time'],
        check_val_every_n_epoch= config_data['model']['training_params']['checkpoint_freq'],
        default_root_dir=config_data['model']['save_dir'],
        profiler="advanced",
        callbacks=[early_stop_callback],
        #inference_mode=False
    )
else: 
    trainer = L.Trainer(
        logger=WandbLogger(log_model="all"),
        max_epochs=config_data['model']['training_params']['num_epochs'],
        max_time=config_data['model']['training_params']['max_time'],
        check_val_every_n_epoch= config_data['model']['training_params']['checkpoint_freq'],
        default_root_dir=config_data['model']['save_dir'],
        profiler="advanced",
        callbacks=[early_stop_callback],
        inference_mode=False
    )

## train the model

In [None]:
trainer.fit(model, datamodule=dm)

## test the model

In [None]:
trainer.test(model, datamodule=dm)

## reload the model from check point

In [None]:
#reload the model from check point
#model = ComPert.load_from_checkpoint('train_data/CPA/3gm2eppz/checkpoints/epoch=14-step=10560.ckpt')