In [None]:
LATENT_DIM: int = 20
RUN_TAG: str = "test"
LEARNING_RATE: float = 1e-3
PATIENT_EMBEDDING_DIM: int = 20
BUCKET_DIRPATH: str = ""

In [None]:
assert isinstance(N_LAYERS, int)
assert isinstance(LATENT_DIM, int)

In [None]:
import os
import sys

import json

import numpy as np
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report
from scarches.models.scpoli import scPoli

import torch

from dotenv import load_dotenv

from lightning.pytorch.loggers import WandbLogger
import wandb

import session_info
import warnings
from pyprojroot.here import here

torch.set_float32_matmul_precision('high')

#plt.style.use(['science','nature','no-latex'])
dpi_fig_save = 300
sc.set_figure_params(dpi=100, dpi_save=dpi_fig_save, vector_friendly=True)

# Setting some parameters
warnings.filterwarnings("ignore")

from sklearn.model_selection import StratifiedKFold

overwriteData = True
overwriteFigures = True

# Set random seed
random_seed = 5

import warnings
warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

In [None]:
def here(fpath):
    return os.path.join(BUCKET_DIRPATH, fpath)

## Custom scPoli

In [None]:
from scarches.trainers.scpoli.trainer import scPoliTrainer
from scarches.models.scpoli.scpoli_model import scPoli

from scarches.dataset.scpoli.anndata import MultiConditionAnnotatedDataset

class CustomScPoliTrainer(scPoliTrainer):

    def __init__(self, *args, w_logger, prefix="", **kwargs):
        super().__init__(*args, **kwargs)
        self.w_logger = w_logger
        self.steps = -1
        self.prefix = prefix
        self.lr_factor = 1
        self.w_logger.define_metric("global_step")
        self.w_logger.define_metric("*", step_metric="global_step", step_sync=True)

    def on_iteration(self, batch_data):
        super().on_iteration(batch_data)
        self.steps += 1
        key = self.prefix + "train_loss"
        self.w_logger.log({key: self.current_loss, 'global_step':self.steps})

    def on_epoch_end(self):
        super().on_epoch_end()
        
        self.w_logger.log({self.prefix+'epoch': self.epoch, 'global_step':self.steps})

        for metric in self.logs:
            if metric.startswith('val_'):
                self.w_logger.log({self.prefix + metric: self.logs[metric][-1], 'global_step':self.steps})
            else:
                self.w_logger.log({self.prefix + "epoch_" + metric: self.logs[metric][-1], 'global_step': self.steps})
                
    def loss(self, total_batch=None):
        latent, recon_loss, kl_loss, mmd_loss = self.model(**total_batch)
        self.iter_logs["recon_loss"].append(recon_loss.item())
        self.iter_logs["kl_loss"].append(kl_loss.item())
        self.iter_logs["mmd_loss"].append(mmd_loss.item())

        #calculate classifier loss for labeled/unlabeled data
        label_categories = total_batch["labeled"].unique().tolist()
        unweighted_prototype_loss = torch.tensor(0.0, device=self.device)
        unlabeled_loss = torch.tensor(0.0, device=self.device)
        labeled_loss = torch.tensor(0.0, device=self.device)
        if self.epoch >= self.pretraining_epochs:
            #calculate prototype loss for all data
            if self.prototypes_unlabeled is not None:
                unlabeled_loss, _ = self.prototype_unlabeled_loss(
                    latent,
                    torch.stack(self.prototypes_unlabeled).squeeze(),
                )
                unweighted_prototype_loss = (
                    unweighted_prototype_loss + self.unlabeled_weight * unlabeled_loss
                )

            # Calculate prototype loss for labeled data
            if (self.any_labeled_data is True) and (self.prototype_training is True):
                labeled_loss = self.prototype_labeled_loss(
                    latent[torch.where(total_batch["labeled"] == 1)[0], :],
                    self.prototypes_labeled,
                    total_batch["celltypes"][
                        torch.where(total_batch["labeled"] == 1)[0], :
                    ],
                )
                unweighted_prototype_loss = unweighted_prototype_loss + labeled_loss

        # Loss addition and Logs
        prototype_loss = self.eta * unweighted_prototype_loss
        cvae_loss = recon_loss + self.calc_alpha_coeff() * kl_loss + mmd_loss
        loss = cvae_loss + prototype_loss
        self.iter_logs["loss"].append(loss.item())
        self.iter_logs["unweighted_loss"].append(
            recon_loss.item()
            + kl_loss.item()
            + mmd_loss.item()
            + unweighted_prototype_loss.item()
        )
        self.iter_logs["cvae_loss"].append(cvae_loss.item())
        if self.epoch >= self.pretraining_epochs:
            self.iter_logs["prototype_loss"].append(prototype_loss.item())
            if 0 in label_categories or self.model.unknown_ct_names is not None:
                self.iter_logs["unlabeled_loss"].append(unlabeled_loss.item())
            if 1 in label_categories:
                self.iter_logs["labeled_loss"].append(labeled_loss.item())
        return loss
    

    def calc_alpha_coeff(self):
        alpha_coeff = super().calc_alpha_coeff()
        self.w_logger.log({'alpha_coeff': alpha_coeff, 'global_step':self.steps})
        return alpha_coeff

    def check_early_stop(self):
        # Calculate Early Stopping and best state
        early_stopping_metric = self.early_stopping.early_stopping_metric
        if self.early_stopping.update_state(self.logs[early_stopping_metric][-1]):
            self.best_state_dict = self.model.state_dict()
            self.best_epoch = self.epoch

        continue_training, update_lr = self.early_stopping.step(self.logs[early_stopping_metric][-1])
        if update_lr:
            print(f'\nADJUSTED LR')
            self.lr_factor *= self.early_stopping.lr_factor
            for param_group in self.optimizer.param_groups:
                param_group["lr"] *= self.early_stopping.lr_factor
        
        self.w_logger.log({'lr_factor': self.lr_factor, 'global_step': self.steps})

        return continue_training


