In [124]:
import os                                      # for working with files
import cv2                                     #for reading the images
import shap                                    # for checking feature importances
import torch                                   # Pytorch module 
import shutil                                  #copying and moving files
import warnings                                #ignoring warnings
import numpy as np                             # for numerical computations
import pandas as pd                            # for working with dataframes
from PIL import Image                          # for checking images
import seaborn as sns
import torch.nn as nn                          # for creating  neural networks
import matplotlib.pyplot as plt                # for plotting informations on graph and images using tensors
import torch.nn.functional as F                # for functions for calculating loss
#from torchsummary import summary              # for getting the summary of our model
import sklearn.metrics as metrics              # using the sklearn metrics
import torchvision.transforms as T             #pytorch transformations
from skimage.segmentation import slic          #skimage slic method
from torch.utils.data import DataLoader        # for dataloaders 
from torchvision.utils import make_grid        # for data checking
import torchvision.transforms as transforms    # for transforming images into tensors 
from torchvision.datasets import ImageFolder   # for working with classes and images
from sklearn.metrics import confusion_matrix
from matplotlib.colors import LinearSegmentedColormap

%matplotlib inline

----------------

In [2]:
#changes the text on matplotlib plots to the computer modern font style
from matplotlib import font_manager
font_path = '../input/compumodern/cmu.serif-roman.ttf'
font_manager.fontManager.addfont(font_path)
prop = font_manager.FontProperties(fname=font_path)

plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = prop.get_name()

----------------------------------------

In [34]:
# for moving data into GPU (if available)
def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available:
        return torch.device("cuda")
    else:
        return torch.device("cpu")

# for moving data to device (CPU or GPU)
def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

# for loading in the device (GPU if available else CPU)
class DeviceDataLoader():
    """Wrap a dataloader to move data to a device"""
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
        
    def __iter__(self):
        """Yield a batch of data after moving it to device"""
        for b in self.dl:
            yield to_device(b, self.device)
        
    def __len__(self):
        """Number of batches"""
        return len(self.dl)


In [4]:
class SimpleResidualBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()
        
    def forward(self, x):
        out = self.conv1(x)
        out = self.relu1(out)
        out = self.conv2(out)
        return self.relu2(out) + x # ReLU can be applied before or after adding the input

In [5]:
# for calculating the accuracy
def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))


# base class for the model
class ImageClassificationBase(nn.Module):
    
    def training_step(self, batch):
        images, labels = batch
        out = self(images)                  # Generate predictions
        loss = F.cross_entropy(out, labels) # Calculate loss
        return loss
    
    def validation_step(self, batch):
        images, labels = batch
        out = self(images)                   # Generate prediction
        loss = F.cross_entropy(out, labels)  # Calculate loss
        acc = accuracy(out, labels)          # Calculate accuracy
        return {"val_loss": loss.detach(), "val_accuracy": acc}
    
    def validation_epoch_end(self, outputs):
        batch_losses = [x["val_loss"] for x in outputs]
        batch_accuracy = [x["val_accuracy"] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()       # Combine loss  
        epoch_accuracy = torch.stack(batch_accuracy).mean()
        return {"val_loss": epoch_loss, "val_accuracy": epoch_accuracy} # Combine accuracies
    
    def epoch_end(self, epoch, result):
        print("Epoch [{}], last_lr: {:.5f}, train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
            epoch, result['lrs'][-1], result['train_loss'], result['val_loss'], result['val_accuracy']))
        

In [6]:
# Architecture for training
# convolution block with BatchNormalization
def ConvBlock(in_channels, out_channels, pool=False):
    layers = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
             nn.BatchNorm2d(out_channels),
             nn.ReLU(inplace=True)]
    if pool:
        layers.append(nn.MaxPool2d(4))
    return nn.Sequential(*layers)


