# Finding Saints in Paintings
This work is based on the paper:\
*Milani F, Fraternali P (2021) A Data Set and a Convolutional Model for Iconography Classification in Paintings. J Comput Cult Herit 14:1–18. https://doi.org/10.1145/3458885*\
The data set can be found here: http://www.artdl.org/ \
\
The goal of this notebook is to replicate the results of Milani et al but instead of using a CNN model, I will be using the Vision Transformer (ViT).
The code is based on these two tutorials on fine-tuning ViT:\
https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/VisionTransformer/Fine_tuning_the_Vision_Transformer_on_CIFAR_10_with_the_%F0%9F%A4%97_Trainer.ipynb#scrollTo=szWwJmqPHZ-r
\
https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/VisionTransformer/Fine_tuning_the_Vision_Transformer_on_CIFAR_10_with_PyTorch_Lightning.ipynb#scrollTo=VV8_9IjhKiDh

The Class Activation Map (CAM) implementation is based on this:
https://github.com/jacobgil/pytorch-grad-cam/blob/master/usage_examples/vit_example.py

# Import libraries
(install requirements, if necessary)

In [None]:
import os
import random
import re
from zipfile import ZipFile
from tqdm import tqdm
import numpy as np
import pandas as pd
from sklearn.metrics import classification_report, precision_score, recall_score, f1_score, accuracy_score, average_precision_score
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
import glob
from PIL import Image as PIL_Image
import cv2
from matplotlib import pyplot as plt
import torch
from torch.utils.data import DataLoader as TorchLoader
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
from torchvision.transforms import ToTensor, RandomHorizontalFlip
from datasets import (Dataset, 
                      load_metric, 
                      load_dataset, 
                      Features, 
                      ClassLabel, 
                      Array3D,
                      Image)
from transformers import (ViTModel, 
                          ViTForImageClassification,
                          ViTFeatureExtractor, 
                          TrainingArguments, 
                          AdamW) 
from transformers import (DeiTModel,  
                          DeiTFeatureExtractor)
from transformers.modeling_outputs import SequenceClassifierOutput
from pytorch_grad_cam import GradCAM, \
    ScoreCAM, \
    GradCAMPlusPlus, \
    AblationCAM, \
    XGradCAM, \
    EigenCAM, \
    EigenGradCAM, \
    LayerCAM, \
    FullGrad
from pytorch_grad_cam import GuidedBackpropReLUModel
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
from pytorch_grad_cam.ablation_layer import AblationLayerVit

torch.cuda.is_available()

# Download data 
(if not already done)

In [None]:
'''
Dataset is stored in Gdrive of the creators
This can take a minute or two. You should get an output like this:

>>> Downloading...
>>> From: https://drive.google.com/uc?id=16FK1YnHPhGqCHf_EpovzcH0v90yXcCer
>>> To: /home/jovyan/DEVKitArtDL.zip
>>> 100% 3.62G/3.62G [01:03<00:00, 68.7MB/s]

Please copy the zip-file location into the variable in the next cell.
'''
!gdown https://drive.google.com/uc?id=16FK1YnHPhGqCHf_EpovzcH0v90yXcCer

In [None]:
zip_file_location = '/home/jovyan/DEVKitArtDL.zip'

# Chose Model

In [None]:
ViT = 'google/vit-base-patch16-224-in21k'
DeIT = 'facebook/deit-base-distilled-patch16-224'
tiny_ViT = 'lysandre/tiny-vit-random'


ViT_extractor = ViTFeatureExtractor.from_pretrained(ViT)
#DeIT_extractor = DeiTFeatureExtractor.from_pretrained(DeIT)


Vit_model = ViTModel.from_pretrained(ViT)
#DeIT_model = DeiTModel.from_pretrained(DeIT)


#########################
version = ViT
extractor = ViT_extractor
model = Vit_model

# Load Data


