In [None]:
import numpy as np
from sklearn.datasets import make_classification
import torch.nn as nn
from skorch import NeuralNetClassifier

In [None]:
import matplotlib.pyplot as plt
import glob
from imageio import imread
# Ignore warnings
import warnings
warnings.filterwarnings("ignore")
import os
import numpy as np
from tqdm import tqdm
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report
import torch
from torchvision import transforms
from util import calculate_weights, train_validation_test_split, get_statistics
from dataset import DatasetGenerator
from custom_transforms import ShuffleChannel
from IPython.core.debugger import Tracer
from torch.utils.data import DataLoader, Dataset
from sklearn.pipeline import Pipeline, FeatureUnion
# Compare Algorithms
from sklearn.model_selection import train_test_split
from imblearn.over_sampling import RandomOverSampler
from models import PretrainedModel, resnet18
from sklearn.metrics import f1_score

In [None]:
from skorch.callbacks import LRScheduler
import torch.optim as optim
from skorch.helper import predefined_split
from skorch.callbacks import Checkpoint, TrainEndCheckpoint
from collections import Counter

In [None]:
import iflai

#### Set all random seeds to the specific value, so the results are more reproducable

In [None]:
seed_value = 42

os.environ['PYTHONHASHSEED']=str(seed_value)
import random
random.seed(seed_value)

np.random.seed(seed_value)
torch.manual_seed(seed_value)

#### Define all necessary parameters

In [None]:
dataset_name = "wbc"
selected_channels = np.arange(12)
path_to_data ="..\..\data/WBC"
model_dir = "models_remote"
scaling_factor = 4095.
reshape_size = 64
num_channels = len(selected_channels)
train_transform = [
         transforms.RandomVerticalFlip(),
         transforms.RandomHorizontalFlip(),
         transforms.RandomRotation(45)
        ]
test_transform = [ ]
channels =np.asarray([ "Ch" + str(i) for i in selected_channels])

In [None]:
batch_size = 64
num_workers = 2
device="cpu"

#### Load data

In [None]:
%time

metadata = iflai.metadata_generator(path_to_data)

In [None]:
indx = metadata["label"] != "unknown"
metadata = metadata.loc[indx,:]
metadata = metadata.reset_index(drop = True)

#### Split data

In [None]:
train_index, validation_index, test_index = train_validation_test_split(metadata.index, metadata["label"], random_state=seed_value)

In [None]:
label_map = dict(zip(sorted(set(metadata.loc[train_index, "label"])), np.arange(len(set(metadata.loc[train_index, "label"])))))

In [None]:
label_map

#### Oversamle and use class weights for imbalance data / Skip if not required

In [None]:
y_train = [label_map.get(metadata.loc[i, "label"]) for i in train_index]
weights = calculate_weights(y_train)
oversample = RandomOverSampler(random_state=seed_value, sampling_strategy='all')

In [None]:
Counter(y_train)

In [None]:
train_index, y_train = oversample.fit_resample(np.asarray(train_index).reshape(-1, 1), y_train)
train_index = train_index.T[0]

In [None]:
Counter(y_train)

#### Calculate statistics of train set and normalize the data

In [None]:
train_dataset = DatasetGenerator(metadata=metadata.loc[train_index,:], label_map=label_map, selected_channels=selected_channels, scaling_factor=scaling_factor, transform=transforms.Compose(train_transform))

In [None]:
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

In [None]:
plt.imshow(train_dataset[0][0][0])

In [None]:
statistics = get_statistics(train_loader, selected_channels=np.arange(12))

In [None]:
train_transform.append(transforms.Normalize(mean=statistics["mean"],
                         std=statistics["std"]))

In [None]:
test_transform.append(transforms.Normalize(mean=statistics["mean"],
                         std=statistics["std"]))

In [None]:
train_dataset = DatasetGenerator(metadata=metadata.loc[train_index,:],
                                 label_map=label_map,
                                 selected_channels=selected_channels,
                                 scaling_factor=scaling_factor, 
                                 transform= transforms.Compose(train_transform))

In [None]:
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

In [None]:
num_classes = len(label_map.keys())

