# POC Explained IA using pytorch and captum

## Load libraries

In [1]:
# Our ML things
import torch
import torch.nn as nn
import torch.nn.functional as F

from captum.attr import IntegratedGradients # Most popular atribution methode

# Visualization
import plotly.graph_objects as go

# Utils
import pandas as pd
import numpy as np
import multiprocessing
import random

import datetime
import time

from sklearn.metrics import classification_report
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, f1_score, precision_score, recall_score, accuracy_score, auc
from torch.utils.data import DataLoader
import json

# Custom
from tool_box.deviceHandler import deviceHandler
from tool_box.model import simpleDenseNN
from tool_box.utilities import createDataLoader, listSplitter, type_converter, secondsConverter, createLog

run_timestamp = datetime.datetime.now()

## Set up processing device

In [2]:
device_handler = deviceHandler()

GPU isn't available, fallback to CPU


## Process data

### Load data

In [3]:
titanic_dataset = pd.read_csv('titanic_dataset.csv')

### Simple data processing

In [4]:
# Get dummies for categorical variables
sex_dummies = pd.get_dummies(titanic_dataset.Sex, prefix='Sex')
embarked_dummies = pd.get_dummies(titanic_dataset.Embarked, prefix='Embarked')
pclass_dummies = pd.get_dummies(titanic_dataset.Pclass.astype(str), prefix='Pclass')

# Fill NaNs
titanic_dataset.Age = titanic_dataset.Age.fillna(round(titanic_dataset.Age.mean(), 1))
titanic_dataset.Fare = titanic_dataset.Fare.fillna(round(titanic_dataset.Fare.mean(), 1))

### Create features and target datasets

In [5]:
features_df = pd.concat([titanic_dataset[['Age', 'SibSp', 'Parch', 'Fare']], pclass_dummies, sex_dummies, embarked_dummies], axis=1)
target_df = titanic_dataset.Survived

### Scale data

In [6]:
scaler = StandardScaler()
s_features_df = scaler.fit_transform(features_df)

### Transform into tuple list of pytorch tensors

In [7]:
features_tensor = device_handler.data_to_tensor(s_features_df)
target_tensor = device_handler.data_to_tensor(target_df).reshape(-1, 1)

features_tensor = type_converter(features_tensor, torch.float32)
target_tensor = type_converter(target_tensor, torch.float32)

tuple_lst_data = list(zip(features_tensor, target_tensor))

### Split data into test and training

In [8]:
lst_splitter = listSplitter(0.3, shuffle = True)
test_data, train_data = lst_splitter.split(tuple_lst_data)

### Create Dataloader

In [9]:
dataloader_gen = createDataLoader()

test_loader = dataloader_gen.create(test_data, batch_size=50)
train_loader = dataloader_gen.create(train_data, batch_size=50, shuffle = True)

## Model

### Define model parameters

In [10]:
simple_model = simpleDenseNN(features_tensor.shape[1], features_tensor.shape[1]*2+1, 1)

opt = torch.optim.Adam(simple_model.parameters(), 1e-02)

### Run model

In [11]:
batch_cum = 0 
epoch_amt = 100
start_time_VAL = time.time()
run_results = []
model_DICT = {}