class CustomScPoli(scPoli):
    def train(
        self,
        n_epochs: int = 100,
        pretraining_epochs=None,
        eta: float = 1,
        lr: float = 1e-3,
        eps: float = 0.01,
        alpha_epoch_anneal = 1e2,
        reload_best: bool = False,
        prototype_training = True,
        unlabeled_prototype_training = True,
        w_logger = None,
        prefix = "",
        **kwargs,
    ):
        """Train the model.

        Parameters
        ----------
        n_epochs
             Number of epochs for training the model.
        lr
             Learning rate for training the model.
        eps
             torch.optim.Adam eps parameter
        kwargs
             kwargs for the scPoli trainer.
        """
        self.prototype_training_ = prototype_training
        self.unlabeled_prototype_training_ = unlabeled_prototype_training
        if self.cell_type_keys_ is None:
            pretraining_epochs = n_epochs
            self.prototype_training_ = False
            print("The model is being trained without using prototypes.")
        elif pretraining_epochs is None:
            pretraining_epochs = int(np.floor(n_epochs * 0.9))


        self.trainer = CustomScPoliTrainer(
            self.model,
            self.adata,
            labeled_indices=self.labeled_indices_,
            pretraining_epochs=pretraining_epochs,
            condition_keys=self.condition_keys_,
            cell_type_keys=self.cell_type_keys_,
            reload_best=reload_best,
            prototype_training=self.prototype_training_,
            unlabeled_prototype_training=self.unlabeled_prototype_training_,
            eta=eta,
            alpha_epoch_anneal=alpha_epoch_anneal,
            w_logger=w_logger,
            prefix = prefix,
            **kwargs,
        )
        print("ScPoliTrainer Initialized")
        self.trainer.train(n_epochs, lr, eps)
        self.is_trained_ = True
        self.prototypes_labeled_ = self.model.prototypes_labeled
        self.prototypes_unlabeled_ = self.model.prototypes_unlabeled