In [None]:
validation_dataset = DatasetGenerator(metadata=metadata.loc[validation_index,:], label_map=label_map, selected_channels=np.arange(12), scaling_factor=scaling_factor, transform= transforms.Compose(test_transform))
test_dataset = DatasetGenerator(metadata=metadata.loc[test_index,:], label_map=label_map, selected_channels=np.arange(12), scaling_factor=scaling_factor, transform= transforms.Compose(test_transform))

In [None]:
validation_loader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

#### Set all hyperparameters for the model

In [None]:
lrscheduler = LRScheduler(
    policy='StepLR', step_size=7, gamma=0.5)

In [None]:
checkpoint = Checkpoint(
    f_params='wbc_net_all_.pth', monitor='valid_loss_best', dirname='models')
train_end_cp = TrainEndCheckpoint(f_params='final_wbc_net_all_.pth', dirname='models')

In [None]:
class_weights = torch.FloatTensor(weights).to(device)

### Initialite and train the model

In [None]:
num_classes

In [None]:
net = NeuralNetClassifier(
    PretrainedModel, 
    criterion=nn.CrossEntropyLoss,
    #criterion__weight=class_weights,
    lr=0.001,
    batch_size=64,
    max_epochs=10,
    module__output_features=num_classes,
    module__num_classes=num_classes,
    module__num_channels=num_channels, 
    optimizer=optim.SGD,
    optimizer__momentum=0.9,
    iterator_train__shuffle=False,
    iterator_train__num_workers=2,
    iterator_valid__shuffle=False,
    iterator_valid__num_workers=2,
    callbacks=[lrscheduler, checkpoint, train_end_cp],
    train_split=predefined_split(validation_dataset),
    #device='cuda' # comment to train on cpu
)

In [None]:
net.fit(train_dataset, y=None)

### Model Evaluation

In [None]:
#net.save_params(f_params='final_wbc_net_all_.pth')
model = PretrainedModel(num_classes, num_channels)
checkpoint = torch.load('models/wbc_net_all_.pth')
model.load_state_dict(checkpoint)
model = model.to(device)

In [None]:
correct = 0.
total = 0.
y_true = list()
y_pred = list()
y_true_proba = list()
with torch.no_grad():
    for data in test_loader:
        inputs, labels = data[0].to(device).float(), data[1].to(device).long()
        #Tracer()()
        outputs = model(inputs)
        pred = outputs.argmax(dim=1)
        true_proba = np.array([j[i] for (i,j) in zip(pred, outputs)])
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (labels.reshape(-1) == predicted).sum().item()
        for i in range(len(pred)):
            y_true.append(labels[i].item())
            y_pred.append(pred[i].item())
            y_true_proba.append(true_proba[i].item())

In [None]:
class_names_targets = [c.decode("utf-8") for c in label_map.keys()]

In [None]:
print(classification_report(y_true, y_pred, target_names=class_names_targets, digits=4))

### Model Interpretation

In [None]:
# where to save results
model_name = "wbc"

#### Pixel-Permutation Tests

In [None]:
from time import process_time

