# **CLASS ACTIVATION MAPS IN OXFORD PET III DATASET**

## Introduction

This notebook is my personal experimentation with Class Activation Maps (https://arxiv.org/abs/1512.04150), using the Oxford III Pet dataset (available here: https://www.kaggle.com/tanlikesmath/the-oxfordiiit-pet-dataset) for the evaluation. 

A reference implementation of this algorithm by Chae Young Lee can be found in this Git Hub repository: https://github.com/chaeyoung-lee/pytorch-CAM


**Disclaimer: this notebook is yet to be finished. So far, the model, the get_cam function and the training loop have been defined. The ultimate milestone is to propose a weakly supervised segmentation algorithm that performs region growing on the Class Activation Maps*.

## Development

### Mount Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

import os
os.chdir('/content/drive/My Drive/')

### Import libraries

In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np

from tqdm import tqdm

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import models
from torch.autograd import Function
from torch.autograd import Variable
import torch.cuda.amp as amp
import json

import random
import scipy.io
from PIL import Image
import cv2 as cv

### Auxiliary functions

In [None]:
def one_hot(mask):
  """
    This function one-hot encodes the segmentation Ground Truth masks. In Oxford III we have three classes:
    '0': background, '1': boundary, '2': animal. 
  """

  new = np.zeros((3, mask.shape[0], mask.shape[1]))

  new[0, (mask==1)] = 1 
  new[1, (mask==2)] = 1
  new[2, (mask==3)] = 1  
  
  new = new.transpose(1,2,0)
  return new

In [None]:
#Kaiming weight initialization
def weights_init(m):
    if isinstance(m, torch.nn.Conv2d):
      torch.nn.init.kaiming_uniform_(m.weight.data,nonlinearity='relu')
      if m.bias is not None:
          nn.init.constant_(m.bias.data, 0)
    if isinstance(m, torch.nn.ConvTranspose2d):
      torch.nn.init.kaiming_uniform_(m.weight.data,nonlinearity='relu')
      if m.bias is not None:
          nn.init.constant_(m.bias.data, 0)

### Preprocessing

In [None]:
#ImageNet mean/std. *To do: calculate mean/std of Oxford Pet III
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]


transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Resize([256,256]), 
     #transforms.RandomVerticalFlip(p=0.5), #Recortar Crop -> Data aug
     #transforms.RandomHorizontalFlip(p=0.5)
    ])

data_aug = transforms.Compose(
    [#transforms.ColorJitter(brightness=0.15, contrast=0.15, saturation=0.15, hue=0.15), #High computational cost -> epoch training time x1.5
     transforms.ToTensor(),
     transforms.Normalize(mean=mean, std=std),
     transforms.ToPILImage()
     
    ])


### Oxford Pet III Database

In [None]:
class oxford_Pet(torch.utils.data.Dataset):
  def __init__(self, root, transform=None, data_aug=None, set='train'):
    self.root = root

    if(set=='train'): self.dataDir = root + 'paths/train.txt'
    if(set=='val'): self.dataDir = root + 'paths/val.txt'
    if(set=='test'): self.dataDir = root + 'paths/test.txt'

    self.data = open(self.dataDir)
    self.paths = self.data.readlines()
    self.transform = transform
    self.data_aug = data_aug

  def __getitem__(self, index):
    name, class_id, species, breed_id = self.paths[index].strip().split()
    imgPath = root + 'images/' + name + '.jpg'
    gtPath = root + 'annotations/trimaps/' + name + '.png'

    img = Image.open(imgPath) #JPG file (3, h, w)
    gt = Image.open(gtPath) #PNG file (1, h, w)
    gt_t = one_hot(np.array(gt))
    
    if self.transform:
      seed = np.random.randint(56346346)

      random.seed(seed)
      torch.manual_seed(seed)
      img_t = self.transform(img)
      
      random.seed(seed)
      torch.manual_seed(seed)
      gt_t = self.transform(gt_t)

    
    species = int(species)
    wl = np.zeros(2) 
    wl[species-1] = 1 #One hot
    wl_t = torch.from_numpy(wl)  #To tensor
    wl_t = wl_t.float() #To float

    return img_t, gt_t, wl_t #Image, Segmentation mask, Classification mask

  def __len__(self):
    return len(self.paths)