# resnet architecture 
class ResNet9(ImageClassificationBase):
    def __init__(self, in_channels, num_diseases):
        super().__init__()
        
        self.conv1 = ConvBlock(in_channels, 64)
        self.conv2 = ConvBlock(64, 128, pool=True) # out_dim : 128 x 64 x 64 
        self.res1 = nn.Sequential(ConvBlock(128, 128), ConvBlock(128, 128))
        
        self.conv3 = ConvBlock(128, 256, pool=True) # out_dim : 256 x 16 x 16
        self.conv4 = ConvBlock(256, 512, pool=True) # out_dim : 512 x 4 x 44
        self.res2 = nn.Sequential(ConvBlock(512, 512), ConvBlock(512, 512))
        
        self.classifier = nn.Sequential(nn.MaxPool2d(4),
                                       nn.Flatten(),
                                       nn.Linear(512, num_diseases), 
                                       nn.Softmax(dim=1))
        
    def forward(self, xb): # xb is the loaded batch
        out = self.conv1(xb)
        out = self.conv2(out)
        out = self.res1(out) + out
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.res2(out) + out
        out = self.classifier(out)
        return out        

--------------------------

In [7]:
PATH = '../input/models/pt-mdlsd.pth'

model = ResNet9(3, 6)
model.load_state_dict(torch.load(PATH))
model.eval()

-----------------------------------------------

In [8]:
model.conv1[0].weight

In [9]:
model.conv1

In [10]:
for param_tensor in reversed(model.state_dict()):
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

---------------------------

In [None]:
#modelbfloat16 2 -> not very good model
PATH_ri = '../input/models/ptri-mdlsd.pth'

model_ri = ResNet9(3, 6)
model_ri.load_state_dict(torch.load(PATH_ri))
model_ri.eval()

In [35]:
device = get_default_device()
device

### Tetsing and Evaluation

find the maximum divide by max so it becomes one then set a threshold to see if the saliency map changes after the thresholding

focus on the saliency map of the


reduce the training set like 10% for the model.... as a way to reduce the performance of the mdel

manual evaluation, 

In [11]:
os.listdir('../input/testing-data')

In [12]:
#Testing model on test data
test_dir = '../input/testing-data/test_data'
test = ImageFolder(test_dir, transform=transforms.Compose(
                                        [transforms.Resize([256, 256]),
                                         transforms.ToTensor()]))

In [13]:
test_images = sorted(os.listdir(test_dir + '/test')) # since images in test folder are not in alphabetical order
#test_images

In [14]:
print(len(test_images))
print(len(test))

In [15]:
###Creating a new test dir bcause there was an svn file or folder found in the test dir
os.mkdir('../test_data')
os.mkdir('../test_data/test')

In [16]:
test_dir

In [17]:
os.listdir('../')

In [18]:
test_dir_old = test_dir
test_dir_new = '../test_data'
print(test_dir_old)
print(test_dir_new)

In [19]:
os.listdir(test_dir_new)

In [20]:
# os.listdir(test_dir_new+'/test')
train_classes = ['potato_early', 'potato_healthy', 'potato_late', 'tomato_early', 'tomato_healthy', 'tomato_late'] 
print(train_classes)
for theclass in train_classes:
#     print(f"{test_dir_new}/test/{theclass}")
    os.mkdir(f"{test_dir_new}/test/{theclass}")

In [21]:
os.listdir('../test_data/test')

In [22]:
os.listdir(test_dir_old+'/test')[-1].split('.')[0].split('_')[0] + '_' + os.listdir(test_dir_old+'/test')[-1].split('.')[0].split('_')[1]

In [23]:
os.listdir('../test_data/test')


In [24]:
img = 'tomato_late_23.JPG'
print(f"{test_dir_new+'/test/'+theclass+'/'+img}")

In [25]:
###Moving file from old test dir to new test dir
num_moved = 0
for img in os.listdir(test_dir_old+'/test'):
    if img.endswith('.JPG'):
        theclass = img.split('_')[0] + '_' + img.split('.')[0].split('_')[1]
        #print(theclass)
        shutil.copy(f"{test_dir_old+'/test/'}{img}", f"{test_dir_new+'/test/'+theclass+'/'+img}")
        num_moved += 1
    elif img.endswith('svn'):
        print('not going to move you!')
print(f"Number of files moved: {num_moved}")

###Moving file from old test dir to new test dir
num_moved = 0
for img in os.listdir(test_dir_old+'/test'):
    if img.endswith('.JPG'):
        shutil.copy(f"{test_dir_old+'/test/'}{img}", f"{test_dir_new+'/test/'}{img}")
        num_moved += 1
    else:
        print('not going to move you!')