In [None]:
class DataLoader():
    '''
    Class for loading data and labels as provided by Milani et al 2021

    returns:
    data (dict): (train:[PIL Images], val:[PIL Images], test:[PIL Images])
    labels (dict): (train:[Label IDs], val:[Label IDs], test:[Label IDs])
    num_labels (int): number of labels
    ID2Label (dict): {label_as_number : label_as_string}
    Label2ID (dict): {label_as_string : label_as_number}
    '''

    def __init__(self, zip_file_location):
        self.zip_file_location = zip_file_location
        self.info_frame = pd.read_csv(ZipFile(zip_file_location).open('DEVKitArt/info.csv'))
        self.drop_ambiguous()
        self.num_labels = 20

    def drop_ambiguous(self):
        '''
        Drops all images that have more than one label
        '''
        self.drop = self.info_frame.loc[(self.info_frame.sum(axis=1) > 1), :]['item']
        self.info_frame.drop(self.drop.index, inplace=True)  

    def load_labels(self):
        '''
        Extracts label details from info.csv
        '''
        self.ID2Label = {id+1:(label if re.search(r'\(.*?\)', label) is None else re.search(r'\(.*?\)', label)[0][1:-1]) for (id, label) in enumerate(self.info_frame.columns[1:-1])}
        self.Label2ID = {(label if re.search(r'\(.*?\)', label) is None else re.search(r'\(.*?\)', label)[0][1:-1]):id+1 for (id, label) in enumerate(self.info_frame.columns[1:-1])}

    def update_frame(self):
        '''
        Adds information about labels to info.csv data frame
        '''
        self.info_frame["label_name"] = [self.ID2Label[id] for id in self.IDs]
        self.info_frame["label_id"] = self.IDs 

    def add_None_Label(self):
        '''
        Adds a None label to images that do not depict any saint
        '''
        self.ID2Label[0] = "None"
        self.Label2ID["None"] = 0
        array = self.info_frame.iloc[:,1:-1].to_numpy()
        array = np.insert(array, 0, np.zeros(array.shape[0]), axis=1)
        self.IDs = np.argmax(array, axis=1)
        
    def shrink(self, frac:float = 1):
        '''
        Since there are far more Mary and None labels than any other, this function
        allows to shrink the examples for those two labels to a smaller fraction
        '''
        self.info_frame = pd.concat([
                   self.info_frame[self.info_frame.label_id > 1],
                   self.info_frame[self.info_frame.label_id < 2].sample(frac=frac, replace=False, random_state=1)]
                ).sample(frac = 1)
        
    def drop_by_occurrence(self, drop:int = 10):
        '''
        Drops n examples based on the least label frequency.
        '''
        occurrences = self.info_frame.groupby('label_name').count().item.to_list()
        occurrences.sort()
        self.info_frame = self.info_frame.groupby('label_name').filter(lambda x: len(x) > occurrences[drop]-1)
        self.num_labels -= drop

    def split(self, resize_factor:tuple = (32,32) ):
        '''
        Splits data set into train, val, and test set.
        '''
        self.stats = {lab:0 for lab in self.Label2ID.keys()}
        zf = ZipFile(self.zip_file_location)
        self.data = {}
        self.labels = {}
        for split in self.info_frame["set"].unique():
            self.data[split] = []
            self.labels[split] = []
            for row in tqdm(self.info_frame[self.info_frame["set"] == split].itertuples(index=False, name=None), 
                          desc=split, 
                          total=self.info_frame[self.info_frame["set"] == split].shape[0]):
                img = zf.open('DEVKitArt/JPEGImages/' + row[0] + '.jpg') 
                img = PIL_Image.open(img)
                self.data[split].append(img.resize(resize_factor))
                self.labels[split].append(row[-1])
                self.stats[self.ID2Label[row[-1]]] += 1
    
    def refresh(self):
        '''
        Adjusts label and ids of dict and dataframe to the dropped labels
        '''
        id_is_key = {}
        label_is_key = {}
        for idx, label_name in enumerate(self.info_frame['label_name'].unique()):
            self.info_frame.loc[self.info_frame['label_name'] == label_name, 'label_id'] = idx
            id_is_key[idx] = label_name
            label_is_key[label_name] = idx
        self.ID2Label = id_is_key
        self.Label2ID = label_is_key
        

    def augmentation(self):
        '''
        Performs horizontal flip on images that are less frequent (= that are neither NONE nor MARY).
        '''
        aug_data = []
        aug_label = []
        args_least_freq = np.argwhere((np.array(self.labels['train']) != self.Label2ID['MARY']) & 
                          (np.array(self.labels['train']) != self.Label2ID['None']))
        for idx in tqdm(args_least_freq, 
                        desc='Augmentation',
                       total= args_least_freq.shape[0]):
            aug_data.append(RandomHorizontalFlip(1)(self.data['train'][int(idx)]))
            aug_label.append(self.labels['train'][int(idx)])
            self.stats[self.ID2Label[self.labels['train'][int(idx)]]] += 1
            
        self.data['train'] = self.data['train'] + aug_data
        self.labels['train'] = self.labels['train'] + aug_label
        
    def load(self, frac:float = 1, drop:int = 9, augment = False, resize_factor:tuple = (128,128)):
        '''
        Calls relevant functions in correct order.
        '''
        self.load_labels()
        self.add_None_Label()
        self.update_frame()
        self.shrink(frac)
        self.drop_by_occurrence(drop)
        self.refresh()
        self.split(resize_factor)
        if augment:
            self.augmentation()
        self.get_stats()
        return self.data, self.labels, self.num_labels, self.ID2Label, self.Label2ID

    def demo(self, resize_factor:int = 3):
        '''
        Displays random image including label as string and int
        '''
        rand = random.randint(0, len(self.data['train']))
        img = self.data['train'][rand]
        label = self.labels['train'][rand]
        width, height = img.size
        width, height = int(width * resize_factor), int(height * resize_factor)
        img = img.resize((width, height))
        img.show()
        display(img)
        print(self.ID2Label[label], label)
    
    def get_stats(self):
        return pd.DataFrame.from_dict(self.stats,orient='index', columns = ['Frequency']).sort_values(by=['Frequency'], ascending=False)
        