In [None]:
root='/content/drive/My Drive/Pet/'

#Sets: 'train' (1846), 'val' (1834), 'test'

trainset = oxford_Pet(root=root, transform=transform, set='train')
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True)

validationset = oxford_Pet(root=root, transform=transform, set='val_simplified')
validationloader = torch.utils.data.DataLoader(validationset, batch_size=2, shuffle=True)

testset = db_isic_Dataset(root=ROOT, idx=idx, transform=transform, set='test')
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=True, num_workers=1)

data_loaders = {"train": trainloader, "val": validationloader}

print("val: ", len(validationset))
print("train: ", len(trainset))


### Model arquitecture

In [None]:
class net(torch.nn.Module):  
  """
    This class creates a hybrid classification-segmentation model and returns three outputs given an input tensor:
      - x: predicted segmentation masks.
      - aspp: Atrous Spatial Pyramid Pooling feature maps (encoder).
      - output: predicted classification vector.
  """
  def __init__(self):
    super().__init__()

    self.resnet = models.resnet101(pretrained=True)

    #ENCODER (RESNET101 PRETRAINED)
    self.layer1 = nn.Sequential(*list(self.resnet.children())[0:5])
    self.layer2 = nn.Sequential(*list(self.resnet.layer2))

    self.layer3 = nn.Sequential(
      nn.Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=True),
      nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
      nn.ReLU(),
      nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=True),
      nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
      nn.ReLU(),
      nn.Conv2d(256, 512, kernel_size=1, stride=1, bias=True),
      nn.BatchNorm2d(512),
      nn.ReLU()
    )

    self.layer4 = nn.Sequential(
        nn.Conv2d(1024, 1024, kernel_size=1, stride=1, bias=True),
        nn.BatchNorm2d(1024),
        nn.ReLU(),
        nn.Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, bias=True),
        nn.BatchNorm2d(1024),
        nn.ReLU()
    )

    #ASPP layers
    self.a1 = nn.Sequential( #1x1 Conv Stride=2
        nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=1, stride=2, padding=0),
        nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
        nn.Conv2d(in_channels= 256, out_channels=256, kernel_size=1, stride=1, padding=0),
        nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
        nn.ReLU()
    )
    self.a2 = nn.Sequential( #Dilation 2
        nn.Conv2d(in_channels= 1024, out_channels=256, kernel_size=3, stride=2, dilation=2, padding=2),
        nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
        nn.Conv2d(in_channels= 256, out_channels=256, kernel_size=1, stride=1, padding=0),
        nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
        nn.Conv2d(in_channels= 256, out_channels=256, kernel_size=1, stride=1, padding=0),
        nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
        nn.ReLU()
    )
    self.a3 = nn.Sequential(
        nn.Conv2d(in_channels= 1024, out_channels=256, kernel_size=3, stride=2, dilation=6, padding=6),
        nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
        nn.Conv2d(in_channels= 256, out_channels=256, kernel_size=1, stride=1, padding=0),
        nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
        nn.Conv2d(in_channels= 256, out_channels=256, kernel_size=1, stride=1, padding=0),
        nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
        nn.ReLU()
    )
    self.a4 = nn.Sequential(
        nn.Conv2d(in_channels= 1024, out_channels=256, kernel_size=3, stride=2, dilation=8, padding=8),
        nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
        nn.Conv2d(in_channels= 256, out_channels=256, kernel_size=1, stride=1, padding=0),
        nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
        nn.Conv2d(in_channels= 256, out_channels=256, kernel_size=1, stride=1, padding=0),
        nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
        nn.ReLU()   
    )
    self.a5 = nn.Sequential(
        nn.MaxPool2d(kernel_size=2),
        nn.Conv2d(in_channels= 1024, out_channels=256, kernel_size=1, stride=1, padding=0),
        nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
        nn.ReLU()
    )

    # Concatenate a_i [batch, 2048, hw]
    self.conv1 = nn.Sequential(
        nn.ConvTranspose2d(in_channels=1280, out_channels=1024, kernel_size=3, stride=2, padding=1, output_padding=1),
        nn.BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),        
        nn.ReLU(),
        nn.ConvTranspose2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
        nn.ReLU()
    )

    self.upsample1 = nn.Sequential(
        nn.ConvTranspose2d(in_channels=1536, out_channels=768, kernel_size=3, stride=2, padding=1, output_padding=1),
        nn.BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
        nn.ReLU(),
        nn.Conv2d(in_channels= 768, out_channels=768, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
        nn.ReLU(),       
    )

    self.upsample2 = nn.Sequential(
        nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=3, stride=2, padding=1, output_padding=1),
        nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
        nn.ReLU(),
        nn.Conv2d(in_channels= 512, out_channels=512, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
        nn.ReLU(),
        nn.ConvTranspose2d(in_channels=512, out_channels=512, kernel_size=3, stride=2, padding=1, output_padding=1),
        nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
        nn.ReLU(),
        nn.Conv2d(in_channels= 512, out_channels=256, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
        nn.ReLU(),
        nn.ConvTranspose2d(in_channels=256, out_channels=3, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)         
    )
    
    self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1,1))

    self.fc = nn.Sequential(
        nn.Linear(in_features=1280, out_features=2, bias=True),
        nn.Softmax()
    )    

     

  def forward(self, input):

    print_shape = False #For debugging purposes

    #ENCODER--------------------------------------------------
    if(print_shape): print("Input: ", input.shape)
    l1 = self.layer1(input)
    if(print_shape): print("(l1): ", l1.shape)
    l2 = self.layer2(l1)
    if(print_shape): print("(l2): ", l2.shape)
    l3 = self.layer3(l2)
    l3 = torch.cat((l3, l2), dim=1)
    if(print_shape): print("(l3): ", l3.shape)
    l3 = self.layer4(l3)
    if(print_shape): print("(l4): ", l3.shape)

    #ATROUS SPATIAL PYRAMID POOLING
    a1 = self.a1(l3)
    if(print_shape): print("(a1) Conv1x1: ", a1.shape)
    a2 = self.a2(l3)
    if(print_shape): print("(a2) Dilation=2: ", a2.shape)
    a3 = self.a3(l3)
    if(print_shape): print("(a3) Dilation=4: ", a3.shape)
    a4 = self.a4(l3)
    if(print_shape): print("(a4) Dilation=8: ", a4.shape)
    a5 = self.a5(l3)
    if(print_shape): print("(a5) Max pool: ", a5.shape)

    aspp = torch.cat((a1, a2, a3, a4, a5), dim=1) #ASPP concatenation
    if(print_shape): print("(aspp) Concat aspp: ", aspp.shape)
    
    #DECODER ------------------------------------------------------------
    
    x = self.conv1(aspp)
    if(print_shape): print("")
    if(print_shape): print("(conv1) Conv1x1: ", x.shape)
    

    x = torch.cat((x, l2), dim=1) #Skip connection
    if(print_shape): print("Concat skip conv1 + l2 ", x.shape)

    x = self.upsample1(x)
    if(print_shape): print("(upsample1) up1 ", x.shape)

    x = torch.cat((x, l1), dim=1) #Skip connection
    if(print_shape): print("Concat up1 + l1: ", x.shape)

    x = self.upsample2(x)
    if(print_shape): print("(upsample2) Output: ", x.shape)
    
    #CAM----------------------------------------------------------------

    output = self.avgpool(aspp)
    output = output.view(output.size(0), -1)

    output = self.fc(output)

    return x, aspp, output

