In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn import datasets, linear_model
from sklearn.linear_model import LinearRegression
import statsmodels.api as sm
from scipy import stats

from scipy.stats import ranksums

In [None]:
import os
working_dir = os.getcwd()
if working_dir.endswith('\\ADProgModel\\notebooks'):
    os.chdir('../')

In [None]:
import src.models.param_estimation_v1 as prestm

In [None]:
def get_errordf(pmdf, df):
    
    # comebine estimted and ground truth data
    pmerrdf = pd.concat((pmdf.set_index('RID'),
                         df.loc[(df.Years==0),['RID','beta','tpo','alpha1','alpha2','gamma']].set_index('RID')), axis=1)

    pmerrdf['alpha2_gamma'] = pmerrdf['alpha2']*pmerrdf['gamma']
    pmerrdf.reset_index(inplace=True)
    
    # compute squared error
    for param_name in ['beta','alpha1','alpha2_gamma','tpo']:
        pmerrdf['%s_se'%(param_name)] = (pmerrdf[param_name] - pmerrdf['%s_estm'%(param_name)])**2
        
    return pmerrdf

### Read data from CSV and create data array for each patient

In [None]:
datatype = 'synthetic' # 'adni', 'synthetic'
sheetname = 'train'

if datatype == 'adni':
    filename = 'dataset/processed/adni_split0.xls'
elif datatype == 'synthetic':
    filename = 'dataset/processed/synthetic_split0.xls'
    

df = pd.read_excel(filename, sheet_name=sheetname)

### Parameter estimation with missing data

##### Create synthetic dataset with missing data

In [None]:
# get percentage of samples missing per year
adtraindf = pd.read_excel('dataset/processed/adni_split0.xls', 
                       sheet_name='train')
advaldf = pd.read_excel('dataset/processed/adni_split0.xls', 
                       sheet_name='valid')
adtestdf = pd.read_excel('dataset/processed/adni_split0.xls', 
                       sheet_name='test')
addf = pd.concat((adtraindf, advaldf, adtestdf), axis=0, ignore_index=True)

miss_prob_series = 1 - addf.Years.value_counts()/len(addf.RID.unique())

In [None]:
# create synthetic dataset with equivalent missing values per year
drop_index_list = []

for yearnum in range(1,11):
    drop_index_list.append(df.loc[df.Years==yearnum].sample(frac=miss_prob_series[yearnum], 
                                     replace=False, random_state=2).index)

drop_index_list = np.array([item2 for item1 in drop_index_list for item2 in item1])

In [None]:
missdf = df.drop(index=drop_index_list).copy()

In [None]:
# cross check that the number of missing values per year are similar
pd.concat((missdf.Years.value_counts()/len(missdf.RID.unique()), 1-miss_prob_series), axis=1)

In [None]:
# remove subjects due to low visit coutns
viscounts = missdf.RID.value_counts()
drop_pats = viscounts.index[viscounts<3]
missdf.drop(index=missdf.index[missdf.RID.isin(drop_pats)], inplace=True)

##### Parameter estimation for different cases and their comparison

In [None]:
subname = 'RID'
reg1_av45 = 'reg1_av45'
reg2_av45 = 'reg2_av45'
reg1_mri = 'reg1_mri'
reg2_mri = 'reg2_mri'
cogvar = 'cogsc'
agename = 'demog1'
tcname = 'Years'
apoestatus = 'demog2'

dfcolnms = prestm.ColumnNames(subname, tcname, agename, cogvar, reg1_mri, reg2_mri, reg1_av45, reg2_av45)

admat = np.matrix([[0,1],[1,0]])
dticlinfo = prestm.DTIMat(admat)

demog_feat_list = ['demog1','demog2']
grouptypedf = df.loc[(df.Years==0), demog_feat_list].value_counts().reset_index()[demog_feat_list]

In [None]:
# clean case where entire data is available
group_nomiss_pmdf, ignore = prestm.compute_all_params_woY_bygroup(df, dfcolnms, dticlinfo, grouptypedf)

# missing data with per subject estimation
perpat_pmdf = prestm.compute_all_params_woY_perpat(missdf, dfcolnms, dticlinfo)

# missing data with group estimation
group_pmdf, ignore = prestm.compute_all_params_woY_bygroup(missdf, dfcolnms, dticlinfo, grouptypedf)


In [None]:
group_nomiss_pmerrdf = get_errordf(group_nomiss_pmdf, df)
group_nomiss_pmerrdf['type'] = 'group_nomiss'

perpat_pmerrdf = get_errordf(perpat_pmdf, missdf)
perpat_pmerrdf['type'] = 'perpat_miss'

group_pmerrdf = get_errordf(group_pmdf, missdf)
group_pmerrdf['type'] = 'group_miss'

In [None]:
par_se_list = ['RID','type','beta_se','alpha1_se','alpha2_gamma_se']
errordf = pd.concat((perpat_pmerrdf[par_se_list],
                     group_pmerrdf[par_se_list]), axis=0, ignore_index=True)

In [None]:
ax = sns.barplot(x='variable', y='value', hue='type',
            data=errordf.melt(id_vars=['RID','type'], value_vars=['beta_se', 'alpha1_se', 'alpha2_gamma_se']))
plt.xticks(ticks=[0,1,2], labels=[r'$\hat{\beta}$',r'$\hat{\alpha_1}}$',r'$\hat{\alpha_2 \gamma}$'], fontsize=14)
plt.yticks(fontsize=14)
plt.xlabel('Parameter', fontsize=15)
plt.ylabel('Squared error', fontsize=15)
plt.semilogy()
handles, _ = ax.get_legend_handles_labels()
ax.legend(handles, ["Individual", "Group"], fontsize=14)

plt.show()