### GradCam Analysis

### Imports

In [206]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [207]:
import os
import cv2
import numpy as np
import pandas as pd
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tensorboardX import SummaryWriter

from sklearn.model_selection import train_test_split

import albumentations as albu
from albumentations.pytorch.transforms import ToTensorV2

from torchvision.models import resnet
from tqdm import tqdm 
import torchvision
import torchvision.transforms as T

from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from matplotlib import cm
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import PIL

import catalyst.dl as dl
from catalyst.contrib.nn import (
    ArcFace,
    CosFace,
    AdaCos,
    SubCenterArcFace,
    CurricularFace,
    ArcMarginProduct,
)

# gradcam: https://github.com/jacobgil/pytorch-grad-cam
from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from torchvision.models import resnet50

# # all model labels per input
from src.dir_paths.figures_paths import get_figures_path

use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
print("Device: ",device)

if use_cuda:
    print('__CUDNN VERSION:', torch.backends.cudnn.version())
    print('__Number CUDA Devices:', torch.cuda.device_count())
    print('__CUDA Device Name:',torch.cuda.get_device_name(0))
    print('__CUDA Device Total Memory [GB]:',torch.cuda.get_device_properties(0).total_memory/1e9)

Device:  cpu


## Define species label dicts

In [208]:
## EratoNet
index2species_erato = {
0: 'Heliconius erato ssp. phyllis',
1: 'Heliconius erato ssp. dignus',
2: 'Heliconius erato ssp. lativitta',
3: 'Heliconius erato ssp. etylus',
4: 'Heliconius erato ssp. cyrbia',
5: 'Heliconius erato ssp. amalfreda', #meriana is the melpomene mimic
6: 'Heliconius erato ssp. venus',
7: 'Heliconius erato ssp. hydara',
8: 'Heliconius erato ssp. petiverana',
9: 'Heliconius erato ssp. notabilis'}

index2species_melpomene = {
0: 'Heliconius melpomene ssp. nanna',
1: 'Heliconius melpomene ssp. bellula',
2: 'Heliconius melpomene ssp. malleti',
3: 'Heliconius melpomene ssp. ecuadorensis',
4: 'Heliconius melpomene ssp. cythera',
5: 'Heliconius melpomene ssp. meriana', #Missing from index2species in dorsal images
6: 'Heliconius melpomene ssp. vulcanus',
7: 'Heliconius melpomene ssp. melpomene',
8: 'Heliconius melpomene ssp. rosina',
9: 'Heliconius melpomene ssp. plesseni',
}

index2species = {0: 'Heliconius melpomene ssp. ecuadorensis',
 1: 'Heliconius melpomene ssp. nanna',
 2: 'Heliconius melpomene ssp. rosina',
 3: 'Heliconius erato ssp. phyllis',
 4: 'Heliconius erato ssp. dignus',
 5: 'Heliconius melpomene ssp. bellula',
 6: 'Heliconius erato ssp. lativitta',
 7: 'Heliconius melpomene ssp. vulcanus',
 8: 'Heliconius erato ssp. etylus',
 9: 'Heliconius melpomene ssp. plesseni',
 10: 'Heliconius melpomene ssp. malleti',
 11: 'Heliconius erato ssp. cyrbia',
 12: 'Heliconius erato ssp. amalfreda',
 13: 'Heliconius erato ssp. venus',
 14: 'Heliconius erato ssp. hydara',
 15: 'Heliconius melpomene ssp. meriana', 
 16: 'Heliconius erato ssp. petiverana',
 17: 'Heliconius melpomene ssp. cythera',
 18: 'Heliconius erato ssp. notabilis',
 19: 'Heliconius melpomene ssp. melpomene'}

#indices in the dictionaries above are assigned such that the mimic pair will share the same numerical label
species2index_erato = {v:k for k,v in index2species_erato.items()}
species2index_melpomene = {v:k for k,v in index2species_melpomene.items()}
species2index = {v:k for k,v in index2species.items()}

n_classes = len(index2species_erato)
num_classes = n_classes

## Select Acuity and Model