In [None]:
t1_start = process_time()
f1_score_original = f1_score(y_true, y_pred, average=None, labels=np.arange(num_classes))
min_mean_dif = 1.0
candidate = 0
shuffle_times = 5
df_all = pd.DataFrame([], columns=class_names_targets)
for c in range(num_channels):
    f1_score_diff_from_original_per_channel_per_shuffle = []
    transform = test_transform.copy()
    transform.append(ShuffleChannel(channels_to_shuffle=[c]))
    for s in range(shuffle_times):
        dataset = DatasetGenerator(metadata=metadata.loc[test_index,:],
                                   label_map=label_map,
                                   selected_channels=np.arange(12),
                                   scaling_factor=scaling_factor,
                                   transform=transforms.Compose(transform))
        dataloader = DataLoader(dataset,
                                batch_size=batch_size,
                                shuffle=False,
                                num_workers=num_workers)
        y_true = list()
        y_pred = list()
        with torch.no_grad():
            for data in dataloader:
                inputs, labels = data[0].to(device).float(), data[1].to(device).reshape(-1).long()
                outputs = model(inputs)
                pred = outputs.argmax(dim=1)
                for i in range(len(pred)):
                    y_true.append(labels[i].item())
                    y_pred.append(pred[i].item())
            f1_score_per_channel = f1_score(y_true, y_pred, average=None, labels=np.arange(num_classes))
            f1_score_diff_from_original_per_channel_per_shuffle.append(f1_score_original - f1_score_per_channel)
    mean_along_columns = np.mean(f1_score_diff_from_original_per_channel_per_shuffle, axis=0)
    mean_dif = np.mean(mean_along_columns)
    if mean_dif < min_mean_dif and mean_dif > 0 and not selected_channels[c]:
        min_mean_dif = mean_dif
        candidate = selected_channels[c]
    df_diff = pd.DataFrame(np.atleast_2d(f1_score_diff_from_original_per_channel_per_shuffle), columns=class_names_targets)
    df_mean_diff = pd.DataFrame(np.atleast_2d(mean_along_columns), columns=class_names_targets)
    df_all = pd.concat([df_all, df_mean_diff], ignore_index=True, sort=False)
    fig, ax = plt.subplots(figsize=(10, 5))
    ax = df_diff.boxplot()
    ax.set_xticklabels(class_names_targets, rotation=45)
    fig.savefig(os.path.join("results",model_name, "{}-shuffle_method-model-{}-channel-{}.png".format(dataset_name, str(model_name), str(selected_channels[c]))))
print("Candidate channel is {}".format(candidate))

In [None]:
df_all

In [None]:
plt.bar(channels[selected_channels], df_all.T.mean(), color='Grey')
plt.savefig(os.path.join("results",model_name, "{}-pixel-permutation-method-model-all-{}.png".format(dataset_name, str("resnet_all"))))

In [None]:
channel_ranking_pixel_permutation = pd.DataFrame(data={'channels': channels[np.asarray(selected_channels)], 'importance': df_all.T.mean().to_numpy()})

In [None]:
channel_ranking_pixel_permutation

### Evaluate the method with AOPC

In [None]:
def calculate_aopc(channel_ranking, method='', ascending=False, plot=True, perturb=False):
    #channel_ranking = pd.DataFrame(data={"channels":channels_ranking, "importance": importance})
    sorted_channels = channel_ranking.sort_values(by="importance", ascending=ascending)
    channels_to_permute=[]
    differences = []
    # calculate (f^0 - f^k)
    for i in range(len(sorted_channels)):
        channels_to_permute.append(np.where(channels==sorted_channels.iloc[i]["channels"])[0][0])
        transform = test_transform.copy()
        transform.append(ShuffleChannel(channels_to_shuffle=channels_to_permute, perturb=perturb))
        dataset_ = DatasetGenerator(metadata=metadata.loc[test_index,:],
                                    label_map=label_map,
                                    selected_channels=np.arange(12),
                                    transform= transforms.Compose(transform),
                                    scaling_factor=scaling_factor)
        dataloader_ = DataLoader(dataset_,
                                batch_size=batch_size,
                                shuffle=False,
                                num_workers=num_workers)
        
        y_true_permut_proba = list()
        with torch.no_grad():
            for data in dataloader_:
                inputs, labels = data[0].to(device).float(), data[1].to(device).long()
                outputs = model(inputs)
                pred = outputs.argmax(dim=1)
                permut_proba = np.array([j[i] for (i,j) in zip(pred, outputs)])
                _, predicted = torch.max(outputs.data, 1)
                for i in range(len(pred)):
                    y_true_permut_proba.append(permut_proba[i].item())
        differences.append(y_true_proba-np.array(y_true_permut_proba))
    stacked_diff = np.stack(differences)
    # calculate summ(f^0-f^k)
    diff_accumulated = []
    for idx, diff in enumerate(stacked_diff):
        if idx==0:
            diff_accumulated.append(stacked_diff[idx])
        else:
            diff_accumulated.append(diff_accumulated[idx-1] + stacked_diff[idx])
    diff_accumulated = np.stack(diff_accumulated)
    # mean over the test set
    diff_accumulated_mean = np.mean(diff_accumulated, axis=-1)
    # divide by 1/L+1
    diff_accumulated_mean_norm = np.array([])
    for ix, d in enumerate(diff_accumulated_mean):
        diff_accumulated_mean_norm = np.append(diff_accumulated_mean_norm, d/(ix+1))
    # insert (0,0)
    diff_accumulated_mean_norm_started_from_0 = np.insert(diff_accumulated_mean_norm,0,0.0)
    # plot line
    if plot:
        x = np.arange(len(diff_accumulated_mean_norm_started_from_0))
        plt.xlabel("permutation steps")
        plt.ylabel("AOPC")
        plt.plot(x, diff_accumulated_mean_norm_started_from_0, color ="red")
        #plt.show()
        plt.savefig(os.path.join("results",model_name, "{}-aopc-{}-{}.svg".format(dataset_name, method, str("resnet_all"))))
    return diff_accumulated_mean_norm_started_from_0

