## Running Analyses

In [48]:
import tensorflow as tf
import numpy as np
import pandas as pd
import keras

from keras.models import load_model
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import OneHotEncoder

import sys
# sys.path.append("/Users/Work/Developer/interpretDL/interprettensor")
root_logdir = "./tf_logs"

# To plot pretty figures
%matplotlib widget
import matplotlib
import matplotlib.pyplot as plt
plt.rcParams['axes.labelsize'] = 14
plt.rcParams['xtick.labelsize'] = 12
plt.rcParams['ytick.labelsize'] = 12

# to make this notebook's output stable across runs
def reset_graph(seed=42):
    tf.reset_default_graph()
    tf.set_random_seed(seed)
    np.random.seed(seed)
    
tf.__version__

'1.13.1'

In [75]:
def perform_analysis(model, analyzer, data, labels):
    analysis = analyzer.analyze(data)
    prediction = model.predict(data)
    
    df_anal = pd.DataFrame(analysis)
    
    return df_anal

def get_relevant_cols(df, thresh = 1e-2):

    all_above_thresh = (df < thresh).all(0) #Check if all values in columns satisfy the criteria
    max_above_thresh = (df.max() < thresh)
    quantile_above_thresh = (df.quantile(0.8) < thresh)

    criteria = quantile_above_thresh
    irrelevant_cols = df.columns[criteria] 
    irrelevant_cols
    
    relevant_features_only = all_lrp.drop(columns = irrelevant_cols)
    
    return relevant_features_only

In [132]:
filename = "AD_CN_TH.csv"
raw_data = pd.read_csv(filename, index_col= 0)
features = raw_data.drop(["labels"],axis=1,)
labels = raw_data["labels"]

In [135]:
test_idx = pd.read_csv("test_indices.csv", dtype=int, index_col=0)["0"]
test_samples = features.iloc[test_idx].index
train_samples = features.drop(test_samples, axis="index")
test_samples

Int64Index([12, 62, 65, 108, 118, 126, 133], dtype='int64')

## Setting up the model and data

In [136]:
# Selecting a DNN
model = load_model("best_dnn.h5")

hot_encoder = OneHotEncoder(categories="auto", sparse=False)
hot_encoder.fit(labels.values.reshape(-1,1))
sample_labels = hot_encoder.transform([[label] for label in labels])
print("Categories:", hot_encoder.categories_)

ZScaler = StandardScaler()
ZScaler.fit(train_samples)
samples = ZScaler.transform(features)

Categories: [array(['AD', 'CN'], dtype=object)]


In [137]:
predictions = model.predict(samples)
preds = np.array([np.argmax(x) for x in predictions])
true_labels = np.array([np.argmax(x) for x in sample_labels])

correct = preds == true_labels
AD_Sample = true_labels == 0

correct_preds = preds[correct]
correct_preds.shape

print("SANITY CHECK")
loss_and_metrics = model.evaluate(samples[correct], sample_labels[correct])
print("Scores on test set: loss={:0.3f} accuracy={:.4f}".format(*loss_and_metrics))

SANITY CHECK
Scores on test set: loss=0.246 accuracy=1.0000


## Starting LRP

In [138]:
import innvestigate
import innvestigate.utils as iutils


# Stripping the softmax activation from the model
model_wo_sm = iutils.keras.graph.model_wo_softmax(model)

# Creating an analyzer
lrp_E = innvestigate.analyzer.relevance_based.relevance_analyzer.LRPEpsilon(model=model_wo_sm)

lrp = innvestigate.analyzer.relevance_based.relevance_analyzer.LRPAlpha2Beta1(model=model_wo_sm)

# Getting all the samples that can be correctly predicted
test_idx = correct
all_samples = samples[test_idx] 
all_labels = sample_labels[test_idx]

# perform_analysis(nn,gradient_analyzer,flowers,types)
all_lrp = perform_analysis(model,lrp, all_samples, all_labels)

all_lrp_E = perform_analysis(model,lrp_E, all_samples, all_labels)

In [139]:
lrp_results = all_lrp
# lrp_E_results = all_lrp_E
population = lrp_results.mean()
population.plot()
# plt.show()

<matplotlib.axes._subplots.Axes3DSubplot at 0x14844a208>

In [140]:
sorted_features = population.sort_values(ascending=False)
best_features = sorted_features[:6]

features.columns[best_features.index]

Index(['G_oc-temp_med-Parahip_TH_rh', 'S_precentral-sup-part_TH_lh',
       'G_oc-temp_med-Parahip_TH_lh', 'G_front_inf-Orbital_TH_lh',
       'G_cingul-Post-dorsal_TH_rh', 'G_orbital_TH_rh'],
      dtype='object')