In [209]:
## ----- uncomment the name of the dataset you want to work with ------
# acuity = 'no_acuity_white_background'
# acuity = 'no_acuity' #Done, Done, done
# acuity = 'heliconius_male_behavioral_acuity' #Done, done, done
# acuity = 'heliconius_female_behavioral_acuity' #Done, done, done
# acuity = 'heliconius_male_morphological_acuity' #Done, done, done
# acuity = 'heliconius_female_morphological_acuity' #Done, done, done
acuity = 'kingfisher_acuity' #Done, done

## ----- uncomment model name ------
# model_name = 'RegularSpeciesClassification'
# model_name = 'EratoNet'
model_name = 'MelpomeneNet'

print('Acuity: ', acuity)
print('Model: ', model_name)

erato_net = True if model_name == 'EratoNet' else False
melpomene_net = True if model_name == 'MelpomeneNet' else False

print(erato_net, melpomene_net)

Acuity:  kingfisher_acuity
Model:  MelpomeneNet
False True


## Load Data

In [210]:
#read in datasets
from src.dir_paths.dataset_paths import get_image_dataset_path
from src.dir_paths.train_val_test_paths import get_split_csvs

#path to images with the applied selected acuity
dataset_path = get_image_dataset_path(acuity)

#path to train/val/test split csvs 
main = get_split_csvs(acuity, model_name)

if model_name == 'RegularSpeciesClassification':
    train_df = pd.read_csv(main + 'train.csv')
    val_df = pd.read_csv(main + 'val.csv')
    test_df = pd.read_csv(main + 'test.csv')

    print(train_df.shape) ##5512, 2
    print(val_df.shape) #613, 2
    print(test_df.shape) #1532, 2

elif erato_net or melpomene_net:
    train_df = pd.read_csv(main + 'train.csv')
    val_df = pd.read_csv(main + 'val.csv')
    test_df_erato = pd.read_csv(main + 'test_erato.csv')
    test_df_melpomene = pd.read_csv(main + 'test_melpomene.csv')

    print(train_df.shape)
    print(val_df.shape) 
    print(test_df_erato.shape)
    print(test_df_melpomene.shape) 

#EratoNet (no acuity)
# (1728, 2)
# (241, 2)
# (433, 2)
# (1420, 2)

print('Image Dataset:', dataset_path)
print('Data Split: ', main)

(1020, 2)
(142, 2)
(2396, 2)
(256, 2)
Image Dataset: /fs/ess/PAS2136/Butterfly/Model_Mimic/model_mimic_images_256_256_removed_background_kingfisher_acuity_dorsal/
Data Split:  /fs/ess/PAS2136/Butterfly/Model_Mimic/Data_Splits/MelpomeneNet/model_mimic_images_256_256_removed_background_kingfisher_acuity_dorsal/


In [211]:
from src.dataloader import read_image, read_sized_image, get_transforms, get_loader

if main.split('/')[7] == "MimicsNet":
    #create lists of image filepaths and their labels for each split
    train_images = list(train_df['path'])
    valid_images = list(val_df['path'])
    test_images = list(test_df['path'])

    #labels were encoded so that mimics share the same label
    print('Using MimicsNet Labels')
    train_targets = list(train_df['label'])
    valid_targets = list(val_df['label'])
    test_targets = list(test_df['label'])
    
elif main.split('/')[7] == "EratoNet":
    print('EratoNet')
    train_images = list(train_df['path'])
    valid_images = list(val_df['path'])
    test_images_erato = list(test_df_erato['path'])
    test_images_melpomene = list(test_df_melpomene['path'])

    train_targets = [species2index_erato[str(elem).split("/")[7]] for elem in train_images]
    valid_targets = [species2index_erato[str(elem).split("/")[7]] for elem in valid_images]
    test_targets_erato = [species2index_erato[str(elem).split("/")[7]] for elem in test_images_erato]
    test_targets_melpomene = [species2index_melpomene[str(elem).split("/")[7]] for elem in test_images_melpomene]
    test_targets_melpomene_actual = [str(elem).split("/")[7] for elem in test_images_melpomene] #so that we have the true melpomene label encodings for each image

    print("Number of train samples -", len(train_images))
    print("Number of valid samples -", len(valid_images))
    print("Number of test erato samples - ",len(test_images_erato))
    print("Number of test melpomene samples - ",len(test_images_melpomene))

