In [1]:
import sys
sys.path.append("../")

import os
current_path = os.getcwd()
print(current_path)

import pickle
import anndata

from src_trainer.plotting import plot_train_val, plot_HVG, plot_single_value, plot_gene_embeddings, load_anndata, reconstruct_data, marker_expressions
from src_trainer.new_train import MyModelClass, prepare_data 

/Users/franci/MasterThesis/multiomics-perturbation/new_model/src_trainer


Global seed set to 0


# Evaluation of Trained model

This is a guide focused on the evaluation of a trained model. 
The file provides following evaluation possibilities: 

- Plot history of model (train and valdiation loss)
- Compare predicted embeddings of the model to the original data embeddings for specific conditions(cell type, perturbation condition and population group) 

## Load the model arguments

In [2]:
model_name = '20220316-101844'

with open('../input/model_{}/{}_args.pickle'.format(model_name, model_name), 'rb') as handle:
    args = pickle.load(handle)

## Load trained model

In [3]:
if not args['save_anndata']:
    splits, dataset = prepare_data(args)
    adata = dataset.data[dataset.indices["train/val"]].copy()
    
my_model = MyModelClass.load("../input/model_{}".format(model_name))

print(my_model.summary_stats)
print(my_model.history)


AnnDataReadError: Above error raised while reading key '/layers' of type <class 'h5py._hl.group.Group'> from /.

## Loss 

Define the arguments to plot as dictionary where: 

- var: name in history to plot 
- var_label: label used for plot (y-axis)
- ylim: tuple indicating range for y-axis, if None then all values included
- modality: modality used to for model training (either RNA, protein, or CITEseq)

In [None]:
args_elbo = {'var': "elbo",
             'var_label': "elbo",
             'ylim': None,
             'modality': args['modality']
             }
args_rl = {'var': "reconstruction_loss",
           'var_label': "rl",
           'ylim': None,
           'modality': args['modality']
           }
args_kld = {'var': "kl_local",
           'var_label': "kld",
           'ylim': None,
           'modality': args['modality']
           }

list_args = [args_elbo, args_rl, args_kld]

for plot_args in list_args:
    if args['check_val_every_n_epoch'] is not None:
        plot_train_val(my_model, model_name, plot_args)
    else:
        plot_single_value(my_model, model_name, plot_args, 'train')

# Embeddings

Plot embeddings from original data and sampled from the trained model. 
The embeddings are plotted for specified covariates: 

Print possible covariates from data.

- cell_type: 
- pert_cond: 
- pop_group: 


## Marker genes

Mikhael expecation: 

Lineage markers for CD4T cells:
- RNA: CD3D, CD4
- ADT: CD3, CD4,

In CD4T cells, you should not find expression for the following:

- RNA: MS4A1, CD19, CD79A (B cell markers) S100A9, CD14 (Monocyte markers)
- ADT: CD19, CD20 (B cell markers) CD14, CD16 (Monocyte markers).

Regarding stimulation response, I would expect to see upregulation of the following

- RNA: IFNG, TNF, IL4, IL5, IL13
- ADT: CD69, CD62L

For difference across the groups in response to stimulation, we should see a decreasing trend in gene counts from RT to DK to LD for the RNAs I listed above. 
So those RNA should be produced by all groups, but to a varying degree.

In [None]:
marker_genes_dict = {'CD4T': ['CD3D', 'CD4'],
                     'B': ['MS4A1', 'CD19', 'CD79A'],
                     'Monocytes': ['S100A9', 'CD14']}

marker_proteins_dict = {'CD4T': ['CD3', 'CD4'],
                         'B': ['CD19', 'CD20'],
                         'Monocytes': ['CD14', 'CD16']}

pert_cond = 'medium'
pop_group = 'RT'

### Plot protein markers

In [None]:
for cell_type, markers_list in marker_genes_dict.items():
    org_adata = load_anndata(model_name, pert_cond, cell_type, pop_group)
    org_adata.X = org_adata.obsm['adt_raw']
    my_model, new_adata = reconstruct_data(model_name, cell_type, pert_cond, pop_group, n_samples = 1000)
    print("Plot markers from Mikhaels list")
    for marker in markers_list:
        plot_gene_embeddings(org_adata, new_adata, my_model, model_name, cell_type, [marker])
        marker_expressions(org_adata, my_model, new_adata.X, [], marker_proteins=[marker])
    
    #print("Find highly-variable genes")
    #plot_HVG(org_adata, new_adata, my_model, n_top_genes=10)
    

In [None]:
### Highly-variable genes

In [None]:
for cell_type, _ in marker_genes_dict.items():
    org_adata = load_anndata(model_name, pert_cond, cell_type, pop_group)
    my_model, new_adata = reconstruct_data(model_name, cell_type, pert_cond, pop_group, n_samples = 1000)
    plot_HVG(org_adata, new_adata, my_model, n_top_genes=10)