print(f"Number of files moved: {num_moved}")

In [26]:
test_dir_old

In [27]:
os.listdir(test_dir_new+'/test')

In [28]:
#Testing model on test data
test = ImageFolder(test_dir_new+'/test', transform=transforms.Compose(
                                        [transforms.Resize([256, 256]),
                                         transforms.ToTensor()]))

In [29]:
img, label = test[-1]
label

In [30]:
type(test)

test_loader = torch.utils.data.DataLoader(test, 
    batch_size=batch_size, shuffle=True)

DeviceDataLoader(test, device)

In [31]:
test

In [32]:
def predict_image(img, model):
    """Converts image to array and return the predicted class
        with highest probability"""
    # Convert to a batch of 1
    train_classes = ['potato_early', 'potato_healthy', 'potato_late', 'tomato_early', 'tomato_healthy', 'tomato_late']
    xb = to_device(img.unsqueeze(0), device)
    # Get predictions from model
    yb = model(xb)
    # Pick index with highest probability
    _, preds  = torch.max(yb, dim=1)
    # Retrieve the class label

    return train_classes[preds[0].item()]

In [36]:
model.to(device)

In [None]:
model_ri.to(device)

In [37]:
print(len(test_images))
print(len(test))

In [38]:
test_images = []
for tclass in train_classes:
    for img in os.listdir(test_dir_new + '/test/'+ tclass):
        test_images.append(img)
        
test_images = sorted(test_images)

In [39]:
print(len(test_images))
print(len(test))

In [40]:
# predicting last image
img, label = test[-1]
plt.imshow(img.permute(1, 2, 0))
print('Label:', test_images[-1], ', Predicted:', predict_image(img, model))

In [None]:
# predicting last image
img, label = test[-1]
plt.imshow(img.permute(1, 2, 0))
print('Label:', test_images[-1], ', Predicted:', predict_image(img, model_ri))

In [41]:
train_classes

In [42]:
print(len(test))
print(len(test_images))

In [43]:
# getting all predictions (actual label vs predicted) for model 1
listt = []
for i, (img, label) in enumerate(test):
    #print('Label:', test_images[i], ', Predicted:', predict_image(img, model))
    listt.append((f"{test_images[i].split('_')[0] + '_'+ test_images[i].split('_')[1]}", predict_image(img, model)))
    
#listt
count = 0
for tup in listt:
    if tup[0]==tup[1]:
        count+=1
test_accuracy = count/len(listt)*100
print(round(test_accuracy, 2), '%')

In [None]:
# getting all predictions (actual label vs predicted) for model 2
listtri = []
for i, (img, label) in enumerate(test):
    #print('Label:', test_images[i], ', Predicted:', predict_image(img, model))
    listtri.append((f"{test_images[i].split('_')[0] + '_'+ test_images[i].split('_')[1]}", predict_image(img, model_ri)))
    
#listtri
count = 0
for tup in listtri:
    if tup[0]==tup[1]:
        count+=1
test_accuracy = count/len(listtri)*100
print(round(test_accuracy, 2), '%')

### Fetaure Attribution

batch_size=32
test_loader = torch.utils.data.DataLoader(test, 
    batch_size=batch_size, shuffle=True)

test_loader = DeviceDataLoader(test_loader, device)

#Trying out shap
# since shuffle=True, this is a random sample of test data
batch = next(iter(test_loader))
images, _ = batch

background = images[:20]
test_images = images[20:24]

e = shap.DeepExplainer(model, background)
shap_values = e.shap_values(test_images)

shap_numpy = [np.swapaxes(np.swapaxes(s, 1, -1), 1, 2) for s in shap_values]
test_numpy = np.swapaxes(np.swapaxes(test_images.cpu().numpy(), 1, -1), 1, 2)
# plot the feature attributions
shap.image_plot(shap_numpy, -test_numpy)
plt.rcParams['figure.figsize'] = [10, 10]

In [44]:
type(test)

In [45]:
batch_size=32
test_loader_r = torch.utils.data.DataLoader(test, 
                                            batch_size=batch_size,
                                            shuffle=True, 
                                            num_workers=2, 
                                            pin_memory=True)

