This notebook contains...

In [None]:
!pip install torchsummary

In [None]:
import os                                      # for working with files
import sys
import shap                                    # for checking feature importances
import torch                                   # Pytorch module 
import shutil
import optuna
import warnings
import numpy as np                             # for numerical computationss
import pandas as pd                            # for working with dataframes
import torch.nn as nn                          # for creating  neural networks
from PIL import Image                          # for checking images
import matplotlib.pyplot as plt                # for plotting informations on graph and images using tensors
import torch.nn.functional as F                # for functions for calculating loss
from torchsummary import summary               # for getting the summary of our model
from torchvision.utils import make_grid        # for data checking
from torch.utils.data import DataLoader        # for dataloaders 
import torchvision.transforms as transforms    # for transforming images into tensors 
from torchvision.datasets import ImageFolder   # for working with classes and images

%matplotlib inline

### Data exploration!

In [None]:
os.listdir('/kaggle/input/dataset/idata/Image Dataset/ImageDataset/')

In [None]:
data_dir = '/kaggle/input/dataset/idata/Image Dataset/ImageDataset/'

In [None]:
# print(f"Number of image directories are {len(os.listdir(data_fpath))+len(os.listdir('/kaggle/input/newds/ImageDataset_new/ImageDataset_new/'))}\n")
print('Number of unique plants are 2, potato and tomato\n')
print('Number of diseases are 4, early and late blight disease for tomato, early and late blight for potato\n')

In [None]:
data_dir

In [None]:
train_dir = data_dir + "train/"
valid_dir = data_dir + "valid/"
# test_dir
diseases_tr = os.listdir(train_dir)
diseases_va = os.listdir(valid_dir)


In [None]:
valid_dir

In [None]:
diseases_tr

In [None]:
plants = []
NumberOfDiseases = 0
for plant in diseases_tr:
    if plant.split('___')[0] not in plants:
        plants.append(plant.split('___')[0])
    if plant.split('_')[1] != 'healthy':
        NumberOfDiseases += 1

In [None]:
# Number of images for each clas in the training data
nums_train = {}
for folder in sorted(os.listdir(f"{data_dir}/train")):
    nums_train[folder] = len(os.listdir(f"/{data_dir}/train/{folder}"))
    
# converting the nums dictionary to pandas dataframe passing index as plant name and number of images as column

img_per_training_class = pd.DataFrame(nums_train.values(), index=nums_train.keys(), columns=["no. of images"])
img_per_training_class

In [None]:
# Number of images for each clas in the training data
nums_valid = {}
for folder in sorted(os.listdir(f"{data_dir}/valid")):
    nums_valid[folder] = len(os.listdir(f"{data_dir}/valid/{folder}"))
    
# converting the nums dictionary to pandas dataframe passing index as plant name and number of images as column

img_per_valid_class = pd.DataFrame(nums_valid.values(), index=nums_valid.keys(), columns=["no. of images"])
img_per_valid_class

In [None]:
# plotting number of images available for each class
index = [n for n in range(6)]
plt.figure(figsize=(20, 5))
plt.bar(index, [n for n in nums_train.values()], color='#8528B0')
plt.xlabel('Classes', fontsize=15)
plt.ylabel('No of images available', fontsize=15)
plt.xticks(index, [key for key in nums_train.keys()], fontsize=15, rotation=90)
plt.title('Images per class for training dataset')

In [None]:
# plotting number of images available for each class
index = [n for n in range(6)]
plt.figure(figsize=(20, 5))
plt.bar(index, [n for n in nums_valid.values()], color='#8528B0')
plt.xlabel('Classes', fontsize=15)
plt.ylabel('No of images available', fontsize=15)
plt.xticks(index, [key for key in nums_valid.keys()], fontsize=15, rotation=90)
plt.title('Images per class for validation dataset')


### Data Augmentation

The data has already been augmented. see https://github.com/Alyeko/potato-tomato-blight-disease-detection

### Images available for training

In [None]:
n_train = 0
for value in nums_train.values():
    n_train += value
print(f"There are {n_train} images for training")

In [None]:
n_valid = 0
for value in nums_valid.values():
    n_valid += value
print(f"There are {n_valid} images for validation")

### Checking if here are non img files in the training data folder


