In [None]:
import os
os.chdir('../')
import logging
import numpy as np
import pytorch_lightning as pl
import pandas as pd
logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s'
    )
%load_ext autoreload
%autoreload 2

In [None]:
import scanpy as sc
# load meta-set
metaset_path = '/home/xlv0877/proj_home/dl/data/test/1000/mixscale/mixscale_filtered_w_emb.h5ad'
#metaset_path = 'data/subset_top_100_all_datasets.h5ad'
adata = sc.read(metaset_path)

### Assign parameters and prepare data

In [None]:
# assign number of cpus to use as data loaders
n_cpus = 10
seed = 42
min_cells = 50
N = 100             # Number of perturbations to subset
M = 10              # Number of perturbations to subset that are not included in :N
model_dir = 'models/mixscale'
# define all labels to classify on
cls_labels = ['celltype', 'perturbation_type', 'perturbation']
# create classification label
adata.obs['cls_label'] = adata.obs[cls_labels].agg(';'.join, axis=1)
# define label in .obs to classify
cls_label = 'cls_label'
batch_key = 'dataset'
# add status label (control or perturbed)
adata.obs['status'] = 'perturbed'
adata.obs.loc[adata.obs['perturbation'] == 'control', 'status'] = 'ctrl'
# Check number of unique perturbations to classify
logging.info(f'Initializing dataset with {adata.obs.cls_label.nunique()} classes')
# remove perturbations with less than minimum amount of cells
p_summary = adata.obs.cls_label.value_counts()
valid_perturbations = p_summary[p_summary >= min_cells].index
adata._inplace_subset_obs(adata.obs.cls_label.isin(valid_perturbations))
# Take a few perturbations out to later predict on
idx = np.where(adata.obs[cls_labels[-1]].isin(adata.obs[cls_labels[-1]].value_counts()[:N].index))[0]
idx_unseen = np.where(adata.obs[cls_labels[-1]].isin(adata.obs[cls_labels[-1]].value_counts()[N:N+M].index))[0]
unseen_adata = adata[idx_unseen].copy()
adata._inplace_subset_obs(idx)
# Select cells that have a match in the embedding
cells_annotated_mask = adata.obsm['gene_embedding'].sum(axis=1).A1!=0
# Select control cells to keep
ctrl_mask = adata.obs['perturbation']=='control'
embedding_mask = (cells_annotated_mask | ctrl_mask)
adata._inplace_subset_obs(embedding_mask)

### Prepare model set

In [6]:
# choose set to train on
model_set = adata.copy()
likelihood = 'normal'

if likelihood == 'normal':
    logging.info('Applying log1p and scaling data to force normal distribution')
    sc.pp.log1p(model_set)
    sc.pp.scale(model_set)
# model_set.write_h5ad(os.path.join(model_dir, '.cache.h5ad'))

### Set hyperparameters

In [7]:
# set model dir
class_labels = model_set.obs[cls_label].unique()
n_labels = len(class_labels)-1 if 'unknown' in class_labels else len(class_labels)  
# set scale for kl divergence
recon_weight = 1
klr = 0.25
g_weight = 100
g_classification_weight = 1000
cr = 100
adjust_by_mean = False

# hyperparameters for model
cls_params = {
    'n_hidden': 128, 'n_layers': 1, 'dropout_rate': 0.1
}
gdvae_params = {
    'n_hidden': 128,
    'n_latent': 100,
    'n_latent_g': 100,
    'n_layers_encoder': 1,
    'n_layers_encoder_g': 1,
    'n_layers_decoder': 1,
    'dropout_rate_encoder': 0.2,
    'dropout_rate_encoder_g': 0.2,
    'dispersion': 'gene',
    'use_batch_norm': 'both',
    'gene_likelihood': likelihood,
    'linear_classifier': False,
    'classifier_parameters': cls_params,
    'use_posterior_mean': False,
    'log_variational': likelihood == 'zinb',
    'g_weight': g_weight,
    'recon_weight': recon_weight,
    'adjust_by_mean': adjust_by_mean,
    'g_classification_weight': g_classification_weight
}

tensor_dir = os.path.join(model_dir, f'n_{n_labels}')

data_params = {
    'train_size': 0.9,
    'batch_size': 128,
    'num_workers': 1
}
plan_params = {
    'lr': 7e-4,
    'weight_decay': 1e-6,
    'kl_weight': klr,
    'classification_ratio': cr,
    'lr_scheduler_metric': 'elbo_validation',
    'plot_cm': True
}
train_params = {
    'max_epochs': 80,
    'early_stopping': False,
    'check_val_every_n_epoch': 1,
    'logger': pl.loggers.TensorBoardLogger(tensor_dir),
    'plan_kwargs': plan_params,
}
# Define covariates, set celltype as covariate
cat_covs = None
cont_covs = None

In [9]:
import torch
from src.models._gdvi import GDVI

torch.set_float32_matmul_precision('medium')

# Setup anndata with model
GDVI.setup_anndata(
    model_set, batch_key=batch_key, labels_key=cls_label, unlabeled_category='unknown',
    cls_labels=cls_labels,
    categorical_covariate_keys=cat_covs,
)
# Training parameters
gdvae = GDVI(model_set, **gedvae_params)

In [None]:
# Get model trainer
logging.info(f'Running at: {tensor_dir}')
gedvae.train(data_params=data_params, model_params=gedvae_params, train_params=train_params, return_runner=False)