In [None]:
frac = 1
drop = 9
augment = False
resize_factor = (224,224)
dl = DataLoader(zip_file_location)
data, labels, num_labels, ID2Label, Label2ID = dl.load(frac=frac, 
                                                       drop=drop, 
                                                       augment=augment, 
                                                       resize_factor=resize_factor)

# Some insights on the data

In [None]:
dl.get_stats()

In [None]:
## Displays a random image including its labels
# You can add a resize factor
dl.demo(2)

In [None]:
def plot_distribution(ids:dict, id2label, split:str = 'train', save = False, save_as = '_Distribution'):
    df = pd.DataFrame.from_dict({'Label' : [id2label[id] for id in ids[split]]})
    plt.figure(figsize=(10,6))
    plt.title(split)
    ax = sns.countplot(y = 'Label', 
                 data = df, 
                 order = df['Label'].value_counts().index, 
                 palette = 'Set3')
    if save:
        fig = ax.get_figure()
        fig.savefig(split + save_as + '.png', bbox_inches="tight")

In [None]:
%matplotlib inline
plot_distribution(labels, ID2Label, 'train', save=True, save_as ='_Distribution')

In [None]:
%matplotlib inline
plot_distribution(labels, ID2Label, 'val')

In [None]:
%matplotlib inline
plot_distribution(labels, ID2Label, 'test')

# Preprocess


In [None]:
class PreProcess():
    '''
    Loads data and labels to Huggingface's Dataset class and performs resizing/formatting
    '''
    def __init__(self, data, labels, ID2Label, extractor):
        self.data = data
        self.labels = labels
        self.label_names = [ID2Label[ID] for ID in range(len(ID2Label))]
        self.feature_extractor = extractor
        self.features = Features({
                    'label': ClassLabel(
                        names=self.label_names),
                    'img': Image(),
                    'pixel_values': Array3D(dtype="float32", shape=(3, 224, 224)),
                })
        self.to_tensor = ToTensor()
        
    def split(self):
        '''
        Splits data in training, validation, and test set as huggingface's dataset object
        '''
        print('Load train set as dataset object...')
        self.train_ds = {'img' : self.data['train'], 'label' : self.labels['train']}
        self.train_ds = Dataset.from_dict(self.train_ds)

        print('Load validation set as dataset object...')
        self.val_ds = {'img' : self.data['val'], 'label' : self.labels['val']}
        self.val_ds = Dataset.from_dict(self.val_ds)

        print('Load test set as dataset object...\n')
        self.test_ds = {'img' : self.data['test'], 'label' : self.labels['test']}
        self.test_ds = Dataset.from_dict(self.test_ds)

    def format_(self):
        '''
        Brings data in right format for Vision Transformer.
        '''
        def preprocess_images(examples):
            images = examples['img']
            images = [self.to_tensor(image) for image in images]
            inputs = self.feature_extractor(images=images)
            examples['pixel_values'] = inputs['pixel_values']
            return examples
        
        print('Bring train data in right format for Vision Transformer...')
        self.preprocessed_train_ds = self.train_ds.map(preprocess_images, batched=True, features=self.features)
        print('Bring validation data in right format for Vision Transformer...')
        self.preprocessed_val_ds = self.val_ds.map(preprocess_images, batched=True, features=self.features)
        print('Bring test data in right format for Vision Transformer...')
        self.preprocessed_test_ds = self.test_ds.map(preprocess_images, batched=True, features=self.features)
        
    def loader(self, train_batch_size = 2, eval_batch_size = 2):
        '''
        Creates Dataloader for training on batches.
        '''
        def collate_fn(examples):
            pixel_values = torch.stack([torch.tensor(example["pixel_values"]) for example in examples])
            labels = torch.tensor([example["label"] for example in examples])
            return {"pixel_values": pixel_values, "labels": labels}
        
        self.train_dataloader = TorchLoader(self.preprocessed_train_ds, shuffle=True, collate_fn=collate_fn, batch_size=train_batch_size)
        self.val_dataloader = TorchLoader(self.preprocessed_val_ds, collate_fn=collate_fn, batch_size=eval_batch_size)
        self.test_dataloader = TorchLoader(self.preprocessed_test_ds, collate_fn=collate_fn, batch_size=eval_batch_size)

    def process(self,train_batch_size = 2, eval_batch_size = 2):
        '''
        Calls relevant functions in correct order.
        '''
        self.split()
        self.format_()
        self.loader(train_batch_size, eval_batch_size)
        return self.train_dataloader, self.val_dataloader, self.test_dataloader
    

