In [None]:
%matplotlib inline
import numpy as np
import torch
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import platform
import sys
import warnings
import pySuStaIn
from sklearn.metrics import roc_curve, auc
import matplotlib.image as mpimg
from pathlib import Path

import statsmodels.formula.api as smf
import os
from DPMoSt import DPMoSt
from utility import data_creation, plot_data, plot_solution, error_eval

folder='/media/aviani/External HD/postdoc/gppm' if platform.system()=='Linux' else '/Volumes/External HD/postdoc/gppm'
sys.path.insert(1, f'{folder}')
import GP_progression_model # type: ignore

In [None]:
warnings.filterwarnings("ignore")

In [None]:
random_state=42
torch.manual_seed(random_state)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device='cpu'
print(f'Device used: {device}')

# Data Creation

In [None]:
n_features=2
n_subjects=20
n_time_for_subject=1
device='cpu'
dpi=500

create_data=False
run_dpmost=False
run_gppm=False
run_sustain=False

name=f'n_features_{n_features}_n_subjects_{n_subjects}_n_time_for_subject_{n_time_for_subject}'
output_folder=f'examples/example_{name}'
Path(f'{output_folder}').mkdir(parents=True, exist_ok=True)

In [None]:
if not os.path.isfile(f'{output_folder}/dpmost_data.pkl'):
    dict_data=data_creation(n_subjects=n_subjects, 
                            n_time_points=n_time_for_subject, 
                            n_features=n_features, 
                            name_path=output_folder,
                            noise_std=0.5, 
                            max_dist=1, 
                            time_shifted=True, 
                            device='cpu')
else:
    with open(f'{output_folder}/dpmost_data.pkl', 'rb') as f:
        dict_data = pickle.load(f) 
data=dict_data['data']
plot_data(data, dict_data=dict_data, dpi=dpi, name_path=f'{output_folder}/data', save=True)

# DP-MoSt

In [None]:
if not os.path.isfile(f'{output_folder}/dpmost_sol.pkl'):
    dpmost=DPMoSt(data=data, 
                  device=device, 
                  benchmarks=False,
                  stopping_criteria=True,
                  name_path=output_folder,
                  verbose=True)

    dpmost.optimise(n_outer_iterations=20, n_inner_iterations=20, lr=1e-1)
    dpmost.save()
else:
    with open(f'{output_folder}/dpmost_sol.pkl', 'rb') as f:
        dpmost = pickle.load(f) 

In [None]:
dpmost=DPMoSt(data=data, 
            device=device, 
            n_prints=5, 
            benchmarks=True, 
            stopping_criteria=False,
            name_path=output_folder,
            lambda_reg_theta=0.001,
            verbose=True)

dpmost.optimise(n_outer_iterations=20, n_inner_iterations=20, lr=1e-1)

# Error evaluation

In [None]:
error_ospa, error_noise = error_eval(dpmost, dict_data)
print(f'Error OSPA: {error_ospa}\nError noise: {error_noise}')

# GPPM

In [None]:
input_data=data.copy()
reparameterization_model='time_shift'
monotonicity = [1 for k in range(data.iloc[:,2:].shape[1])]
input_data.rename(columns={"subj_id": "RID"}, inplace=True)
Xdata, Ydata, RID, list_biomarkers, group = GP_progression_model.convert_from_df(input_data, input_data.iloc[:,2:].columns, time_var='time')

N_outer_iterations=6
N_iterations=200
n_minibatch=5
# create a GPPM object
model = GP_progression_model.GP_Progression_Model(x=Xdata, 
                                                  y=Ydata, 
                                                  names_biomarkers=input_data.iloc[:,2:].columns,
                                                  monotonicity=monotonicity, 
                                                  trade_off=50, 
                                                  reparameterization_model=reparameterization_model, 
                                                  sigma_0=10, 
                                                  device=device)

if not os.path.isfile(f'{output_folder}/gppm_sol.pkl'):
     model.model = model.model.to(device)
     # Optimise the model
     model.Optimize(N_outer_iterations=N_outer_iterations, N_iterations=N_iterations, 
                n_minibatch=n_minibatch, verbose=True, plot=False, benchmark=False)
    
     model.Save(f'{output_folder}/', name=f'gppm_sol')
else:
     model.Load(f'{output_folder}/', name=f'gppm_sol')

In [None]:
y = []
t = []
model.tr.list_id = np.arange(len(model.x[0])).tolist()

x_min, x_max = np.inf, -np.inf
y_min, y_max = np.inf, -np.inf

