In [4]:
import copy
import os
import shutil

import numpy as np
import importlib
import pickle
import matplotlib.cm as cm
from sklearn.manifold import TSNE
from plotting import plot_tsne_selection_grid
import matplotlib.pyplot as plt
### (from: https://github.com/eiriniar/CellCnn/blob/0413a9f49fe0831c8fe3280957fb341f9e028d2d/cellCnn/examples/NK_cell_ungated.ipynb ) AND https://github.com/eiriniar/CellCnn/blob/0413a9f49fe0831c8fe3280957fb341f9e028d2d/cellCnn/examples/PBMC.ipynb
import pandas as pd
import glob
import seaborn as sns

from cellCnn.ms.utils.helpers import *
from cellCnn.ms.utils.helpers import calc_frequencies, handle_directories
from cellCnn.ms.utils.helpers import get_fitted_model
from cellCnn.plotting import *
from cellCnn.utils import mkdir_p
from cellCnn.utils import save_results, get_selected_cells
from cellCnn.plotting import plot_filters, discriminative_filters
from sklearn.metrics import accuracy_score, mean_squared_error, r2_score


def reload_modules():
    import cellCnn.revised_uncertainty_loss
    import cellCnn.utils
    import cellCnn.plotting
    import cellCnn.loss_history
    import cellCnn.ms.utils.helpers

    importlib.reload(cellCnn.revised_uncertainty_loss)
    importlib.reload(cellCnn.ms.utils.helpers)
    importlib.reload(cellCnn.utils)
    importlib.reload(cellCnn.plotting)
    importlib.reload(cellCnn.model)
    importlib.reload(cellCnn.loss_history)
    importlib.reload(cellCnn)
    importlib.reload(cellCnn.ms.utils.helpers)
    import cellCnn.revised_uncertainty_loss
    import cellCnn.utils
    import cellCnn.plotting
    import cellCnn.loss_history
    import cellCnn.ms.utils.helpers

    pass

In [5]:
##### state vars
cytokines = ['CCR2', 'CCR4', 'CCR6', 'CCR7', 'CXCR4', 'CXCR5', 'CD103', 'CD14', 'CD20', 'CD25', 'CD27', 'CD28', 'CD3',
             'CD4', 'CD45RA', 'CD45RO', 'CD56', 'CD57', 'CD69', 'CD8', 'TCRgd', 'PD.1', 'GM.CSF', 'IFN.g', 'IL.10',
             'IL.13', 'IL.17A', 'IL.2', 'IL.21', 'IL.22', 'IL.3', 'IL.4', 'IL.6', 'IL.9', 'TNF.a']
infile = 'cohort_denoised_clustered_diagnosis_patients.csv'
indir = 'data/input'
outdir = 'test5_stl'
rand_seed = 123
train_perc = 0.7
test_perc = 0.3
batch_size_pheno = 840  # deprecated  # so a size of 8425 is about equally sized in batches
batch_size_cd4 = 550  # deprecated #so a size of 550 gets me 16 batches for cd4
## information from ms_data project
cluster_to_celltype_dict = {0: 'b', 1: 'cd4', 3: 'nkt', 8: 'cd8', 10: 'nk', 11: 'my', 16: 'dg'}

np.random.seed(rand_seed)
mkdir_p(outdir)
df = pd.read_csv(os.path.join(indir, infile), index_col=0)
df = df.drop_duplicates()  ### reduces overfitting at cost of fewer data
df.shape
##### no duplicates in

(16889, 38)

In [6]:
# pitch: key = gate_source, value = dict{diagnosis: df OR freq?}
rrms_df = df[df['diagnosis'] == 'RRMS']
rrms_patients2df = {id: patient_df.drop(columns=['diagnosis', 'gate_source']) for id, patient_df in
                    rrms_df.groupby('gate_source')}