In [147]:
# fig = plt.figure()
# population = all_lrp.mean()
# sorted_features = population.sort_values(ascending=False)
# sorted_features.plot(kind="bar", figsize=[12,10])
# plt.xticks(rotation=65, fontsize="small")
desc = lrp_results.describe()
desc

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,138,139,140,141,142,143,144,145,146,147
count,76.0,76.0,76.0,76.0,76.0,76.0,76.0,76.0,76.0,76.0,...,76.0,76.0,76.0,76.0,76.0,76.0,76.0,76.0,76.0,76.0
mean,0.042596,-0.00392,-0.003798,0.034732,0.01903,-0.000935,0.005752,0.000653,0.004047,0.0021,...,0.002124,0.005619,0.010609,0.002386,0.011664,0.010957,0.004685,-0.000682,-0.017008,0.00021
std,0.258553,0.223624,0.461686,0.17112,0.092938,0.031058,0.024395,0.024588,0.027262,0.069183,...,0.045212,0.028789,0.063435,0.013141,0.082244,0.045732,0.025372,0.024586,0.135524,0.009963
min,-0.618569,-1.807765,-3.446275,-0.354645,-0.346158,-0.13487,-0.072854,-0.200798,-0.036922,-0.436161,...,-0.10311,-0.028655,-0.037577,-0.016503,-0.135802,-0.079952,-0.028642,-0.197539,-1.137114,-0.064132
25%,-0.010943,-0.002689,-0.025059,-0.007361,-0.005988,-0.001873,-0.003449,-0.0005,-0.002898,-0.002143,...,-0.00679,-0.000364,-0.005512,-0.001171,-0.001448,-0.000221,-0.00139,-0.000161,-0.000899,-0.000863
50%,0.002077,0.01517,0.003013,0.003445,0.009111,0.000328,0.001001,0.000816,0.000916,0.004374,...,0.000347,0.000552,-9.6e-05,-8.2e-05,0.000905,0.001832,0.000597,0.00191,7.1e-05,0.000591
75%,0.015865,0.041253,0.068085,0.03676,0.042037,0.003882,0.011595,0.004271,0.00555,0.012642,...,0.004167,0.003234,0.00634,0.003704,0.007487,0.009538,0.003157,0.004971,0.001288,0.004082
max,1.828964,0.416703,1.396092,1.381291,0.347212,0.153877,0.156249,0.036516,0.211771,0.322382,...,0.251897,0.209,0.406834,0.092465,0.624229,0.336468,0.202937,0.023239,0.097216,0.013877


In [148]:
def get_relevant_cols(df, thresh = 1e-2):

    all_above_thresh = (df < thresh).all(0) #Check if all values in columns satisfy the criteria
    max_above_thresh = (df.max() < thresh)
    quantile_above_thresh = (df.quantile(0.8) < thresh)

    criteria = quantile_above_thresh
    irrelevant_cols = df.columns[criteria] 
    irrelevant_cols
    
    relevant_features_only = all_lrp.drop(columns = irrelevant_cols)
    
    return relevant_features_only

In [149]:
relevant_features_only = get_relevant_cols(all_lrp)
relevant_features_only.describe()

Unnamed: 0,0,1,2,3,4,6,9,10,12,13,...,121,123,125,126,127,132,134,136,137,143
count,76.0,76.0,76.0,76.0,76.0,76.0,76.0,76.0,76.0,76.0,...,76.0,76.0,76.0,76.0,76.0,76.0,76.0,76.0,76.0,76.0
mean,0.042596,-0.00392,-0.003798,0.034732,0.01903,0.005752,0.0021,-0.001536,0.049915,0.013953,...,0.017248,0.007985,0.046566,0.008115,-0.103683,-0.009122,-0.022484,0.012433,0.029794,0.010957
std,0.258553,0.223624,0.461686,0.17112,0.092938,0.024395,0.069183,0.094439,0.598275,0.545583,...,0.157852,0.029446,0.527406,0.145763,1.015396,0.064157,0.304771,0.16129,0.117416,0.045732
min,-0.618569,-1.807765,-3.446275,-0.354645,-0.346158,-0.072854,-0.436161,-0.595094,-1.144035,-3.365649,...,-0.545281,-0.175718,-1.878063,-1.138978,-7.565024,-0.396798,-2.278876,-0.347809,-0.190511,-0.079952
25%,-0.010943,-0.002689,-0.025059,-0.007361,-0.005988,-0.003449,-0.002143,-0.017833,-0.070138,-0.020153,...,0.000292,-0.00026,-0.020477,-0.006386,-0.027307,-0.009577,-0.031321,-0.022712,-0.001606,-0.000221
50%,0.002077,0.01517,0.003013,0.003445,0.009111,0.001001,0.004374,-0.000118,0.011632,0.01376,...,0.005905,0.003911,0.005562,0.009208,0.005444,0.004811,0.001989,0.002938,0.00673,0.001832
75%,0.015865,0.041253,0.068085,0.03676,0.042037,0.011595,0.012642,0.027627,0.086324,0.093517,...,0.01506,0.017853,0.045913,0.038671,0.04407,0.015188,0.060845,0.015839,0.033802,0.009538
max,1.828964,0.416703,1.396092,1.381291,0.347212,0.156249,0.322382,0.304244,4.814804,2.682676,...,1.242268,0.11888,2.91953,0.271285,2.914621,0.067084,0.815949,1.008004,0.912922,0.336468