elif main.split('/')[7] == "MelpomeneNet":
    print('MelpomeneNet')
    train_images = list(train_df['path'])
    valid_images = list(val_df['path'])
    test_images_erato = list(test_df_erato['path'])
    test_images_melpomene = list(test_df_melpomene['path'])

    train_targets = [species2index_melpomene[str(elem).split("/")[7]] for elem in train_images]
    valid_targets = [species2index_melpomene[str(elem).split("/")[7]] for elem in valid_images]
    test_targets_melpomene = [species2index_melpomene[str(elem).split("/")[7]] for elem in test_images_melpomene]
    test_targets_erato = [species2index_erato[str(elem).split("/")[7]] for elem in test_images_erato]
    test_targets_erato_actual = [str(elem).split("/")[7] for elem in test_images_erato] #so that we have the true melpomene label encodings for each image

    print("Number of train samples -", len(train_images))
    print("Number of valid samples -", len(valid_images))
    print("Number of test melpomene samples - ",len(test_images_melpomene))
    print("Number of test erato samples - ",len(test_images_erato))

else:
    #regular species classification
    train_images = list(train_df['path'])
    valid_images = list(val_df['path'])
    test_images = list(test_df['path'])

    train_targets = [species2index[str(elem).split("/")[7]] for elem in train_images]
    valid_targets = [species2index[str(elem).split("/")[7]] for elem in valid_images]
    test_targets = [species2index[str(elem).split("/")[7]] for elem in test_images]

    print("Number of train samples -", len(train_images))
    print("Number of valid samples -", len(valid_images))
    print("Number of test samples - ",len(test_images))

#create dataloaders for each split of data
batch = 32
train_dataset, train_loader = get_loader("train", train_images, train_targets, batch_size=batch, num_workers=1)
valid_dataset, valid_loader = get_loader("valid", valid_images, valid_targets, batch_size=batch, num_workers=1)

if erato_net or melpomene_net:
    test_dataset_erato, test_loader_erato, = get_loader("test", test_images_erato, test_targets_erato, batch_size=batch, num_workers=1)
    test_dataset_melpomene, test_loader_melpomene, = get_loader("test", test_images_melpomene, test_targets_melpomene, batch_size=batch, num_workers=1)
    
    #create dictionary with each of our dataloaders
    #only include the ID loader --> leave out the mimic loader for now. We'll work with that one independently
    loaders = { "train": train_loader,
            "valid": valid_loader,
            "test": test_loader_erato if erato_net else test_loader_melpomene
            }

else:
    test_dataset, test_loader, = get_loader("test", test_images, test_targets, batch_size=batch, num_workers=1)
    #create dictionary with each of our dataloaders
    loaders = { "train": train_loader,
            "valid": valid_loader,
            "test": test_loader
            }

loaders

MelpomeneNet
Number of train samples - 1020
Number of valid samples - 142
Number of test melpomene samples -  256
Number of test erato samples -  2396


{'train': <torch.utils.data.dataloader.DataLoader at 0x2b49975eaa90>,
 'valid': <torch.utils.data.dataloader.DataLoader at 0x2b48575e2cd0>,
 'test': <torch.utils.data.dataloader.DataLoader at 0x2b4833eaa400>}

## Load Model

In [212]:
#read in the model
from src.utils import load_model
from src.dir_paths.model_log_paths import get_logdir
from src.model import ResNetEncoder, EncoderWithHead

model_path = get_logdir(acuity, model_name) + '/checkpoints/classification_ckpt.pth'

#load in model with weights
num_classes = 20 if model_name == 'RegularSpeciesClassification' else 10 #20 (0r 10 if mimicsnet, eratonet, melpomene net)

encoder = ResNetEncoder("resnet50", 128)
model   = EncoderWithHead(encoder,
                          ArcFace(128, num_classes, s=2**0.5*np.log(num_classes - 1), m=0.25))

model = load_model(model, model_path, device)
encoder = model.encoder
device = 'cpu'