### Focal Loss

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, **kwargs):
        super(FocalLoss, self).__init__()
        self.kwargs = kwargs
       
    def forward(self, inputs, targets, smooth=1, alpha=1, gamma=2):
            
        #-----------------------------------------------------------------------
        
        #FOCAL CROSS ENTROPY----------------------------------------------------  
        _, targets = torch.max(targets, dim=1) #[7, h, w] -> [1, h, w]
            
        BCE_loss = nn.CrossEntropyLoss(redution='none')(inputs, targets)

        pt = torch.exp(-BCE_loss)
        F_loss = alpha * (1-pt)**gamma * BCE_loss
        F_loss = torch.mean(F_loss)

        return F_loss

### Model and Parameters

In [None]:
#Arquitecture: 'unet', 'unetV2, 'deeplab', 'deeplab_simplified'
#Loss_fn: 'focal', 'cross entropy', 'cross entropy weighted', 'tversky', 'dice', 'focal tversky', 'focal v2'
arquitecture = 'aspp'
loss_fn = 'combo'
optim_ = 'adam'

#To GPU if available -------------------------------------------------------------
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
use_gpu = torch.cuda.is_available()

#Model-----------------------------------------------------------------------------
 

if(arquitecture == 'aspp'):
  model = net()
  params_e3 = [p for p in model.layer3.parameters() if p.requires_grad] #From scratch
  params_e4 = [p for p in model.layer4.parameters() if p.requires_grad]
  
  params_a1 = [p for p in model.a1.parameters() if p.requires_grad] 
  params_a2 = [p for p in model.a2.parameters() if p.requires_grad]
  params_a3 = [p for p in model.a3.parameters() if p.requires_grad]
  params_a4 = [p for p in model.a4.parameters() if p.requires_grad]
  params_a5 = [p for p in model.a5.parameters() if p.requires_grad]
  params_conv1 = [p for p in model.conv1.parameters() if p.requires_grad] #Decoder
  params_up1 = [p for p in model.upsample1.parameters() if p.requires_grad]
  params_up2 = [p for p in model.upsample2.parameters() if p.requires_grad]
  params_fc = [p for p in model.fc.parameters() if p.requires_grad] #Fully Connected