In [None]:
folds = [folder for folder in os.listdir(train_dir)]
folds

In [None]:
for i in folds:
    for img in os.listdir(train_dir+i):
        if not img.endswith('.JPG'):
            print('yes!')

In [None]:
for i in folds:
    for img in os.listdir(valid_dir+i):
        if not img.endswith('.JPG'):
            print('yes!')

In [None]:
data_dir

In [None]:
print(f"There are {len(os.listdir('/kaggle/input/dataset/idata/Image Dataset/test_data/test'))} images for test")

In [None]:
print(f"Training dir: {os.listdir('/kaggle/input/dataset/idata/Image Dataset/ImageDataset/')}")
print(f"All: {os.listdir('/kaggle/input/dataset/idata/Image Dataset')}")

In [None]:
test_dir = '/kaggle/input/dataset/idata/Image Dataset/test_data/'
# print(f"There are {len(os.listdir('/kaggle/input/newds/ImageDataset_new/ImageDataset_new/test_data'))} images for training")
os.listdir(test_dir)

In [None]:
for img in os.listdir(test_dir+'test'):
        if not img.endswith('.JPG'):
            print('Yes! I knew it!')

### Data Preparation for training 

# datasets for validation and training
train = ImageFolder(train_dir, 
                    transform = transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]))

print(train, '\n')
valid = ImageFolder(valid_dir,
                      transform = transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]))
print(valid, '\n')

Next, after loading the data, we need to transform the pixel values of each image (0-255) to 0-1 as neural networks works quite good with normalized data. The entire array of pixel values is converted to torch [tensor](https://pytorch.org/tutorials/beginner/examples_tensor/two_layer_net_tensor.html#:~:text=A%20PyTorch%20Tensor%20is%20basically,used%20for%20arbitrary%20numeric%20computation.) and then divided by 255.
If you are not familiar why normalizing inputs help neural network, read [this](https://towardsdatascience.com/why-data-should-be-normalized-before-training-a-neural-network-c626b7f66c7d) post.

In [None]:
print(train_dir)
print(valid_dir)

In [None]:
# datasets for validation and training
train = ImageFolder(train_dir, transform=transforms.Compose(
                                        [transforms.Resize([256, 256]),
                                         transforms.ToTensor()]))

valid = ImageFolder(valid_dir, transform=transforms.Compose(
                                        [transforms.Resize([256, 256]),
                                         transforms.ToTensor()]))

In [None]:
#Image shape
img, label = train[4590]
print(img.shape, label)

img, label = train[0]
print(img.shape, label)

We can see the shape (3, 256 256) of the image. 3 is the number of channels (RGB) and 256 x 256 is the width and height of the image

In [None]:
len(train.classes) #multiclass classification with 6 classes

In [None]:
# for checking some images from training dataset
def show_image(image, label):
    print("Label :" + train.classes[label] + "(" + str(label) + ")")
    plt.imshow(image.permute(1, 2, 0))
    

In [None]:
# Setting the seed value
random_seed = 7
torch.manual_seed(random_seed)

In [None]:
show_image(*train[10000])

In [None]:
show_image(*train[6580])

In [None]:
show_image(*train[1000])

In [None]:
show_image(*train[5000])

In [None]:
print(train, '\n')
print(valid)

In [None]:
# DataLoaders for training and validation
# setting the batch size
batch_size = 32
train_dl = DataLoader(train, batch_size, shuffle=True, num_workers=2, pin_memory=True)
valid_dl = DataLoader(valid, batch_size, num_workers=2, pin_memory=True)

In [None]:
# helper function to show a batch of training instances
def show_batch(data):
    for images, labels in data:
        fig, ax = plt.subplots(figsize=(30, 30))
        ax.set_xticks([]); ax.set_yticks([])
        ax.imshow(make_grid(images, nrow=8).permute(1, 2, 0))
        break
        
# Images for first batch of training
show_batch(train_dl) 

In [None]:
# 🏗️ Modelling 🏗️

In [None]:
# for moving data into GPU (if available)
def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available:
        return torch.device("cuda")
    else:
        return torch.device("cpu")

# for moving data to device (CPU or GPU)
def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

# for loading in the device (GPU if available else CPU)
class DeviceDataLoader():
    """Wrap a dataloader to move data to a device"""
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
        
    def __iter__(self):
        """Yield a batch of data after moving it to device"""
        for b in self.dl:
            yield to_device(b, self.device)
        
    def __len__(self):
        """Number of batches"""
        return len(self.dl)

In [None]:
device = get_default_device()
device

In [None]:
# Moving data into GPU
train_dl = DeviceDataLoader(train_dl, device)
valid_dl = DeviceDataLoader(valid_dl, device)

In [None]:
class SimpleResidualBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()
        
    def forward(self, x):
        out = self.conv1(x)
        out = self.relu1(out)
        out = self.conv2(out)
        return self.relu2(out) + x # ReLU can be applied before or after adding the input

In [None]:
# for calculating the accuracy
def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))


