In [None]:
# Specific to colab
# Gives access to the Drive
# from google.colab import drive
# drive.mount('/content/drive')

# import tensorflow as tf
# import sys, os

# # GPU status verification
# tf.test.gpu_device_name()

# # GPU type verification
# gpu_info = !nvidia-smi
# gpu_info = '\n'.join(gpu_info)
# if gpu_info.find('failed') >= 0:
#     print('Select the Runtime > "Change runtime type" menu to enable a GPU accelerator, ')
#     print('and then re-execute this cell.')
# else:
#     print(gpu_info)

# # Need to copy all the files on the local computer
# !cp -r "drive/MyDrive/data/main_dataset.zip" .
# !unzip main_dataset.zip

# sys.path.append('drive/MyDrive/colab_notebooks/')

!pip install tensorflow_addons
!pip install vit-keras

# Vision Transformer (ViT)

inspired by this [notebook](https://www.kaggle.com/raufmomin/vision-transformer-vit-fine-tuning)

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
# import cv2

import sys, os
from pathlib import Path
# import shutil
# import glob
# import itertools

# from sklearn.model_selection import train_test_split

import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam

from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, TerminateOnNaN, EarlyStopping

from tensorflow.keras.utils import Sequence
from collections import Counter
from sklearn.utils.class_weight import compute_class_weight

from vit_keras import vit, utils
from vit_keras import visualize

from keras.models import load_model
from sklearn.metrics import classification_report, confusion_matrix

# sys.path.append('../')
import leukopy_lib as leuko
from importlib import reload

reload(leuko)

# Set up

## Generate dataframes

In [None]:
# pa/kaggle/ = Path('main_dataset/')
path = Path('../input/main-dataset/main_dataset/')

df_train = leuko.generate_images_df(path/'training_set')
df_test = leuko.generate_images_df(path/'testing_set')
df_valid = leuko.generate_images_df(path/'validation_set')

df_train.head()

### choose classes

In [None]:
n_classes, df_train, df_test, df_valid = leuko.choose_classes(df_train, df_test, df_valid, n_classes = 11)

In [None]:
n_classes

In [None]:
df_train["label"].value_counts()

## Image generator

For ViT image are patched into [16x16](https://arxiv.org/abs/2010.11929) images ("images is worth 16x16 words")

In [None]:
batch_size = 32
img_size  = 352 #need to be a multiple of patch size = 16
epochs=100

In [None]:
train_generator = ImageDataGenerator(rotation_range = 90,
                                     horizontal_flip = True, 
                                     vertical_flip = True)
valid_generator = ImageDataGenerator()
test_generator = ImageDataGenerator()

# Resize pictures, batchs from dataframe
training_set = train_generator.flow_from_dataframe(df_train, 
                                                   directory = None, # uses x_col
                                                   x_col = 'img_path', 
                                                   y_col = 'label',
                                                   target_size = (img_size, img_size), 
                                                   color_mode = 'rgb',
                                                   classes = None,   # uses y_col
                                                   class_mode = 'categorical', 
                                                   batch_size = batch_size,
                                                   shuffle = True)

validation_set = valid_generator.flow_from_dataframe(df_valid, 
                                                     directory = None, # uses x_col
                                                     x_col = 'img_path', 
                                                     y_col = 'label',
                                                     target_size = (img_size, img_size), 
                                                     color_mode = 'rgb',
                                                     classes = None,   # uses y_col
                                                     class_mode = 'categorical', 
                                                     batch_size = batch_size, 
                                                     shuffle = True)

testing_set = test_generator.flow_from_dataframe(df_test, 
                                                 directory = None, # uses x_col
                                                 x_col = 'img_path', 
                                                 y_col = 'label',
                                                 target_size = (img_size, img_size),
                                                 color_mode = 'rgb',
                                                 classes = None,   # uses y_col
                                                 class_mode = 'categorical', 
                                                 batch_size = batch_size, 
                                                 shuffle = False)

# Labels/Index connection :
label_map = training_set.class_indices
print('Train :', training_set.class_indices)
print('Valid :', validation_set.class_indices)
print('Test  :', testing_set.class_indices)

In [None]:
images = [training_set[0][0][i] for i in range(5)]
fig, axes = plt.subplots(2, 2, figsize = (10, 10))

axes = axes.flatten()

for img, ax in zip(images, axes):
    ax.imshow(img.reshape(img_size, img_size, 3).astype('uint8'))
    ax.axis('off')

plt.tight_layout()
plt.show()

## ViT model set up

In [None]:
# classes = utils.get_imagenet_classes()

vit_model = vit.vit_b32(
    weights='imagenet21k+imagenet2012',
        image_size = img_size,
        activation = 'softmax',
        pretrained = True,
        include_top = False,
        pretrained_top = False,
        classes = n_classes)

# image_size = 384
# classes = utils.get_imagenet_classes()
# model = vit.vit_b16(
#     image_size=image_size,
#     activation='sigmoid',
#     pretrained=True,
#     include_top=True,
#     pretrained_top=True
# )
# url = 'https://upload.wikimedia.org/wikipedia/commons/d/d7/Granny_smith_and_cross_section.jpg'
# image = utils.read(url, image_size)
# X = vit.preprocess_inputs(image).reshape(1, image_size, image_size, 3)
# y = model.predict(X)
# print(classes[y[0].argmax()]) # Granny smith

In [None]:
x = testing_set.next()
image = x[0][0]

attention_map = visualize.attention_map(model = vit_model, image = image)

# Plot results
fig, (ax1, ax2) = plt.subplots(ncols = 2)
ax1.axis('off')
ax2.axis('off')
ax1.set_title('Original')
ax2.set_title('Attention Map')
_ = ax1.imshow(image.astype('uint8'))
_ = ax2.imshow(attention_map)

In [None]:
model = tf.keras.Sequential([
        vit_model,
        tf.keras.layers.Flatten(),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Dense(64, activation = tfa.activations.gelu),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Dense(11, 'softmax')
    ],
    name = 'vision_transformer')

model.summary()

In [None]:
learning_rate = 1e-3

optimizer = tfa.optimizers.RectifiedAdam(learning_rate = learning_rate)

model.compile(optimizer = optimizer, 
              loss = tf.keras.losses.CategoricalCrossentropy(label_smoothing = 0.2), 
              metrics = ['accuracy'])

In [None]:
# Callbacks


reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor = 'val_accuracy',
                                                 factor = 0.2,
                                                 patience = 2,
                                                 verbose = 1,
                                                 min_delta = 1e-4,
                                                 min_lr = 1e-6,
                                                 mode = 'max')

earlystopping = tf.keras.callbacks.EarlyStopping(monitor = 'val_accuracy',
                                                 min_delta = 1e-4,
                                                 patience = 5,
                                                 mode = 'max',
                                                 restore_best_weights = True,
                                                 verbose = 1)

checkpointer = tf.keras.callbacks.ModelCheckpoint(filepath = '../working/model_saved/model.hdf5',
                                                  monitor = 'val_accuracy', 
                                                  verbose = 1, 
                                                  save_best_only = True,
                                                  save_weights_only = True,
                                                  mode = 'max')

callbacks_list = [earlystopping, reduce_lr, checkpointer]

In [None]:

## Compute weights :
# class_weights = compute_weights(method = 3)

## Training :
epochs = 100
training_history = model.fit(x = training_set, 
                             steps_per_epoch = training_set.n/training_set.batch_size,
                             validation_steps = validation_set.n/validation_set.batch_size,

                             epochs = epochs,
                             callbacks = callbacks_list,
                             validation_data = validation_set, 
#                              class_weight = class_weights
                            )

model.save('../working/model_saved/model_vit_test')

In [None]:
training_history.history

In [None]:
training_accuracy = training_history.history['accuracy']
validation_accuracy = training_history.history['val_accuracy']

plt.figure()
plt.plot(np.arange(earlystopping.stopped_epoch), training_accuracy[0:earlystopping.stopped_epoch], label = 'Training Set')
plt.plot(np.arange(earlystopping.stopped_epoch), validation_accuracy[0:earlystopping.stopped_epoch], label = 'Validation Set')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

In [None]:
model.evaluate(testing_set)


In [None]:
model.evaluate(validation_set)

In [None]:
test = pickle.load(open('../working/model_saved/training_hist', 'rb'))
test

In [None]:
import pickle
with open('../working/model_saved/training_hist', 'wb') as f:
    pickle.dump(training_history.history, f)

In [None]:
predicted_classes = np.argmax(model.predict(valid_gen, steps = valid_gen.n // valid_gen.batch_size + 1), axis = 1)
true_classes = valid_gen.classes
class_labels = list(valid_gen.class_indices.keys())  

confusionmatrix = confusion_matrix(true_classes, predicted_classes)
plt.figure(figsize = (16, 16))
sns.heatmap(confusionmatrix, cmap = 'Blues', annot = True, cbar = True)

print(classification_report(true_classes, predicted_classes))

In [None]:
predictions = model.predict(validation_set)
y_pred = tf.argmax(predictions, axis = 1)

In [None]:
predicted_classes = np.argmax(model.predict(validation_set,
                                            steps = validation_set.n // validation_set.batch_size + 1), 
                              axis = 1)

In [None]:
from sklearn.metrics import classification_report, confusion_matrix

confusion_matrix(validation_set.classes, y_pred)

In [None]:
import seaborn as sns
# true_classes = validation_set.classes
# class_labels = list(validation_set.class_indices.keys())  

# confusionmatrix = confusion_matrix(true_classes, predicted_classes)
plt.figure(figsize = (5, 5))
sns.heatmap(confusion_matrix(validation_set.classes, y_pred), cmap = 'Blues', annot = True, cbar = True)

# print(classification_report(true_classes, predicted_classespredictions

In [None]:
training_accuracy = training_history.history['accuracy']
validation_accuracy = training_history.history['val_accuracy']

plt.figure()
plt.plot(np.arange(early_stopping.stopped_epoch), training_accuracy[0:early_stopping.stopped_epoch], label = 'Training Set')
plt.plot(np.arange(early_stopping.stopped_epoch), validation_accuracy[0:early_stopping.stopped_epoch], label = 'Validation Set')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

In [None]:
## Évaluation sur les données de test :
model.evaluate(testing_set)

# Léger overfitting ou pas... telle est la question

In [None]:
########################################################################### ANALYSE - MODEL(2) - FUNC. #########################################

In [None]:
### 1 -- Matrice de confusion + rapport de classification ("présentable")

def print_classification_report(testing_set, labels):

  # Prédiction : utilise le jeu de test (testing_set)
  predictions = model.predict(testing_set)
  y_pred = tf.argmax(predictions, axis = 1)

  # Calcul et affichage de la matrice de confusion
  cnf_matrix = confusion_matrix(testing_set.classes, y_pred)
  classes = range(len(labels))
  
  plt.figure(figsize = (12,12))
  plt.imshow(cnf_matrix, interpolation = 'nearest', cmap = 'Blues')
  plt.title("Matrice de confusion")
  plt.colorbar()

  tick_marks = np.arange(len(labels))
  plt.xticks(tick_marks, labels)
  plt.yticks(tick_marks, labels)

  for i, j in itertools.product(range(cnf_matrix.shape[0]), 
                                range(cnf_matrix.shape[1])):
    plt.text(j, i, cnf_matrix[i, j],
             horizontalalignment = "center",
             color = "white" if cnf_matrix[i, j] > ( cnf_matrix.max() / 2) else "black")

  plt.ylabel('Vrais labels')
  plt.xlabel('Labels prédits')
  plt.show()

  # Rapport de classification 
  report = classification_report(testing_set.classes, y_pred, target_names = labels, output_dict = True)

  df_report = pd.DataFrame(index = list(report.keys())[:-3], columns = list(report["BA"].keys()))
  for key in list(report.keys())[:-3]:
    for column in list(report["BA"].keys()):
      df_report.loc[key, column] = report[key][column]
  
  print("Classification Report : avant Fine-Tuning")
  return display(df_report)

print_classification_report(testing_set, label_map)

# Ajouter fonction pour enregistrer df_report dans un .csv + transfert sur Drive

In [None]:
### 2 -- DataFrames pour l'exploitation
predictions = model.predict(testing_set)
y_pred = tf.argmax(predictions, axis = 1)

# DF avec le résultat et le label de toutes les images du test :
df_results = pd.DataFrame(data = {"real":testing_set.classes,
                                  "pred":y_pred, 
                                  "img_path":df_test["img_path"]})

# Tri des images : mal classé (df_false), bien classé (df_true):
df_false = df_results[df_results["real"] != df_results["pred"]].reset_index(drop = True)
df_true = df_results[df_results["real"] == df_results["pred"]].reset_index(drop = True)

# Ajouter fonction pour enregistrer les df_false et df_true dans un .csv + transfert sur Drive

In [None]:
### 3 -- Grad-CAM sur les images bien classées :

fig = plt.figure(figsize = (20, 40))
i = 0

for cell_class in range(n_classes):
  df_temp = df_true[df_true["real"] == cell_class]
  id = np.random.choice(df_temp.index, size = 1, replace = False)
  img_path = df_true.loc[id[0],"img_path"]
 
  big_heatmap, superimposed_img = gradcam(model, img_path, alpha = 0.8, plot = False)

  fig.add_subplot(n_classes,2,i+1)
  plt.imshow(plt.imread(img_path))
  plt.title(df_true.loc[id[0],"real"])

  fig.add_subplot(n_classes,2,i+2)
  plt.imshow(superimposed_img)

  i += 2

In [None]:
### 4 -- Exemples d'images mal classées :
# Exemple : Confusion BNE/SNE : labels 1 et 10

conf_bne_sne = df_false[((df_false["real"] == 1) & (df_false["pred"] == 10)) | ((df_false["real"] == 10) & (df_false["pred"] == 1))]

i = 0

fig = plt.figure(figsize = (15,15))

for id in np.random.choice(conf_bne_sne.index, size = 5, replace = False):
  img_path = conf_bne_sne.loc[id,"img_path"]

  big_heatmap, superimposed_img = gradcam(model, img_path, alpha = 0.8, plot = False)

  fig.add_subplot(5,2,i+1)
  plt.imshow(plt.imread(img_path))
  plt.title("Label réel : %s ; Label prédit : %s"%(list(label_map.keys())[conf_bne_sne.loc[id,"real"]],
                                                   list(label_map.keys())[conf_bne_sne.loc[id,"pred"]]))
  
  fig.add_subplot(5,2,i+2)
  plt.imshow(superimposed_img)
  i = i + 2

In [None]:
# Exemple : Confusion MY/MMY : labels 5 et 7
conf_my_mmy = df_false[((df_false["real"] == 5) & (df_false["pred"] == 7)) | ((df_false["real"] == 7) & (df_false["pred"] == 5))]

fig = plt.figure(figsize = (30,15))
i = 0

for id in np.random.choice(conf_my_mmy.index, size = 5, replace = False):

  img_path = conf_my_mmy.loc[id,"img_path"]
  big_heatmap, superimposed_img = gradcam(model, img_path, alpha = 0.8, plot = False)

  fig.add_subplot(5,2,i+1)
  plt.imshow(plt.imread(img_path))
  plt.title("Label réel : %s ; Label prédit : %s"%(list(label_map.keys())[conf_my_mmy.loc[id,"real"]],
                                                   list(label_map.keys())[conf_my_mmy.loc[id,"pred"]]))
  
  fig.add_subplot(5,2,i+2)
  plt.imshow(superimposed_img)
  i += 2

In [None]:
# Exemple : Confusion PMY/MY : labels 9 et 7
conf_pmy_my = df_false[((df_false["real"] == 9) & (df_false["pred"] == 7)) | ((df_false["real"] == 7) & (df_false["pred"] == 9))]

fig = plt.figure(figsize = (30,15))
i = 0

for id in np.random.choice(conf_pmy_my.index, size = 5, replace = False):

  img_path = conf_pmy_my.loc[id,"img_path"]
  big_heatmap, superimposed_img = gradcam(model, img_path, alpha = 0.8, plot = False)

  fig.add_subplot(5,2,i+1)
  plt.imshow(plt.imread(img_path))
  plt.title("Label réel : %s ; Label prédit : %s"%(list(label_map.keys())[conf_pmy_my.loc[id,"real"]],
                                                   list(label_map.keys())[conf_pmy_my.loc[id,"pred"]]))
  
  fig.add_subplot(5,2,i+2)
  plt.imshow(superimposed_img)
  i += 2

In [None]:
######################################################################### 05/08/21 - FINE TUNING #######################################################################################

In [None]:
# Callbacks
TON = TerminateOnNaN()

control_lr = ReduceLROnPlateau(monitor = 'val_loss',
                               factor = 0.1, patience = 3, verbose = 1, mode = 'min', min_lr = 1e-7)
  
early_stopping = EarlyStopping(monitor = "val_loss",
                               patience = 4, mode = 'min', restore_best_weights = True)

callbacks_list = [TON, control_lr, early_stopping]
                    
## Compute weights :
class_weights = compute_weights(method = 3)

## Unfreeze Block5 + Compile   (Rq: on pourrait essayer de dégeler block4 + block5)

for layer in model.layers:
  if "block5" in layer.name:
    layer.trainable = True

optimizer = tf.keras.optimizers.Adam(learning_rate = 1e-4)
model.compile(optimizer = optimizer,
              loss = "categorical_crossentropy",
              metrics = ["accuracy"])

## Training :
epochs = 20
history = model.fit(x = training_set, 
                    epochs = epochs,
                    callbacks = callbacks_list, 
                    validation_data = validation_set, 
                    class_weight = class_weights)
## remplacer history par fine_history pour éviter d'écraser history...

### Entraînement sans "base_model.training = False" dans VGG (03/08/21): Accuracy : 8.7%, Val Acc : 6.82%, Test Acc : 8.29% => Catastrophe...
# Pas compris pourquoi...

In [None]:
model.evaluate(testing_set)
# Overfitting léger.

In [None]:
## Sauvegarde :
model.save("/content/drive/MyDrive/Leukopy/VGG19_TL_11/model_fullsave_trainfalse")

In [None]:
def print_classification_report(testing_set, labels):

  # Prédiction : utilise le jeu de test (testing_set)
  predictions = model.predict(testing_set)
  y_pred = tf.argmax(predictions, axis = 1)

  # Calcul et affichage de la matrice de confusion
  cnf_matrix = confusion_matrix(testing_set.classes, y_pred)
  classes = range(len(labels))
  
  plt.figure(figsize = (12,12))
  plt.imshow(cnf_matrix, interpolation = 'nearest', cmap = 'Blues')
  plt.title("Matrice de confusion")
  plt.colorbar()

  tick_marks = np.arange(len(labels))
  plt.xticks(tick_marks, labels)
  plt.yticks(tick_marks, labels)

  for i, j in itertools.product(range(cnf_matrix.shape[0]), 
                                range(cnf_matrix.shape[1])):
    plt.text(j, i, cnf_matrix[i, j],
             horizontalalignment = "center",
             color = "white" if cnf_matrix[i, j] > ( cnf_matrix.max() / 2) else "black")

  plt.ylabel('Vrais labels')
  plt.xlabel('Labels prédits')
  plt.show()

  # Rapport de classification 
  report = classification_report(testing_set.classes, y_pred, target_names = labels, output_dict = True)

  df_report = pd.DataFrame(index = list(report.keys())[:-3], columns = list(report["BA"].keys()))
  for key in list(report.keys())[:-3]:
    for column in list(report["BA"].keys()):
      df_report.loc[key, column] = report[key][column]
  
  print("Classification Report : après fine-tuning")
  return display(df_report)

print_classification_report(testing_set, label_map)

# df_report => csv => Drive

In [None]:
"""
Le tuning de VGG a amélioré le F1 de toutes les classes sans exception. Toujours des problèmes avec MY/MMY/PMY et BNE/SNE.
Pistes : 
1°/ fine-tuning du 'block4' en plus du 'block5' 
ou
2°/ Procéder en deux temps : d'abord dégeler le block5 et entraîner, puis dégeler le block4 et entraîner.
Rq : je ne remonterais pas plus haut que le block4, sinon on risque de trop spécialiser le modèle sur les jolies images de l'hôpital de Barcelone
Rq2 : très peu d'augmentation de données dans ce run du modèle... on pourrait introduire de légères variations sur : luminosité / netteté / zoom
"""

In [None]:
### 2 -- DataFrames pour l'exploitation
predictions = model.predict(testing_set)
y_pred = tf.argmax(predictions, axis = 1)

# DF avec le résultat et le label de toutes les images du test :
df_results = pd.DataFrame(data = {"real":testing_set.classes,
                                  "pred":y_pred, 
                                  "img_path":df_test["img_path"]})

# Tri des images : mal classé (df_false), bien classé (df_true):
df_false = df_results[df_results["real"] != df_results["pred"]].reset_index(drop = True)
df_true = df_results[df_results["real"] == df_results["pred"]].reset_index(drop = True)

In [None]:
fig = plt.figure(figsize = (20, 40))
i = 0

for cell_class in range(n_classes):
  df_temp = df_true[df_true["real"] == cell_class]
  id = np.random.choice(df_temp.index, size = 1, replace = False)
  img_path = df_true.loc[id[0],"img_path"]
 
  big_heatmap, superimposed_img = gradcam(model, img_path, alpha = 0.8, plot = False)

  fig.add_subplot(n_classes,2,i+1)
  plt.imshow(plt.imread(img_path))
  plt.title(df_true.loc[id[0],"real"])

  fig.add_subplot(n_classes,2,i+2)
  plt.imshow(superimposed_img)

  i += 2

In [None]:
### 4 -- Exemples d'images mal classées :
# Exemple : Confusion BNE/SNE : labels 1 et 10

conf_bne_sne = df_false[((df_false["real"] == 1) & (df_false["pred"] == 10)) | ((df_false["real"] == 10) & (df_false["pred"] == 1))]

i = 0

fig = plt.figure(figsize = (15,15))

for id in np.random.choice(conf_bne_sne.index, size = 5, replace = False):
  img_path = conf_bne_sne.loc[id,"img_path"]

  big_heatmap, superimposed_img = gradcam(model, img_path, alpha = 0.8, plot = False)

  fig.add_subplot(5,2,i+1)
  plt.imshow(plt.imread(img_path))
  plt.title("Label réel : %s ; Label prédit : %s"%(list(label_map.keys())[conf_bne_sne.loc[id,"real"]],
                                                   list(label_map.keys())[conf_bne_sne.loc[id,"pred"]]))
  
  fig.add_subplot(5,2,i+2)
  plt.imshow(superimposed_img)
  i = i + 2

In [None]:
# Exemple : Confusion MY/MMY : labels 5 et 7
conf_my_mmy = df_false[((df_false["real"] == 5) & (df_false["pred"] == 7)) | ((df_false["real"] == 7) & (df_false["pred"] == 5))]

fig = plt.figure(figsize = (30,15))
i = 0

for id in np.random.choice(conf_my_mmy.index, size = 5, replace = False):

  img_path = conf_my_mmy.loc[id,"img_path"]
  big_heatmap, superimposed_img = gradcam(model, img_path, alpha = 0.8, plot = False)

  fig.add_subplot(5,2,i+1)
  plt.imshow(plt.imread(img_path))
  plt.title("Label réel : %s ; Label prédit : %s"%(list(label_map.keys())[conf_my_mmy.loc[id,"real"]],
                                                   list(label_map.keys())[conf_my_mmy.loc[id,"pred"]]))
  
  fig.add_subplot(5,2,i+2)
  plt.imshow(superimposed_img)
  i += 2

In [None]:
# Exemple : Confusion PMY/MY : labels 9 et 7
conf_pmy_my = df_false[((df_false["real"] == 9) & (df_false["pred"] == 7)) | ((df_false["real"] == 7) & (df_false["pred"] == 9))]

fig = plt.figure(figsize = (30,15))
i = 0

for id in np.random.choice(conf_pmy_my.index, size = 5, replace = False):

  img_path = conf_pmy_my.loc[id,"img_path"]
  big_heatmap, superimposed_img = gradcam(model, img_path, alpha = 0.8, plot = False)

  fig.add_subplot(5,2,i+1)
  plt.imshow(plt.imread(img_path))
  plt.title("Label réel : %s ; Label prédit : %s"%(list(label_map.keys())[conf_pmy_my.loc[id,"real"]],
                                                   list(label_map.keys())[conf_pmy_my.loc[id,"pred"]]))
  
  fig.add_subplot(5,2,i+2)
  plt.imshow(superimposed_img)
  i += 2

In [None]:
"""
On a encore du travail pour ce qui relève de l'interprétabilité du modèle. Grad-CAM montre que le modèle se concentre sur le centre de l'image,
ce qui est positif, mais il ne prend pas en compte la totalité de la cellule, seulement une partie (et qui n'est pas nécessairement le noyau)
Pistes :
- entraîner le modèle sur les images découpées "C_NMC_2019 Dataset" et comparer les Grad-CAM / performances

- tenter un clustering / une PCA en sortie de modèle, puis colorer les images mal classées sur la figure produite par le clustering / la PCA
"""

In [None]:
""" 
Autre : 
- on pourrait intégrer les fonctions Grad-CAM à leuko_lib ... ou faire un module à part.