model.apply(weights_init)
model.to(device)

#Loss--------------------------------------------------------------------------------
criterion = FocalLoss() #Segmentation

weight_class = torch.tensor([0.66, 0.33])
weight_class = weight_class.to(device)
criterion2 = nn.BCELoss(weight=weight_class) #Classification

#Optimizer ------------------------------------------------------------------------------------

optimizer = optim.Adam([{'params': params_e3, 'lr': 5e-5}, #Encoder
                        {'params': params_e4, 'lr': 5e-5},

                        {'params': params_a1, 'lr': 5e-5}, #ASPP
                        {'params': params_a2, 'lr': 5e-5},
                        {'params': params_a3, 'lr': 5e-5},
                        {'params': params_a4, 'lr': 5e-5},
                        {'params': params_a5, 'lr': 5e-5},

                        {'params': params_up1, 'lr': 5e-5}, #Decoder
                        {'params': params_up2, 'lr': 5e-5},
                        {'params': params_fc, 'lr': 1e-4} #Classifier
                        ], lr=1e-3, weight_decay=1e-6)


#Learning rate scheduler---------------------------------------------------------------------------
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.99)

## Get Class Activation Maps

In [None]:
def get_cam(model, feature_map, num_classes=2):
  """
    inputs:
      · model: net() object with the parameters for all layers. Used to recover weights from GAP --> FC layer
      · feature_map: last convolutional volume of encoder, with shape [batch, channels, h,w], where channels is 
        a design parameter. 
        - h,w ~ (16,16) or (8,8), output stride=16 (wrt input size)
      · num_classes: number of classes. *Update needed to be class agnostic. Which variable here contains the num of classes?  
    output: cam #[batch, num_classes, h, w]. Class activation maps for each class and batch
  """
  params = list(model.parameters())
  w = np.squeeze(params[-2].data.cpu().numpy()) #(6, 1536) e.g
  batch, channels, height, width = feature_map.shape
 
  cam = np.zeros((batch, num_classes, height, width))

  for batch in range(batch):
    f_b = feature_map[batch].detach().cpu().numpy() 
    f_b = f_b.reshape(channels, height*width) 

    cam_i = np.dot(w, f_b) #(2, 256)
    cam_i = cam_i.reshape(num_classes, height, width) #(2, 16, 16)

    cam[batch] = cam_i #(batch, 2, 16, 16)

  return cam

In [None]:
def show_cam(cam):
  """
    Visualize the first group of Class Activation Maps in a batch. Plots the class activation maps for the 2 classes: cat, dog.
      input: cam #(batch, 6, h, w)
  """
  batch, classes, h, w = cam.shape
  cam = torch.from_numpy(cam) #to Tensor

  upsample = nn.UpsamplingBilinear2d(scale_factor=16) #Rescale CAMs to the input/output resolution.
  cam = upsample(cam)


  plt.figure(figsize=(8,4))
  plt.subplot(121)
  plt.title("Cat")
  plt.imshow(cam[0, 0], cmap='gnuplot2')
  plt.subplot(122)
  plt.title("Dog")
  plt.imshow(cam[0, 1], cmap='gnuplot2')
  plt.show()
  