In [None]:
train_batch_size, eval_batch_size = 2, 2
pp = PreProcess(data, labels, ID2Label, extractor)
train_dataloader, val_dataloader, test_dataloader = pp.process(train_batch_size, eval_batch_size)

# Vision Transformer


In [None]:
class ViT(pl.LightningModule):
    '''
    Class for Vision Transformer.
    '''
    def __init__(self, 
                 version, 
                 ID2Label, 
                 Label2ID, 
                 num_labels,
                 lr = 5e-5):
        super(ViT, self).__init__()
        self.vit = ViTForImageClassification.from_pretrained(version,
                                                              num_labels=num_labels,
                                                              id2label=ID2Label,
                                                              label2id=Label2ID)
        self.lr = lr

    def forward(self, pixel_values):
        '''
        Forward pass.
        '''
        outputs = self.vit(pixel_values=pixel_values)
        return outputs.logits
        
    def common_step(self, batch, batch_idx):
        '''
        Unpacks batch and computes loss and accuracy.
        '''
        pixel_values = batch['pixel_values']
        labels = batch['labels']
        logits = self(pixel_values)

        criterion = nn.CrossEntropyLoss()
        loss = criterion(logits, labels)
        predictions = logits.argmax(-1)
        correct = (predictions == labels).sum().item()
        accuracy = correct/pixel_values.shape[0]
        return loss, accuracy
      
    def training_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)     
        self.log("training_loss", loss)
        self.log("training_accuracy", accuracy)
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)     
        self.log("validation_loss", loss, on_epoch=True)
        self.log("validation_accuracy", accuracy, on_epoch=True)
        return loss

    def test_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)     
        return loss

    def configure_optimizers(self):
        return AdamW(self.parameters(), lr=self.lr)

