# Test the load data, model initialization and manual training functions
test the codes on some randomly generated data

In [1]:
import sys
sys.path.append('/Users/zkm/Desktop/chemCPA_fork')

In [None]:
import wandb
wandb.login()

from chemCPA.data import load_dataset_splits
from chemCPA.model import ComPert
import torch
import numpy as np

## initialize the datasets
7 subdatasets in total: training; training_control; training_treated; test; test_control; test_treated; ood 

In [3]:
device = "mps"

# a collection of datasets generated by some random adata
# the knockout embeddings will be generated randomly
datasets = load_dataset_splits(
                                "Comb.h5ad",
                                drug_key=None,
                                dose_key=None,
                                drugs_embeddings=None,
                                knockout_key="treatment",
                                knockouts_embeddings=None,
                                covariate_keys=["cell_line"],
                                smiles_key=None,
                                pert_category="pert_category",
                                split_key="split",
                                degs_key='rank_genes_groups_cov'
                                )

 

In [4]:
print(len(datasets['training']))
print(len(datasets['training_control']))
print(len(datasets['training_treated']))
print(len(datasets['test']))
print(len(datasets['test_control']))
print(len(datasets['test_treated']))
print(len(datasets['ood']))

55050
3724
51326
10485
725
9760
1260


## initialize the model

In [5]:
# initialize the model
# the knockouts embeddings are initialized as random 
model = ComPert(
    datasets["training"].num_genes,
    datasets["training"].num_drugs,
    datasets["training"].num_knockouts,
    datasets["training"].num_covariates,
    device=device,
    drug_embedding_dimension=None,
    knockout_embedding_dimension=256,
    knockout_effect_type="sigm"
)
    

## train functions and collate function

In [6]:
def train_epoch(model, loader, epoch):
    batch_ct = epoch * len(loader)
    cumu_loss = {
        "loss_reconstruction":0,
        "loss_adv_drugs":0,
        "loss_adv_knockouts":0,
        "loss_adv_covariates":0,
    }
    for data in loader:
        genes, drugs_idx, dosages, drugs_embeddings, knockouts_idx, knockouts_embeddings, covariates_idx = (
                    data[0],
                    data[1],
                    data[2],
                    data[3],
                    data[4],
                    data[5],
                    data[6:]
                )
        training_stats = model.update(
            genes=genes,
            drugs_idx=drugs_idx,
            dosages=dosages,
            drugs_embeddings = drugs_embeddings,
            knockouts_idx=knockouts_idx,
            knockouts_embeddings = knockouts_embeddings,
            covariates_idx=covariates_idx,     
                )
        batch_ct += 1
        wandb.log({"batch_loss_reconstruction": training_stats["loss_reconstruction"], "batch_ct": batch_ct})
        wandb.log({"batch_loss_adv_drugs": training_stats["loss_adv_drugs"], "batch_ct": batch_ct})
        wandb.log({"batch_loss_adv_knockouts": training_stats["loss_adv_knockouts"], "batch_ct": batch_ct})
        wandb.log({"batch_loss_adv_covariates": training_stats["loss_adv_covariates"], "batch_ct": batch_ct})

        cumu_loss = {i: (cumu_loss[i]+training_stats[i]) for i in cumu_loss.keys()}
        torch.mps.empty_cache()

    return {i: cumu_loss[i]/len(loader) for i in cumu_loss.keys()}

In [7]:
# a naive collate function for only one covariate
# todo: improve the collate function

def custom_collate(batch):
    genes, drugs_idx, dosages, drugs_emb, knockouts_idx, knockouts_emb, cov = zip(*batch)
    genes = torch.stack(genes, 0).to(device)
    drugs_idx = None if drugs_idx[0] is None else [d.to(device) for d in drugs_idx]
    dosages = None if dosages[0] is None else [d.to(device) for d in dosages]
    drugs_emb = None if drugs_emb[0] is None else [d.to(device) for d in drugs_emb]
    knockouts_idx = None if knockouts_idx[0] is None else [d.to(device) for d in knockouts_idx]
    knockouts_emb = None if knockouts_emb[0] is None else [d.to(device) for d in knockouts_emb]
    cov = None if cov[0] is None else  torch.stack(cov, 0).to(device)
    return [genes, drugs_idx, dosages, drugs_emb, knockouts_idx, knockouts_emb, cov]



def train(model, datasets):
    with wandb.init(project="cpa", config=model.hparams):
        datasets.update(
            {
                "loader_tr": torch.utils.data.DataLoader(
                    datasets["training"],
                    batch_size=model.hparams["batch_size"],
                    collate_fn=custom_collate,
                    shuffle=True,
                    )
            }
        )
        
        wandb.define_metric("batch_loss_reconstruction", step_metric="batch_ct")
        wandb.define_metric("batch_loss_adv_drugs", step_metric="batch_ct")
        wandb.define_metric("batch_loss_adv_knockouts", step_metric="batch_ct")
        wandb.define_metric("batch_loss_adv_covariates", step_metric="batch_ct")
        
        wandb.define_metric("epoch_loss_reconstruction", step_metric="epoch")
        wandb.define_metric("epoch_loss_adv_drugs", step_metric="epoch")
        wandb.define_metric("epoch_loss_adv_knockouts", step_metric="epoch")
        wandb.define_metric("epoch_loss_adv_covariates", step_metric="epoch")
        
        for epoch in range(100):
            avg_stats = train_epoch(model, datasets["loader_tr"], epoch)
            wandb.log({"epoch_loss_reconstruction": avg_stats["loss_reconstruction"], "epoch": epoch})
            wandb.log({"epoch_loss_adv_drugs": avg_stats["loss_adv_drugs"], "epoch": epoch})
            wandb.log({"epoch_loss_adv_knockouts": avg_stats["loss_adv_knockouts"], "epoch": epoch})
            wandb.log({"epoch_loss_adv_covariates": avg_stats["loss_adv_covariates"], "epoch": epoch})
            
     
            
    return model

## manual run of the train function

In [None]:
train(model, datasets)

## test the evaluation codes

In [9]:
from train import evaluate_logfold_r2, evaluate_disentanglement, evaluate_r2, evaluate_r2_sc, evaluate

In [10]:
#model.eval()
#evaluate_logfold_r2(model, datasets['test_treated'], datasets['test_control'])
#evaluate_disentanglement(model, datasets['test'])
#evaluate_r2(model, datasets["test_treated"], datasets['test_control'].genes)
#evaluate_r2_sc(model, datasets['test'])

0 combinations had '-inf' R2 scores:
	 set()


[0.6732103911763064, -13.59968694131729, -88.9191639934421, -784.672868597486]

In [None]:
evaluate(
    model,
    datasets,
    {},
    run_disentangle=True,
    run_r2=False,
    run_r2_sc=False,
    run_logfold=False,
)