test_loader_r = DeviceDataLoader(test_loader_r, device)
test_loader_r

-----------------------

### Evaluation Metrics!

#### Confusion Matrix

In [46]:
predictions, targets = [], []  #code adapted from https://stackoverflow.com/questions/63647547/how-to-find-confusion-matrix-and-plot-it-for-image-classifier-in-pytorch
for images, labels in test_loader_r:
    logps = model(images)
    output = torch.exp(logps)
    pred = torch.argmax(output, 1)

    # convert to numpy arrays
    pred = pred.detach().cpu().numpy()
    labels = labels.detach().cpu().numpy()
    
    for i in range(len(pred)):
        predictions.append(pred[i])
        targets.append(labels[i])

predictions
#print(len(predictions))

In [51]:
targets
#print(len(targets))

In [91]:
def create_confusion_matrix(y_true, y_pred, classes):
    """ creates and plots a confusion matrix given two list (targets and predictions)
    :param list y_true: list of all targets (in this case integers bc. they are indices)
    :param list y_pred: list of all predictions (in this case one-hot encoded)
    :param dict classes: a dictionary of the countries with they index representation
    """

    amount_classes = len(classes)

    confusion_matrix = np.zeros((amount_classes, amount_classes))
    for idx in range(len(y_true)):
        target = y_true[idx]

        output = y_pred[idx]
        #output = list(output).index(max(output))

        confusion_matrix[target][output] += 1
    #print(confusion_matrix)
    fig, ax = plt.subplots(1)

    ax.matshow(confusion_matrix)
    ax.set_xticks(np.arange(len(list(classes.keys()))))
    ax.set_yticks(np.arange(len(list(classes.keys()))))

    ax.set_xticklabels(list(classes.keys()))
    ax.set_yticklabels(list(classes.keys()))

    plt.setp(ax.get_xticklabels(), rotation=90, ha="left", rotation_mode="anchor")
    plt.setp(ax.get_yticklabels(), rotation=45, ha="right", rotation_mode="anchor")

    plt.show()
    
    return confusion_matrix

for idx in range(len(targets)):
    print(idx)
    targets[idx]

In [92]:
classes = {train_classes[0]:0,
           train_classes[1]:1,
           train_classes[2]:2, 
           train_classes[3]:3, 
           train_classes[4]:4, 
           train_classes[5]:5}
len(classes)

In [93]:
create_confusion_matrix(targets, predictions, classes)

In [94]:
cmatrix = create_confusion_matrix(targets, predictions, classes)
cmatrix

In [100]:
cf_matrix = confusion_matrix(targets, predictions)
cf_matrix

In [101]:
class_names = train_classes
dataframe = pd.DataFrame(cf_matrix, index=class_names, columns=class_names)
dataframe


In [123]:
plt.figure(figsize=(8, 8))

#Create heatmap
sns.heatmap(dataframe, annot=True, cbar=True,cmap="OrRd",fmt="d") #'purples', 'PuRd'
plt.title("Confusion Matrix", size=12), plt.tight_layout()
 
plt.ylabel("Target Class", size=12), 
plt.xlabel("Predicted Class",  size=12)
plt.tight_layout()
plt.savefig('../working/cmatrx_orangewcbar.png', dpi=600,  bbox_inches="tight")
plt.show()

In [None]:
def precision(outputs, labels):
op = outputs.cpu()
la = labels.cpu()
_, preds = torch.max(op, dim=1)
return torch.tensor(precision_score(la,preds, average=‘weighted’))

---------------------

for images, labels in test_loader_r:
    print(labels)

In [None]:
import warnings
warnings.filterwarnings('ignore', category=DeprecationWarning)

# not the actual function
images, targets =  next(iter(test_loader_r))
BACKGROUND_SIZE = 20
background_images = images[:BACKGROUND_SIZE]
background_targets = targets[:BACKGROUND_SIZE].cpu().numpy()
#increase the size after you've fixed everything 

test_images = images[BACKGROUND_SIZE:BACKGROUND_SIZE+3]
test_targets = targets[BACKGROUND_SIZE:BACKGROUND_SIZE+3].cpu().numpy()

