# TODO:

- remove unnecessary imports
- add shap
- compare FIMP rankings
- remove 'temporary' stuff

## Interpretability outputs

Here we compare a number of standard interpretability methods and use them to produce bespoke visualisations. We choose to focus on the RF classifier, but this approach (other the use of TreeInterpreter) could be used with any classifier algorithm. 

To run this script please first run 'fitting_classifiers.ipynb' to train optimised SVC, LR, RF. This script uses those pre-trained joblib models.

In [1]:
DATA_DIR = "../data/"   
DATA = "CAP"
CLASS_LABEL = 'pca_death_code'  # Target label to predict

RANDOM_STATE = 42
TEST_SIZE = 0.2

CV = 5

In [7]:
import numpy as np
import pandas as pd
import pickle
from joblib import dump, load
from collections import Counter

import nltk
from nltk.stem import WordNetLemmatizer
nltk.download('stopwords', quiet=True)
nltk.download('wordnet', quiet=True)
from nltk.corpus import stopwords

from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.metrics import roc_curve, auc, roc_auc_score
from sklearn.calibration import CalibratedClassifierCV, calibration_curve
from sklearn.metrics import brier_score_loss, precision_score, recall_score, f1_score

from helper import (pd_print, 
                    accuracy,
                    lemmatize_text,
                    summarise_gridsearch_classifier,
                    calibrate_random_forest, 
                    plot_calibration_curve,
                    plot_calibration_curve_easy_hard,
                    plot_roc_curve,
                    compute_all_metrics)

from explainability import (get_rf_feature_importances,
                            wordcloud,
                            run_tree_interpreter,
                            get_ti_feature_contributions_for_instance_i,
                            get_ti_feature_contributions_average)

In [3]:
if DATA == "CAP":
    
    from cap_helper import *
    
    print("Loading CAP prostate cancer data for preprocessing.")
    df = load_data(DATA_DIR)
    # Combine text from all feature columns into a single string column
    df = concatenate_feature_columns(df)
    # Link to dates of death:
    df = add_dates(df, DATA_DIR)
    # Link to reviewer Ids:
    df = add_reviewer_ids(df, DATA_DIR)
    # Convert all dates to be in units of months before/after death (Note: this regex is not foolproof)
    df = convert_dates_relative(df)  
    
    print("Preprocessing complete.")
    
    with open('temp_data.pickle', 'wb') as outfile:
        pickle.dump(df, outfile)
    
    
## Temporary:
else:
    with open('temp_data.pickle', 'rb') as infile:
        df = pickle.load(infile)  

Loading CAP prostate cancer data for preprocessing.
Preprocessing complete.


In [4]:
stemmer = WordNetLemmatizer()
X,y = df.combined, df[CLASS_LABEL]
documents = lemmatize_text(X, stemmer)
X_train, X_test, y_train, y_test = train_test_split(documents, 
                                                    df[CLASS_LABEL], 
                                                    test_size=TEST_SIZE, 
                                                    random_state=RANDOM_STATE)

X_train = np.array(X_train)
y_train = np.array(y_train)

In [5]:
clf = load('models/cap_rf_gridsearch.joblib')

In [8]:
fimps = get_rf_feature_importances(clf)
assert fimps.feature.is_unique
fimps.head(15)

Unnamed: 0,feature,contribution,magnitude
342,bone scan,0.026547,0.026547
1293,spine,0.025653,0.025653
1482,widespread,0.023809,0.023809
528,docetaxel,0.022729,0.022729
264,androgen,0.02267,0.02267
1222,sclerotic,0.021713,0.021713
1212,scan,0.020712,0.020712
681,hormone,0.016235,0.016235
375,casodex,0.014661,0.014661
1192,rib,0.014009,0.014009