nindc_df = df[df['diagnosis'] == 'NINDC']
nindc_patients2df = {id: patient_df.drop(columns=['diagnosis', 'gate_source']) for id, patient_df in
                     nindc_df.groupby('gate_source')}
#### here we could see freq differences across the 2 groups
print('Frequencies: ')
rrms_patients_freq = {id: calc_frequencies(patient_df, cluster_to_celltype_dict, return_list=True) for id, patient_df in
                      rrms_patients2df.items()}
nindc_patients_freq = {id: calc_frequencies(patient_df, cluster_to_celltype_dict, return_list=True) for id, patient_df
                       in nindc_patients2df.items()}
print('DONE')
### desease states 1 = RRMS and 0 = NINDC
selection_pool_rrms_cd8 = [(df.loc[:, df.columns != 'cluster'], rrms_patients_freq[patient], 1)
                           for patient, df in rrms_patients2df.items()]
selection_pool_nindc_cd8 = [(df.loc[:, df.columns != 'cluster'], nindc_patients_freq[patient], 0)
                            for patient, df in nindc_patients2df.items()]

# make sure list are equally long:
if len(selection_pool_rrms_cd8) > len(selection_pool_nindc_cd8):
    selection_pool_rrms_cd8 = selection_pool_rrms_cd8[:len(selection_pool_nindc_cd8)]
elif len(selection_pool_rrms_cd8) < len(selection_pool_nindc_cd8):
    selection_pool_nindc_cd8 = selection_pool_nindc_cd8[:len(selection_pool_rrms_cd8)]

all_chunks = selection_pool_rrms_cd8 + selection_pool_nindc_cd8
np.random.shuffle(all_chunks)  # to get differing phenotypes...

X = [selection[0].to_numpy() for selection in all_chunks]
freqs = [selection[1] for selection in all_chunks]
Y = [selection[2] for selection in all_chunks]
print('DONE: batches created')

Frequencies: 
DONE
DONE: batches created


In [7]:
    #todo we need to save our x_train and so on if we wantz to plot stuff here

In [9]:
indir = 'data/input'
df = pd.read_csv(os.path.join(indir, infile), index_col=0)
df = df.drop_duplicates()
cluster = df.loc[:, 'cluster'].astype(int).reset_index(drop=True)
cluster_to_color_dict = {0: u'orchid', 1: u'darkcyan', 3: u'grey', 8: u'dodgerblue', 10: u'honeydue', 11: u'turquoise',
                         16: u'darkviolet'}
cluster_to_color_series = cluster.replace(cluster_to_color_dict, regex=True)
diagnosis = df.loc[:, 'diagnosis'].astype(str).reset_index(drop=True)

samples = [df.iloc[:, :len(cytokines)]]
sample_names = ['cohort']
x = samples[0].reset_index(drop=True)
x_for_tsne = x.iloc[np.random.choice(x.shape[0], 1000), :]
x_tsne = TSNE(n_components=2).fit_transform(x)
x_tsne_df = pd.DataFrame(x_tsne)



In [131]:
#indir = '../../../v1_stl/stl_models_freqs/filter_10/stl_models/'
indir = '../../../v2_mtl/revised_uncertainty/v2/mtl_models_reg/'
overlap_dir_outer = '../../../comparisons/overlap_generated_v2/'

files = glob.glob(indir + '**/results.pkl', recursive=True)
print(files)

#comparison_dir = '../../../comparisons'

['../../../v2_mtl/revised_uncertainty/v2/mtl_models_reg/mtl_cd4/results.pkl', '../../../v2_mtl/revised_uncertainty/v2/mtl_models_reg/mtl_dg/results.pkl', '../../../v2_mtl/revised_uncertainty/v2/mtl_models_reg/mtl_b/results.pkl', '../../../v2_mtl/revised_uncertainty/v2/mtl_models_reg/mtl_nkt/results.pkl', '../../../v2_mtl/revised_uncertainty/v2/mtl_models_reg/mtl_my/results.pkl', '../../../v2_mtl/revised_uncertainty/v2/mtl_models_reg/mtl_nk/results.pkl', '../../../v2_mtl/revised_uncertainty/v2/mtl_models_reg/mtl_cd8/results.pkl']