for epoch in range(epoch_amt):
  
  train_loss = []
  train_acc = []
  train_f1 = []
  train_prec = []
  train_rec = []
  train_auc = []
  
  for _i, batch in enumerate(train_loader):
    
    preds = simple_model.training_step(batch)
    
    train_loss.append(preds['loss'])
    train_acc.append(preds['acc'])
    train_f1.append(preds['f1'])
    train_prec.append(preds['prec'])
    train_rec.append(preds['rec'])
    
    preds['loss'].backward()
    opt.step()
    opt.zero_grad()
    
  validation_result = [simple_model.testing_step(batch) for batch in test_loader]
  
  epoch_train_loss = np.mean([val.item() for val in train_loss])
  epoch_train_acc = np.mean([val.item() for val in train_acc])
  epoch_train_f1 = np.mean([val.item() for val in train_f1])
  epoch_train_prec = np.mean([val.item() for val in train_prec])
  epoch_train_rec = np.mean([val.item() for val in train_rec])
  
  epoch_test_loss = np.mean([dic['loss'].item() for dic in validation_result])
  epoch_test_acc = np.mean([dic['acc'].item() for dic in validation_result])
  epoch_test_f1 = np.mean([dic['f1'].item() for dic in validation_result])
  epoch_test_prec = np.mean([dic['prec'].item() for dic in validation_result])
  epoch_test_rec = np.mean([dic['rec'].item() for dic in validation_result])
  
  epoch_results = {'run_id': run_timestamp.strftime('%Y%m%d%H%M%S'), 'calendar_dt': run_timestamp.strftime('%Y-%m-%d'),\
                   'training_cases': len(train_data), 'testing_cases': len(test_data), 'epoch': epoch+1, 'total_epochs': epoch_amt,\
                   'training_run' : {'loss': epoch_train_loss, 'accuracy': epoch_train_acc, 'f1': epoch_train_f1,\
                                     'precision': epoch_train_prec, 'recall': epoch_train_rec},\
                   'test_run' : {'loss': epoch_test_loss, 'accuracy': epoch_test_acc, 'f1': epoch_test_f1,\
                                 'precision': epoch_test_prec, 'recall': epoch_test_rec}}
  
  run_results.append(epoch_results)
  
  print('{} Message time: Epoch {}/{} processed - {} time passed'\
    .format(datetime.datetime.now().strftime("%d/%m/%Y %H:%M:%S"), epoch+1, epoch_amt, secondsConverter(time.time()-start_time_VAL)))
  
  print('TRAINING\tLoss: {:.5f} | Accuracy: {:.2f}% | F1: {:.2f} | Precision: {:.2f} | Recall: {:.2f}'\
    .format(epoch_train_loss.item(), epoch_train_acc * 100, epoch_train_f1, epoch_train_prec, epoch_train_rec))
                 
  print('TESTING\t\tLoss: {:.5f} | Accuracy: {:.2f}% | F1: {:.2f} | Precision: {:.2f} | Recall: {:.2f}'\
    .format(epoch_test_loss.item(), epoch_test_acc * 100, epoch_test_f1, epoch_test_prec, epoch_test_rec) + '\n')

AINING	Loss: 0.38165 | Accuracy: 86.12% | F1: 0.80 | Precision: 0.85 | Recall: 0.76
TESTING	Loss: 0.31680 | Accuracy: 89.11% | F1: 0.85 | Precision: 0.91 | Recall: 0.80

15/03/2021 01:23:26 Message time: Epoch 22/100 processed - 0:01:28 time passed
TRAINING	Loss: 0.37056 | Accuracy: 86.21% | F1: 0.80 | Precision: 0.86 | Recall: 0.76
TESTING	Loss: 0.30960 | Accuracy: 90.11% | F1: 0.86 | Precision: 0.92 | Recall: 0.81

15/03/2021 01:23:30 Message time: Epoch 23/100 processed - 0:01:32 time passed
TRAINING	Loss: 0.37761 | Accuracy: 84.95% | F1: 0.79 | Precision: 0.84 | Recall: 0.75
TESTING	Loss: 0.30998 | Accuracy: 89.36% | F1: 0.85 | Precision: 0.91 | Recall: 0.80

15/03/2021 01:23:34 Message time: Epoch 24/100 processed - 0:01:36 time passed
TRAINING	Loss: 0.37938 | Accuracy: 85.18% | F1: 0.78 | Precision: 0.86 | Recall: 0.74
TESTING	Loss: 0.29931 | Accuracy: 89.86% | F1: 0.86 | Precision: 0.92 | Recall: 0.81

15/03/2021 01:23:38 Message time: Epoch 25/100 processed - 0:01:40 time passe

### Save log file

In [14]:
create_log = createLog(run_results, json = True, file_name = run_timestamp.strftime('%Y%m%d%H%M%S'))
create_log.write()

## Results

### Process results

In [18]:
results_df = pd.DataFrame(run_results)

training_org_cols = [col for col in list(results_df['training_run'][0].keys())]
test_org_cols = [col for col in list(results_df['test_run'][0].keys())]

training_cols = ['train_' + col for col in list(results_df['training_run'][0].keys())]
test_cols = ['test_' + col for col in list(results_df['test_run'][0].keys())]


train_cols_dict = dict()
test_cols_dict = dict()

for _i in range(len(training_cols)):
  train_cols_dict[training_org_cols[_i]] = training_cols[_i]
  
