In [None]:
import seaborn as sns
import  matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import pickle as pkl
import pandas as pd
from scipy.stats import rankdata
import warnings
warnings.simplefilter("ignore", UserWarning)
from IPython.display import display

from Utils.data_preparation import get_feature_set
from Utils.plot_utils import plot_diag_lambda_mat, heatmap, annotate_heatmap #plot_full_relevance_matrix,

def normalise_vector(x):
    normalised_x = (x / np.linalg.norm(x)).ravel()
    assert np.abs(np.linalg.norm(normalised_x)-1) < 1e-6
    return normalised_x

## Compare Feature Ranking

### Examine the kernel type of the final SVM model

In [None]:
# SVM
folderpath = '/Users/lirui/Downloads/Cohort_Dementia_Prediction/Classification/Nested_CV_Results/Complete_Case/all_pooled/SVM/Multi/SVM_Multi.pkl'
full_results = pkl.load(open(folderpath, 'rb'))
final_model = full_results['final_model']
final_model.kernel

### Examine a single model

In [None]:
from Utils.plot_utils import plot_diag_lambda_mat

model_name = 'Reg_Cox'
analysis_type = 'Survival'

featureset_name = 'Multi'
input_variables_to_print, FS_name, var_description, cat_feature_indices = get_feature_set(featureset_name)

folderpath = '/Users/lirui/Downloads/Cohort_Dementia_Prediction/{}/Nested_CV_Results/Complete_Case/all_pooled/{}/Multi/{}_Multi.pkl'.format(analysis_type, model_name, model_name)
full_results = pkl.load(open(folderpath, 'rb'))
final_input_scaler = full_results['final_input_scaler']
final_model = full_results['final_model']


In [None]:
cont_feature_indices = [i for i in np.arange(len(input_variables_to_print)) if i not in cat_feature_indices]
cont_features = [input_variables_to_print[i] for i in cont_feature_indices]
pd.DataFrame.from_dict({
    'Features': cont_features,
    'Means': final_input_scaler.mean_,
    'SD': np.sqrt(final_input_scaler.var_)
})

In [None]:
print('Model coefficients:', final_model.coef_.ravel())
pd.DataFrame.from_dict({
    'Features': input_variables_to_print,
    'Values': final_model.coef_.ravel()
})


### Feature ranking by inspecting the final models

In [None]:
featureset_name = 'Multi'
input_variables_to_print, FS_name, var_description, cat_feature_indices = get_feature_set(featureset_name)
print(input_variables_to_print)
filepaths = {
    'Logistic': '/Users/lirui/Downloads/Cohort_Dementia_Prediction/Classification/Nested_CV_Results/Complete_Case/all_pooled/Logistic/Multi/Logistic_Multi.pkl',
    'Reg_Logistic': '/Users/lirui/Downloads/Cohort_Dementia_Prediction/Classification/Nested_CV_Results/Complete_Case/all_pooled/Reg_Logistic/Multi/Reg_Logistic_Multi.pkl',
    'GMLVQ': '/Users/lirui/Downloads/Cohort_Dementia_Prediction/Classification/Nested_CV_Results/Complete_Case/all_pooled/GMLVQ/Multi/GMLVQ_Multi.pkl',
    'GRLVQ': '/Users/lirui/Downloads/Cohort_Dementia_Prediction/Classification/Nested_CV_Results/Complete_Case/all_pooled/GRLVQ/Multi/GRLVQ_Multi.pkl',
    'CoxPH': '/Users/lirui/Downloads/Cohort_Dementia_Prediction/Survival/Nested_CV_Results/Complete_Case/all_pooled/CoxPH/Multi/CoxPH_Multi.pkl',
    'Reg_Cox': '/Users/lirui/Downloads/Cohort_Dementia_Prediction/Survival/Nested_CV_Results/Complete_Case/all_pooled/Reg_Cox/Multi/Reg_Cox_Multi.pkl'
}