# base class for the model
class ImageClassificationBase(nn.Module):
    
    def training_step(self, batch):
        images, labels = batch
        out = self(images)                  # Generate predictions
        loss = F.cross_entropy(out, labels) # Calculate loss
        return loss
    
    def validation_step(self, batch):
        images, labels = batch
        out = self(images)                   # Generate prediction
        loss = F.cross_entropy(out, labels)  # Calculate loss
        acc = accuracy(out, labels)          # Calculate accuracy
        return {"val_loss": loss.detach(), "val_accuracy": acc}
    
    def validation_epoch_end(self, outputs):
        batch_losses = [x["val_loss"] for x in outputs]
        batch_accuracy = [x["val_accuracy"] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()       # Combine loss  
        epoch_accuracy = torch.stack(batch_accuracy).mean()
        return {"val_loss": epoch_loss, "val_accuracy": epoch_accuracy} # Combine accuracies
    
    def epoch_end(self, epoch, result):
        print("Epoch [{}], last_lr: {:.5f}, train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
            epoch, result['lrs'][-1], result['train_loss'], result['val_loss'], result['val_accuracy']))
        

In [None]:
# Architecture for training

# convolution block with BatchNormalization
def ConvBlock(in_channels, out_channels, pool=False):
    layers = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
             nn.BatchNorm2d(out_channels),
             nn.ReLU(inplace=True)]
    if pool:
        layers.append(nn.MaxPool2d(4))
    return nn.Sequential(*layers)


# resnet architecture 
class ResNet9(ImageClassificationBase):
    def __init__(self, in_channels, num_diseases):
        super().__init__()
        
        self.conv1 = ConvBlock(in_channels, 64)
        self.conv2 = ConvBlock(64, 128, pool=True) # out_dim : 128 x 64 x 64 
        self.res1 = nn.Sequential(ConvBlock(128, 128), ConvBlock(128, 128))
        
        self.conv3 = ConvBlock(128, 256, pool=True) # out_dim : 256 x 16 x 16
        self.conv4 = ConvBlock(256, 512, pool=True) # out_dim : 512 x 4 x 44
        self.res2 = nn.Sequential(ConvBlock(512, 512), ConvBlock(512, 512))
        
        self.classifier = nn.Sequential(nn.MaxPool2d(4),
                                       nn.Flatten(),
                                       nn.Linear(512, num_diseases), 
                                       nn.Softmax(dim=1))
        
    def forward(self, xb): # xb is the loaded batch
        out = self.conv1(xb)
        out = self.conv2(out)
        out = self.res1(out) + out
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.res2(out) + out
        out = self.classifier(out)
        return out        

In [None]:
# defining the model and moving it to the GPU
model = to_device(ResNet9(3, len(train.classes)), device) 
model

In [None]:
# getting summary of the model
INPUT_SHAPE = (3, 256, 256)
print(summary(model.cuda(), (INPUT_SHAPE)))

In [None]:
# for training
@torch.no_grad()
def evaluate(model, val_loader):
    model.eval()
    outputs = [model.validation_step(batch) for batch in val_loader]
    return model.validation_epoch_end(outputs)


def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']
    

