<a href="https://colab.research.google.com/github/HadarRosenwald/severity-detection/blob/main/severity_detection_with_cv.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<a href="https://colab.research.google.com/github/HadarRosenwald/severity-detection/blob/main/TabularModel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip -q install torchxrayvision
!pip -q install image_tabular

In [2]:
# torch.__version__

In [3]:
import matplotlib.pyplot as plt
import os
import shutil
import torch
import torchxrayvision as xrv
import numpy as np
import pandas as pd
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models
import fastai
from fastai.tabular.data import TabularList
import image_tabular as imtab
from sklearn.model_selection import train_test_split, KFold
from sklearn.utils import class_weight

import math


from fastai.vision import *
from fastai.tabular import *
from image_tabular.core import *
from image_tabular.dataset import *
from image_tabular.model import *
from image_tabular.metric import *

In [4]:
!rm -rf covid-chestxray-dataset

In [5]:
!git clone https://github.com/ieee8023/covid-chestxray-dataset
d = xrv.datasets.COVID19_Dataset(imgpath="covid-chestxray-dataset/images/",csvpath="covid-chestxray-dataset/metadata.csv")

Cloning into 'covid-chestxray-dataset'...
remote: Enumerating objects: 3641, done.[K
remote: Total 3641 (delta 0), reused 0 (delta 0), pack-reused 3641[K
Receiving objects: 100% (3641/3641), 632.96 MiB | 36.09 MiB/s, done.
Resolving deltas: 100% (1450/1450), done.
Checking out files: 100% (1174/1174), done.


# **Configurations**

In [6]:
# data split
test_pct=0.2
valid_pct=0.2

# models param
##TAB
tab_out_sz = 18 # output size that will be concatenated with the CNN, same output size
dropout_prob_tab = 0.2
tab_layers = [100, 200] # the sizes of the hidden fully connected layers between the input (after embedding) and before the classification layer. The number of hidden layers is determined by the length of the list.
# TODO: tune this, for our data size, [100,200] seems highly overfitted. according to the rule of thumb bellow, we should have 0
# https://forums.fast.ai/t/an-attempt-to-find-the-right-hidden-layer-size-for-your-tabular-learner/45714
# len_train = 110; alpha = 2; n_input=8; n_output=18; io=n_input+n_output; numHiddenLayers=2
# tab_layers = [(len_train//(alpha*(io)))//numHiddenLayers]*numHiddenLayers

##CNN
cnn_out_sz = 18 # following xrv.models.DenseNet output layer
image_size = 224 # to fit xrv.models.DenseNet
# image_resize_method = ResizeMethod.SQUISH
# image_convert_mode = 'L' #for greyscale
image_convert_mode = 'RGB'

##CNN_TAB
cnn_tabular_dropout_prob = 0.2
cnn_tabular_layers = [cnn_out_sz + tab_out_sz, 32]
cnn_tabular_out_sz = 6 #number of classes
batch_size = 64
n_epoch = 2


# misc
seed=42
data_path = Path("./covid-chestxray-dataset/")
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# **Creating label**

In [7]:
def generate_label(x):
    # no data
    if np.all(pd.isna([x.survival, x.intubated, x.went_icu, x.needed_supplemental_O2])):
        return np.NaN

    # didn't survive
    if x.survival=='N':
        return '5'

    #either survived or survival is unknown
    if x.intubated == 'Y':
        return '4'
    if x.went_icu == 'Y' and x.needed_supplemental_O2 == 'Y':
        return '3'
    if x.went_icu == 'Y':
        return '2'
    if x.needed_supplemental_O2 == 'Y':
        return '1'
    return '0'

In [8]:
metadata = d.csv
metadata['severity_class']=metadata.apply(generate_label, axis=1)

# **Tabular Data handling**

## **Avoiding confounders**

In [9]:
filtered_metadata = metadata.loc[(metadata.view!="APS") & (metadata.offset>=0) & (metadata.offset<=8) & (metadata.intubation_present != 'Y') & (metadata.in_icu != 'Y')]
filtered_metadata = filtered_metadata[['index','patientid','sex','age','RT_PCR_positive','temperature','pO2_saturation', 'leukocyte_count', 'neutrophil_count', 'lymphocyte_count', 'severity_class' ,'filename']]

## **Handling missing data**

In [10]:
filtered_metadata = filtered_metadata.dropna(subset=['severity_class'], how='any')

In [11]:
filtered_metadata.head()

Unnamed: 0,index,patientid,sex,age,RT_PCR_positive,temperature,pO2_saturation,leukocyte_count,neutrophil_count,lymphocyte_count,severity_class,filename
0,0,2,M,65.0,Y,,,,,,1,auntminnie-a-2020_01_28_23_51_6665_2020_01_28_...
1,1,2,M,65.0,Y,,,,,,1,auntminnie-b-2020_01_28_23_51_6665_2020_01_28_...
2,2,2,M,65.0,Y,,,,,,1,auntminnie-c-2020_01_28_23_51_6665_2020_01_28_...
3,3,2,M,65.0,Y,,,,,,1,auntminnie-d-2020_01_28_23_51_6665_2020_01_28_...
4,4,4,F,52.0,Y,,,,,,0,nejmc2001573_f1a.jpeg


## **Train test and validation split**

In [12]:
# kf = KFold(n_splits=9)
# for train, test in kf.split(filtered_metadata):
#   print("train", filtered_metadata.iloc[train])
#   print("-------------")
#   print("test", filtered_metadata.iloc[test])
#   print("-------------")

In [13]:
# train_df, test_df = train_test_split(filtered_metadata, test_size=test_pct)

# # idx for validation, shared by image and tabular data
# val_idx = get_valid_index(train_df, valid_pct=valid_pct, seed=seed)

## **Preparing fastai LabelLists**

### **Features**

In [14]:
# Features with categorical values
cat_names = ['sex', 'RT_PCR_positive']

# Features with continious values
cont_names = ['age', 'temperature', 'pO2_saturation', 'leukocyte_count', 'neutrophil_count', 'lymphocyte_count']

### **Labels**

In [15]:
# Target
dep_var = ['severity_class']
procs = [FillMissing, Categorify, Normalize]

### **Ensambling the tabular dataset**

In [16]:
def create_tab_data(train_df, test_df): 
# FillMissing fills the missing values in continuous columns. Catagorical 
# variables are left untouched (their missing value will be replaced by code 0 
# in the TabularDataBunch). The fill stratagy is MEDIAN; nans are replaced by 
# the median value of the column
  val_idx = get_valid_index(train_df, valid_pct=valid_pct, seed=seed)
  tab_data = (TabularList.from_df(train_df, path=data_path, cat_names=cat_names, cont_names=cont_names, procs=procs)
            .split_by_idx(val_idx)
            .label_from_df(cols=dep_var))

  test_tab_data = TabularList.from_df(test_df, cat_names=cat_names, cont_names=cont_names, processor=tab_data.train.x.processor)
  tab_data = tab_data.add_test(test_tab_data)

  return tab_data

In [17]:
# Iterating over tab_data items, printing class name and items len, using 
# `show_some()` to return the representation of the first 5 elements in `items`.
# tab_data
# Note that the Test LabelList has no labels. Like in Kaggle competitions.

### one example from the tabular data

In [18]:
# print(f"features: {tab_data.train[8][0]}")
# print(f"class: {tab_data.train[8][1]}")

# **Image Data handling**

## **Creating test and train image folders**

In [19]:
def create_sub_image_folder(dataframe, imgs_type, sub_dir_by_lable : bool):
  filtered_imgpath = d.imgpath + '/' + imgs_type + '/'
  filtered_filenames = dataframe.filename + ';' + dataframe.severity_class
  if not os.path.exists(filtered_imgpath):
      os.mkdir(filtered_imgpath)
  if sub_dir_by_lable:
    for severity_class in list(dataframe.severity_class):
        if not os.path.exists(filtered_imgpath + f'/{severity_class}'):
            os.mkdir(filtered_imgpath + f'/{severity_class}')
  for file_name_label in filtered_filenames:
      file_name, label = file_name_label.split(';')
      src = d.imgpath + file_name
      dst = filtered_imgpath + label + '/' + file_name if sub_dir_by_lable else filtered_imgpath + file_name
      if not os.path.exists(dst):
          shutil.copyfile(src, dst)

In [20]:
def create_image_folders(train_df, test_df):
  create_sub_image_folder(train_df, 'train', False)
  create_sub_image_folder(test_df, 'test', False)
  
filtered_img_base_path = d.imgpath + '/'

## **Preparing fastai LabelLists**

In [21]:
# tfms = get_transforms(xtra_tfms=crop_pad(size=image_size))

def create_image_data(train_df, test_df):
  val_idx = get_valid_index(train_df, valid_pct=valid_pct, seed=seed)
  image_data = (ImageList.from_df(train_df, path=d.imgpath, cols="filename", 
                                folder="train", convert_mode = image_convert_mode)
                        .split_by_idx(val_idx)
                        .label_from_df(cols=dep_var)
                        # ).transform(tfms, size=image_size, resize_method=image_resize_method)
                        # ).transform(tfms)
                        ).transform([crop_pad(), crop_pad()], size=image_size)
            
  test_image_data = ImageList.from_df(test_df, path=d.imgpath, cols="filename",
                                    folder="test", convert_mode = image_convert_mode)
  image_data = image_data.add_test(test_image_data)
  return image_data

### one example from the image data

In [22]:
# print(f"Class: {image_data.train[8][1]}")
# image_data.train[8][0]

# **Integrate image and tabular data**

In [23]:
def create_databunch(image_data, tab_data):
  integrate_train, integrate_val, integrate_test = get_imagetabdatasets(image_data, tab_data)

  return DataBunch.create(integrate_train, integrate_val, integrate_test, path=data_path, bs=batch_size)

In [24]:
# TODO do we need this? I think it ruins the images, making them 3 channels again
# # image normalization with imagenet_stats
# db.norm, db.denorm = normalize_funcs_image_tab(*imagenet_stats)
# db.add_tfm(db.norm)

In [25]:
# x, y = next(iter(db.train_dl))

# print(f"x holds {len(x)} items")
# print(f"first item - batch of images ({x[0].shape})")
# print(f"second item - holds both categorial ({x[1][0].shape}) and continuous ({x[1][1].shape}) tabular data")

# print(f"y is the targets ({y.shape})")

In [26]:
# x, y = next(iter(db.train_dl))
# y

# **The models**

### **The tabular model**

In [27]:
# embedding sizes of categorical data. Return the default embedding sizes suitable for this data. Using the rule of thumb - min(600, round(1.6 * n_cat**0.56))
# TODO think if we want to replace that with one-hot, since they are binary
def create_tab_model(tab_data):
  emb_szs = tab_data.train.get_emb_szs()
  print(f"emb_szs: {emb_szs}")

  # The tabular model
  tabular_model = TabularModel(emb_szs=emb_szs, n_cont = len(cont_names), out_sz=tab_out_sz, layers=tab_layers, ps=dropout_prob_tab)
  return tabular_model

## **The CNN model** 
Using pretrained xrv.models.DenseNet for transfer learning

In [28]:
# cnn_model = xrv.models.DenseNet(weights="densenet121-res224-all")
cnn_model = models.densenet121(pretrained=True)
cnn_model.classifier

cnn_model.classifier = torch.nn.Linear(1024, 18)

In [29]:
# Don't backprop model parameters!
for param in cnn_model.parameters():
    param.requires_grad = False

In [30]:
#TODO check if this is necessary!

# # New classifier
# cnn_model.classifier = nn.Sequential(nn.Linear(1024, 512),
#                           nn.ReLU(),
#                           nn.Dropout(0.2),
#                           nn.Linear(512,256),
#                           nn.ReLU(),
#                           nn.Dropout(0.2),
#                           nn.Linear(256,18))

# #criterion = nn.NLLLoss()

# # Training only the classifier parameters, cnn_model parameters remains unchanged
# optimizer = optim.RMSprop(cnn_model.classifier.parameters(), lr=0.004)

## **The integrated CNN Tabular model** 

In [31]:
def create_img_tab_model(cnn_model, tabular_model):
  return CNNTabularModel(cnn_model,
                                  tabular_model,
                                  layers = cnn_tabular_layers,
                                  ps=cnn_tabular_dropout_prob,
                                  out_sz=cnn_tabular_out_sz).to(device)

In [32]:
# img_tabular_model

In [33]:
# check model output dimension, should be (batch_size, 6)
# img_tabular_model(*x).shape

In [34]:
def create_loss_func(train_df):
  print("Class distribution of train set - unbalanced:")
  print(train_df.severity_class.value_counts().sort_index())

  weights = class_weight.compute_class_weight('balanced', 
                                                  np.unique(train_df.severity_class),
                                                  train_df.severity_class)

  print(f"\nThe weights (calculated with respect to label distribution of train set): {np.round(weights,2)}")
  return CrossEntropyFlat(weight=torch.FloatTensor(weights).to(device))

In [35]:
# adjust loss function weight because the dataset is unbalanced


In [36]:
# def accuracy_multi(preds, targs, thresh=0.5):
#     print(f"preds shape: {preds.shape} \n{preds}")
#     print(f"targs shape: {targs.shape} \n{targs}")
#     return ((preds>thresh)==targs).float().mean()

In [37]:
def train_model():
  kf = KFold(n_splits=9)
  for train, test in kf.split(filtered_metadata):
    train_df = filtered_metadata.iloc[train]
    test_df = filtered_metadata.iloc[test]
    tab_data = create_tab_data(train_df, test_df)
    create_image_folders(train_df, test_df)
    image_data = create_image_data(train_df, test_df)
    db = create_databunch(image_data, tab_data)
    tabular_model = create_tab_model(tab_data)
    img_tabular_model = create_img_tab_model(cnn_model, tabular_model)
    loss_func = create_loss_func(train_df)
    learn = Learner(db, img_tabular_model, metrics=[accuracy, Recall(), Precision(), error_rate], loss_func=loss_func)
    learn.fit_one_cycle(n_epoch, 1e-4)

In [None]:
train_model()

emb_szs: [(3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3)]
Class distribution of train set - unbalanced:
0    71
1     3
2    13
3     3
4    22
5    10
Name: severity_class, dtype: int64

The weights (calculated with respect to label distribution of train set): [0.29 6.78 1.56 6.78 0.92 2.03]


epoch,train_loss,valid_loss,accuracy,recall,precision,error_rate,time
0,1.958286,1.807427,0.0,,,1.0,00:38
1,1.907005,1.807324,0.0,,,1.0,00:39


emb_szs: [(3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3)]
Class distribution of train set - unbalanced:
0    68
1     9
2    10
3     3
4    21
5    11
Name: severity_class, dtype: int64

The weights (calculated with respect to label distribution of train set): [0.3  2.26 2.03 6.78 0.97 1.85]


epoch,train_loss,valid_loss,accuracy,recall,precision,error_rate,time
0,1.923203,1.807605,0.041667,0.166667,,0.958333,00:39
1,1.87084,1.805138,0.041667,0.166667,,0.958333,00:40


emb_szs: [(3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3)]
Class distribution of train set - unbalanced:
0    74
1     9
2    13
3     3
4    15
5     8
Name: severity_class, dtype: int64

The weights (calculated with respect to label distribution of train set): [0.27 2.26 1.56 6.78 1.36 2.54]


epoch,train_loss,valid_loss,accuracy,recall,precision,error_rate,time
0,2.133552,1.807859,0.375,0.115385,,0.625,00:38
1,1.963368,1.808281,0.375,0.269231,,0.625,00:39


emb_szs: [(3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3)]
Class distribution of train set - unbalanced:
0    78
1     9
2     7
3     3
4    17
5     9
Name: severity_class, dtype: int64

The weights (calculated with respect to label distribution of train set): [0.26 2.28 2.93 6.83 1.21 2.28]


epoch,train_loss,valid_loss,accuracy,recall,precision,error_rate,time
0,1.817336,1.826887,0.166667,,,0.833333,00:37
1,1.889769,1.828909,0.166667,,,0.833333,00:38


emb_szs: [(3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3)]
Class distribution of train set - unbalanced:
0    67
1     9
2    12
3     3
4    21
5    11
Name: severity_class, dtype: int64

The weights (calculated with respect to label distribution of train set): [0.31 2.28 1.71 6.83 0.98 1.86]


epoch,train_loss,valid_loss,accuracy,recall,precision,error_rate,time
0,1.99113,1.781175,0.125,,,0.875,00:37
1,1.955652,1.775029,0.166667,,,0.833333,00:37


emb_szs: [(3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3)]
Class distribution of train set - unbalanced:
0    67
1     9
2    12
3     2
4    22
5    11
Name: severity_class, dtype: int64

The weights (calculated with respect to label distribution of train set): [ 0.31  2.28  1.71 10.25  0.93  1.86]


epoch,train_loss,valid_loss,accuracy,recall,precision,error_rate,time
0,1.930975,1.806851,0.083333,,,0.916667,00:39
1,1.967312,1.799544,0.083333,,,0.916667,00:40


emb_szs: [(3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3)]
Class distribution of train set - unbalanced:
0    73
1     7
2    12
3     2
4    18
5    11
Name: severity_class, dtype: int64

The weights (calculated with respect to label distribution of train set): [ 0.28  2.93  1.71 10.25  1.14  1.86]


epoch,train_loss,valid_loss,accuracy,recall,precision,error_rate,time
0,1.800573,1.793766,0.458333,,,0.541667,00:41
1,1.860122,1.798596,0.416667,,,0.583333,00:38


emb_szs: [(3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3)]
Class distribution of train set - unbalanced:
0    71
1     8
2    13
3     3
4    21
5     7
Name: severity_class, dtype: int64

The weights (calculated with respect to label distribution of train set): [0.29 2.56 1.58 6.83 0.98 2.93]


epoch,train_loss,valid_loss,accuracy,recall,precision,error_rate,time
0,1.9192,1.764069,0.125,,,0.875,00:40
1,1.941746,1.764558,0.083333,,,0.916667,00:40


emb_szs: [(3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3)]
Class distribution of train set - unbalanced:
0    71
1     9
2    12
3     2
4    19
5    10
Name: severity_class, dtype: int64

The weights (calculated with respect to label distribution of train set): [ 0.29  2.28  1.71 10.25  1.08  2.05]


epoch,train_loss,valid_loss,accuracy,recall,precision,error_rate,time
0,1.910973,1.819939,0.0,,,1.0,00:41


In [38]:
# learn = Learner(db, img_tabular_model, metrics=[accuracy, ROCAUC()], loss_func=torch.nn.NLLLoss)
# learn = Learner(db, img_tabular_model, metrics=[accuracy, ROCAUC()], loss_func=loss_func) # <- TODO check why ROCAUC doesnt work


#TODO! check optimizer

In [39]:
learn.model

NameError: ignored

In [None]:
#TODO - check if we need this and if so, make this work

# # organize layer groups in order to use differential learning rates provided by fastai
# # the first two layer groups are earlier layers of resnet
# # the last layer group consists of the fully connected layers of cnn model, tabular model,
# # and final fully connected layers for the concatenated data
# learn.layer_groups = [nn.Sequential(*flatten_model(cnn_model.layer_groups[0])),
#                       nn.Sequential(*flatten_model(cnn_model.layer_groups[1])),
#                       nn.Sequential(*(flatten_model(cnn_model.layer_groups[2]) +
#                                       flatten_model(integrate_model.tabular_model) +
#                                       flatten_model(integrate_model.layers)))]

# **Training**

In [None]:
# TODO currently doesnt work. maybe because we don't have layer groups. check
# # find learning rate to train the last layer group first 
# learn.freeze()
# learn.lr_find()
# learn.recorder.plot()

In [None]:
# train


In [None]:
# # unfreeze all layer groups to train the entire model using differential learning rates
# learn.unfreeze()
# learn.fit_one_cycle(n_epoch, slice(1e-6, 1e-4))

# **Prediction**

In [None]:
# make predictions for the test set
preds, y = learn.get_preds(DatasetType.Test)

In [None]:
preds

In [None]:
y

In [None]:
print(y.shape)
print(preds.shape)

In [None]:
# for test_images, test_labels in trainloader:  
#     plt.imshow(test_images[0][0])
#     break

# **Explainability**

In [None]:
# saliency_transform = transforms.Compose([transforms.Grayscale(num_output_channels=1),
#                                        transforms.RandomRotation(30),
#                                        transforms.Resize((224, 224)),
#                                        transforms.RandomHorizontalFlip(),
#                                        transforms.ToTensor(),
#                                        transforms.Normalize((0.5), (0.5)),
#                                        transforms.Lambda(lambda x: x[None])])
# def create_saliency_map(image_filename='000001-17.jpg'):
#   image = PIL.Image.open(f'/content/covid-chestxray-dataset/images/{image_filename}')
#   image = saliency_transform(image)
#   image = image.reshape(1, 1, image_size, image_size)

#   image.requires_grad_()
#   output = cnn_model(image)

#   # Catch the output
#   output_idx = output.argmax()
#   output_max = output[0, output_idx]

#   # Do backpropagation to get the derivative of the output based on the image
#   output_max.backward()

#   saliency, _ = torch.max(image.grad.data.abs(), dim=1) 
#   saliency = saliency.reshape(image_size, image_size)

#   # Reshape the image
#   image = image.reshape(image_size, image_size)

#   # Visualize the image and the saliency map
#   fig, ax = plt.subplots(1, 2)
#   ax[0].imshow(image.cpu().detach().numpy(), cmap="gray")
#   ax[0].axis('off')
#   ax[1].imshow(saliency.cpu(), cmap='hot')
#   ax[1].axis('off')
#   plt.tight_layout()
#   fig.suptitle('The Image and Its Saliency Map')
#   plt.show()




# Preprocess the image
def preprocess(image, size=224):
    transform = transforms.Compose([
        transforms.Resize((size,size)),
        transforms.ToTensor(),
        # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        transforms.Lambda(lambda x: x[None]),
        # transforms.Lambda(lambda x: x.repeat(3, 1, 1) ),
    ])
    return transform(image)

'''
    Y = (X - μ)/(σ) => Y ~ Distribution(0,1) if X ~ Distribution(μ,σ)
    => Y/(1/σ) follows Distribution(0,σ)
    => (Y/(1/σ) - (-μ))/1 is actually X and hence follows Distribution(μ,σ)
'''

def show_img(PIL_IMG):
    plt.imshow(np.asarray(PIL_IMG))


def create_saliency_map(image_filename='000001-17.jpg'):
    img = PIL.Image.open(f'/content/covid-chestxray-dataset/images/{image_filename}', ).convert(image_convert_mode)
    # preprocess the image
    X = preprocess(img)
    

    # we would run the model in evaluation mode
    cnn_model.eval()

    # we need to find the gradient with respect to the input image, so we need to call requires_grad_ on it
    X.requires_grad_()

    '''
    forward pass through the model to get the scores, note that VGG-19 model doesn't perform softmax at the end
    and we also don't need softmax, we need scores, so that's perfect for us.
    '''
    scores = cnn_model(X)
    

    # Get the index corresponding to the maximum score and the maximum score itself.
    score_max_index = scores.argmax()
    score_max = scores[0,score_max_index]

    '''
    backward function on score_max performs the backward pass in the computation graph and calculates the gradient of 
    score_max with respect to nodes in the computation graph
    '''
    score_max.backward()

    '''
    Saliency would be the gradient with respect to the input image now. But note that the input image has 3 channels,
    R, G and B. To derive a single class saliency value for each pixel (i, j),  we take the maximum magnitude
    across all colour channels.
    '''
    saliency, _ = torch.max(X.grad.data.abs(),dim=1)

    # # code to plot the saliency map as a heatmap
    # plt.imshow(saliency[0], cmap=plt.cm.hot)
    # plt.axis('off')
    # plt.show()


    transform_image_for_print = transforms.Compose([
                                    transforms.Grayscale(num_output_channels=1)
    ])
    img_for_visual = transform_image_for_print(X)


    # Visualize the image and the saliency map
    fig, ax = plt.subplots(1, 3)
    ax[0].imshow(np.squeeze(img_for_visual.cpu().detach().numpy()), cmap="gray")
    ax[0].axis('off')
    ax[1].imshow(saliency[0].cpu(), cmap='hot')
    ax[1].axis('off')
    ax[2].imshow(np.squeeze(img_for_visual.cpu().detach().numpy()), cmap="gray")
    ax[2].imshow(saliency[0].cpu(), cmap='hot', alpha=0.5)
    ax[2].axis('off')
    plt.tight_layout()
    fig.suptitle('The Image and Its Saliency Map')
    plt.show()



In [None]:
create_saliency_map()