In [22]:
import os
import pickle
import scanpy as sc
import numpy as np
import process_data
import scipy.stats as stats

import warnings
warnings.filterwarnings("ignore")

import importlib
importlib.reload(process_data)

<module 'process_data' from '/home/masse/work/PsychADxD/7_trajectory/process_data.py'>

## Coverting PyTorch Lightning model outputs to an h5ad
This required for all downstream analysis.  
process_data will add donor-related information (e.g. average predicted Braak, average gene expression) to the h5ad.  
This worksheet shows an example of how to perform this conversion.  

In [5]:
# Helper function to measure model accuracy

def explained_var(x_pred, x_real):
    
    idx = ~np.isnan(x_real) * (x_real > -1)
    x = x_real[idx]
    y = x_pred[idx]
    ex_var = 1 - np.nanvar(x - y) / np.nanvar(x)
    r, p = stats.pearsonr(x, y)
    return ex_var, r, p

def classification_score(x_pred, x_real):
    
    idx = ~np.isnan(x_real) * (x_real > -1)
    s0 = np.sum((x_real[idx] == 0) * (x_pred[idx] < 0.5)) / np.sum(x_real[idx] == 0)
    s1 = np.sum((x_real[idx] > 0.99) * (x_pred[idx] >= 0.5)) / np.sum(x_real[idx] > 0.99)
    return (s0 + s1) / 2

In [None]:
# When save data is True, model results will be saved as an h5ad after running create_data below
# Can be set to False to first measure model accuracy across different epochs
save_data = False

In [None]:
# Data and metadata are created by create_dataset.py
# Must fill in these file names!!!
data_fn = "XXXX"
meta_fn = "XXXX"

In [4]:
mr = process_data.ModelResults(
    data_fn = data_fn,
    meta_fn = meta_fn,
    obs_list = ["pred_BRAAK_AD", "pred_Dementia",],  
    include_analysis_only = True, # only include donors with no comorbidities, used for paper
    normalize_gene_counts = True,
    log_gene_counts = True,
)

In [24]:
# where PyTorch Lightning model outputs are saved
base_path = "lightning_logs"

# min_cell_count only used when measuring Braak and Dementia accuracy, does not affect saved model
min_cell_count = 5 

# the 20 splits be saved in different lightning_logs/version_XX directories
# version_nums indicate which versions to look at
version_nums = np.arange(20, 40)

# can average the results over ultiple epochs if desired
n_epochs_to_avg = 1

model_save_fn = "test.h5ad"

for epoch in range(21 - n_epochs_to_avg):

    fns = []
    for n in range(n_epochs_to_avg):
        fns_temp = []
        for v in version_nums:
            fns_temp.append(os.path.join(base_path, f"version_{v}/test_results_ep{epoch + n}.pkl"))
        fns.append(fns_temp)

    adata = mr.create_data(fns, model_average=True)
    # when measuring BRAAK accuracy, select donors with at least min_cell_count cells, and BRAAK value must exist (>= 0)
    idx = (adata.uns["donor_BRAAK_AD"] > -1) * (adata.uns["donor_cell_count"] >= min_cell_count)
    braak_acc, _ = stats.pearsonr(adata.uns["donor_pred_BRAAK_AD"][idx], adata.uns["donor_BRAAK_AD"][idx])
    
     # when measuring Dementia accuracy, select donors with at least min_cell_count cells, and Dementia value must exist (>= 0)
    idx = (adata.uns["donor_Dementia"] > -1) * (adata.uns["donor_cell_count"] >= min_cell_count)
    dementia_acc = classification_score(adata.uns["donor_pred_Dementia"], adata.uns["donor_Dementia"])

    bd_corr, _ = stats.pearsonr(adata.uns["donor_pred_Dementia"], adata.uns["donor_pred_BRAAK_AD"])

    print(f"Epoch: {epoch}, dementia acc: {dementia_acc:1.4f}, braak corr: {braak_acc:1.4f}, braak-dementia corr: {bd_corr:1.4f}")

if save_data:
    adata.write(model_save_fn)
    print(f"{model_save_fn} saved")
else:
    print("Model not saved!")

Subclasses present: ['PVM' 'Micro' 'Adaptive']
Epoch: 0, dementia acc: 0.6205, braak corr: 0.4601, braak-dementia corr: 0.7603
Subclasses present: ['PVM' 'Micro' 'Adaptive']
Epoch: 1, dementia acc: 0.6405, braak corr: 0.5089, braak-dementia corr: 0.8459
Subclasses present: ['PVM' 'Micro' 'Adaptive']
Epoch: 2, dementia acc: 0.6607, braak corr: 0.5135, braak-dementia corr: 0.6809
Subclasses present: ['PVM' 'Micro' 'Adaptive']
Epoch: 3, dementia acc: 0.6451, braak corr: 0.5199, braak-dementia corr: 0.5999
Subclasses present: ['PVM' 'Micro' 'Adaptive']
Epoch: 4, dementia acc: 0.6341, braak corr: 0.5200, braak-dementia corr: 0.5693
Subclasses present: ['PVM' 'Micro' 'Adaptive']
Epoch: 5, dementia acc: 0.6275, braak corr: 0.5250, braak-dementia corr: 0.5629
Subclasses present: ['PVM' 'Micro' 'Adaptive']
Epoch: 6, dementia acc: 0.6475, braak corr: 0.5244, braak-dementia corr: 0.5410
Subclasses present: ['PVM' 'Micro' 'Adaptive']
Epoch: 7, dementia acc: 0.6545, braak corr: 0.5210, braak-dement