In [102]:
reload_modules()

for file in files:
    print(f'File {file}')
    results = pickle.load(open(file, 'rb'))
    print('Creating needed directories')
    abundancy_dir, filters_dir, plotdir = handle_directories(file)
    overlap_dir = os.path.join(overlap_dir_outer, file.split('/')[-2])
    mkdir_p(overlap_dir)
    stats_dict = dict()
    #comparison_folder = file.split('/')[-2]
    #comparison_folder_minus_2 = file.split('/')[-3]
    #mkdir_p(f'{comparison_dir}/{comparison_folder}')

    print('Filter Info Stuff')
    filter_diff_thres_pool = [0.0, 0.05]
    # filter_diff_thres_pool = [0.0, 0.2, 0.1.]
    #filter_diff_thres = 0.0
    for filter_diff_thres in filter_diff_thres_pool:
        print(f'########### ###########')
        print(f'########### For Filter differency thresold: {filter_diff_thres}')
        print(f'########### ###########')
        if results['selected_filters'] is None:
            print(f' \n \n \n NO selected filter for {file} \n \n \n')
            continue
        filter_info = [(idx, filter_diff_thres) for idx in np.arange(results['selected_filters'].shape[0])]
        _v = discriminative_filters(results, os.path.join(plotdir, 'filter_plots_discriminative'),
                                    filter_diff_thres=filter_diff_thres, show_filters=True)

        print('Start of getting cells per filter...')
        flags = np.zeros((x.shape[0], 2 * len(filter_info)))
        columns = []

        for i, (filter_idx, thres) in enumerate(filter_info):
            #### Filter response threshold only plays a role when we dont select cells with this continous parameter
            cells = get_selected_cells(results['selected_filters'][filter_idx], x.to_numpy(), results['scaler'], thres,
                                       True)
            flags[:, 2 * i:2 * (i + 1)] = cells
            columns += ['filter_%d_continuous' % filter_idx, 'filter_%d_binary' % filter_idx]
            flags_df = pd.DataFrame(flags[:, 2 * i:2 * (i + 1)])
            flags_df['cluster'] = cluster
            flags_df['diagnosis'] = diagnosis
            flags_df.to_csv(os.path.join(filters_dir, f'filter_{filter_idx}_selected_cells_thresh_{thres}.csv'),
                            index=False)
            # here i could restrict the amount of fiulters by selected cells
            if len(flags_df[abs(flags_df.loc[:, 1]) != 0]) > 100:
                flags_df.to_csv(os.path.join(overlap_dir, f'F{filter_idx}_{indir.split("/")[-2]}_T:{thres}.csv'),
                            index=False)
            print(f'Saved selected cells with cluster & diagnosis for filter {filter_idx} ...')

            # I select all cells that are != 0
            selected_cells_filter = flags_df[abs(flags_df.loc[:, 1]) != 0]
            print(f'Filter  {filter_idx} saved {len(selected_cells_filter)} selected cells  ...')
            selected_rrms = selected_cells_filter[selected_cells_filter['diagnosis'] == 'RRMS']
            selected_nindc = selected_cells_filter[selected_cells_filter['diagnosis'] == 'NINDC']

            stats_dict[f'{thres}_{filter_idx}'] =(copy.copy(len(selected_cells_filter)), len(selected_rrms), len(selected_nindc))

            # this gets me the percentual abundancies
            selected_cells_filter_grpd_tot = selected_cells_filter.groupby('cluster').count()[0]
            selected_cells_filter_grpd = selected_cells_filter_grpd_tot / selected_cells_filter.shape[0]
            selected_cells_filter_grpd.to_csv(
                os.path.join(abundancy_dir, f'filter_{filter_idx}_cells_{len(selected_cells_filter)}_thresh_{thres}.csv'),
                index=True)
            print(f'Saves the selected cells for filter {filter_idx}')

            # as well save in comparison folder..
            #selected_cells_filter_grpd.to_csv(f'{comparison_dir}/{comparison_folder}/filter_{filter_idx}_{comparison_folder_minus_2}.csv', index=True)

            print(f'Saved cell-type abundancies for filter {filter_idx} ...')
            plot_tSNE_for_selected_cells(x_tsne_df, selected_cells_filter, cluster, cluster_to_color_dict, filter_idx,
                                         thres, abundancy_dir, cluster_to_celltype_dict)


        print('Filter specific stuff DONE....  \n \n \n ')
        plot_abundancy_comparison_barplot(cluster_to_celltype_dict, file, thres, abundancy_dir)
        print('Plotted a comparison barchart for abundancy levels\n\n')
    with open(os.path.join(filters_dir, 'selected_cell_statistics.csv'), 'w+') as f:
        for key in stats_dict.keys():
            f.write("%s,%s\n"%(key,stats_dict[key]))
    print('Plotted selected cell statistics')