def fit_OneCycle(epochs, max_lr, model, train_loader, val_loader, weight_decay=0,
                grad_clip=None, opt_func=torch.optim.SGD):
    torch.cuda.empty_cache()
    history = []
    
    optimizer = opt_func(model.parameters(), max_lr, weight_decay=weight_decay)
    
    # scheduler for one cycle learniing rate
    sched = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr, epochs=epochs, steps_per_epoch=len(train_loader))
    
    
    for epoch in range(epochs):
        # Training
        model.train()
        train_losses = []
        lrs = []
        for batch in train_loader:
            loss = model.training_step(batch)
            train_losses.append(loss)
            loss.backward()
            
            # gradient clipping
            if grad_clip: 
                nn.utils.clip_grad_value_(model.parameters(), grad_clip)
                
            optimizer.step()
            optimizer.zero_grad()
            
            # recording and updating learning rates
            lrs.append(get_lr(optimizer))
            sched.step()
            
    
        # validation
        result = evaluate(model, val_loader)
        result['train_loss'] = torch.stack(train_losses).mean().item()
        result['lrs'] = lrs
        model.epoch_end(epoch, result)
        history.append(result)
        
    return history
    

In [None]:
%%time
history = [evaluate(model, valid_dl)]
history

------------------------

### Hyperparamter tuning with optuna

Code for hyperparameter tuning adapted from https://towardsdatascience.com/hyperparameter-tuning-of-neural-networks-with-optuna-and-pytorch-22e179efc837

In [None]:
len(valid_dl)

In [None]:
def train_and_evaluate(param, model, train_dl, valid_dl):
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    criterion = nn.CrossEntropyLoss()
    optimizer = getattr(torch.optim, param['optimizer'])(model.parameters(), lr= param['learning_rate'])

    if use_cuda:
            model = model.cuda()
            criterion = criterion.cuda()

    for epoch_num in range(EPOCHS):
            total_acc_train = 0
            total_loss_train = 0
            for train_input, train_label in train_dl:
                train_label = train_label.to(device)
                train_input = train_input.to(device)

                output = model(train_input.float())
                
                batch_loss = criterion(output, train_label.long())
                total_loss_train += batch_loss.item()
                
                acc = (output.argmax(dim=1) == train_label).sum().item()
                total_acc_train += acc

                model.zero_grad()
                batch_loss.backward()
                optimizer.step()
            
            total_acc_val = 0
            total_loss_val = 0
            with torch.no_grad():
                num_diff_val_accuracies = []
                for val_input, val_label in valid_dl:
                    val_label = val_label.to(device)
                    val_input = val_input.to(device)

                    output = model(val_input.float())

                    batch_loss = criterion(output, val_label.long())
                    total_loss_val += batch_loss.item()
                    
                    acc = (output.argmax(dim=1) == val_label).sum().item()
                    num_diff_val_accuracies.append(acc)
                    total_acc_val += acc
                    
                    
            accuracy = total_acc_val/len(num_diff_val_accuracies)                
    return accuracy
    

In [None]:
def objective(trial):
    params = {'learning_rate': trial.suggest_loguniform('learning_rate', 1e-5, 1e-1),
                    'optimizer': trial.suggest_categorical("optimizer", ["Adam", "SGD"]),
                 'weight_decay': trial.suggest_loguniform('weight_decay', 1e-4, 1e-1),
                    'grad_clip': trial.suggest_float('grad_clip', 0.1, 0.4),
                      'epochs' : trial.suggest_int('epochs', 2, 7)
              }
            
    accuracy = train_and_evaluate(params, model,  train_dl, valid_dl)
    return accuracy

In [None]:
optuna.create_study?

In [None]:
%%time
EPOCHS = 8
study = optuna.create_study(direction="maximize", sampler=optuna.samplers.TPESampler(), pruner=optuna.pruners.MedianPruner())
study.optimize(objective, n_trials=8)

In [None]:
best_trial = study.best_trial

for key, value in best_trial.params.items():
    print("{}: {}".format(key, value))

### Visualizing the hyperparameter tuning process

In [None]:
optuna.visualization.plot_intermediate_values(study)

In [None]:
optuna.visualization.plot_optimization_history(study)  #visualizing the tuning history

In [None]:
optuna.visualization.plot_param_importances(study) #visualizing the parameter importances

To-Do!

1. change hyper params and see how accuracy changes
2. import an image, export with dpi and see the difference
3. You have too many bullet points in overleaf, make them paragraph-y

---------------------