In [None]:
# Get normalised feature importance
normalised_feature_importance = {}
for model_name in list(filepaths.keys()):
    final_model = pkl.load(
        open(filepaths[model_name], 'rb')
    )['final_model']
    if model_name in ['Logistic', 'Reg_Logistic', 'CoxPH', 'Reg_Cox']:
        normalised_coef = normalise_vector(final_model.coef_.ravel())
    elif model_name == 'GMLVQ':
        normalised_coef = normalise_vector(np.diag(final_model.lambda_))
    elif model_name == 'GRLVQ':
        normalised_coef = normalise_vector(final_model.lambda_)
    else:
        print('Unrecognised model type!')
    normalised_feature_importance[model_name] = normalised_coef

normalised_feature_importance_dict = pd.DataFrame.from_dict(normalised_feature_importance, orient='index', columns=input_variables_to_print)

normalised_feature_importance_dict.to_csv('/Users/lirui/Downloads/Cohort_Dementia_Prediction/Feature_Importance/Complete_Case_All_Pooled/Normalised_Feature_Importance.csv', index=True)


In [None]:
# Get feature ranking
results_df = pd.DataFrame.from_dict({'Rank': 1+np.arange(len(input_variables_to_print))})
for model_name in list(filepaths.keys()):
    final_model = pkl.load(
        open(filepaths[model_name], 'rb')
    )['final_model']
    if model_name in ['Logistic', 'Reg_Logistic', 'CoxPH', 'Reg_Cox']:
        coef = final_model.coef_.ravel()
    elif model_name == 'GMLVQ':
        coef = np.diag(final_model.lambda_)
    elif model_name == 'GRLVQ':
        coef = final_model.lambda_.ravel()
    else:
        print('Unrecognised model type!')
    weights = np.abs(coef)

    ranks = rankdata(-1*weights, 'min')

    results_df_for_this_model = pd.DataFrame.from_dict({
        model_name: input_variables_to_print,
        model_name+' Original Coefficients': coef,
        model_name+' Rank': ranks
    }, orient='columns')
    results_df_for_this_model = results_df_for_this_model.sort_values(model_name+' Rank', axis=0, ascending=True).reset_index(drop=True)
    results_df = pd.concat([results_df, results_df_for_this_model], axis=1)

display(results_df)


In [None]:
results_df.to_csv('/Users/lirui/Downloads/Cohort_Dementia_Prediction/Feature_Importance/Complete_Case_All_Pooled/Feature_Ranking.csv', index=False)

### Feature ranking by ablation study

In [None]:
model = 'SVM'
rank_metric = 'ROC-AUC' #ROC-AUC/Harrell_C
analysis = 'Classification' # Classification/Survival
rank_dataset = 'test'
ablation_results_folderpath = '/Users/lirui/Downloads/Cohort_Dementia_Prediction/{}/Nested_CV_Results/Complete_Case/all_pooled/{}/Ablation_Study'.format(analysis, model)

multi_var_results_folderpath = '/Users/lirui/Downloads/Cohort_Dementia_Prediction/{}/Nested_CV_Results/Complete_Case/all_pooled/{}/Multi'.format(analysis, model)
multi_var_results_df = pd.read_csv(multi_var_results_folderpath+'/{}_Multi_summary.csv'.format(model), header=[0,1], index_col=0)
multi_var_score = multi_var_results_df.loc['Mean', (rank_metric, rank_dataset)]

result = {}
multi_variables_to_print, FS_name, var_description, cat_feature_indices = get_feature_set('Multi')
for feature in multi_variables_to_print:
    results_folderpath_for_feature = ablation_results_folderpath+'/No_'+feature
    this_results_df = pd.read_csv(ablation_results_folderpath + '/No_{}/{}_No_{}_summary.csv'.format(feature, model, feature), header=[0,1], index_col=0)
    this_score = this_results_df.loc['Mean', (rank_metric, rank_dataset)]

    delta = this_score - multi_var_score
    result[feature] = delta

result_df = pd.DataFrame.from_dict(result, orient='index', columns=['Delta mean metric'])
order = result_df['Delta mean metric'].argsort()
result_df['Ranking'] = rankdata(result_df['Delta mean metric'].to_numpy(), 'min')
result_df = result_df.sort_values('Ranking', axis=0, ascending=True)
display(result_df)