print('all DONE')

File ../../../v2_mtl/revised_uncertainty/v3/mtl_models_reg/mtl_cd4/results.pkl
Creating needed directories
Filter Info Stuff
########### ###########
########### For Filter differency thresold: 0.0
########### ###########
Start of getting cells per filter...
Saved selected cells with cluster & diagnosis for filter 0 ...
Filter  0 saved 293 selected cells  ...
Saves the selected cells for filter 0
Saved cell-type abundancies for filter 0 ...
Saved selected cells with cluster & diagnosis for filter 1 ...
Filter  1 saved 0 selected cells  ...
Saves the selected cells for filter 1
Saved cell-type abundancies for filter 1 ...
Saved selected cells with cluster & diagnosis for filter 2 ...
Filter  2 saved 330 selected cells  ...
Saves the selected cells for filter 2
Saved cell-type abundancies for filter 2 ...
Filter specific stuff DONE....  
 
 
 
Plotted a comparison barchart for abundancy levels


########### ###########
########### For Filter differency thresold: 0.05
########### #########

In [19]:
##### #for comparison plots on specific directory

dir_mtl = '../../../v2_mtl/'
folders = glob.glob(dir_mtl + '*/')
print(folders)
for folder in folders:
    print(folder)
    print('######################################')
    print('######################################')
    print('######################################')

    ### pitch for plotting all the abundancies of all selected filters and plot them aside to compare within a model
    abundancy_files = glob.glob(folder + '/*.csv')
    dfs = [pd.read_csv(filename, index_col=0, header=0) for filename in abundancy_files]
    df = pd.concat(dfs, axis=1, ignore_index=True)
    df.columns = [filename.split('/')[-1] for filename in abundancy_files]
    df = df.reset_index()
    dfm = df.melt('cluster', var_name='cols', value_name='vals')
    fig, ax = plt.subplots(1, 1, figsize=(10, 10))
    if dfm.shape[0] == 0:
        continue
    sns.barplot(x='cluster', y='vals', hue='cols', data=dfm)
    available_labels = dfm.iloc[:, 0].unique()
    x_tick_labels = [v for k, v in cluster_to_celltype_dict.items() if k in available_labels]
    #  print('ERROR')
    ax.set_xticklabels(x_tick_labels)
    plt.title('Abundancy comparison barplot')
    plt.savefig(f'{folder}/comparison_barplot_thres.png')
    plt.close()
print('DONE plotting comparison barplot')

['../../../v2_mtl/mtl_models_reg/', '../../../v2_mtl/mtl_models_class_perSample/', '../../../v2_mtl/mtl_models_class_totloss/', '../../../v2_mtl/mtl_models_class/', '../../../v2_mtl/revised_uncertainty/', '../../../v2_mtl/mtl_models_class_perSample_totloss/', '../../../v2_mtl/uncertainty/']


