# Test the load data, model initialization and manual training functions
test the codes on some randomly generated data

In [2]:
import wandb
wandb.login()

from chemCPA.data import load_dataset_splits
from chemCPA.model import ComPert
import torch


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mkemingzhang[0m. Use [1m`wandb login --relogin`[0m to force relogin


## initialize the datasets
7 subdatasets in total: training; training_control; training_treated; test; test_control; test_treated; ood 

In [3]:
device = "mps"

# a collection of datasets generated by some random adata
# the knockout embeddings will be generated randomly
datasets = load_dataset_splits(
                                "fake.h5ad",
                                drug_key=None,
                                dose_key=None,
                                drugs_embeddings=None,
                                knockout_key="guide_merged",
                                knockouts_embeddings=None,
                                covariate_keys=["cell_type"],
                                smiles_key=None,
                                pert_category="pert_category",
                                split_key="split"
                                )

 

## initialize the model

In [4]:
# initialize the model
# the knockouts embeddings are initialized as random 
model = ComPert(
    datasets["training"].num_genes,
    datasets["training"].num_drugs,
    datasets["training"].num_knockouts,
    datasets["training"].num_covariates,
    device=device,
    drug_embedding_dimension=None,
    knockout_embedding_dimension=256,
    knockout_effect_type="sigm"
)
    

## train functions and collate function

In [5]:
def train_epoch(model, loader, epoch):
    batch_ct = epoch * len(loader)
    cumu_loss = {
        "loss_reconstruction":0,
        "loss_adv_drugs":0,
        "loss_adv_knockouts":0,
        "loss_adv_covariates":0,
    }
    for data in loader:
        genes, drugs_idx, dosages, drugs_embeddings, knockouts_idx, knockouts_embeddings, covariates_idx = (
                    data[0],
                    data[1],
                    data[2],
                    data[3],
                    data[4],
                    data[5],
                    data[6:]
                )
        training_stats = model.update(
            genes=genes,
            drugs_idx=drugs_idx,
            dosages=dosages,
            drugs_embeddings = drugs_embeddings,
            knockouts_idx=knockouts_idx,
            knockouts_embeddings = knockouts_embeddings,
            covariates_idx=covariates_idx,     
                )
        batch_ct += 1
        wandb.log({"batch_loss_reconstruction": training_stats["loss_reconstruction"], "batch_ct": batch_ct})
        wandb.log({"batch_loss_adv_drugs": training_stats["loss_adv_drugs"], "batch_ct": batch_ct})
        wandb.log({"batch_loss_adv_knockouts": training_stats["loss_adv_knockouts"], "batch_ct": batch_ct})
        wandb.log({"batch_loss_adv_covariates": training_stats["loss_adv_covariates"], "batch_ct": batch_ct})

        cumu_loss = {i: (cumu_loss[i]+training_stats[i]) for i in cumu_loss.keys()}
        torch.mps.empty_cache()

    return {i: cumu_loss[i]/len(loader) for i in cumu_loss.keys()}

In [6]:
# a naive collate function for only one covariate
# todo: improve the collate function

def custom_collate(batch):
    genes, drugs_idx, dosages, drugs_emb, knockouts_idx, knockouts_emb, cov = zip(*batch)
    genes = torch.stack(genes, 0).to(device)
    drugs_idx = None if drugs_idx[0] is None else [d.to(device) for d in drugs_idx]
    dosages = None if dosages[0] is None else [d.to(device) for d in dosages]
    drugs_emb = None if drugs_emb[0] is None else [d.to(device) for d in drugs_emb]
    knockouts_idx = None if knockouts_idx[0] is None else [d.to(device) for d in knockouts_idx]
    knockouts_emb = None if knockouts_emb[0] is None else [d.to(device) for d in knockouts_emb]
    cov = None if cov[0] is None else  torch.stack(cov, 0).to(device)
    return [genes, drugs_idx, dosages, drugs_emb, knockouts_idx, knockouts_emb, cov]



def train(model, datasets):
    with wandb.init(project="cpa", config=model.hparams):

        datasets.update(
            {
                "loader_tr": torch.utils.data.DataLoader(
                    datasets["training"],
                    batch_size=model.hparams["batch_size"],
                    collate_fn=custom_collate,
                    shuffle=True,
                    )
            }
        )
        
        wandb.define_metric("batch_loss_reconstruction", step_metric="batch_ct")
        wandb.define_metric("batch_loss_adv_drugs", step_metric="batch_ct")
        wandb.define_metric("batch_loss_adv_knockouts", step_metric="batch_ct")
        wandb.define_metric("batch_loss_adv_covariates", step_metric="batch_ct")
        
        wandb.define_metric("epoch_loss_reconstruction", step_metric="epoch")
        wandb.define_metric("epoch_loss_adv_drugs", step_metric="epoch")
        wandb.define_metric("epoch_loss_adv_knockouts", step_metric="epoch")
        wandb.define_metric("epoch_loss_adv_covariates", step_metric="epoch")
        
        for epoch in range(1):
            avg_stats = train_epoch(model, datasets["loader_tr"], epoch)
            wandb.log({"epoch_loss_reconstruction": avg_stats["loss_reconstruction"], "epoch": epoch})
            wandb.log({"epoch_loss_adv_drugs": avg_stats["loss_adv_drugs"], "epoch": epoch})
            wandb.log({"epoch_loss_adv_knockouts": avg_stats["loss_adv_knockouts"], "epoch": epoch})
            wandb.log({"epoch_loss_adv_covariates": avg_stats["loss_adv_covariates"], "epoch": epoch})
            
     
            
    return model

## manual run of the train function

In [7]:
train(model, datasets)



VBox(children=(Label(value='0.030 MB of 0.030 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
batch_ct,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
batch_loss_adv_covariates,▇▇▅▆▁█▃▆▇▆▄▄▆▄▅█▇▅▆▅▆▅▆▂▄▅▃▅▃▂▆▃▅▂▃▁▁▅▁█
batch_loss_adv_drugs,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
batch_loss_adv_knockouts,▄▄▅▆▄▆▇▅▃▅▅▁▃▂▃▅▄▆▄▄▂▇▆▄█▆▇▇▄▅▆▃▃▅▄█▄▅▄▂
batch_loss_reconstruction,█▇▆▅▅▄▄▃▃▃▃▂▂▂▂▂▂▃▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▂▁▂▂▁▂▂
epoch,▁▁▁▁
epoch_loss_adv_covariates,▁
epoch_loss_adv_drugs,▁
epoch_loss_adv_knockouts,▁
epoch_loss_reconstruction,▁

0,1
batch_ct,118.0
batch_loss_adv_covariates,2.12337
batch_loss_adv_drugs,0.0
batch_loss_adv_knockouts,6.5844
batch_loss_reconstruction,-1.31836
epoch,0.0
epoch_loss_adv_covariates,2.08669
epoch_loss_adv_drugs,0.0
epoch_loss_adv_knockouts,6.80522
epoch_loss_reconstruction,-1.12191


ComPert(
  (encoder): MLP(
    (network): Sequential(
      (0): Linear(in_features=2000, out_features=512, bias=True)
      (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Linear(in_features=512, out_features=512, bias=True)
      (4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
      (6): Linear(in_features=512, out_features=512, bias=True)
      (7): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (8): ReLU()
      (9): Linear(in_features=512, out_features=512, bias=True)
      (10): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (11): ReLU()
      (12): Linear(in_features=512, out_features=256, bias=True)
    )
  )
  (decoder): MLP(
    (network): Sequential(
      (0): Linear(in_features=256, out_features=512, bias=True)
      (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=Tru

## test the evaluation codes

In [8]:
## currently, the fake data do not suppert degs, so that the r2 scores based on degs cannot be tested
## todo: get a complete anndata set and test the evaluation codes



genes = datasets['test'].genes
knockouts_idx = datasets['test'].knockouts_idx
knockouts_embeddings = [datasets['test'].knockouts_embeddings(idx) for idx in knockouts_idx]
covariate_idx = datasets['test'].covariates_idx

with torch.no_grad():
    model.eval()
    pred = model.predict(genes,
                        drugs_idx=None,
                        dosages=None,
                        drugs_embeddings=None,
                        knockouts_idx=knockouts_idx,
                        knockouts_embeddings=knockouts_embeddings,
                        covariates_idx=covariate_idx
                        )
    model.train()

In [None]:
from train import compute_r2
pred = pred[:, 0:2000].mean(dim=0)
genes = genes.mean(dim=0)
compute_r2(genes.to(device), pred)