In [None]:
result_df.to_csv('/Users/lirui/Downloads/Cohort_Dementia_Prediction/{}/Nested_CV_Results/Complete_Case/all_pooled/{}/Ablation_Study/{}_Feature_Ranking_by_{}.csv'.format(analysis, model, model, rank_metric))

## Inpect into a trained LVQ model

In [None]:
def plot_full_relevance_matrix(matrix, labels, savepath=None, showfig=False, cmap='YlGn', center=0):
    fig, ax = plt.subplots(1,1,figsize=(12,10))
    im, cbar = heatmap(matrix, labels, labels, ax=ax, cmap=cmap)
    texts = annotate_heatmap(im, valfmt="{x:.2f}")
    #fig.suptitle("Relevance Matrix from GMLVQ Model")
    fig.tight_layout()
    if showfig==True:
        plt.show()
    if savepath!=None:
        fig.savefig(savepath)

In [None]:
featureset_name = 'Multi'
input_variables_to_print, FS_name, var_description, cat_feature_indices = get_feature_set(featureset_name)
folderpath = '/Users/lirui/Downloads/Cohort_Dementia_Prediction/Classification/Nested_CV_Results/Complete_Case/all_pooled/GMLVQ/Multi/GMLVQ_Multi.pkl'
full_results = pkl.load( open(folderpath, 'rb'))
final_model = full_results['final_model']

In [None]:
fig, ax = plt.subplots(1,1, figsize=(10,8))
ax = sns.heatmap(
    final_model.lambda_, 
    vmin=-0.3, 
    vmax=0.35, 
    annot=True, 
    fmt='.2f', 
    linewidths=2, 
    cmap='coolwarm',
    ax=ax)
ax.set_xticklabels(input_variables_to_print, fontsize=12)
ax.set_yticklabels(input_variables_to_print, fontsize=12)
ax.set_xlabel('Feature', fontsize=14)
plt.setp(ax.get_xticklabels(), rotation=36, ha="right",rotation_mode="anchor")
#plt.setp(ax.get_yticklabels(), rotation=90, ha="right",rotation_mode="anchor")
plt.yticks(rotation=0) 
plt.title('Relevance Matrix ' + r'$\Lambda$'+' with Multimodal Feature Set 1', fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
fig.savefig('/Users/lirui/Downloads/Cohort_Dementia_Prediction/Classification/Nested_CV_Results/GMLVQ/external_harmo_only/inner_rank_AUC/transformed/4_year_outcome/multimodal_1_16_var/full_relevance_matrix.png')

In [None]:
feature_dim = final_model.lambda_.shape[0]
fig,ax = plt.subplots(figsize=(2.3,9))
ax.set_xlabel("Weight", fontsize=14)
ax.set_yticklabels([]) # Hide the left y-axis tick-labels
ax.set_yticks([]) # Hide the left y-axis ticks
ax.grid(True, axis='x')

ax1 = ax.twinx() # Create a twin x-axis
ax1.barh(range(feature_dim), np.diag(final_model.lambda_), align='center') # Plot using `ax1` instead of `ax`
ax1.set_yticks(range(feature_dim))
ax1.set_yticklabels(input_variables_to_print)
ax1.invert_yaxis()  # labels read top-to-bottom
ax.invert_xaxis()

plt.show()


In [None]:
fig.savefig('/Users/lirui/Downloads/Cohort_Dementia_Prediction/Classification/Nested_CV_Results/Complete_Case/all_pooled/GMLVQ/Multi/vertical_diag_lambda.png')

In [None]:
full_results = pkl.load( open('/Users/lirui/Downloads/Cohort_Dementia_Prediction/Classification/Nested_CV_Results/Complete_Case/all_pooled/GRLVQ/Multi/GRLVQ_Multi.pkl', 'rb'))
final_model = full_results['final_model']

input_variables_to_print, FS_name, var_description, cat_feature_indices = get_feature_set('Multi')
plot_diag_lambda_mat(final_model.lambda_, input_variables_to_print, 'GRLVQ', savepath='/Users/lirui/Downloads/Cohort_Dementia_Prediction/Classification/Nested_CV_Results/Complete_Case/all_pooled/GRLVQ/Multi/horizontal_diag_lambda.png', vertical_plot=False)