# A Primer on Artificial Intelligence in Plant Digital Phenomics: Embarking on the Data to Insights Journey (*Tutorial*)

This tutorial is a supplement to the paper **A Primer on Artificial Intelligence in Plant Digital Phenomics: Embarking on the Data to Insights Journey** (submitted to *Trends in Plant Science, 2021*) by Antoine L. Harfouche, Farid Nakhle, Orlando G. Sardella, Antoine H.
Harfouche, Eli Dart, and Daniel Jacobson.

Read the accompanying paper [here](https://doi.org).

This interactive tutorial aims to train, for the first time, an interpretable by design model to identify and classify cassava plant diseases, and to explain its predictions.

This tutorial covers:
- Downloading and extracting a dataset to Google Colab from a remote repository.
- Exploring the dataset class distribution frequency using descriptive data analysis.
- Training the interpretable by design 'this looks like that' explainable artificial intelligence (X-AI) algorithm with an augmented training dataset.
- Analyzing the model performance by generating a confusion matrix using a test dataset.
- Generating explanations for the predictions made by the model.

**NB:** 
- Basic data preprocessing steps, including data splitting, balancing, cropping, and segmenting are explained in our previous tutorial **Ready, Steady, Go AI: A Practical Tutorial on Fundamentals of Artificial Intelligence and Its Applications in Phenomics Image Analysis** where their code is implemented in interactive notebooks hosted on our Github repository at https://github.com/HarfoucheLab/Ready-Steady-Go-AI. These steps will be reffered to in this tutorial where needed, directly linking to the corresponding notebook.
- 'This looks like that' algorithm was created, introduced, and developed by Chaofan Chen, Oscar Li, Chaofan Tao, Alina Jade Barnett, Jonathan Su, and Cynthia Rudin. 2019. This looks like that: deep learning for interpretable image recognition. Proceedings of the 33rd International Conference on Neural Information Processing Systems. Curran Associates Inc., Red Hook, NY, USA, Article 801, 8930–8941.
- The cassava dataset consists of 21,397 labeled images collected during a regular survey in Uganda where images were crowdsourced from farmers taking photos of their gardens, and annotated by experts at the National Crops Resources Research Institute (NaCRRI) in collaboration with the AI lab at Makerere University, Kampala. The dataset is publicly available on the Kaggle repository at https://www.kaggle.com/c/cassava-leaf-disease-classification.



Before diving into the code to train the interpretable by design 'this looks like that' X-AI algorithm, the next section will briefly cover the differences between post-hoc explainable models and interpretable by design models.

#Opening the Black Box *vs.* Designing a Transparent Glass Box

AI models are commonly referred to as a black boxes because they do not reveal their internal mechanisms to their users. Such models are created directly from data and, not even the scientists who created them can understand or explain what exactly is happening inside them or how they made a specific prediction.
As AI becomes more advanced and widely adopted, scientists are challenged to comprehend and retrace how a model came to a prediction.



In an attempt towards opening black box models, approaches that make the inner workings of AI models understandable to humans have been developed. These approaches consist of creating a second (post-hoc) model to explain the first black box model. Post-hoc models can be classified based on whether they are applicable to all AI algorithms (i.e., model-agnostic) or only to one AI algorithm (i.e., model-specific); they often employ data perturbation strategies which involve modifying the input data and observing the changes in the black box model predictions. Based on these changes, they identify which parts of data have been important for the predictions and thus, generate an explanation. However, according to [Cynthia Rudin](https://doi.org/10.1038/s42256-019-0048-x), these explanations are unreliable as they cannot have perfect fidelity with respect to the original model. Rudin explains that if the explanation was completely faithful to what the black box model computes, the post-hoc model predictions and explanations would then be equal to the predictions of the black box model, and thus, one would not need the black box model in the first place, only the post-hoc one. But since this is not the case, this leads to the danger that any explanation method for a black box model can be an inaccurate representation of the original model. Even a post-hoc model that predicts almost identically to a black box model might use completely different features, and is thus not faithful to the computation of the black box one.

As a solution, other approaches aimed to develop models that are interpretable by design; they provide their own explanations, which are faithful to what the model actually computes. For example, in image analysis, the 'this looks like that' algorithm appends a special prototype layer to the end of a deep convolutional neural network where, during training, the prototype layer finds parts of training images that act as prototypes for each class. Thus, during testing, when a new test image needs to be evaluated, the network finds parts of the test image that are similar to the prototypes it learned during training. The final class prediction of the network is based on the weighted sum of similarities to the prototypes; this is the sum of evidence throughout the image for a particular class. The explanations given by the network are the prototypes. These explanations are the actual computations of the model, and are not post-hoc explanations.



In this tutorial, we will show how to use the 'this looks like that' interpretable by design algorithm to identify and classify cassava plant diseases, and provide an explanation for the predictions.

#Coding

Before starting, you should note that graphics processing units (GPUs) can dramatically increase training speed thanks to their processing cores initially designed to process visual data such as videos and images.
It is recommended to use a GPU instance for faster training. By default, this notebook runs on GPU. If you would like to change the instance type, check Colab docs [here](https://colab.research.google.com/notebooks/gpu.ipynb).

Let us start by checking the number and type of GPUs that Google Colab assigned us for this session:

In [None]:
import tensorflow as tf
import torch

print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))

The original version of the cassava dataset consists of 21,397 labeled images collected during a regular survey in Uganda where images were crowdsourced from farmers taking photos of their gardens, and annotated by experts at the National Crops Resources Research Institute (NaCRRI) in collaboration with the AI lab at Makerere University, Kampala. The dataset is publicly available on the Kaggle repository at https://www.kaggle.com/c/cassava-leaf-disease-classification.

However, we will use a manually cleaned version of the dataset consisting of 17,190 images.
In addition, the dataset has been split (60% training, 20% validation, 20% testing), cropped, and balanced (see Figure 4 in the paper). Please visit the following notebooks for our tutorials on:
- [Data Splitting Using split-folders](https://colab.research.google.com/github/faridnakhle/RSG/blob/main/1.%20RSG_Data%20splitter.ipynb)
- [Image Cropping Using the 'you only look once' (YOLO) AI Algorithm](https://colab.research.google.com/github/faridnakhle/RSG/blob/main/2.%20RSG_Leaf%20cropper.ipynb)
- [Image Segmentation Using SegNet AI Algorithm](https://colab.research.google.com/github/faridnakhle/RSG/blob/main/3.%20RSG_Leaf%20segmenter.ipynb)
- [Data Balancing by Oversampling with Geometric Transformations Using Augmentor](https://colab.research.google.com/github/faridnakhle/RSG/blob/main/4.%20RSG_Oversample%20with%20Augmentor.ipynb)
- [Data Balancing by Oversampling with Synthetic Data Using Deep Convolutional Generative Adverserial Network (DCGAN) AI Algorithm](https://colab.research.google.com/github/faridnakhle/RSG/blob/main/5.%20RSG_Oversample%20with%20DCGAN.ipynb)
- [Data Balancing by Downsampling Using K Nearest Neighbor AI Algorithm](https://colab.research.google.com/github/faridnakhle/RSG/blob/main/6.%20RSG_Downsample%20with%20KNN.ipynb)


The following code block will download the prepared cassava dataset which is hosted on Google Drive. 

In [None]:
import requests

def download_file_from_google_drive(id, destination):
    URL = "https://docs.google.com/uc?export=download"

    session = requests.Session()

    response = session.get(URL, params = { 'id' : id }, stream = True)
    token = get_confirm_token(response)

    if token:
        params = { 'id' : id, 'confirm' : token }
        response = session.get(URL, params = params, stream = True)

    save_response_content(response, destination)    

def get_confirm_token(response):
    for key, value in response.cookies.items():
        if key.startswith('download_warning'):
            return value

    return None

def save_response_content(response, destination):
    CHUNK_SIZE = 32768

    with open(destination, "wb") as f:
        for chunk in response.iter_content(CHUNK_SIZE):
            if chunk: # filter out keep-alive new chunks
                f.write(chunk)



In [None]:
file_id = '13jwC684Sg1wWLhF7SjPIlsfJNuKqJ_IQ'
destination = '/content/dataset.zip'
download_file_from_google_drive(file_id, destination)

Next, we will create a folder called 'dataset' under /content/, and extract the downloaded dataset to it. As a result, three folders should be created under /content/dataset/cdsv5/ as following:
- train: the folder containing the training dataset.
- train_aug: the folder containing the augmented training dataset.
- val: the folder containing the validation dataset.

In [None]:
#unzip dataset
!mkdir /content/dataset
!apt-get install unzip
!unzip /content/dataset.zip -d /content/dataset/
!rm -R  /content/dataset.zip #save some space

Now that our dataset is ready, let us take a quick look on the differences between the class distribution in the original and the balanced training sets. To do so, the next code block will count all images in every class in the training and augmented datasets. A bar plot will be used to display the results.

In [None]:
import numpy as np
import pandas as pd
import os
import shutil
import cv2
import matplotlib.pyplot as plt
import seaborn as sns

train_dir = '/content/dataset/cdsv5/train/'
train_classes = [path for path in os.listdir(train_dir)]
train_imgs = dict([(ID, os.listdir(os.path.join(train_dir, ID))) for ID in train_classes])
train_classes_count = []
for trainClass in train_classes:
  train_classes_count.append(len(train_imgs[trainClass]))

plt.figure(figsize=(15, 10))
g = sns.barplot(x=train_classes, y=train_classes_count)
g.set_xticklabels(labels=train_classes, rotation=30, ha='right')

We can see that the training set is highly unbalanced. Let us check the distribution in the balanced folder by running the next code block.

In [None]:
train_dir = '/content/dataset/cdsv5/train_aug/'
train_classes = [path for path in os.listdir(train_dir)]
train_imgs = dict([(ID, os.listdir(os.path.join(train_dir, ID))) for ID in train_classes])
train_classes_count = []
for trainClass in train_classes:
  train_classes_count.append(len(train_imgs[trainClass]))

plt.figure(figsize=(15, 10))
g = sns.barplot(x=train_classes, y=train_classes_count)
g.set_xticklabels(labels=train_classes, rotation=30, ha='right')

We can see that the data is balanced.

Now that we have everything set, we will clone our implementation of the 'this looks like that' algorithm hosted on our [Github repository](https://github.com/HarfoucheLab/A-Primer-on-AI-in-Plant-Digital-Phenomics).

It is worth mentioning that our version of the code was modified to support distributed computing and thus can run on a cluster with multiple nodes and multiple GPUs.

In [None]:
!git clone https://github.com/HarfoucheLab/A-Primer-on-AI-in-Plant-Digital-Phenomics.git

The code should now be located under /content/A-Primer-on-AI-in-Plant-Digital-Phenomics/py.



Next, we identify some settings that indicate the architecutre of our network and include other parameters, such as the dataset relevant directories (training and validation sets), the batch size, the number of workers, learning rates, etc.

Feel free to change those parameters in accordance to your needs and hardware.

In [None]:
settings = """base_architecture = 'densenet161'
img_size = 224
prototype_shape = (2000, 128, 1, 1)
num_classes = 5
prototype_activation_function = 'log'
add_on_layers_type = 'regular'

experiment_run = '001'

data_path = '/content/dataset/cdsv5/'
train_dir = data_path + 'train_aug/'
test_dir = data_path + 'val/'
train_push_dir = data_path + 'train/'
train_batch_size = 40 #80
test_batch_size = 40
train_push_batch_size = 64

num_workers=3
min_saving_accuracy=0.05

joint_optimizer_lrs = {'features': 1e-4,
                       'add_on_layers': 3e-3,
                       'prototype_vectors': 3e-3}
joint_lr_step_size = 5

warm_optimizer_lrs = {'add_on_layers': 3e-3,
                      'prototype_vectors': 3e-3}

last_layer_optimizer_lr = 1e-4

coefs = {
    'crs_ent': 1,
    'clst': 0.8,
    'sep': -0.08,
    'l1': 1e-4,
}

num_train_epochs = 1000
num_warm_epochs = 5

push_start = 10
push_epochs = [i for i in range(num_train_epochs) if i % 10 == 0] """

Now that the settings are defined, we write them to a file called settings.py so that 'this looks like that' can locate and read those settings.

In [None]:
text_file = open("/content/A-Primer-on-AI-in-Plant-Digital-Phenomics/py/settings.py", "w")
n = text_file.write(settings)
text_file.close()

Now that we are all set, we're ready to start the training process!
As this is a data intensive step and time-consuming, we have included our pretrained model to this notebook and thus, you can skip the next code block.

The parameters indicate the number of nodes and gpus which we would like to train our model on.

In [None]:
%cd /content/A-Primer-on-AI-in-Plant-Digital-Phenomics/py/
!python3 mainDistributed.py --nodes 1 --gpus 1 --nr 0

The next code block will download and extract our pretrained model hosted on Google Drive. The downloaded archive will be extracted to /content/pretrained.

In [None]:
%cd /content/
file_id = '12ugCaMfPdylDPPmfqzoOMWtB55k0L9tL'
destination = '/content/pretrained.zip'
download_file_from_google_drive(file_id, destination)

In [None]:
!mkdir /content/pretrained
!unzip /content/pretrained.zip -d /content/pretrained/
!rm -R /content/pretrained.zip

Now that we have our model ready, it's time to test its performance!
The first step is to download the testing dataset from the same Google Drive.

In [None]:
#download test set:
file_id = '1Ruy2At0G3oLlA1Gb9gz1-aMpcfJ6653B'
destination = '/content/dataset_test.zip'
download_file_from_google_drive(file_id, destination)


In [None]:
!unzip /content/dataset_test.zip -d /content/dataset/cdsv5/
!rm -R /content/dataset_test.zip

The testing set is now located under /content/dataset/cdsv5/test/

The next code block will attempt to classify every image in the test dataset, and calculate the overall accuracy of the model, along with its confusion matrix.

In [None]:
%cd /content/A-Primer-on-AI-in-Plant-Digital-Phenomics/py/
!python3 RunTestAndConfusionMatrix.py
%cd /content/

Let us display the normalized confusion matrix to get a better overview on the model performance.

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

img = mpimg.imread('/content/confusion_matrix.png')
plt.figure(figsize=(10,10))
plt.imshow(img)
plt.show()

Now that we are satisfied with the model performance, we can generate the explanation of a specific prediction using the following code:

In [None]:
!python3 /content/A-Primer-on-AI-in-Plant-Digital-Phenomics/py/local_analysis.py -modeldir /content/pretrained/ -model 240_12push0.8884.pth -imgdir /content/dataset/cdsv5/test/1/ -img 931787054.jpg -imgclass 1

The next code will plot the most activated region in the image, which the algorithm based its prediction on.

In [None]:
img = mpimg.imread('/content/dataset/cdsv5/test/1/pretrained/240_12push0.8884.pth/top-1_class_prototypes/prototype_activation_map_by_top-1_prototype.png')
plt.figure(figsize=(10,10))
plt.imshow(img)
plt.show()

Next, we display some prototypes that the activation region resembled to, and thus, we can interpret the prediction.

In [None]:
img = mpimg.imread('/content/dataset/cdsv5/test/1/pretrained/240_12push0.8884.pth/top-1_class_prototypes/top-1_activated_prototype.png')
plt.figure(figsize=(3,3))
plt.imshow(img)
plt.show()
img = mpimg.imread('/content/dataset/cdsv5/test/1/pretrained/240_12push0.8884.pth/top-1_class_prototypes/top-2_activated_prototype.png')
plt.figure(figsize=(3,3))
plt.imshow(img)
plt.show()
img = mpimg.imread('/content/dataset/cdsv5/test/1/pretrained/240_12push0.8884.pth/top-1_class_prototypes/top-17_activated_prototype.png')
plt.figure(figsize=(3,3))
plt.imshow(img)
plt.show()

Let us generate the explanation for another example!

In [None]:
#Explanation
!python3 /content/A-Primer-on-AI-in-Plant-Digital-Phenomics/py/local_analysis.py -modeldir /content/pretrained/ -model 240_12push0.8884.pth -imgdir /content/dataset/cdsv5/test/1/ -img 1074333151.jpg -imgclass 1

In [None]:
img = mpimg.imread('/content/dataset/cdsv5/test/1/pretrained/240_12push0.8884.pth/top-1_class_prototypes/prototype_activation_map_by_top-1_prototype.png')
plt.figure(figsize=(10,10))
plt.imshow(img)
plt.show()

In [None]:
img = mpimg.imread('/content/dataset/cdsv5/test/1/pretrained/240_12push0.8884.pth/top-1_class_prototypes/top-1_activated_prototype.png')
plt.figure(figsize=(3,3))
plt.imshow(img)
plt.show()
img = mpimg.imread('/content/dataset/cdsv5/test/1/pretrained/240_12push0.8884.pth/top-1_class_prototypes/top-2_activated_prototype.png')
plt.figure(figsize=(3,3))
plt.imshow(img)
plt.show()
img = mpimg.imread('/content/dataset/cdsv5/test/1/pretrained/240_12push0.8884.pth/top-1_class_prototypes/top-17_activated_prototype.png')
plt.figure(figsize=(3,3))
plt.imshow(img)
plt.show()