## Training

In [None]:
showImg = True 
checkpoint_path = "/content/drive/My Drive/Checkpoints/model.tar"

saveCheckpoint = True #Saves the model / optimizer each epoch
loadCheckpoint = False
#---------------------------------------------------------------------------------
#Loss
train_loss = []
validation_loss = []

#Jaccard Score
iou_train_score = []
iou_val_score = []

#Learning Rate
lrate_backbone = []
lrate_head = []

running_loss = 0.0
running_iou = 0.0

min_loss = 100
iou = 0
j = 0 #Scheduler 
th = 120

if(loadCheckpoint):
  print("Loading checkpoint... | Path: ", checkpoint_path)
  checkpoint = torch.load(checkpoint_path, map_location=device)
  model.load_state_dict(checkpoint['model_state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
 
epochs = 50
for epoch in range(epochs):
    if(epoch > 0): #Guardar modelo al menos en la 1ª época
      if(saveCheckpoint):
          print("Saving checkpoint... | epoch: ", epoch, " | model: ", arquitecture, " | Path: ", checkpoint_path)
          torch.save({'model_state_dict': model.state_dict(),  
                      'optimizer_state_dict': optimizer.state_dict(),
                      'scheduler_state': scheduler1.state_dict()},
                      checkpoint_path)

    # Each epoch has a training and validation phase
    for phase in ['train', 'val']:
        if phase == 'train':
            model.train()  # Set model to training mode
            print("Epoch: ", epoch, "| Training phase |")
        else:
            model.eval()  # Set model to evaluate mode
            print("Epoch: ", epoch, "| Validation phase |")

        for i, data in enumerate(trainloader,0):
          inputs, labels, wl = data
          labels = labels.long()          

          inputs, labels, wl = inputs.to(device), labels.to(device), wl.to(device)
          outputs, feature_map, out_c = model(inputs) #Output mask (b, 7,h,w) 
                                                      #Feature map (b, 1536, h/16, w/16)
                                                      #Output vector (b, 2)

          y_pred = outputs
          
          loss = criterion(y_pred, labels)
          #loss = criterion2(out_c, wl)
 
          _, y_pred = torch.max(y_pred,dim=1) #Undo one-hot        
          __, labels = torch.max(labels, dim=1) #Undo one-hot

          y_pred = y_pred.data.cpu().numpy() #To numpy
          y_true = labels.data.cpu().numpy()

          iou = jaccard_score(y_pred, y_true)
          
          if(i%10==0):
            if(showImg):
              print("")
              #print("Phase: ", phase)
              plt.figure(figsize=(15,6))

              plt.subplot(131)
              plt.title('Input')
              plt.imshow(inputs.cpu().data[0].squeeze().permute(1,2,0))
              
              plt.subplot(132)
              plt.title('Prediction')
              plt.imshow(y_pred[0])
              print("Prediction values: ", np.unique(y_pred[0]))
                            
              plt.subplot(133)
              plt.title("Label")            
              print("Label: ", np.unique(y_true[0]))
              plt.imshow(y_true[0])
                    
              plt.show()

              cam = get_cam(model=model, feature_map=feature_map, num_classes=2) #Get Class Activation Maps
              show_cam(cam) #Plot the CAMs

              out_c0 = out_c[0].detach().cpu().numpy()
              print("Prediction: ", out_c0) #Classification prediction
              print("Label: ", wl[0].cpu().numpy()) #Classification label

         
          optimizer.zero_grad()
          loss.backward()
          optimizer.step()     
          scheduler.step()
          
          if(i%2 == 0):
            train_loss.append(loss.item())
            iou_train_score.append(iou)


          if(i%25 ==0):
            print("Loss: %.4f" % loss.item(), " | Jaccard Score: %.4f" % iou)
            
            plt.figure(figsize=(12,6))
            plt.subplot(121)
            plt.title("Training Loss")
            plt.plot(train_loss, 'r')
            plt.subplot(122)
            plt.title("Training IoU")
            plt.plot(iou_train_score, 'y')
            plt.show()          