print('Model Ckpt:', model_path)
print('Num Classes: ', num_classes)

Model Ckpt: /fs/ess/PAS2136/Butterfly/Model_Mimic/resnet_arcface_logs/MelpomeneNet/model_mimic_images_256_256_removed_background_kingfisher_acuity_dorsal/checkpoints/classification_ckpt.pth
Num Classes:  10


## GradCam - One label (target label=ground truth) per image (Butterflies)

In [213]:
# model
# for name, layer in model.named_modules():
#     print(name, layer)

In [214]:
save_folder = get_figures_path(acuity, model_name)
plot_path = save_folder + '/gradcam/'
os.makedirs(plot_path, exist_ok=True)
print(plot_path)

/fs/ess/PAS2136/Butterfly/Model_Mimic/Figures/MelpomeneNet/model_mimic_images_256_256_removed_background_kingfisher_acuity_dorsal/gradcam/


In [215]:
index2species

{0: 'Heliconius melpomene ssp. ecuadorensis',
 1: 'Heliconius melpomene ssp. nanna',
 2: 'Heliconius melpomene ssp. rosina',
 3: 'Heliconius erato ssp. phyllis',
 4: 'Heliconius erato ssp. dignus',
 5: 'Heliconius melpomene ssp. bellula',
 6: 'Heliconius erato ssp. lativitta',
 7: 'Heliconius melpomene ssp. vulcanus',
 8: 'Heliconius erato ssp. etylus',
 9: 'Heliconius melpomene ssp. plesseni',
 10: 'Heliconius melpomene ssp. malleti',
 11: 'Heliconius erato ssp. cyrbia',
 12: 'Heliconius erato ssp. amalfreda',
 13: 'Heliconius erato ssp. venus',
 14: 'Heliconius erato ssp. hydara',
 15: 'Heliconius melpomene ssp. meriana',
 16: 'Heliconius erato ssp. petiverana',
 17: 'Heliconius melpomene ssp. cythera',
 18: 'Heliconius erato ssp. notabilis',
 19: 'Heliconius melpomene ssp. melpomene'}

In [216]:
#meriana is 15

# AllNet - getting Meriana samples
meriana_images = [n for n in train_images if n.split('/')[-2] == 'Heliconius melpomene ssp. meriana']
meriana_targets = [t for t in train_targets if t == 15]

test_images = meriana_images
test_targets = meriana_targets

In [217]:
# single image example - trying out different methods besides just GradCam
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget, ClassifierOutputSoftmaxTarget
from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad, LayerCAM, GradCAMElementWise, EigenGradCAM
import albumentations as albu
from albumentations.pytorch.transforms import ToTensorV2

test_transforms = albu.Compose([
            albu.Resize(224, 224),
            albu.Normalize(),
            ToTensorV2()
        ])

resize_transform = albu.Compose([
            albu.Resize(224, 224),
        ])

# create a wrapper for our model
class WrapperModel(torch.nn.Module):
        def __init__(self, model, target):
            super().__init__()
            self.model = model
            self.target = target
        def forward(self, x):
            self.model.eval()
            return self.model(x, targets=self.target)
        
# define the layer for whom we want to visualize output
encoder.eval()
# target_layers = [encoder.base.layer4]
target_layers = [layer for name, layer in model.named_modules()]

#set of labels
if erato_net:
    labels = list(index2species_erato.keys())
    label_map = index2species_erato
    test_images = test_images_erato + test_images_melpomene
    test_targets = test_targets_erato + test_targets_melpomene
elif melpomene_net:
    labels = list(index2species_melpomene.keys())
    label_map = index2species_melpomene
    test_images = test_images_erato + test_images_melpomene
    test_targets = test_targets_erato + test_targets_melpomene

    # getting Meriana samples (TEMPORARY - REMOVE ONCE DONE)
    meriana_images = [n for n in train_images if n.split('/')[-2] == 'Heliconius melpomene ssp. meriana']
    meriana_targets = [t for t in train_targets if t == 5]

    test_images = meriana_images
    test_targets = meriana_targets

else:
    labels = list(index2species.keys())
    label_map = index2species