In [None]:
epochs = 5
max_lr = 0.01
grad_clip = 0.1
weight_decay = 1e-4
opt_func = torch.optim.SGD
#opt_func = torch.optim.Adam

In [None]:
%%time
history += fit_OneCycle(epochs, max_lr, model, train_dl, valid_dl, 
                             grad_clip=grad_clip, 
                             weight_decay=1e-4, 
                             opt_func=opt_func)

In [None]:
def plot_accuracies(history):
    accuracies = [x['val_accuracy'] for x in history]
    plt.grid(color='#EAE4E3')
    plt.plot(accuracies, '-x', color='black')
    plt.xlabel('epoch')
    plt.ylabel('accuracy')
    plt.title('Accuracy vs. No. of epochs');

def plot_losses(history):
    train_losses = [x.get('train_loss') for x in history]
    val_losses = [x.get('val_loss').cpu().numpy() for x in history] #[x['val_loss'] for x in history]
    plt.grid(color='#EAE4E3')
    plt.plot(train_losses, '-bx')
    plt.plot(val_losses, '-rx')
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.legend(['Training', 'Validation'])
    plt.title('Loss vs. No. of epochs');
    
def plot_lrs(history):
    lrs = np.concatenate([x.get('lrs', []) for x in history])
    plt.grid(color='#EAE4E3')
    plt.plot(lrs)
    plt.xlabel('Batch no.')
    plt.ylabel('Learning rate')
    plt.title('Learning Rate vs. Batch no.');

In [None]:
#Validation accuracy
plot_accuracies(history)

In [None]:
#Validation loss
plot_losses(history)

In [None]:
#Learning Rate overtime
plot_lrs(history)

In [None]:
# os.listdir('../input/imagedataset/ImageDataset/test_data')
# os.listdir('../test')

In [None]:
###Creating a new test dir bcause there was an svn file or folder found in the test dir
os.mkdir('../test_data')
os.mkdir('../test_data/test')

In [None]:
test_dir_old = test_dir
test_dir_new = '../test_data'
print(test_dir_old)
print(test_dir_new)

In [None]:
test_dir_old

In [None]:
###Moving file from old test dir to new test dir
num_moved = 0
for img in os.listdir(test_dir_old+'test'):
    if img.endswith('.JPG'):
        shutil.copy(f"{test_dir_old+'test/'}{img}", f"{test_dir_new+'/test/'}{img}")
        num_moved += 1
    else:
        print('not going to move you!')
print(f"Number of files moved: {num_moved}")

In [None]:
len(os.listdir('../test_data/test')) #files have been moved

In [None]:
#Testing model on test data
test = ImageFolder(test_dir_new, transform=transforms.Compose(
                                        [transforms.Resize([256, 256]),
                                         transforms.ToTensor()]))

In [None]:
test

In [None]:
test_images = sorted(os.listdir(test_dir_new + '/test')) # since images in test folder are not in alphabetical order
#test_images

In [None]:
def predict_image(img, model):
    """Converts image to array and return the predicted class
        with highest probability"""
    # Convert to a batch of 1
    xb = to_device(img.unsqueeze(0), device)
    # Get predictions from model
    yb = model(xb)
    # Pick index with highest probability
    _, preds  = torch.max(yb, dim=1)
    # Retrieve the class label

    return train.classes[preds[0].item()]

In [None]:
len(test)

In [None]:
img, label = test[0]
label

In [None]:
print(len(test_images))
print(len(test))

In [None]:
# predicting first image
img, label = test[-100]
plt.imshow(img.permute(1, 2, 0))
print('Label:', test_images[-100], ', Predicted:', predict_image(img, model))

In [None]:
f"{test_images[0].split('_')[0] + '_'  + test_images[0].split('_')[1]}"

In [None]:
# getting all predictions (actual label vs predicted)
listt = []
for i, (img, label) in enumerate(test):
    #print('Label:', test_images[i], ', Predicted:', predict_image(img, model))
    listt.append((f"{test_images[i].split('_')[0] + '_'+ test_images[i].split('_')[1]}", predict_image(img, model)))
    
#listt

In [None]:
count = 0
for tup in listt:
    if tup[0]==tup[1]:
        count+=1