In [None]:
class Train_VisionTransformer():
    '''
    Trainer for Vision Transformer.
    '''
    def __init__(self, 
                 version, 
                 train_dataloader, 
                 val_dataloader, 
                 test_dataloader,
                 ID2Label, 
                 Label2ID, 
                 num_labels,
                 lr = 5e-5):
        self.early_stop_callback = EarlyStopping(
                                    monitor='val_loss',
                                    patience=3,
                                    strict=False,
                                    verbose=False,
                                    mode='min'
                                )
        self.model = ViT(version, 
                         ID2Label, 
                         Label2ID, 
                         num_labels,
                         lr)
        self.target_names = [ID2Label[ID] for ID in range(len(ID2Label))]
        self.train_dataloader = train_dataloader
        self.val_dataloader = val_dataloader
        self.test_dataloader = test_dataloader
        self.trainer = Trainer(gpus=1, callbacks=[EarlyStopping(monitor='validation_loss')])
        
    def train(self):
        '''
        Fine-tunes the model to the new data.
        '''
        self.trainer.fit(model = self.model, 
                         train_dataloaders = self.train_dataloader, 
                         val_dataloaders = self.val_dataloader)
    
    
    def test(self):
        '''
        Tests the fine-tuned model on the test data and computes precision, recall, f1, and accuracy.
        '''
        self.target = []
        self.pred = []
        self.model.eval()
        for batch in tqdm(self.test_dataloader, desc = 'Test',total=len(self.test_dataloader)):
            for label, example in zip(batch['labels'], batch['pixel_values']):
                self.target.append(label)
                self.pred.append(int(self.model(example.reshape((1, 3, 224, 224))).argmax(-1)))
        self.pred = np.array(self.pred)
        self.target = np.array(self.target)
        print(classification_report(self.target, self.pred, target_names=self.target_names))
        
                          
    def confusion_matrix(self, save=False, name = 'ViT'):
        '''
        Plots confusion matrix of the results on the test data.
        '''
        cm = confusion_matrix(self.target, self.pred)
        eps = 0.000000001
        cm = cm/(np.sum(cm, axis=1)+eps)
        plt.figure(figsize=(10,6))
        ax = sns.heatmap(cm, annot=True, fmt='.3f',cmap='Blues')
        ax.set_title('Confusion Matrix (Recall)')
        ax.set_xlabel('\nPrediction')
        ax.set_ylabel('Groundtruth ')
        ax.xaxis.set_ticklabels(self.target_names)
        ax.yaxis.set_ticklabels(self.target_names)
        ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
        ax.set_yticklabels(ax.get_yticklabels(), rotation=360)
        plt.show()
        if save:
            fig = ax.get_figure()
            fig.savefig('ConfusionMatrix/' + name + '.png')
        
    def save(self, name):
        '''
        Saves the model weights as statedict.
        '''
        torch.save(self.model.state_dict(), '/home/jovyan/Model_Checkpoints/' + name + '_stateDict.pt')


In [None]:
learning_rate = 5e-5
train_vit = Train_VisionTransformer(version,  
                                    train_dataloader, 
                                    val_dataloader, 
                                    test_dataloader,
                                    ID2Label, 
                                    Label2ID, 
                                    num_labels = num_labels,
                                    lr = learning_rate)

In [None]:
train_vit.train()

In [None]:
train_vit.test()

In [None]:
%matplotlib inline
train_vit.confusion_matrix(save=False)

In [None]:
train_vit.save('ViT_v3')

# Hyper-Parameter Tuning

In [None]:
class TUNE():
    '''
    Class for hyper-parameter tuning.
    '''
    def __init__(self, lr, batch):
        self.lr = lr
        self.batch = batch
        self.log = {
            'model_name' : [],
            'learning_rate' : [],
            'batch_size' : [],
            'Precision(macro)' : [],
            'Precision(micro)' : [],
            'Recall(macro)' : [],
            'Recall(micro)' : [],
            'F1(macro)' : [],
            'F1(micro)' : [],
            'Accuracy' : []
        }
    
    def get_data(self,
                  preprocessed_train_ds,
                  preprocessed_val_ds,
                  preprocessed_test_ds):
        '''
        Retrieves datasets.
        '''
        self.preprocessed_train_ds = preprocessed_train_ds
        self.preprocessed_val_ds = preprocessed_val_ds
        self.preprocessed_test_ds = preprocessed_test_ds
    
    def get_model_details(self,
                         version, 
                         ID2Label, 
                         Label2ID, 
                         num_labels):
        '''
        Retrieves model details.
        '''
        self.version = version
        self.ID2Label = ID2Label
        self.Label2ID = Label2ID
        self.num_labels = num_labels
        
    def get_batch(self, train_batch_size, eval_batch_size):
        '''
        Creates batch with varying sizes.
        '''
        def collate_fn(examples):
            pixel_values = torch.stack([torch.tensor(example["pixel_values"]) for example in examples])
            labels = torch.tensor([example["label"] for example in examples])
            return {"pixel_values": pixel_values, "labels": labels}
        
        train_dataloader = TorchLoader(self.preprocessed_train_ds, shuffle=True, collate_fn=collate_fn, batch_size=train_batch_size)
        val_dataloader = TorchLoader(self.preprocessed_val_ds, collate_fn=collate_fn, batch_size=eval_batch_size)
        test_dataloader = TorchLoader(self.preprocessed_test_ds, collate_fn=collate_fn, batch_size=eval_batch_size)
        return train_dataloader, val_dataloader, test_dataloader
    
    def metrics(self, y_pred, y_true):
        '''
        Computes metrics for each model during tuning.
        '''
        self.log['Precision(macro)'].append(precision_score(y_true, y_pred, average='macro'))
        self.log['Precision(micro)'].append(precision_score(y_true, y_pred, average='micro'))
        self.log['Recall(macro)'].append(recall_score(y_true, y_pred, average='macro'))
        self.log['Recall(micro)'].append(recall_score(y_true, y_pred, average='micro'))
        self.log['F1(macro)'].append(f1_score(y_true, y_pred, average='macro'))
        self.log['F1(micro)'].append(f1_score(y_true, y_pred, average='micro'))
        self.log['Accuracy'].append(accuracy_score(y_true, y_pred))

    def tune(self):
        '''
        Trains a model for each learning rate and batch size.
        '''
        idx = 1 
        skip_first = True
        for learning_rate in self.lr:
            for batch_size in self.batch:
                print(f'Train Model: Learning rate: {learning_rate}, Batch size: {batch_size}...\n')
                train_dataloader, val_dataloader, test_dataloader = self.get_batch(batch_size, batch_size)
                model = Train_VisionTransformer(
                                        self.version,
                                        train_dataloader, 
                                        val_dataloader, 
                                        test_dataloader,
                                        self.ID2Label, 
                                        self.Label2ID, 
                                        self.num_labels,
                                        learning_rate
                )
                model.train()
                name = 'ViT_' + str(idx)
                model.save(name)
                model.test()
                idx += 1
                self.log['model_name'].append(name)
                self.log['learning_rate'].append(learning_rate)
                self.log['batch_size'].append(batch_size)
                pred = model.pred
                target = model.target
                self.metrics(pred, target)
                self.results = pd.DataFrame.from_dict(self.log)
                self.results.to_csv('ViT_HypTune.csv', index = False)
                model.confusion_matrix(save = True, name = name)
                
    def get_results(self):
        display(self.results)
        return self.results