for _i in range(len(test_cols)):
  test_cols_dict[test_org_cols[_i]] = test_cols[_i]
  

training_df = pd.DataFrame(list(results_df['training_run'])).rename(columns=train_cols_dict)
testing_df = pd.DataFrame(list(results_df['test_run'])).rename(columns=test_cols_dict)

formatted_results_df = pd.concat([results_df.drop(['training_run', 'test_run'], axis=1), training_df, testing_df], axis=1)

### Graph loss

In [19]:
fig = go.Figure()
fig.add_trace(go.Scatter(x=formatted_results_df['epoch'], y=formatted_results_df['train_loss'], name= 'Training', mode='lines+markers', line_color='blue'))
fig.add_trace(go.Scatter(x=formatted_results_df['epoch'], y=formatted_results_df['test_loss'], name= 'Testing', mode='lines+markers', line_color='red'))
fig.update_layout(
    title="Loss progression",
    xaxis_title="Epoch",
    yaxis_title="Loss",
    font=dict(size=15),
    width=1300, 
    height=600)
fig.show()

### Graph evaluation metrics

In [20]:
fig = go.Figure()

fig.add_trace(go.Scatter(x=formatted_results_df['epoch'], y=formatted_results_df['train_accuracy'], name= 'Training Accuracy',\
                         opacity=0.3, mode='lines+markers', line_color='blue', line_width= 2))
fig.add_trace(go.Scatter(x=formatted_results_df['epoch'], y=formatted_results_df['train_f1'], name= 'Training F1',\
                         opacity=0.3, mode='lines+markers', line_color='red', line_width= 2))
fig.add_trace(go.Scatter(x=formatted_results_df['epoch'], y=formatted_results_df['train_precision'], name= 'Training Precision',\
                         opacity=0.3, mode='lines+markers', line_color='gold', line_width= 2))
fig.add_trace(go.Scatter(x=formatted_results_df['epoch'], y=formatted_results_df['train_recall'], name= 'Training Recall',\
                         opacity=0.3, mode='lines+markers', line_color='green', line_width= 2))

fig.add_trace(go.Scatter(x=formatted_results_df['epoch'], y=formatted_results_df['test_accuracy'], name= 'Testing Accuracy',\
                         mode='lines+markers', line_color='blue', line_width= 3))
fig.add_trace(go.Scatter(x=formatted_results_df['epoch'], y=formatted_results_df['test_f1'], name= 'Testing F1',\
                         mode='lines+markers', line_color='red', line_width= 3))
fig.add_trace(go.Scatter(x=formatted_results_df['epoch'], y=formatted_results_df['test_precision'], name= 'Testing Precision',\
                         mode='lines+markers', line_color='gold', line_width= 3))
fig.add_trace(go.Scatter(x=formatted_results_df['epoch'], y=formatted_results_df['test_recall'], name= 'Testing Recall',\
                         mode='lines+markers', line_color='green', line_width= 3))

fig.update_layout(
    title="Evaluation metrics",
    xaxis_title="Epoch",
    font=dict(size=15),
    width=1300, 
    height=600)
fig.show()

## Explainable IA

## Single prediction interpretation

### Calculate single case attribution

In [30]:
ig = IntegratedGradients(simple_model)

attributions, approximation_error = ig.attribute(batch[0][0:1], target = 0, return_convergence_delta = True)

### Attribution

In [31]:
features_names = features_df.columns.tolist()

attributions_lst = list(zip(features_names, attributions.numpy().tolist()[0]))

for feature in attributions_lst:
    print('{:<10}:\t{:>6.3f}'.format(feature[0], feature[1]))

Age       :	 0.037
SibSp     :	-0.419
Parch     :	 0.048
Fare      :	 0.008
Pclass_1  :	-0.013
Pclass_2  :	 0.012
Pclass_3  :	-0.014
Sex_female:	-0.026
Sex_male  :	-0.038
Embarked_C:	 0.002
Embarked_Q:	-0.003
Embarked_S:	-0.001


In [33]:
fig = go.Figure([go.Bar(x = features_names, y = attributions.numpy().tolist()[0])])

fig.update_layout(
    title="Attribution",
    xaxis_title="Features",
    font=dict(size=15),
    width=1300, 
    height=600)

fig.show()