KeyboardInterrupt: 

In [38]:
######## copies comparison barplot in comparison folder to better compare


import shutil

indir = '../../../v2_mtl/None/'
outdir_2 = '../../../comparisons/None'
files = glob.glob(indir + '**/comparison_barplot_thres_0.0.png', recursive=True)
print(files)
for file in files:
    celltype_folder = file.split('/')[-5]
    file_savestr = file.replace('/', '_')
    mkdir_p(f'{outdir_2}/{celltype_folder}')
    shutil.copy(file, f'{outdir_2}/{celltype_folder}/{file_savestr}')
print('done')

['../../../v2_mtl/None/mtl_models_class_perSample_accuracy/mtl_cd4/selected_cells/filters/abundancies/comparison_barplot_thres_0.0.png', '../../../v2_mtl/None/mtl_models_class_perSample_accuracy/mtl_dg/selected_cells/filters/abundancies/comparison_barplot_thres_0.0.png', '../../../v2_mtl/None/mtl_models_class_perSample_accuracy/mtl_b/selected_cells/filters/abundancies/comparison_barplot_thres_0.0.png', '../../../v2_mtl/None/mtl_models_class_perSample_accuracy/mtl_my/selected_cells/filters/abundancies/comparison_barplot_thres_0.0.png', '../../../v2_mtl/None/mtl_models_class_perSample_accuracy/mtl_cd8/selected_cells/filters/abundancies/comparison_barplot_thres_0.0.png', '../../../v2_mtl/None/mtl_models_class_acc/mtl_cd4/selected_cells/filters/abundancies/comparison_barplot_thres_0.0.png', '../../../v2_mtl/None/mtl_models_class_acc/mtl_dg/selected_cells/filters/abundancies/comparison_barplot_thres_0.0.png', '../../../v2_mtl/None/mtl_models_class_acc/mtl_b/selected_cells/filters/abundancie

In [None]:
#### TODOs:
# compare filters between desease states -> selected cells done
### compare performance of models

In [129]:
####### COMPARISON OF SELECTED CELLS BETWEEN FILTERS
#indir = '../../../v1_stl/stl_models_freqs/filter_10/stl_models/'
overlapdir = '../../../comparisons/overlap_generated_v2/mtl_nkt/'
files = sorted(glob.glob(overlapdir + '**/*.csv', recursive=True))[::-1]
print(files)

['../../../comparisons/overlap_generated_v2/mtl_nkt/filter_1_STL.csv', '../../../comparisons/overlap_generated_v2/mtl_nkt/F8_mtl_models_class_T:0.05.csv', '../../../comparisons/overlap_generated_v2/mtl_nkt/F8_mtl_models_class_T:0.0.csv', '../../../comparisons/overlap_generated_v2/mtl_nkt/F7_mtl_models_class_perSample_T:0.05.csv', '../../../comparisons/overlap_generated_v2/mtl_nkt/F7_mtl_models_class_perSample_T:0.0.csv', '../../../comparisons/overlap_generated_v2/mtl_nkt/F7_mtl_models_class_T:0.05.csv', '../../../comparisons/overlap_generated_v2/mtl_nkt/F7_mtl_models_class_T:0.0.csv', '../../../comparisons/overlap_generated_v2/mtl_nkt/F6_mtl_models_class_perSample_T:0.05.csv', '../../../comparisons/overlap_generated_v2/mtl_nkt/F6_mtl_models_class_perSample_T:0.0.csv', '../../../comparisons/overlap_generated_v2/mtl_nkt/F5_mtl_models_class_perSample_T:0.05.csv', '../../../comparisons/overlap_generated_v2/mtl_nkt/F5_mtl_models_class_perSample_T:0.0.csv', '../../../comparisons/overlap_gene

In [130]:
print(f'Compare selected cells among the files from {overlapdir}')
idxs = dict()
for i, file in enumerate(files):
    df = pd.read_csv(file, header=0)
    selected_cells_filter = df[abs(df.iloc[:, 1]) != 0]
    idxs[i] = selected_cells_filter.index

results = dict()
df = pd.DataFrame()
for key_outer, idxs_outer in idxs.items():
    result = []
    for key_inner, idxs_inner in idxs.items():
        diff = set(idxs_outer).intersection(idxs_inner)
        result.append(len(diff))
    results[files[key_outer].split('/')[-1]] = result
df = pd.DataFrame.from_dict(results)
df.index = results.keys()

print('\n\nDF showing differences between selected cells')
print(f'For {overlapdir}')
print('Center diagonal is the amount of selected cells of a filter')
df.to_csv(os.path.join(overlapdir, 'confusion_common_selected_cells.csv'))
df

Compare selected cells among the files from ../../../comparisons/overlap_generated_v2/mtl_nkt/


DF showing differences between selected cells
For ../../../comparisons/overlap_generated_v2/mtl_nkt/
Center diagonal is the amount of selected cells of a filter


Unnamed: 0,filter_1_STL.csv,F8_mtl_models_class_T:0.05.csv,F8_mtl_models_class_T:0.0.csv,F7_mtl_models_class_perSample_T:0.05.csv,F7_mtl_models_class_perSample_T:0.0.csv,F7_mtl_models_class_T:0.05.csv,F7_mtl_models_class_T:0.0.csv,F6_mtl_models_class_perSample_T:0.05.csv,F6_mtl_models_class_perSample_T:0.0.csv,F5_mtl_models_class_perSample_T:0.05.csv,...,F2_mtl_models_class_T:0.05.csv,F2_mtl_models_class_T:0.0.csv,F1_mtl_models_class_perSample_T:0.05.csv,F1_mtl_models_class_perSample_T:0.0.csv,F1_mtl_models_class_T:0.05.csv,F1_mtl_models_class_T:0.0.csv,F0_mtl_models_class_perSample_T:0.05.csv,F0_mtl_models_class_perSample_T:0.0.csv,F0_mtl_models_class_T:0.05.csv,F0_mtl_models_class_T:0.0.csv
filter_1_STL.csv,6155,34,530,9,115,231,697,94,320,98,...,69,1528,20,62,229,1022,27,125,5,341
F8_mtl_models_class_T:0.05.csv,34,34,34,0,0,22,26,2,5,0,...,0,0,3,4,0,0,3,11,0,0
F8_mtl_models_class_T:0.0.csv,530,34,643,0,3,125,267,14,63,0,...,9,61,21,51,0,36,12,73,0,0
F7_mtl_models_class_perSample_T:0.05.csv,9,0,0,13,13,0,0,0,0,0,...,1,7,0,0,1,5,0,0,0,1
F7_mtl_models_class_perSample_T:0.0.csv,115,0,3,13,165,1,5,2,6,1,...,3,78,0,0,3,46,0,0,0,18
F7_mtl_models_class_T:0.05.csv,231,22,125,0,1,231,231,3,14,0,...,0,7,2,8,0,1,3,17,0,7
F7_mtl_models_class_T:0.0.csv,697,26,267,0,5,231,700,12,45,2,...,0,29,4,14,0,11,5,36,0,14
F6_mtl_models_class_perSample_T:0.05.csv,94,2,14,0,2,3,12,110,110,2,...,8,52,1,1,0,20,15,30,0,2
F6_mtl_models_class_perSample_T:0.0.csv,320,5,63,0,6,14,45,110,459,12,...,30,196,3,9,8,79,30,96,0,10
F5_mtl_models_class_perSample_T:0.05.csv,98,0,0,0,1,0,2,2,12,372,...,50,188,2,3,112,284,1,4,21,186