In [150]:
from sklearn.decomposition import PCA

pca = PCA(n_components=2, random_state=42)
pos_only = all_lrp.copy()
pos_only[pos_only < 0] = 0
pca.fit(pos_only)

PCA(copy=True, iterated_power='auto', n_components=2, random_state=42,
  svd_solver='auto', tol=0.0, whiten=False)

In [151]:
print("Variance: ", pca.explained_variance_ratio_)
# pca.components_

Variance:  [0.41100113 0.27924447]


In [152]:
X = pca.transform(pos_only)
df = pd.DataFrame(X, columns=["PC1","PC2"])
_labels = np.array([np.argmax(x) for x in all_labels])


# plt.figure(figsize=(20,10))
df.plot.scatter(x="PC1", y="PC2", s= 30, c=_labels, colormap='winter',figsize=(10,8))
# plt.legend(["AD", "CN"])
print("Categories:", hot_encoder.categories_)

FigureCanvasNbAgg()

Categories: [array(['AD', 'CN'], dtype=object)]


# Plot PCA for 3 components

In [176]:
from mpl_toolkits.mplot3d import Axes3D

def plot_2d(X,labels, name="1"):
#     plt.close()
    fig = plt.figure(figsize=(12,10))
    plt.scatter(x=X[:,0], y=X[:,1], s= 30, c=_labels)
    plt.colorbar()

def plot_3d(X, labels):
    fig = plt.figure(figsize=(10,8))
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(X[:,0], X[:,1], X[:,2], c=labels, s=40)
    ax.set(xlabel="x", ylabel="y", zlabel="z")

In [95]:
pca3 = PCA(n_components=3)
pca3.fit(pos_only)
print("Variance: ", pca3.explained_variance_ratio_)

pc_3d = pca3.transform(pos_only)

Variance:  [0.55622541 0.28497895 0.05635775]


In [97]:
# %matplotlib widget
fig = plt.figure(figsize=(10,8))
ax = fig.add_subplot(111, projection='3d')
ax.scatter(pc_3d[:,0], pc_3d[:,1], pc_3d[:,2], c=_labels, s=40)
ax.set(xlabel="PC1", ylabel="PC2", zlabel="PC3")

FigureCanvasNbAgg()

[Text(0.5, 0, 'PC3'), Text(0.5, 0, 'PC2'), Text(0.5, 0, 'PC1')]

## Try tSNE

In [166]:
from sklearn.manifold import TSNE
# pca_reduced = PCA(n_components=50)
tSNE = TSNE(n_components=2, init="pca", random_state=42)
tSNE_relevance = tSNE.fit_transform(pos_only)

In [167]:
X = tSNE_relevance
plot_2d(X, labels)
# plot_3d(X,_labels)

    
tSNE_relevance.shape

(76, 2)

## Trying UMAP

In [168]:
import umap
# %matplotlib widget



### Figures not being reused

In [171]:
pos_only = get_relevant_cols(all_lrp)
# pos_only = lrp_results.copy()
plt.close()
# pos_only[pos_only < 0] = 0
reducer = umap.UMAP(random_state=42,
                    n_components = 2,
                    n_neighbors=3,
                    min_dist=0)
embedding = reducer.fit_transform(pos_only)
plot_2d(embedding, _labels, name="Alpha")

# fig = plt.figure(figsize=(10,8))
# ax = fig.add_subplot(111, projection='3d')
# ax.scatter(embedding[:,0], embedding[:,1], embedding[:,2], c=_labels, s=40)


plt.show()

FigureCanvasNbAgg()

In [177]:
pos_only = get_relevant_cols(all_lrp_E)
# pos_only = all_lrp_E.copy()
pos_only[pos_only < 0] = 0

reducer = umap.UMAP(random_state=42,
                    n_components = 3,
                    n_neighbors=3,
                    min_dist=0)
embedding = reducer.fit_transform(pos_only)

plot_3d(embedding, _labels)

plot_2d(embedding, _labels, name="Epsilon")

FigureCanvasNbAgg()

FigureCanvasNbAgg()