In [None]:
res_pixel_permutated_perturb_reverse = calculate_aopc(channel_ranking_pixel_permutation, method='pixel-permutation-perturb', perturb=True, ascending=False)

### Interpretation by methods from captum

In [None]:
from captum.attr import Occlusion, DeepLift, IntegratedGradients, LRP
from time import process_time

#### Occlusion

In [None]:
t1_start = process_time()
ablator = Occlusion(model)
dataset = DatasetGenerator(metadata=metadata.loc[test_index,:],
                           label_map=label_map,
                           selected_channels=np.arange(12),
                           transform=transforms.Compose(test_transform),
                           scaling_factor=scaling_factor)
dataloader = DataLoader(dataset,
                                batch_size=batch_size,
                                shuffle=False,
                                num_workers=num_workers)

heatmaps = torch.empty(0, dtype=torch.float32, device=device)
with torch.no_grad():
    for data in dataloader:
        inputs, labels = data[0].to(device).float(), data[1].to(device).reshape(-1).long()
        attr = ablator.attribute(inputs, target=labels, sliding_window_shapes=(1,3,3))
        heatmaps = torch.cat((heatmaps, torch.from_numpy(np.percentile(torch.flatten(attr, start_dim=-2).cpu().numpy(), q=50, axis=-1)).to(dev)))
heatmaps_mean = torch.mean(heatmaps, dim=0)
plt.bar(channels[np.asarray(only_channels)], heatmaps_mean.cpu(), color='grey')
plt.savefig(os.path.join("results",model_name, "{}-occl_method-model-50-percentile-{}.png".format(dataset_name, str("resnet_all"))))

t1_stop = process_time()
print("Elapsed time:", t1_stop, t1_start) 
   
print("Elapsed time during the whole program in seconds:",
                                         t1_stop-t1_start) 

In [None]:
channel_ranking_occlusion = pd.DataFrame(data={'channels': channels[np.asarray(only_channels)], 'importance': heatmaps_mean.cpu().numpy()})

In [None]:
channel_ranking_occlusion

In [None]:
res_pixel_ocll_perturb_reverse = calculate_aopc(channel_ranking_occlusion, method='pixel-occlusion-perturb-reverse', ascending=False, perturb=True)

#### DeepLift

In [None]:
t1_start = process_time()
ablator = DeepLift(model)

dataset = DatasetGenerator(metadata=metadata.loc[test_index,:],
                           label_map=label_map,
                           selected_channels=np.arange(12),
                           transform=transforms.Compose(test_transform),
                           scaling_factor=scaling_factor)
testloader = DataLoader(test_dataset,
                                batch_size=batch_size,
                                shuffle=False,
                                num_workers=num_workers)

heatmaps_deeplift = torch.empty(0, dtype=torch.float32, device=device)
with torch.no_grad():
    for data in testloader:
        inputs, labels = data[0].to(device).float(), data[1].to(device).reshape(-1).long()
        # baselines=torch.zeros(inputs.shape).to(dev)
        attr = ablator.attribute(inputs, target=labels)
        heatmaps_deeplift = torch.cat((heatmaps_deeplift,  torch.from_numpy(np.percentile(torch.flatten(attr, start_dim=-2).cpu().numpy(), q=50, axis=-1)).to(device)))