In [None]:
lr = [0.00003, 0.00002, 0.000005]
batch = [2, 8, 16, 64]
tune = TUNE(lr, batch)
tune.get_data(
                    pp.preprocessed_train_ds,
                    pp.preprocessed_val_ds,
                    pp.preprocessed_test_ds
)
tune.get_model_details(
                    version,              
                    ID2Label, 
                    Label2ID, 
                    num_labels
)

In [None]:
tune.tune()
tune.get_results()

In [None]:
def highlight_max(data, color='yellow'):
    '''
    highlight the maximum
    '''
    if data.name in ['model_name', 'learning_rate', 'batch_size']:
        return ['' for _ in data]
    else:
        attr = 'background-color: {}'.format(color)
        is_max = data == data.max()
        return [attr if v else '' for v in is_max]
    
df = pd.read_csv('ViT_HypTune_reduced_Dataset.csv')
df.style.apply(highlight_max)

# Load Model
(In case you already saved a fine-tuned version)

In [None]:
class Load_ViT():
    def __init__(self, name, version, ID2Label, Label2ID, num_labels):
        self.FILE = '/home/jovyan/Model_Checkpoints/' + name + '_stateDict.pt'
        self.loaded_model = ViT(version, ID2Label, Label2ID, num_labels, 5e-5)
        self.loaded_model.load_state_dict(torch.load(self.FILE))
        
    def get(self):
        return self.loaded_model

In [None]:
name = 'ViT_tune_corrected_6'
vit = Load_ViT(name, 'google/vit-base-patch16-224-in21k', ID2Label, Label2ID, num_labels).get()

# Class Activation Map

