In [None]:
import cellCnn
import importlib
importlib.reload(cellCnn)
import random
import numpy as np
### (from: https://github.com/eiriniar/CellCnn/blob/0413a9f49fe0831c8fe3280957fb341f9e028d2d/cellCnn/examples/NK_cell_ungated.ipynb ) AND https://github.com/eiriniar/CellCnn/blob/0413a9f49fe0831c8fe3280957fb341f9e028d2d/cellCnn/examples/PBMC.ipynb
import pandas as pd
from sklearn.metrics import roc_auc_score
from sklearn.utils import shuffle
import seaborn as sns
import matplotlib.pyplot as plt

from cellCnn.ms.utils.helpers import get_min_mean_from_clusters, get_chunks
from cellCnn.utils import mkdir_p
from cellCnn.plotting import plot_results
from cellCnn.ms.utils.helpers import get_chunks
from cellCnn.ms.utils.helpers import print_regression_model_stats
from cellCnn.plotting import plot_results
from cellCnn.utils import mkdir_p
from cellCnn.utils import save_results
from cellCnn.ms.utils.helpers import get_fitted_model
from cellCnn.ms.utils.helpers import split_test_valid_train
from cellCnn.ms.utils.helpers import calc_frequencies, get_chunks_from_df

import os
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score
from sklearn.model_selection import KFold

In [None]:
#%reset

In [None]:
##### state vars
cytokines = ['CCR2', 'CCR4', 'CCR6', 'CCR7', 'CXCR4', 'CXCR5', 'CD103', 'CD14', 'CD20', 'CD25', 'CD27', 'CD28', 'CD3',
             'CD4', 'CD45RA', 'CD45RO', 'CD56', 'CD57', 'CD69', 'CD8', 'TCRgd', 'PD.1', 'GM.CSF', 'IFN.g', 'IL.10',
             'IL.13', 'IL.17A', 'IL.2', 'IL.21', 'IL.22', 'IL.3', 'IL.4', 'IL.6', 'IL.9', 'TNF.a']
infile = 'cohort_denoised_clustered_diagnosis_patients.csv'
indir = 'data/input.nosync'
outdir = 'out_ms_default'
rand_seed = 123
train_perc = 0.7
test_perc = 0.3
batch_size_pheno = 840 # deprecated  # so a size of 8425 is about equally sized in batches
batch_size_cd4 = 550  # deprecated #so a size of 550 gets me 16 batches for cd4
## information from ms_data project
cluster_to_celltype_dict = {0: 'b', 1: 'cd4', 3: 'nkt', 8: 'cd8', 10: 'nk', 11: 'my', 16: 'dg'}

In [None]:

np.random.seed(rand_seed)
mkdir_p(outdir)
df = pd.read_csv(os.path.join(indir, infile), index_col=0)
df = df.drop_duplicates()  ### reduces overfitting at cost of fewer data
df.shape
##### no duplicates in

In [None]:
plt.figure()
patients_clusters_conf_table = pd.crosstab(df['gate_source'], df['cluster'])
sns.heatmap(patients_clusters_conf_table, annot=False, vmax=100)
plt.show()
#plt.savefig('images/patient_vs_cluster_conf_table.png')
patients_clusters_conf_table

In [None]:
print('Mean abundancies')
print('b: ' + str(patients_clusters_conf_table.iloc[:,0].mean()))
print('CD4: ' + str(patients_clusters_conf_table.iloc[:,1].mean()))
print('NKT: ' + str(patients_clusters_conf_table.iloc[:,2].mean()))
print('CD8: ' + str(patients_clusters_conf_table.iloc[:,3].mean()))
print('NK: ' + str(patients_clusters_conf_table.iloc[:,4].mean()))
print('My: ' + str(patients_clusters_conf_table.iloc[:,5].mean()))
print('dg: ' + str(patients_clusters_conf_table.iloc[:,6].mean()))


In [None]:
# pitch: key = gate_source, value = dict{diagnosis: df OR freq?}
rrms_df = df[df['diagnosis'] == 'RRMS']
rrms_patients2df = {id: patient_df.drop(columns=['diagnosis', 'gate_source']) for id, patient_df in
                    rrms_df.groupby('gate_source')}
nindc_df = df[df['diagnosis'] == 'NINDC']
nindc_patients2df = {id: patient_df.drop(columns=['diagnosis', 'gate_source']) for id, patient_df in
                     nindc_df.groupby('gate_source')}

In [None]:
print('RRMS cell-type abundances')
rrms_df.groupby('cluster').count()
print('Mean abundancy / patient is 273,032258065')

In [None]:
print('NINDC cell-type abundances')
nindc_df.groupby('cluster').count()

In [None]:
importlib.reload(cellCnn.ms.utils.helpers)
from cellCnn.ms.utils.helpers import *

#### here we could see freq differences across the 2 groups
print('Frequencies: ')
rrms_patients_freq = {id: calc_frequencies(patient_df, cluster_to_celltype_dict, return_list=True) for id, patient_df in
                      rrms_patients2df.items()}
nindc_patients_freq = {id: calc_frequencies(patient_df, cluster_to_celltype_dict, return_list=True) for id, patient_df
                       in nindc_patients2df.items()}