test_accuracy = count/len(listt)*100
print(round(test_accuracy, 2))

In [None]:
model #check if the softmax layer is present before you save the model

In [None]:
# saving to the kaggle working directory ###check this again
PATH1 = './pt-mdlsd.pth'  
torch.save(model.state_dict(), PATH1)

In [None]:
PATH2 = './pt-mdl.pth' 
torch.save(model, PATH2)

### Checking feature importances

In [None]:
type(train_dl)


In [None]:
# Moving data into GPU
train_dl = DeviceDataLoader(train_dl, device)
valid_dl = DeviceDataLoader(valid_dl, device)

In [None]:
test_loader_r = torch.utils.data.DataLoader(test, 
                                            batch_size=batch_size,
                                            shuffle=True)

test_loader_r = DeviceDataLoader(test_loader_r, device)
test_loader_r

batch = next(iter(test_loader_r))
images, _ = batch

background = images[:20]
test_images = images[20:24]

e = shap.DeepExplainer(model, background)
shap_values = e.shap_values(test_images)

shap_numpy = [np.swapaxes(np.swapaxes(s, 1, -1), 1, 2) for s in shap_values] #how can you add the predicted label vs true label to plot
test_numpy = np.swapaxes(np.swapaxes(test_images.cpu().numpy(), 1, -1), 1, 2)
# plot the feature attributions
shap.image_plot(shap_numpy, -test_numpy)
#plt.rcParams['figure.figsize'] = [10, 10]

In [None]:
# since shuffle=True, this is a random sample of test data
images, targets =  next(iter(test_loader_r))
BACKGROUND_SIZE = 20
background_images = images[:BACKGROUND_SIZE]
background_targets = targets[:BACKGROUND_SIZE].cpu().numpy()
#increase the size after you've fixed everything 

test_images = images[BACKGROUND_SIZE:BACKGROUND_SIZE+9]
test_targets = targets[BACKGROUND_SIZE:BACKGROUND_SIZE+9].cpu().numpy()
def show_attributions(model):
    # predict the probabilities of the digits using the test images
    output = model(test_images.to(device))
    # get the index of the max log-probability
    pred = output.max(1, keepdim=True)[1] 
    # convert to numpy only once to save time
    pred_np = pred.cpu().numpy() 

    expl = shap.DeepExplainer(model, background_images)
    train_classes = ['potato_early', 'potato_healthy', 'potato_late', 'tomato_early', 'tomato_healthy', 'tomato_late'] 
    for i in range(0, len(test_images)):
        warnings.filterwarnings('ignore')
        
        torch.cuda.empty_cache()
        ti = test_images[[i]]
        sv = expl.shap_values(ti)
        sn = [np.swapaxes(np.swapaxes(s, 1, -1), 1, 2) for s in sv]
        tn = np.swapaxes(np.swapaxes(ti.cpu().numpy(), 1, -1), 1, 2) #.cpu().numpy()?

        # Prepare the attribution plot, but do not draw it yet
        # We will add more info to the plots later in the code
        shap.image_plot(sn, -tn, show=False)

        # Prepare to augment the plot
        fig = plt.gcf()
        allaxes = fig.get_axes()

        # Show the actual/predicted class
        #plot the original image here as well
        allaxes[0].set_title('Actual: {}, Pred: {}'.format(train_classes[test_targets[i]], train_classes[pred_np[i][0]]), fontsize=10)
        
        
        # Show the probability of each class
        # There are 11 axes for each picture: 1 for the digit + 10 for each SHAP
        # There is a last axis for the scale - we don't want to apply a label for that one
        prob = output[i].detach().cpu().numpy()
        for x in range(1, len(allaxes)-1):
            #allaxes[x].set_title('{}'.format(train_classes[x-1]), fontsize=10)
            allaxes[x].set_title('{}({:.2%})'.format(train_classes[x-1], prob[x-1]), fontsize=10)
            allaxes[0].imshow(test_images[i].cpu().permute(1, 2, 0))
#            
#             allaxes[x].set_title('{}({:.2%})'.format(train_classes[x-1], prob[x-1]), fontsize=9)            

        plt.show()

In [None]:
feature_attributions = show_attributions(model)
feature_attributions

In [None]:
type(feature_attributions)