In [None]:
class CAM():
    '''
    Class Activation Map for given image and model.
    '''
    def __init__(self, model, info_frame, zip_file_location, ID2Label):
        self.model = model
        self.model.eval()
        self.info_frame = info_frame[info_frame['set'] == 'test']
        self.zf = ZipFile(zip_file_location)
        self.ID2Label = ID2Label
        self.target_layers = [self.model.vit.vit.encoder.layer[11].layernorm_before]
        self.methods = \
        {"gradcam": GradCAM,
         "scorecam": ScoreCAM,
         "gradcam++": GradCAMPlusPlus,
         "xgradcam": XGradCAM,
         "eigencam": EigenCAM,
         "eigengradcam": EigenGradCAM,
         "layercam": LayerCAM
         }
        self.targets = None
        
    def reshape_transform(self, tensor, height=14, width=14):
        result = tensor[:, 1:, :].reshape(tensor.size(0),height, width, tensor.size(2))
        result = result.transpose(2, 3).transpose(1, 2)
        return result

    def remind_methods(self):
        '''
        Returns all possible CAM methods.
        '''
        print(self.methods.keys())
        return list(self.methods.keys())
        
    def remind_labels(self):
        '''
        Returns all class names.
        '''
        print(self.info_frame['label_name'].unique())
        return self.info_frame['label_name'].unique()
        
    def chose_method(self, method):
        '''
        Selects CAM method.
        '''
        
        self.cam = self.methods[method](model=self.model,
                   target_layers=self.target_layers,
                   use_cuda=torch.cuda.is_available(),
                   reshape_transform=self.reshape_transform)
        self.cam.batch_size = 32
        self.single_method = method
        
    def fit_all(self):
        '''
        Applies all CAM methods to given image.
        '''
        self.cams = {}
        for name, method in self.methods.items():
            self.cams[name] = self.methods[name](
                model=self.model,
                target_layers=self.target_layers,
                use_cuda=torch.cuda.is_available(),
                reshape_transform=self.reshape_transform)
            self.cams[name].batch_size = 32
        
    def chose_image(self, label = 'MARY'):
        '''
        Selects random image from given class.
        '''
        name = self.info_frame[self.info_frame['label_name'] == label].sample().iloc[0,0]
        img = self.zf.read('DEVKitArt/JPEGImages/' + name + '.jpg')
        self.rgb_img = cv2.imdecode(np.frombuffer(img, np.uint8), 1)[:, :, ::-1]  
        
        self.rgb_img = cv2.resize(self.rgb_img, (224, 224))
        self.rgb_img = np.float32(self.rgb_img) / 255
        self.input_tensor = preprocess_image(self.rgb_img, mean=[0.5, 0.5, 0.5],
                                    std=[0.5, 0.5, 0.5])
        
    def predict(self, label= 'MARY'):
        '''
        Model output on given image.
        '''
        title = 'True: ' + label + ', Pred: '
        title += self.ID2Label[int(self.model(self.input_tensor.to(device='cuda')).argmax(-1))]
        return title
        
    def show_all(self, label = 'MARY', save=False, save_as='CAM_01'):
        '''
        Displays results from fit_all()
        '''
        self.fit_all()
        self.chose_image(label)
        imgs = {self.predict(label) : self.rgb_img}
        for name, cam in self.cams.items():
            grayscale_cam = cam(
                input_tensor=self.input_tensor,
                targets=self.targets,
                eigen_smooth=True,
                aug_smooth=True)
            grayscale_cam = grayscale_cam[0, :]
            imgs[name] = show_cam_on_image(self.rgb_img, grayscale_cam)
        fig, axs = plt.subplots(2,4, figsize=(18,7))
        for ax, title_img in zip(axs.ravel(), imgs.items()):
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_title(title_img[0])
            ax.imshow(title_img[1])
        if save:
            fig.savefig('CAM/' + save_as + '.png')
        
    def show_single(self, label = 'MARY', save=False, name='CAM_01'):
        '''
        Displays image and CAM.
        '''
        self.chose_image(label)
        grayscale_cam = self.cam(input_tensor=self.input_tensor,
                        targets=self.targets,
                        eigen_smooth=True,
                        aug_smooth=True)
        grayscale_cam = grayscale_cam[0, :]
        self.cam_image = show_cam_on_image(self.rgb_img, grayscale_cam)
        self.cam_image = cv2.cvtColor(self.cam_image, cv2.COLOR_BGR2RGB)
        title = 'Class Activation Map ' + '(' + self.single_method + ')'
        fig, axs = plt.subplots(1,2, figsize=(14,12))
        for ax in axs:
            ax.set_xticks([])
            ax.set_yticks([])
        axs[1].set_title(title)
        axs[0].set_title(self.predict(label))
        axs[0].imshow(self.rgb_img)
        axs[1].imshow(self.cam_image)
        if save:
            fig.savefig('CAM/' + name + '.png')

In [None]:
cam = CAM(vit, dl.info_frame, zip_file_location, ID2Label)

In [None]:
method = 'gradcam'
cam.chose_method(method)
cam.show_single(label='JEROME', save=True, name='JEROME_03')

In [None]:
for label in cam.remind_labels():
    for idx in range(1, 4):
        if label in ['None']:
            continue
        name = label + '_ViT_tune_9_' + str(idx)
        cam.show_all(label=label, save=True, save_as=name)