for bio_pos, biomarker in enumerate(model.names_biomarkers):
    bio_id = np.where([model.names_biomarkers[i] == biomarker for i in range(model.N_biomarkers)])[0][0]
    x_data = model.model.time_reparameterization(model.x_torch)[bio_id].detach().data.cpu().numpy()
    y_data = model.y_torch[bio_id].detach().data.cpu().numpy()
    y.append(y_data)

    x_min = min(x_min, np.float64(np.min(x_data)))
    x_max = max(x_max, np.float64(np.max(x_data)))
    y_min = min(y_min, np.float64(np.min(y_data)))
    y_max = max(y_max, np.float64(np.max(y_data)))

    x_range = torch.autograd.Variable(torch.arange(x_min, x_max, np.float64((x_max - x_min) / 50)))
    x_range = x_range.reshape(x_range.size()[0], 1)

    new_x = model.Transform_subjects()
    t.append((x_data * model.x_mean_std[bio_id][1] + model.x_mean_std[bio_id][0]).flatten())

t=torch.tensor(t[0], dtype=torch.float32, device=device)
y=torch.tensor(np.array(y), dtype=torch.float32, device=device).T

for fdx in range(input_data.iloc[:,2:].shape[1]):
    y_min, y_max = y[:,fdx].min(), y[:,fdx].max()
    new_max=data.iloc[:,2+fdx].max()
    new_min=data.iloc[:,2+fdx].min()
    y_new = (y[:,fdx] - y_min)/(y_max - y_min)*(new_max - new_min) + new_min
    input_data.iloc[:,2+fdx]=y_new.cpu().numpy()

t_min, t_max = t.min(), t.max()
new_max=20
new_min=0
t_new = (t - t_min)/(t_max - t_min)*(new_max - new_min) + new_min

input_data['time']=t_new.cpu().numpy()
input_data.rename(columns={"RID": "subj_id"}, inplace=True)

In [None]:
plot_data(input_data, dict_data=dict_data, alpha=0, name_path=f'{output_folder}/gppm_sol', dpi=dpi)

# SuStaIn

In [None]:
zdata = data.copy()
# for each biomarker
for biomarker in zdata.columns[2:]:
    mod = smf.ols('%s ~ time'%biomarker, data=zdata).fit()    
    predicted = mod.predict(zdata[['time',biomarker]])     
    w_score = (zdata.loc[:,biomarker] - predicted) / mod.resid.std()
    zdata.loc[:,biomarker] = w_score

In [None]:
Z_vals = np.array([[0.5,1,1.5]]*dict_data['n_features'])     # Z-scores for each biomarker
Z_max = np.array([5]*dict_data['n_features'])           # maximum z-score
N_S_max=3
N_iterations_MCMC=int(1e4)

In [None]:
# Initiate the SuStaIn object
sustain_input = pySuStaIn.ZscoreSustain(
                            zdata[zdata.columns[2:]].values,
                            Z_vals=Z_vals,
                            Z_max=Z_max,
                            biomarker_labels=zdata.columns[2:],
                            N_startpoints=15,
                            N_S_max=N_S_max, 
                            N_iterations_MCMC=N_iterations_MCMC, 
                            output_folder=output_folder, 
                            dataset_name=f'sustain_sol', 
                            use_parallel_startpoints=True)

_=sustain_input.run_sustain_algorithm()

In [None]:
# for each subtype model
all_like_mean=np.zeros(N_S_max)
plt.figure(figsize=(10,3))
for s in range(N_S_max):
    # load pickle file (SuStaIn output) and get the sample log likelihood values
    pickle_filename_s = output_folder + f'/pickle_files/sustain_sol_subtype' + str(s) + '.pickle'
    pk = pd.read_pickle(pickle_filename_s)
    samples_likelihood = pk["samples_likelihood"]
    all_like_mean[s]=samples_likelihood.mean()
    
    plt.plot(range(N_iterations_MCMC), samples_likelihood, label="sub-pop" + str(s+1))
    plt.legend(loc='upper right')
    plt.xlabel('MCMC samples')
    plt.ylabel('Log likelihood')
    plt.title('MCMC trace')
plt.xlim(0,N_iterations_MCMC+0.15*N_iterations_MCMC)
plt.tight_layout()
plt.show()

In [None]:
s = 1#np.argmax(all_like_mean)
M = len(zdata) 

pickle_filename_s = output_folder + f'/pickle_files/sustain_sol_subtype' + str(s) + '.pickle'
pk = pd.read_pickle(pickle_filename_s)

for variable in ['ml_subtype', 'prob_ml_subtype', 'ml_stage', 'prob_ml_stage',]:
    zdata.loc[:,variable] = pk[variable]
for i in range(s):
    zdata.loc[:,'prob_S%s'%i] = pk['prob_subtype'][:,i]

zdata.to_csv(f'{output_folder}/sustain_sol')

In [None]:
samples_sequence = pk["samples_sequence"]
samples_f = pk["samples_f"]

# use this information to plot the positional variance diagrams
if s==0:
    tmp=pySuStaIn.ZscoreSustain._plot_sustain_model(sustain_input, samples_sequence, samples_f, M, figsize=(25,5))