In [None]:
assert load_dotenv()

# Loading data


In [None]:
# Load the h5ad file
adata_path = '04_MAIN_geneUniverse.h5ad'
adata = sc.read_h5ad(adata_path)

In [None]:
adata

### Setting up parameters for scPoly training

In [None]:
scPoli_params = dict()

In [None]:
early_stopping_kwargs = {
    "early_stopping_metric": "val_prototype_loss",
    "mode": "min",
    "threshold": 0.1, 
    "patience": 5,
    "reduce_lr": False,
    "lr_patience": 5,
    "lr_factor": 0.1,
}

In [None]:
model_parameters_dict = {
    "condition_keys":['chemistry','sampleID'],
    "cell_type_keys":['Level1', 'Level2'],
    "embedding_dims":[3,PATIENT_EMBEDDING_DIM],
    "latent_dim": LATENT_DIM,
    "recon_loss":'nb'
}

In [None]:
train_parameters_dict = {
    "lr": LEARNING_RATE,
    "n_epochs": 1000,
    "alpha_epoch_anneal": 200,
    "pretraining_epochs": 40,
    "eta": 1, 
    "unlabeled_prototype_training": False,
    "n_workers": 8
}

In [None]:
scPoli_params.update(early_stopping_kwargs)
scPoli_params.update(model_parameters_dict)
scPoli_params.update(train_parameters_dict)
scPoli_params

#### Connect to wandb

In [None]:
logger = wandb.init(
    project='inflammation_atlas_PatientClassifier_scPoly', 
    entity='inflammation',
    config=scPoli_params,
    name = f'scPoli_{N_LAYERS}_{LATENT_DIM}_{PATIENT_EMBEDDING_DIM}_{RUN_TAG}',
    tags = [RUN_TAG]
)

In [None]:
WANDB_ID = logger.id

In [None]:
RunParams = dict(random_seed=random_seed, wandb_run_id = logger.id)
RunParams.update(scPoli_params)
RunParams

# scPoly integration

LOOK HERE FOR PARAMETERS: https://github.com/theislab/scPoli_reproduce/blob/main/pbmc8M/lataq_czi_8Mpbmc.py

In [None]:
scpoli_model = CustomScPoli( 
    adata=adata, 
    condition_keys=model_parameters_dict['condition_keys'],
    cell_type_keys=model_parameters_dict['cell_type_keys'],
    embedding_dims=model_parameters_dict['embedding_dims'],
    latent_dim = model_parameters_dict['latent_dim'], #let's try the value in between the suggested upper and lower bounds (i.e., 10 and 20)
    recon_loss=model_parameters_dict['recon_loss']
)

In [None]:
scpoli_model.train(
    **train_parameters_dict,
    early_stopping_kwargs=early_stopping_kwargs,
    w_logger=logger,
    prefix="train/"
)

In [None]:
wandb.finish()    

## Save the results

In [None]:
if overwriteData:
    scpoli_model.save(
        f"results/scPoly_model_{N_LAYERS}_{LATENT_DIM}_{PATIENT_EMBEDDING_DIM}_{RUN_TAG}_{WANDB_ID}", 
        overwrite = True, 
        save_anndata = False)

In [None]:
latents = scpoli_model.get_latent(
    adata,
    mean=True
)

In [None]:
final_ad = sc.AnnData(
    X=latents, 
    obs=adata.obs
)

In [None]:
final_ad.uns.update(scpoli_model.get_conditional_embeddings())

In [None]:
final_ad.write(
    here(f"03_Downstream_Analysis/PatientClassifier/scPoli/results/01_reference/output/scPoli_{N_LAYERS}_{LATENT_DIM}_{PATIENT_EMBEDDING_DIM}_{RUN_TAG}_{WANDB_ID}.h5ad"), 
    compression='gzip')