In [None]:
for v in ['adataPATH_ref','adataPATH_query']:
    if v in locals() or v in globals():
        print(f"{v} = {eval(v)}")
    else:
        raise Exception(f"{v} not specified")

In [None]:
import os
import sys
from glob import glob

import scanpy as sc
import pandas as pd

import pynndescent
import numpy as np
import numba

from sklearn.metrics import balanced_accuracy_score as bas

from pyprojroot import here

sys.path.insert(1, str(here('bin')))
# Import custom functions
from customPythonFunctions import aggregating_features

### Loading data

In [None]:
adataPATH_ref
adataPATH_ref_list = glob(str(here(adataPATH_ref)))
assert(len(adataPATH_ref_list) == 1)

adataPATH_query
adataPATH_query_list = glob(str(here(adataPATH_query)))
assert(len(adataPATH_query_list) == 1)

In [None]:
adataR = sc.read_h5ad(adataPATH_ref_list[0])
adataQ = sc.read_h5ad(adataPATH_query_list[0])
adataR, adataQ

In [None]:
Level1_dict = adataR.obs[['Level1','Level2']].set_index('Level2').to_dict()['Level1']

In [None]:
adataQ.obs['Level2_scANVI_pred'] = adataQ.obs['labels']

In [None]:
adataQ.obs['Level1_scANVI_pred'] = adataQ.obs['Level2_scANVI_pred'].map(Level1_dict)

In [None]:
print(f"Label transfer with scANVI predict obtain a BAS ={bas(adataQ.obs['Level2'], adataQ.obs['Level2_scANVI_pred'])}, considering Level2")


In [None]:
print(f"Label transfer with scANVI predict obtain a BAS ={bas(adataQ.obs['Level1'], adataQ.obs['Level1_scANVI_pred'])}, considering Level1")

#### Generating PSEUDOBULKs

In [None]:
adataPB_R = aggregating_features(Z = adataR.X, 
                             obsDF = adataR.obs[['sampleID','Level1','disease']], 
                             mode = 'mean', 
                             obs_names_col=['sampleID','Level1'], 
                             min_observation=0)
adataPB_R

**Considering scANVI predicted labels**

In [None]:
adataPB_Q_scANVI = aggregating_features(Z = adataQ.X, 
                             obsDF = adataQ.obs[['sampleID','Level1_scANVI_pred','disease']], 
                             mode = 'mean', 
                             obs_names_col=['sampleID','Level1_scANVI_pred'], 
                             min_observation=0)

adataPB_Q_scANVI.obs.rename({'Level1_scANVI_pred':'Level1'}, axis=1, inplace=True)
adataPB_Q_scANVI

### Saving pseudobulk adata objects

In [None]:
adataPB_R.write(here(re.sub(r'/reference/scANVI_SPLIT_\d+_',
                            '/PSEUDOBULKs/scANVI_PSEUDOBULK_reference_', 
                            adataPATH_ref_list[0])), compression='gzip')

In [None]:
adataPB_Q_scANVI.write(here(re.sub(r'/query/scANVI_SPLIT_\d+_',
                                   '/PSEUDOBULKs/scANVI_PSEUDOBULK_query_',
                                   adataPATH_query_list[0])), compression='gzip')