else:
    tmp=pySuStaIn.ZscoreSustain._plot_sustain_model(sustain_input, samples_sequence, samples_f, M, subtype_order=(0,1), figsize=(25,5))

plt.tight_layout()
plt.savefig(f'{output_folder}/sustain_sol',dpi=dpi)

# Analysis

In [None]:
sol_aux = data.copy()
sol_aux['sustain_subpop']=[1 if _>0.5 else 0 for _ in zdata['prob_S0']]

sub_pop_counts = sol_aux.groupby('subj_id')['sustain_subpop'].value_counts().unstack(fill_value=0).reset_index()
sub_pop_counts['ratio']=sub_pop_counts[0]/sub_pop_counts[1]
sub_pop_counts['variance']=sub_pop_counts['ratio']*(sub_pop_counts['ratio']-1)
sub_pop_counts['sustain_subpop_mean']=[0 if _>1 else 1 for _ in sub_pop_counts['ratio']]
sub_pop_counts['dpmost_subpop']=[1 if _>0.5 else 0 for _ in dpmost.pi]

In [None]:
fpr, tpr, _=roc_curve(sub_pop_counts['sustain_subpop_mean'].values, 1-dpmost.pi)
fpr_2, tpr_2, _=roc_curve(sub_pop_counts['sustain_subpop_mean'].values, dpmost.pi)

if auc(fpr, tpr)>auc(fpr_2, tpr_2):
    tpr=tpr_2
    fpr=fpr_2

In [None]:
fig, ax=plt.subplots(1,3, figsize=(15,3))
plt.sca(ax[0])
plt.plot(tpr,fpr)
plt.plot([0, 1], [0, 1], linestyle='--', linewidth=1, color='r', label='Random guess')
plt.xticks([0, 0.25, 0.5, 0.75, 1])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title(r'ROC curve')

plt.sca(ax[1])
plt.plot(sub_pop_counts['sustain_subpop_mean'].values, 'r.', alpha=0.5, label='sustain')
plt.plot(sub_pop_counts['dpmost_subpop'].values, 'bx', alpha=0.5, label='dpmost')
plt.title('Sub-populations differences')
plt.legend()

plt.sca(ax[2])
plt.plot(sub_pop_counts['variance'].values, '.')
plt.title('Variance')

plt.tight_layout()
plt.show()

# Biomarkers Progression

In [None]:
names=['data.png', 'gppm_sol.png', 'dpmost_sol.png', 'sustain_sol.png']
titles=['data', 'gppm', 'dp-most', 'sustain']

fig, ax=plt.subplots(len(names),1, figsize=(5*n_features,5*len(names)))
for idx in range(len(names)):
    plt.sca(ax[idx])
    img = mpimg.imread(f'{output_folder}/' + names[idx])
    plt.imshow(img)
    plt.title(titles[idx], fontsize=20)
    plt.axis('off')
    plt.tight_layout()
    
plt.savefig(f'{output_folder}/sol_comparison',dpi=dpi)
plt.show()

# Distributions

In [None]:
aux=zdata.copy()
aux['algorithm']=['sustain' for _ in range(aux.shape[0])]

a=np.zeros(aux.shape[0])
for idx in range(aux.shape[0]):
    indices = torch.nonzero(torch.tensor(data['subj_id'].unique()) == aux['subj_id'].values[idx], as_tuple=True)[0].item()
    a[idx]= 0 if sub_pop_counts['ratio'][indices] > 0.5 else 1
aux['est_subpop']=a


aux = aux.drop(columns=['ml_subtype', 'prob_ml_subtype', 'ml_stage', 'prob_ml_stage', 'prob_S0'])

aux_2=dpmost.data.copy()
aux_2['algorithm']=['dpmost' for _ in range(aux_2.shape[0])]
aux_2['est_subpop']=dpmost.est_subpop

aux_3=pd.concat([aux, aux_2])

In [None]:
fig, ax=plt.subplots(1,len(data.columns[2:]),figsize=(5*len(data.columns[2:]),5), sharex=True, sharey=True)
for i in range(len(data.columns[2:])):
    plt.sca(ax[i])
    legend=True if i==0 else False
    sns.violinplot(data=aux_3, y=data.columns[2:][i], x='algorithm', hue='est_subpop', saturation=0.7, orient='v', 
                   split=True, gap=.05, bw_adjust=1.9 , inner='quartile', legend=legend, linewidth=1, palette=['steelblue','lightcoral'])
    if i==0: 
        handles, _ = ax[0].get_legend_handles_labels()
        plt.legend(handles, ['Sub-population 1', 'Sub-population 2'], loc="upper left")

    plt.title(data.columns[2:][i])
    plt.xlabel('', fontsize=0)
    plt.ylabel('', fontsize=0)
plt.tight_layout()
plt.savefig(f'{output_folder}/sol_distributions', dpi=dpi)