# predict the probabilities of the digits using the test images
output = model(test_images.to(device))
# get the index of the max log-probability
pred = output.max(1, keepdim=True)[1] 
# convert to numpy only once to save time
pred_np = pred.cpu().numpy() 

expl = shap.DeepExplainer(model, background_images)
train_classes = ['potato_early', 'potato_healthy', 'potato_late', 'tomato_early', 'tomato_healthy', 'tomato_late'] 

for i in range(0, len(test_images)):
    warnings.filterwarnings('ignore')
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")

    torch.cuda.empty_cache()
    ti = test_images[[i]]
    sv = expl.shap_values(ti)
    sn = [np.swapaxes(np.swapaxes(s, 1, -1), 1, 2) for s in sv]
    tn = np.swapaxes(np.swapaxes(ti.cpu().numpy(), 1, -1), 1, 2) #.cpu().numpy()?

    # Prepare the attribution plot, but do not draw it yet
    # We will add more info to the plots later in the code
    shap.image_plot(sn, -tn, show=False)

    # Prepare to augment the plot
    fig = plt.gcf()
    allaxes = fig.get_axes()

    # Show the actual/predicted class
    #plot the original image here as well
    allaxes[0].set_title('Actual: {}, Pred: {}'.format(train_classes[test_targets[i]], train_classes[pred_np[i][0]]), fontsize=10)


    # Show the probability of each class
    prob = output[i].detach().cpu().numpy()
    allaxes[0].imshow(test_images[i].cpu().permute(1, 2, 0))
    for x in range(1, len(allaxes)-1):
        allaxes[x].set_title('{}({:.2%})'.format(train_classes[x-1], prob[x-1]), fontsize=10)

#plt.savefig('../working/figure saved.png', dpi=600)
plt.show()

In [None]:
# type(sv[0])

In [None]:
# sv[0].shape
max_val = np.max([np.max(np.abs(sv[i][:,:-1])) for i in range(len(sv))])
print(f'max val: {max_val}')

inds = pred_np[0][0]
print(f'inds: {inds}')

In [None]:
# #len(sv) 6 values present here
# plt.imshow(sv[0].reshape(sv[0].shape[-1], sv[0].shape[-2], sv[0].shape[-3]))

In [None]:
output.max(1)

In [None]:
#Dealing with segments slic
img = ti[0].cpu().permute(1, 2, 0).cpu().numpy()
segments_slic = slic(img, n_segments=100, compactness=30, sigma=3)

plt.imshow(segments_slic);
plt.axis('off');

In [None]:
segments_slic.shape

In [None]:
def fill_segmentation(values, segmentation):
    out = np.zeros(segmentation.shape)
    for i in range(len(values)):
        out[segmentation == i] = values[i]
    return out

In [None]:
# len(sv[inds][0][0][0])  -> 256([])
sv[inds].shape

In [None]:
# make a color map
from matplotlib.colors import LinearSegmentedColormap
colors = []
for l in np.linspace(1, 0, 100):
    colors.append((245 / 255, 39 / 255, 87 / 255, l))
for l in np.linspace(0, 1, 100):
    colors.append((24 / 255, 196 / 255, 93 / 255, l))
cm = LinearSegmentedColormap.from_list("shap", colors)

In [None]:
# sv[inds[i]][0]
#Dealing with segments slic
img = ti[0].cpu().permute(1, 2, 0).cpu().numpy()
segments_slic = slic(img, n_segments=100, compactness=30, sigma=3)
# plt.imshow(segments_slic);
# plt.axis('off');

fig, axes = fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(12,4))
m = fill_segmentation(sv[inds][0][0][0], segments_slic) ##otherwise sv[inds][0]
# axes[i+1].set_title(feature_names[str(inds[i])][1])
pil_img = Image.fromarray((img*255).astype(np.uint8))

axes[0].imshow(img)
# axes[1].imshow(np.array(pil_img.convert('LA'))[:, :, 0], alpha=0.5)
max_val = 0.0008
im = axes[1].imshow(m, cmap=cm, vmin=-max_val, vmax=max_val)
axes[0].axis('off')
axes[1].axis('off')
# cb = fig.colorbar(im, ax=axes.ravel().tolist(), label="SHAP value", orientation="horizontal", aspect=60)
# cb.outline.set_visible(True)
plt.tight_layout()
plt.show()