print('DONE')
# we got 31 patients each

In [None]:
#### To get true frequencies we need to get rid of 0 entries (there are patient without some cells type due to data)
print('RRMS')
rrms_freq_df = pd.DataFrame(list(rrms_patients_freq.values()), columns=list(cluster_to_celltype_dict.values()))
rrms_freq_df = rrms_freq_df.replace(0, np.NaN)  ## this lets us skip the 0 entries
rrms_freq_df.describe()

In [None]:
print('NINDC')
nindc_freq_df = pd.DataFrame(list(nindc_patients_freq.values()), columns=list(cluster_to_celltype_dict.values()))
nindc_freq_df = nindc_freq_df.replace(0, np.NaN)
nindc_freq_df.describe()


In [None]:
#batch_sizes = [1,2,3,4,5,7,10]
batch_sizes = [7]
batch_size_dict_nkt = dict()
cluster = 3
for batch_size in batch_sizes:
    ### desease states 1 = RRMS and 0 = NINDC
    selection_pool_rrms_nkt, too_few_data_rrms_nkt = get_chunks_from_df(rrms_patients2df,
                           freq_df=rrms_patients_freq,
                           desease_state=1,
                           cluster=cluster,
                           batch_size=batch_size)
    selection_pool_nindc_nkt, too_few_data_nindc_nkt = get_chunks_from_df(nindc_patients2df,
                           freq_df=nindc_patients_freq,
                           desease_state=0,
                           cluster=cluster,
                           batch_size=batch_size)
    #todo make sure list are equally long:
    if len(selection_pool_rrms_nkt) > len(selection_pool_nindc_nkt):
        selection_pool_rrms_nkt = selection_pool_rrms_nkt[:len(selection_pool_nindc_nkt)]
    elif len(selection_pool_rrms_nkt) < len(selection_pool_nindc_nkt):
        selection_pool_nindc_nkt = selection_pool_nindc_nkt[:len(selection_pool_rrms_nkt)]

    all_chunks = selection_pool_rrms_nkt + selection_pool_nindc_nkt
    np.random.shuffle(all_chunks)

    X = [selection[0] for selection in all_chunks]
    nkt = [selection[1] for selection in all_chunks]
    Y = [selection[2] for selection in all_chunks]
    batch_size_dict_nkt[batch_size] = (X, nkt, Y)

print('prep done')

In [None]:
##### NKT trial reg
model_container_nkt = []
stats_dict_reg_nkt = dict()
cluster=3

In [None]:
for batch_size, values in batch_size_dict_nkt.items():
    # for regression task stratified is wrong since there are no classes
    kf = KFold(n_splits=2, random_state=rand_seed, shuffle=True)
    model_container = []
    freq_idx =2
    X, nkt = values[0], values[1]
    nkt = [series[freq_idx] for series in nkt]
    X_test, X_train, X_valid, nkt_test, nkt_train, nkt_valid = split_test_valid_train(
        X=X,
        y=nkt,
        test_perc=test_perc,
        train_perc=train_perc,
        valid_perc=0.5, seed=rand_seed)
    X = np.array([*X_train, *X_valid])
    nkt = np.array([*nkt_train, *nkt_valid])
    i = 1
    for train_idx, valid_idx, in kf.split(X=X):
        outdir_pheno_reg_nkt = f'ms_pheno_reg_nkt_{batch_size}_v2_{i}'
        i = i +1
        X_train, X_valid = X[train_idx], X[valid_idx]
        nkt_train, nkt_valid = nkt[train_idx], nkt[valid_idx]
        model = get_fitted_model(X_train, X_valid, nkt_train, nkt_valid,
                                 nsubset=1000,
                                 nfilters=[3, 15, 25, 35], coeff_l1=0,
                                 max_epochs=100, nrun=30, learning_rate=None,
                                 ncell=batch_size,
                                 per_sample=True, regression=True,
                                 outdir=outdir_pheno_reg_nkt, verbose=False)
        model_container_nkt.append(model)
print('DONE NKT models built')

In [None]:
outdir_nkt = 'nkt'
stats_file = open(f"{outdir_nkt}_stats_file.txt", "x+")
stats_file.write(f"Batchsize: {batch_size}")
mses = []
for model in model_container_nkt:
    test_pred = model.predict(X_test)
    train_pred = model.predict(X_train)
    valid_pred = model.predict(X_valid)
    #print_regression_model_stats(test_pred, b_test)
    mse_test = mean_squared_error(nkt_test, test_pred)
    mse_train = mean_squared_error(nkt_train, train_pred)
    mse_valid = mean_squared_error(nkt_valid, valid_pred)
    mses.append(mse_test)
    stats_file.write(f'MSE test {mse_test}')
    stats_file.write(f'MSE train {mse_train}')
    stats_file.write(f'MSE valid {mse_valid}')
    stats_file.write('\n')
mean_mse = float(sum(mses) / len(mses))
stats_file.write(f"Mean MSE: {str(mean_mse)}")
stats_dict_reg_nkt[batch_size] = {'mean_mse': mean_mse}
stats_file.close()
print('DONE')