# idx = 70
# for img_path, target in zip(train_images[idx:idx+1], train_targets[idx:idx+1]):
print(len(test_images))
for img_path, target in zip(test_images, test_targets):
    print(img_path)
    results = []
    result_labels = []

    #format our image as a tensor for gradcam
    orig_img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB) #(256,256,3)
    orig_img = resize_transform(image=orig_img)["image"] #(224,224,3)

    input_tensor = test_transforms(image=orig_img)["image"] #(3, 224, 224)apply transforms and convert to tensor    
    repeated_tensor = input_tensor[None, :] #(1,3,224,224) - add batch dim
    
    targets_for_gradcam = [ClassifierOutputTarget(target)] #get CAM for ground truth label
    result_labels.append(target)

    #get the predicted label our model would assign to the current image
    with torch.no_grad():
        out = model(input_tensor[None, :], targets=torch.tensor(target, device=device))
        print(out.shape)
        _, pred = torch.max(out, 1)
        print(f"Prediction: {pred.item()} | {label_map[pred.item()]}")
        print(f"Actual label: {target} | {label_map[target]}")

    #instantiate wrapper model for current target
    wrapped_model = WrapperModel(model, torch.tensor(target, device=device))

    #build the CAM object once and re-use it on many images
    cam = GradCAM(model=wrapped_model, target_layers=target_layers, use_cuda=False)

    # get CAM of ground truth label
    batch_results = cam(input_tensor=repeated_tensor, 
                        targets=targets_for_gradcam,
                        aug_smooth=False,
                        eigen_smooth=False) 
    
    for grayscale_cam in batch_results:
        visualization = show_cam_on_image(np.float32(orig_img)/255,
                                        grayscale_cam,
                                        use_rgb=True)
        # visualization = cv2.resize(visualization,(visualization.shape[1]//2, visualization.shape[0]//2))
        # plt.imshow(visualization)
    
    # save the visualization
    img_folder = f"{plot_path}{img_path.split('/')[-2]}/"
    img_name = img_path.split('/')[-2] + "/" + img_path.split('/')[-1] 
    
    os.makedirs(img_folder, exist_ok=True)
    print(f"{plot_path}{img_name}")
    cv2.imwrite(f"{plot_path}{img_name}", cv2.cvtColor(visualization, cv2.COLOR_RGB2BGR))


2
/fs/ess/PAS2136/Butterfly/Model_Mimic/model_mimic_images_256_256_removed_background_kingfisher_acuity_dorsal/Heliconius melpomene ssp. meriana/13819_H_m_meriana_D.JPG.png
torch.Size([1, 10])
Prediction: 5 | Heliconius melpomene ssp. meriana
Actual label: 5 | Heliconius melpomene ssp. meriana
/fs/ess/PAS2136/Butterfly/Model_Mimic/Figures/MelpomeneNet/model_mimic_images_256_256_removed_background_kingfisher_acuity_dorsal/gradcam/Heliconius melpomene ssp. meriana/13819_H_m_meriana_D.JPG.png
/fs/ess/PAS2136/Butterfly/Model_Mimic/model_mimic_images_256_256_removed_background_kingfisher_acuity_dorsal/Heliconius melpomene ssp. meriana/13715_H_m_meriana_D.JPG.png
torch.Size([1, 10])
Prediction: 5 | Heliconius melpomene ssp. meriana
Actual label: 5 | Heliconius melpomene ssp. meriana
/fs/ess/PAS2136/Butterfly/Model_Mimic/Figures/MelpomeneNet/model_mimic_images_256_256_removed_background_kingfisher_acuity_dorsal/gradcam/Heliconius melpomene ssp. meriana/13715_H_m_meriana_D.JPG.png


In [218]:
print('Done.')
print(plot_path)
print(model_name, acuity)

Done.
/fs/ess/PAS2136/Butterfly/Model_Mimic/Figures/MelpomeneNet/model_mimic_images_256_256_removed_background_kingfisher_acuity_dorsal/gradcam/
MelpomeneNet kingfisher_acuity


`To do: `

`Get maps for all images and save each horizontal stack under the image name and the true label associated with it`