In [None]:
plt.imshow(m)

In [None]:
# Image.fromarray(img)
Image.fromarray((img)*255).convert('LA')

In [None]:
# theimg = Image.fromarray(img.astype('uint8')*255, 'RGB')
# theimg

new_im = Image.fromarray(img)
new_im.show()

In [None]:
Image.fromarray((img * 1).astype(np.uint8)).convert('RGB')

In [None]:
plt.imshow(img)

In [None]:
np.zeros((1, 50)).shape

plt.imshow(ti[0].cpu().permute(1, 2, 0)) #works

plt.imshow(np.squeeze(sv[0]).reshape(256, 256, 3))

In [None]:
ti.shape

In [None]:
plt.imshow(ti[0].cpu().permute(1, 2, 0).cpu().numpy())

-----------

In [None]:
targets

In [None]:
#This is the actual function which works well
images, targets =  next(iter(test_loader_r))
BACKGROUND_SIZE = 20
background_images = images[:BACKGROUND_SIZE]
background_targets = targets[:BACKGROUND_SIZE].cpu().numpy()

test_images = images[BACKGROUND_SIZE:BACKGROUND_SIZE+7]
test_targets = targets[BACKGROUND_SIZE:BACKGROUND_SIZE+7].cpu().numpy()

def show_attributions(model):
    # predict the probabilities of the digits using the test images
    output = model(test_images.to(device))
    # get the index of the max log-probability
    pred = output.max(1, keepdim=True)[1] 
    # convert to numpy only once to save time
    pred_np = pred.cpu().numpy() 

    expl = shap.DeepExplainer(model, background_images)
    train_classes = ['potato_early', 'potato_healthy', 'potato_late', 'tomato_early', 'tomato_healthy', 'tomato_late'] 
    
    for i in range(0, len(test_images)):
        warnings.filterwarnings('ignore')
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
        
        torch.cuda.empty_cache()
        ti = test_images[[i]]
        sv = expl.shap_values(ti)
        sn = [np.swapaxes(np.swapaxes(s, 1, -1), 1, 2) for s in sv]
        tn = np.swapaxes(np.swapaxes(ti.cpu().numpy(), 1, -1), 1, 2) #.cpu().numpy()?

        # Prepare the attribution plot, but do not draw it yet
        # We will add more info to the plots later in the code
        shap.image_plot(sn, -tn, show=False)

        # Prepare to augment the plot
        fig = plt.gcf()
        allaxes = fig.get_axes()

        # Show the actual/predicted class
        #plot the original image here as well
        allaxes[0].set_title('Actual: {}, Pred: {}'.format(train_classes[test_targets[i]], train_classes[pred_np[i][0]]), fontsize=10)
        
        
        # Show the probability of each class
        prob = output[i].detach().cpu().numpy()
        allaxes[0].imshow(test_images[i].cpu().permute(1, 2, 0))
        for x in range(1, len(allaxes)-1):
            allaxes[x].set_title('{}({:.2%})'.format(train_classes[x-1], prob[x-1]), fontsize=10)
                    
        #plt.savefig('../working/figure saved.png', dpi=600)
    plt.show()

PYTORCH_CUDA_ALLOC_CONF="max_split_size_mb"

In [None]:
show_attributions(model) 

In [None]:
show_attributions(model_ri)

In [None]:
plt.imshow(plt.imread('../working/figure saved.png'))
plt.axis('off')
fig = plt.gcf()
axes = fig.get_axes()

In [None]:
axes[0]

1. set a threshold across the pixels to see what happens
2. Also find a way to make a sublot where all individual plots go

what can we do with the saliency maps
image classification change in the report
more on deep learning 
avoid large bullet points

---------

In [None]:
first thing is to get all plots on one plt.subplots figure 
then try to access the individual plots then 
make the segmentation on the individual plot

In [None]:
len_ti = 7
fig, axes = plt.subplots(len_ti, 1)
[axi.set_axis_off() for axi in axes.ravel()]

In [None]:
preprocess_input?