# RELAX: Representation Learning Explainability

This notebook contains pre-made code for the explainable machine learning exercises at the 2023 DTU summer school on self-supervised learning and illustrates the usage of RELAX, a framework for explainability in representation learning. RELAX is based on measuring similarities in the representation space between an input and occluded versions of itself. For more information see: https://link.springer.com/article/10.1007/s11263-023-01773-2 (or here https://arxiv.org/abs/2112.10161 if access problems).

You are tasked with implementing the main algorithm of RELAX, and to investigate the representation produced by some well-known feature extractors. Remember to enable GPU support to speed up computation (edit -> notebook settings -> hardware accelerator -> GPU)

In [None]:
#@title Downloading example images from Wikimedia

!wget 'https://upload.wikimedia.org/wikipedia/commons/e/ee/Cat_in_Cat_Caf%C3%A9_Nekokaigi%2C_Tokyo%2C_February_2013.jpg'
!wget 'https://upload.wikimedia.org/wikipedia/commons/thumb/a/ae/Tringa_totanus-pjt.jpg/640px-Tringa_totanus-pjt.jpg'
!wget 'https://upload.wikimedia.org/wikipedia/commons/thumb/e/eb/Two_Cats_in_a_Corner.jpg/640px-Two_Cats_in_a_Corner.jpg'
!wget 'https://upload.wikimedia.org/wikipedia/commons/thumb/2/2f/Kerry_Hill_ewe_and_lamb.jpg/540px-Kerry_Hill_ewe_and_lamb.jpg'



In [72]:
#@title Loading the necessary packages

import torch
import torchvision
import torch.nn as nn
import tqdm.notebook as tqdm
import torch.nn.functional as F
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

from PIL import Image
from matplotlib import pyplot as plt
from torchvision.transforms.functional import pil_to_tensor
from torchvision.transforms import ToTensor, Normalize, Resize


In [73]:
#@title Function for rescaling and displaying images
#@markdown This function is taken from the TorchRay library (https://github.com/facebookresearch/TorchRay),  and is used to rescale and plot the example with their explanation

def imsc(img, *args, quiet=False, lim=None, interpolation='lanczos', **kwargs):
    if isinstance(img, Image.Image):
        img = pil_to_tensor(img)
    handle = None
    with torch.no_grad():
        if not lim:
            lim = [img.min(), img.max()]
        img = img - lim[0]  # also makes a copy
        img.mul_(1 / (lim[1] - lim[0]))
        img = torch.clamp(img, min=0, max=1)
        if not quiet:
            bitmap = img.expand(3,
                                *img.shape[1:]).permute(1, 2, 0).cpu().numpy()
    return bitmap

In [None]:
#@title Function for loading data
#@markdown This function loads the example images downloaded in one of the previous cells.

def load_img(img, shape=224):
    if img == 'Ex1':
        img = Image.open(
            '/content/Cat_in_Cat_Café_Nekokaigi,_Tokyo,_February_2013.jpg'
            )
    elif img == 'Ex2':
        img = Image.open(
            '/content/640px-Tringa_totanus-pjt.jpg'
            )
    elif img == 'Ex3':
        img = Image.open(
            '/content/640px-Two_Cats_in_a_Corner.jpg'
            )
    elif img == 'Ex4':
        img = Image.open(
            '/content/540px-Kerry_Hill_ewe_and_lamb.jpg'
            )
    else:
        print('Incorrect keyword')
        raise

    transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize((shape, shape)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225]),
    ])

    x = transform(img).unsqueeze(0)

    return x.to('cuda')

plt.figure(1)
plt.subplot(141)
plt.imshow(imsc(load_img('Ex1').squeeze(0)))
plt.axis('off')
plt.title('Ex1')
plt.subplot(142)
plt.imshow(imsc(load_img('Ex2').squeeze(0)))
plt.axis('off')
plt.title('Ex2')
plt.subplot(143)
plt.imshow(imsc(load_img('Ex3').squeeze(0)))
plt.axis('off')
plt.title('Ex3')
plt.subplot(144)
plt.imshow(imsc(load_img('Ex4').squeeze(0)))
plt.axis('off')
plt.title('Ex4')
plt.tight_layout()
plt.show()

In [75]:
#@title Load feature extractors
#@markdown This cell contain the functions for loading the feature extractors used
#@markdown in this notebook. An image is represented by the output of the
#@markdown adaptive pooling layer of a ResNet50 and the Alexnet.

def load_resnet50():
  resnet50 = torchvision.models.resnet50(weights="DEFAULT")
  modules = list(resnet50.children())[:-1]
  encoder = nn.Sequential(*modules, nn.Flatten()).to('cuda')
  encoder.eval()
  return encoder

def load_alexnet():
  alexnet = torchvision.models.alexnet(weights="DEFAULT")
  encoder = nn.Sequential(alexnet.features,
                          nn.AdaptiveAvgPool2d(output_size=(1, 1)),
                          nn.Flatten()).to('cuda')
  encoder.eval()
  return encoder


In [None]:
#@title Generator for masks
#@markdown This cell contains a generator that will generate random masks following the procdure
#@markdown developed for the RISE algorithm (https://arxiv.org/abs/1806.07421)
#@markdown and used in RELAX. It also contains an example of how to use it.

def MaskGenerator(num_batches, mask_bs, inp_shape=224, num_cells=7, p=0.5, nsd=2, dev='cuda'):
    for batch_i in range(num_batches):

        pad_size = (num_cells // 2, num_cells // 2, num_cells // 2, num_cells // 2)
        grid = (torch.rand(mask_bs, 1, *((num_cells,) * nsd), device=dev) < p).float()

        grid_up = F.interpolate(grid, size=(inp_shape), mode='bilinear', align_corners=False)
        grid_up = F.pad(grid_up, pad_size, mode='reflect')

        shift_x = torch.randint(0, num_cells, (mask_bs,), device='cpu')
        shift_y = torch.randint(0, num_cells, (mask_bs,), device='cpu')

        masks = torch.empty((mask_bs, 1, inp_shape, inp_shape), device=dev)

        for i in range(mask_bs):
            masks[i] = grid_up[i, :,
                               shift_x[i]:shift_x[i] + inp_shape,
                               shift_y[i]:shift_y[i] + inp_shape]

        yield masks

example_image_1 = load_img('Ex1')

plt.figure(1)
for batch_i, mask in enumerate(MaskGenerator(num_batches=2, mask_bs=3)):
  print(f"Batch {batch_i+1} of masked images.")
  example_image_1_masked = example_image_1*mask
  plt.subplot(1,3,1)
  plt.imshow(imsc(example_image_1_masked[0]))
  plt.axis('off')
  plt.subplot(1,3,2)
  plt.imshow(imsc(example_image_1_masked[1]))
  plt.axis('off')
  plt.subplot(1,3,3)
  plt.imshow(imsc(example_image_1_masked[2]))
  plt.axis('off')
  plt.tight_layout()
  plt.show()


In [None]:
#@title Implement RELAX