heatmaps_deeplift_mean = torch.mean(heatmaps_deeplift, dim=0)
plt.bar(channels, heatmaps_deeplift_mean.cpu(), color='grey')
plt.savefig(os.path.join("results", "resnet_all", "{}-deeplift_method-model-50-percentile-{}.png".format(dataset_name, str("resnet_all"))))

t1_stop = process_time()
print("Elapsed time:", t1_stop, t1_start) 
   
print("Elapsed time during the whole program in seconds:",
                                         t1_stop-t1_start) 

In [None]:
channel_ranking_deep_lift = pd.DataFrame(data={'channels': channels[selected_channels], 'importance': heatmaps_deeplift_mean.cpu().numpy()})

In [None]:
channel_ranking_deep_lift

In [None]:
res_deep_lift_perturb_reverse = calculate_aopc(channel_ranking_deep_lift, method='deep-lift-aopc', ascending=False, perturb=True)

#### Random

In [None]:
channel_ranking_random = pd.DataFrame(data={'channels': channels[selected_channels], 'importance': np.random.randint(12, size=12)})

In [None]:
channel_ranking_random

In [None]:
res_random_perturb_reverse = calculate_aopc(channel_ranking_random, method='random-perturb-aopc', ascending=False, perturb=True)

In [None]:
x = np.arange(len(res_random_perturb_reverse))

In [None]:
plt.rcParams.update({'font.size': 13})
got_label=False
plt.plot(x, res_deep_lift_perturb_reverse, label  = "Channel-wise DeepLift", color="orange")
plt.plot(x, res_pixel_ocll_perturb_reverse, label  = "Channel-wise Occlusion", color="green")
plt.plot(x, res_pixel_permutated_perturb_reverse, label  = "Pixel-Permutation", color="red")
plt.plot(x, res_random_perturb_reverse, label  = "Random Baseline", color="blue")
plt.xlabel('Perturbation steps')
plt.ylabel('AOPC')
plt.legend()
plt.savefig(os.path.join("results", "resnet_all", "{}-aopc-all-methods-{}.svg".format(dataset_name, str("resnet_all"))))
#plt.savefig(os.path.join("results", "resnet_all", "{}-aopc-all-methods-{}.png".format(dataset_name, str("resnet_all"))))

##### Calculate the random channel ranking 100 times to estimate the lower und upper bound 

In [None]:
z=2.576

In [None]:
upper_border = mean + (z * (std / np.sqrt(len(random_rankings))))

In [None]:
lower_border = mean - (z * (std / np.sqrt(len(random_rankings))))

In [None]:
for i in range(100):
    channel_ranking_random = pd.DataFrame(data={'channels': channels[np.asarray(only_channels)], 'importance': np.random.randint(12, size=12)})
    random_rankings.append(calculate_aopc(channel_ranking_random, method='random-perturb-reverse', ascending=False, perturb=True, plot=False))

In [None]:
plt.rcParams.update({'font.size': 13})
got_label=False
for ranking in random_rankings:
    if not got_label:
        plt.plot(x, ranking, label  = "Random Baseline", color="grey", linewidth=0.5, alpha=0.1)
        got_label=True
    else:
        plt.plot(x, ranking, color="grey", linewidth=0.5, alpha=0.1)
plt.plot(x, res_deep_lift_perturb_reverse, label  = "Channel-wise DeepLift", color="orange")
plt.plot(x, res_pixel_ocll_perturb_reverse, label  = "Channel-wise Occlusion", color="green")
plt.plot(x, res_pixel_permutated_perturb_reverse, label  = "Pixel-Permutation", color="red")
#plt.plot(x, res_random_perturb_reverse, label  = "Random Baseline", color="blue")
plt.xlabel('Perturbation steps')
plt.ylabel('AOPC')
plt.legend()
plt.savefig(os.path.join("results", "resnet_all", "{}-aopc-all-methods-1010-{}.svg".format(dataset_name, str("resnet_all"))))
plt.savefig(os.path.join("results", "resnet_all", "{}-aopc-all-methods-1010-{}.png".format(dataset_name, str("resnet_all"))))