# DL Project 2023/24

## Introduction

Description of the method choosen and the work done

## Install packages

In [None]:
!pip install torch torchvision matplotlib tqdm Pillow numpy --quiet

## Import packages

In [None]:
# Import PyTorch and related modules
import torch
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torch.utils.tensorboard import SummaryWriter

# Import torchvision and related modules
import torchvision
from torchvision.models import resnet50, ResNet50_Weights
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF

# Import other libraries
import matplotlib.pyplot as plt
import os
import tarfile
import shutil
import re
from tqdm import tqdm
from PIL import Image, ImageOps, ImageEnhance
import numpy as np
import copy
import random
import pandas as pd

import torch
from torch.utils.data import DataLoader, Dataset
import multiprocessing
import torch.nn as nn
import math

import numpy as np
import torch.multiprocessing as mp

In [None]:
# Check if running in Google Colab
try:
    from google.colab import drive
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

if IN_COLAB:
    drive.mount('/content/drive')
    tar_file = "/content/drive/MyDrive/DL_project/imagenet-a.tar"
else:
    # Set the path for Jupyter 
    tar_file = "./imagenet-a.tar"

data_folder = "imagenet-a"

## Reading Data

In [None]:
# function to untar the dataset and store it in a new folder
def extract_dataset(compress_file, destination_folder):
  # function to change dir names to their words description
  def change_folders_names(readme_file, dataset_root):
    with open(readme_file, 'r') as f:
        lines = f.readlines()
        for line in lines:
            # Match lines containing WordNet IDs and descriptions
            match = re.match(r'n\d+ (.+)', line)
            if match:
                # Split the line into WordNet ID and description
                parts = match.group(0).split()
                wordnet_id = parts[0]
                description = ' '.join(parts[1:])
                os.rename(os.path.join(dataset_root, wordnet_id),
                            os.path.join(dataset_root, description))

  if not os.path.exists(compress_file):
    print("Compress file doesn't exist.")
    return

  if os.path.exists(destination_folder):
    # remove the folder if already exists one
    shutil.rmtree(destination_folder)

  # extract content from the .tar file
  with tarfile.open(compress_file, 'r') as tar_ref:
    tar_ref.extractall("./")
  print("All the data is extracted.")

  change_folders_names(destination_folder+"/README.txt", destination_folder)

extract_dataset(tar_file, data_folder)

In [None]:
ids_list = os.listdir(data_folder)
len(ids_list) # 200 folders + 1 readme

In [None]:
def get_data(batch_size, dataset_path, transform):
    # Load the entire dataset
    data = torchvision.datasets.ImageFolder(root=dataset_path, transform=transform)
    
    # Get the class labels
    class_labels = data.classes
    print(f"The dataset contains {len(data)} images.")
    print(f"The dataset contains {len(class_labels)} labels.")
    
    # Create a subset with only the first 1000 images
    subset_indices = list(range(100))  # Indices of the first 1000 images
    data_subset = torch.utils.data.Subset(data, subset_indices)
    
    # Create a DataLoader for the subset
    test_loader = torch.utils.data.DataLoader(data_subset, batch_size=batch_size, shuffle=False, num_workers=multiprocessing.cpu_count())
    
    return test_loader, class_labels

In [None]:
# function that returns a DataLoader for the dataset
def get_data(batch_size, dataset_path, transform):

  data = torchvision.datasets.ImageFolder(root=dataset_path, transform=transform)

  class_labels = data.classes
  print(f"The dataset contains {len(data)} images.")
  print(f"The dataset contains {len(class_labels)} labels.")

  test_loader = torch.utils.data.DataLoader(data, batch_size, shuffle=False, num_workers=multiprocessing.cpu_count())

  return test_loader, class_labels

In [None]:
# function to display images from the DataLoader
def show_images(dataloader, num_images=5):
  # get a batch of data
  data_iter = iter(dataloader)
  images, labels = next(data_iter)

  # convert images to numpy array
  images = images.numpy()

  # display images
  fig, axes = plt.subplots(1, num_images, figsize=(15, 3))
  for i in range(num_images):
      image = np.transpose(images[i], (1, 2, 0))  # move channels in last position
      image = np.clip(image, 0, 1)
      axes[i].imshow(image)
      axes[i].axis('off')
      axes[i].set_title(dataloader.dataset.classes[labels[i]])
  plt.show()

## RIMs

In [None]:
class blocked_grad(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, mask):
        ctx.save_for_backward(x, mask)
        return x

    @staticmethod
    def backward(ctx, grad_output):
        x, mask = ctx.saved_tensors
        return grad_output * mask, mask * 0.0

class GroupLinearLayer(nn.Module):
    def __init__(self, din, dout, num_blocks):
        super(GroupLinearLayer, self).__init__()
        self.w = nn.Parameter(0.01 * torch.randn(num_blocks, din, dout))

    def forward(self, x):
        x = x.permute(1, 0, 2)
        x = torch.bmm(x, self.w)
        return x.permute(1, 0, 2)

class GroupLSTMCell(nn.Module):
    """
    GroupLSTMCell can compute the operation of N LSTM Cells at once.
    """
    def __init__(self, inp_size, hidden_size, num_lstms):
        super().__init__()
        self.inp_size = inp_size
        self.hidden_size = hidden_size
        self.i2h = GroupLinearLayer(inp_size, 4 * hidden_size, num_lstms)
        self.h2h = GroupLinearLayer(hidden_size, 4 * hidden_size, num_lstms)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)

    def forward(self, x, hid_state):
        """
        input: x (batch_size, num_lstms, input_size)
               hid_state (tuple of length 2 with each element of size (batch_size, num_lstms, hidden_state))
        output: h (batch_size, num_lstms, hidden_state)
                c ((batch_size, num_lstms, hidden_state))
        """
        h, c = hid_state
        preact = self.i2h(x) + self.h2h(h)
        gates = preact[:, :, :3 * self.hidden_size].sigmoid()
        g_t = preact[:, :, 3 * self.hidden_size:].tanh()
        i_t = gates[:, :, :self.hidden_size]
        f_t = gates[:, :, self.hidden_size:2 * self.hidden_size]
        o_t = gates[:, :, -self.hidden_size:]
        c_t = torch.mul(c, f_t) + torch.mul(i_t, g_t)
        h_t = torch.mul(o_t, c_t.tanh())
        return h_t, c_t

class GroupGRUCell(nn.Module):
    """
    GroupGRUCell can compute the operation of N GRU Cells at once.
    """
    def __init__(self, input_size, hidden_size, num_grus):
        super(GroupGRUCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.x2h = GroupLinearLayer(input_size, 3 * hidden_size, num_grus)
        self.h2h = GroupLinearLayer(hidden_size, 3 * hidden_size, num_grus)
        self.reset_parameters()

    def reset_parameters(self):
        std = 1.0 / math.sqrt(self.hidden_size)
        for w in self.parameters():
            w.data = torch.ones(w.data.size())

    def forward(self, x, hidden):
        """
        input: x (batch_size, num_grus, input_size)
               hidden (batch_size, num_grus, hidden_size)
        output: hidden (batch_size, num_grus, hidden_size)
        """
        gate_x = self.x2h(x)
        gate_h = self.h2h(hidden)
        i_r, i_i, i_n = gate_x.chunk(3, 2)
        h_r, h_i, h_n = gate_h.chunk(3, 2)
        resetgate = torch.sigmoid(i_r + h_r)
        inputgate = torch.sigmoid(i_i + h_i)
        newgate = torch.tanh(i_n + (resetgate * h_n))
        hy = newgate + inputgate * (hidden - newgate)
        return hy

class RIMCell(nn.Module):
    def __init__(self, 
        device, input_size, hidden_size, num_units, k, rnn_cell, input_key_size = 64, input_value_size = 400, input_query_size = 64,
        num_input_heads = 1, input_dropout = 0.1, comm_key_size = 32, comm_value_size = 100, comm_query_size = 32, num_comm_heads = 4, comm_dropout = 0.1
    ):
        super().__init__()
        if comm_value_size != hidden_size:
            comm_value_size = hidden_size
        self.device = device
        self.hidden_size = hidden_size
        self.num_units = num_units
        self.rnn_cell = rnn_cell
        self.key_size = input_key_size
        self.k = k
        self.num_input_heads = num_input_heads
        self.num_comm_heads = num_comm_heads
        self.input_key_size = input_key_size
        self.input_query_size = input_query_size
        self.input_value_size = input_value_size
        self.comm_key_size = comm_key_size
        self.comm_query_size = comm_query_size
        self.comm_value_size = comm_value_size
        self.key = nn.Linear(input_size, num_input_heads * input_query_size).to(self.device)
        self.value = nn.Linear(input_size, num_input_heads * input_value_size).to(self.device)
        if self.rnn_cell == 'GRU':
            self.rnn = GroupGRUCell(input_value_size, hidden_size, num_units)
            self.query = GroupLinearLayer(hidden_size, input_key_size * num_input_heads, self.num_units)
        else:
            self.rnn = GroupLSTMCell(input_value_size, hidden_size, num_units)
            self.query = GroupLinearLayer(hidden_size, input_key_size * num_input_heads, self.num_units)
        self.query_ = GroupLinearLayer(hidden_size, comm_query_size * num_comm_heads, self.num_units)
        self.key_ = GroupLinearLayer(hidden_size, comm_key_size * num_comm_heads, self.num_units)
        self.value_ = GroupLinearLayer(hidden_size, comm_value_size * num_comm_heads, self.num_units)
        self.comm_attention_output = GroupLinearLayer(num_comm_heads * comm_value_size, comm_value_size, self.num_units)
        self.comm_dropout = nn.Dropout(p=input_dropout)
        self.input_dropout = nn.Dropout(p=comm_dropout)

    def transpose_for_scores(self, x, num_attention_heads, attention_head_size):
        new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def input_attention_mask(self, x, h):
        """
        Input : x (batch_size, 2, input_size) [The null input is appended along the first dimension]
                h (batch_size, num_units, hidden_size)
        Output: inputs (list of size num_units with each element of shape (batch_size, input_value_size))
                mask_ binary array of shape (batch_size, num_units) where 1 indicates active and 0 indicates inactive
        """
        key_layer = self.key(x)
        value_layer = self.value(x)
        query_layer = self.query(h)
        key_layer = self.transpose_for_scores(key_layer, self.num_input_heads, self.input_key_size)
        value_layer = torch.mean(self.transpose_for_scores(value_layer, self.num_input_heads, self.input_value_size), dim=1)
        query_layer = self.transpose_for_scores(query_layer, self.num_input_heads, self.input_query_size)
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) / math.sqrt(self.input_key_size)
        attention_scores = torch.mean(attention_scores, dim=1)
        mask_ = torch.zeros(x.size(0), self.num_units).to(self.device)
        not_null_scores = attention_scores[:, :, 0]
        topk1 = torch.topk(not_null_scores, self.k, dim=1)
        row_index = np.arange(x.size(0))
        row_index = np.repeat(row_index, self.k)
        mask_[row_index, topk1.indices.view(-1)] = 1
        attention_probs = self.input_dropout(nn.Softmax(dim=-1)(attention_scores))
        inputs = torch.matmul(attention_probs, value_layer) * mask_.unsqueeze(2)
        return inputs, mask_

    def communication_attention(self, h, mask):
        """
        Input : h (batch_size, num_units, hidden_size)
                mask obtained from the input_attention_mask() function
        Output: context_layer (batch_size, num_units, hidden_size). New hidden states after communication
        """
        query_layer = []
        key_layer = []
        value_layer = []
        query_layer = self.query_(h)
        key_layer = self.key_(h)
        value_layer = self.value_(h)
        query_layer = self.transpose_for_scores(query_layer, self.num_comm_heads, self.comm_query_size)
        key_layer = self.transpose_for_scores(key_layer, self.num_comm_heads, self.comm_key_size)
        value_layer = self.transpose_for_scores(value_layer, self.num_comm_heads, self.comm_value_size)
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.comm_key_size)
        attention_probs = nn.Softmax(dim=-1)(attention_scores)
        mask = [mask for _ in range(attention_probs.size(1))]
        mask = torch.stack(mask, dim=1)
        attention_probs = attention_probs * mask.unsqueeze(3)
        attention_probs = self.comm_dropout(attention_probs)
        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.num_comm_heads * self.comm_value_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        context_layer = self.comm_attention_output(context_layer)
        context_layer = context_layer + h
        return context_layer

    def forward(self, x, hs, cs=None):
        """
        Input : x (batch_size, 1 , input_size)
                hs (batch_size, num_units, hidden_size)
                cs (batch_size, num_units, hidden_size)
        Output: new hs, cs for LSTM
                new hs for GRU
        """
        size = x.size()
        null_input = torch.zeros(size[0], 1, size[2]).float().to(self.device)
        x = torch.cat((x, null_input), dim=1)
        inputs, mask = self.input_attention_mask(x, hs)
        h_old = hs * 1.0
        if cs is not None:
            c_old = cs * 1.0
        if cs is not None:
            hs, cs = self.rnn(inputs, (hs, cs))
        else:
            hs = self.rnn(inputs, hs)
        mask = mask.unsqueeze(2)
        h_new = blocked_grad.apply(hs, mask)
        h_new = self.communication_attention(h_new, mask.squeeze(2))
        hs = mask * h_new + (1 - mask) * h_old
        if cs is not None:
            cs = mask * cs + (1 - mask) * c_old
            return hs, cs
        return hs, None

class RIM(nn.Module):
    def __init__(self, device, input_size, hidden_size, num_units, k, rnn_cell, n_layers, bidirectional, **kwargs):
        super().__init__()
        if device == 'cuda':
            self.device = torch.device('cuda')
        else:
            self.device = torch.device('cpu')
        self.n_layers = n_layers
        self.num_directions = 2 if bidirectional else 1
        self.rnn_cell = rnn_cell
        self.num_units = num_units
        self.hidden_size = hidden_size
        if self.num_directions == 2:
            self.rimcell = nn.ModuleList([RIMCell(self.device, input_size, hidden_size, num_units, k, rnn_cell, **kwargs).to(self.device) if i < 2 else
                                          RIMCell(self.device, 2 * hidden_size * self.num_units, hidden_size, num_units, k, rnn_cell, **kwargs).to(self.device) for i in range(self.n_layers * self.num_directions)])
        else:
            self.rimcell = nn.ModuleList([RIMCell(self.device, input_size, hidden_size, num_units, k, rnn_cell, **kwargs).to(self.device) if i == 0 else
                                          RIMCell(self.device, hidden_size * self.num_units, hidden_size, num_units, k, rnn_cell, **kwargs).to(self.device) for i in range(self.n_layers)])

    def layer(self, rim_layer, x, h, c=None, direction=0):
        batch_size = x.size(1)

        # Debugging: Print input shapes
        #print(f"layer - Initial x shape: {x.shape}")
        #print(f"layer - Initial h shape: {h.shape}")
        #if c is not None:
            #print(f"layer - Initial c shape: {c.shape}")

        xs = list(torch.split(x, 1, dim=0))
        if direction == 1:
            xs.reverse()

        hs = h.squeeze(0).view(batch_size, self.num_units, -1)

        # Debugging: Print reshaped hs shape
        #print(f"layer - Reshaped hs shape: {hs.shape}")

        cs = None
        if c is not None:
            cs = c.squeeze(0).view(batch_size, self.num_units, -1)
            # Debugging: Print reshaped cs shape
            #print(f"layer - Reshaped cs shape: {cs.shape}")

        outputs = []
        for x in xs:
            x = x.squeeze(0)

            # Debugging: Print shape before rim_layer call
            #print(f"layer - x shape before rim_layer: {x.shape}")
            #print(f"layer - hs shape before rim_layer: {hs.shape}")
            #if cs is not None:
                #print(f"layer - cs shape before rim_layer: {cs.shape}")

            hs, cs = rim_layer(x.unsqueeze(1), hs, cs)

            # Debugging: Print shape after rim_layer call
            #print(f"layer - hs shape after rim_layer: {hs.shape}")
            #if cs is not None:
                #print(f"layer - cs shape after rim_layer: {cs.shape}")

            outputs.append(hs.view(1, batch_size, -1))

        if direction == 1:
            outputs.reverse()

        outputs = torch.cat(outputs, dim=0)

        # Debugging: Print final outputs shape
        #print(f"layer - Final outputs shape: {outputs.shape}")

        if c is not None:
            return outputs, hs.view(batch_size, -1), cs.view(batch_size, -1)
        else:
            return outputs, hs.view(batch_size, -1)

    def forward(self, x, h=None, c=None):
        """
        Input: x (seq_len, batch_size, feature_size)
               h (num_layers * num_directions, batch_size, hidden_size * num_units)
               c (num_layers * num_directions, batch_size, hidden_size * num_units)
        Output: outputs (batch_size, seqlen, hidden_size * num_units * num-directions)
                h(and c) (num_layer * num_directions, batch_size, hidden_size* num_units)
        """

        # Flatten the input image to match RIM input shape
        # x = x.view(x.size(0), x.size(1), -1)  # (1, 3, 224*224) -> (1, 1, 3*224*224)

        hs = torch.split(h, 1, 0) if h is not None else torch.split(torch.randn(self.n_layers * self.num_directions, x.size(1), self.hidden_size * self.num_units).to(self.device), 1, 0)
        hs = list(hs)
        cs = None
        if self.rnn_cell == 'LSTM':
            cs = torch.split(c, 1, 0) if c is not None else torch.split(torch.randn(self.n_layers * self.num_directions, x.size(1), self.hidden_size * self.num_units).to(self.device), 1, 0)
            cs = list(cs)
        for n in range(self.n_layers):
            idx = n * self.num_directions
            if cs is not None:
                x_fw, hs[idx], cs[idx] = self.layer(self.rimcell[idx], x, hs[idx], cs[idx])
            else:
                x_fw, hs[idx] = self.layer(self.rimcell[idx], x, hs[idx], c=None)
            if self.num_directions == 2:
                idx = n * self.num_directions + 1
                if cs is not None:
                    x_bw, hs[idx], cs[idx] = self.layer(self.rimcell[idx], x, hs[idx], cs[idx], direction=1)
                else:
                    x_bw, hs[idx] = self.layer(self.rimcell[idx], x, hs[idx], c=None, direction=1)
                x = torch.cat((x_fw, x_bw), dim=2)
            else:
                x = x_fw
        hs = torch.stack(hs, dim=0)
        if cs is not None:
            cs = torch.stack(cs, dim=0)
            return x, hs, cs
        return x, hs

## MEMO

In [None]:
# define some image augmentations

def vertical_flip(img):
    img = TF.to_pil_image(img)
    res = img.transpose(Image.FLIP_TOP_BOTTOM)
    return TF.to_tensor(res)

def brightness(img, factor_range=(0.5, 1.5)):
  img = TF.to_pil_image(img)
  factor = np.random.uniform(factor_range[0], factor_range[1])
  enhancer = ImageEnhance.Brightness(img)
  res = enhancer.enhance(factor)
  return TF.to_tensor(res)

'''
def rotation(img, angle_range=(-45, 45)):
  angle = np.random.uniform(angle_range[0], angle_range[1])
  return img.rotate(angle)
'''

def color(img, factor_range=(0.5, 1.5)):
  img = TF.to_pil_image(img)
  factor = np.random.uniform(factor_range[0], factor_range[1])
  enhancer = ImageEnhance.Color(img)
  res = enhancer.enhance(factor)
  return TF.to_tensor(res)

def sharpness(img, factor_range=(0.5, 1.5)):
  img = TF.to_pil_image(img)
  factor = np.random.uniform(factor_range[0], factor_range[1])
  enhancer = ImageEnhance.Sharpness(img)
  res = enhancer.enhance(factor)
  return TF.to_tensor(res)

augmentations = [vertical_flip, brightness, color, sharpness]

In [None]:
def augment_image_rims(img, augmentations, rims, B=15):
    assert len(augmentations) > 0, "No augmentations provided."
    images = [img]
    aug_names = []  # List to store augmentation names
    aug_factors = []  # List to store augmentation factors

    device = rims.device  # Ensure device is defined
    if isinstance(img, torch.Tensor):
        img_tensor = img.unsqueeze(0).to(device)
    else:
        img_tensor = TF.to_tensor(img).unsqueeze(0).to(device)

    batch_size = img_tensor.size(0)
    input_size = img_tensor.size(1) * img_tensor.size(2) * img_tensor.size(3)
    img_tensor = img_tensor.view(1, 1, input_size)

    h = torch.zeros(rims.n_layers * rims.num_directions, batch_size, rims.hidden_size * rims.num_units).to(device)
    c = torch.zeros(rims.n_layers * rims.num_directions, batch_size, rims.hidden_size * rims.num_units).to(device)

    linear = nn.Linear(rims.hidden_size * rims.num_units, len(augmentations)).to(device)

    for i in range(B):
        outputs, h, c = rims(img_tensor, h, c)

        preds = linear(outputs)
        attention_weights = torch.softmax(preds, dim=-1)
        hyperparameters = torch.tanh(outputs).mean(dim=-1)

        aug_idx = torch.multinomial(attention_weights.squeeze(), 1).item()
        augmentation = augmentations[aug_idx]

        if augmentation.__name__ in ['brightness', 'color', 'sharpness']:
            factor = (hyperparameters[0].item() + 1) / 2
            diverse_factor = random.uniform(0.5, 1.5)  # Generate a random factor between 0.5 and 1.5 for diversity
            augmented_img = augmentation(img, factor_range=(factor * diverse_factor, factor * diverse_factor))
            aug_factors.append(factor * diverse_factor)
        else:
            augmented_img = augmentation(img)
            aug_factors.append(None)

        images.append(augmented_img)
        aug_names.append(augmentation.__name__)  # Store the augmentation name

        # Print the image every 10 iterations
        #if i % 10 == 0:
            #fig, axs = plt.subplots(1, 2, figsize=(10, 5))
            #axs[0].imshow(TF.to_pil_image(img))
            #axs[0].set_title("Original Image")
            #axs[0].axis('off')
            #axs[1].imshow(TF.to_pil_image(augmented_img))
            #axs[1].set_title(f"Iteration {i} - Augmented: {augmentation.__name__}")
            #axs[1].axis('off')
            #plt.show()

    # Create a summary table
    df_aug_summary = pd.DataFrame({'Augmentation': aug_names, 'Factor': aug_factors})

    return images, df_aug_summary

In [None]:
def augment_image(img, augmentations, B=15):
    assert len(augmentations) > 0, "There are no augmentations provided."

    images = [img]
    aug_names = []  # List to store augmentation names
    aug_factors = []  # List to store augmentation factors

    for i in range(B):
        # Randomly choose an augmentation from the augmentation functions
        index = random.randrange(0, len(augmentations))
        augmentation = augmentations[index]
        
        if augmentation.__name__ in ['brightness', 'color', 'sharpness']:
            factor = random.uniform(0.5, 1.5)  # Generate a random factor between 0.5 and 1.5
            augmented_img = augmentation(img, factor_range=(factor, factor))
            aug_factors.append(factor)
        else:
            augmented_img = augmentation(img)
            aug_factors.append(None)
        
        # Add the augmented image to the list of images to evaluate
        images.append(augmented_img)
        aug_names.append(augmentation.__name__)  # Store the augmentation name

        # Print the image every 10 iterations
        #if i % 10 == 0 and i != 0:
            #plt.imshow(augmented_img.permute(1, 2, 0))
            #plt.title(f"Iteration {i} - Augmentation: {augmentation.__name__}")
            #plt.show()

    # Create a summary table
    df_aug_summary = pd.DataFrame({'Augmentation': aug_names, 'Factor': aug_factors})
    print("\nAugmentation Summary Table:")
    print(df_aug_summary)

    print("Number of images:", len(images))

    return images, df_aug_summary

In [None]:
# define the cost function used to evaluate the model output
def get_cost_function():
  cost_function = torch.nn.CrossEntropyLoss()
  return cost_function

In [None]:
# define the optimizer
def get_optimizer(net, lr, wd, momentum):
    optimizer = torch.optim.SGD(net.parameters(), lr=lr, weight_decay=wd, momentum=momentum)
    return optimizer

In [None]:
# compute the marginal output distribution
def marginal_distribution(images, model, transforms, device):
  # collect the prediction for every image in input
  img_results = []
  for img in images:
    single_batch = transforms(img).unsqueeze(0).to(device)
    prediction = model(single_batch).squeeze(0).softmax(0)
    img_results.append(prediction)

  # sum all the resulting tensors
  sum_results = torch.sum(torch.stack(img_results), dim=0).to(device)
  # divide each element by B to obtain the marginal output distribution
  num_images = len(images)
  res = torch.div(sum_results, num_images).to(device)
  return res

In [None]:
# compute the marginal cross entropy
def marginal_cross_entropy(marginal_dist, labels, cost_function):
  entropy = 0.0
  # sum all entropies for the different labels since I don't know the real one
  for label in labels:
    entropy += cost_function(marginal_dist, label)
  return entropy


def marginal_cross_entropy(marginal_dist, labels, cost_function):
    entropy = 0.0
    for label in labels:
        entropy += cost_function(marginal_dist, label)
    return entropy


In [None]:
import copy
import torch
import torch.nn.functional as F
from torchvision.models import resnet50, ResNet50_Weights
from tqdm import tqdm
import pandas as pd

def ttr_MEMO_with_rims(model, test_sample, labels, B, cost_function, optimizer, transforms, device, rims):
    original_params = copy.deepcopy(model.state_dict())

    with torch.enable_grad():
        augmented_images, summary_table = augment_image_rims(test_sample, augmentations, rims, B)
        marginal_dist = marginal_distribution(augmented_images, model, transforms, device)
        
        loss = marginal_cross_entropy(marginal_dist, labels, cost_function)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    test_sample = transforms(test_sample).unsqueeze(0).to(device)
    output = model(test_sample).squeeze(0).softmax(0)
    model.load_state_dict(original_params)
    
    return output, summary_table


def test_with_rims(model, data_loader, B, cost_function, optimizer, transforms, device, rims):
    samples = 0.0
    cumulative_loss = 0.0
    cumulative_accuracy = 0.0
    all_summary_tables = []

    model.eval()

    with torch.no_grad():
        for batch_idx, (inputs, targets) in tqdm(enumerate(data_loader), total=len(data_loader), desc="Testing", leave=False):
            inputs = inputs.to(device)
            targets = targets.to(device)
            
            batch_size = inputs.size(0)
            intermediate_outputs = []
            for input in inputs:
                output, summary_table = ttr_MEMO_with_rims(model, input, targets, B, cost_function, optimizer, transforms, device, rims)
                intermediate_outputs.append(output)
                all_summary_tables.append(summary_table)  # Collect summary table for each test sample

            outputs = torch.stack(intermediate_outputs).to(device)
            loss = cost_function(outputs, targets)

            samples += inputs.shape[0]
            cumulative_loss += loss.item()
            _, predicted = outputs.max(1)
            cumulative_accuracy += predicted.eq(targets).sum().item()

    combined_summary_table = pd.concat(all_summary_tables, ignore_index=True)  # Concatenate all summary tables

    return cumulative_loss / samples, cumulative_accuracy / samples * 100, combined_summary_table

def main_with_rims(
    run_name,
    batch_size=32,  # Adjusted batch size
    device="cuda",
    learning_rate=0.0001,  # Reduced learning rate
    weight_decay=0.00001,  # Adjusted weight decay
    momentum=0.9,
    num_augmentations=15
):
    device = torch.device(device)
    weights = ResNet50_Weights.DEFAULT
    model = resnet50(weights=weights).to(device)
    preprocess = weights.transforms()

    # Initialize the test dataloader
    test_loader, _ = get_data(batch_size, data_folder, preprocess)

    # Initialize the optimizer
    optimizer = get_optimizer(model, learning_rate, weight_decay, momentum)

    # Initialize the cost function
    cost_function = get_cost_function()

    # Get a sample image to determine the input dimension
    sample_img, _ = next(iter(test_loader))
    sample_img = sample_img[0].unsqueeze(0)  # Take one sample and add batch dimension
    sample_img = preprocess(sample_img)
    input_dim = sample_img.size(1) * sample_img.size(2) * sample_img.size(3)  # Channels * Height * Width

    hidden_dim = 128
    num_units = 4
    k = 2

    rims = RIM('cuda', input_dim, hidden_dim, num_units, k, rnn_cell='LSTM', n_layers=2, bidirectional=False).to(device)
    print(f"Initialized RIMs with input_dim={input_dim}, hidden_dim={hidden_dim}, num_units={num_units}, k={k}")

    # Run the test
    test_loss, test_accuracy, combined_summary_table = test_with_rims(model, test_loader, num_augmentations, cost_function, optimizer, preprocess, device, rims)
    print(f"\tTest loss {test_loss:.5f}, Test accuracy {test_accuracy:.2f}")

    # Print the combined summary table
    print("\nCombined Augmentation Summary Table:")
    print(combined_summary_table)

    # Calculate statistics
    aug_counts = combined_summary_table['Augmentation'].value_counts().reset_index()
    aug_counts.columns = ['Augmentation', 'Count']

    aug_factors = combined_summary_table.groupby('Augmentation')['Factor'].mean().reset_index()
    aug_factors.columns = ['Augmentation', 'Average Factor']

    final_summary = pd.merge(aug_counts, aug_factors, on='Augmentation', how='left')

    print("\nFinal Summary Table:")
    print(final_summary)

main_with_rims("resnet_MEMO")

In [None]:
# test time robustness via MEMO algorithm
def ttr_MEMO(model, test_sample, label, B, cost_function, optimizer, transforms, device):
    # Save the original model weights
    original_params = copy.deepcopy(model.state_dict())

    with torch.enable_grad():
        # Get the B + 1 images
        augmented_images, summary_table = augment_image(test_sample, augmentations, B)

        # Get the marginal output distribution
        marginal_dist = marginal_distribution(augmented_images, model, transforms, device)

        # Update the model weights
        loss = cost_function(marginal_dist, label)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    test_sample = transforms(test_sample).unsqueeze(0).to(device)
    output = model(test_sample).squeeze(0).softmax(0)

    # Reapply original weights to the model
    model.load_state_dict(original_params)

    return output, summary_table

def test(model, data_loader, B, cost_function, optimizer, transforms, device="cuda"):
    samples = 0.0
    cumulative_loss = 0.0
    cumulative_accuracy = 0.0
    all_summary_tables = []

    # Set the network to evaluation mode
    model.eval()

    # Disable gradient computation for testing mode
    with torch.no_grad():
        # Iterate over the test set with tqdm for progress tracking
        for batch_idx, (inputs, targets) in tqdm(enumerate(data_loader), total=len(data_loader), desc="Testing", leave=False):
            # Load data into GPU
            inputs = inputs.to(device)
            targets = targets.to(device)

            # Forward pass
            batch_size = inputs.size(0)
            num_labels = 1000  # 1000 is the ImageNet number of labels

            # Apply MEMO to each test point in the batch
            intermediate_outputs = []
            for input, target in zip(inputs, targets):
                output, summary_table = ttr_MEMO(model, input, target, B, cost_function, optimizer, transforms, device)
                intermediate_outputs.append(output)
                all_summary_tables.append(summary_table)  # Collect summary table for each test sample

            outputs = torch.stack(intermediate_outputs).to(device)

            # Loss computation
            loss = cost_function(outputs, targets)

            # Fetch prediction and loss value
            samples += inputs.shape[0]
            cumulative_loss += loss.item()  # Note: the .item() is needed to extract scalars from tensors
            _, predicted = outputs.max(1)

            # Compute accuracy
            cumulative_accuracy += predicted.eq(targets).sum().item()

    combined_summary_table = pd.concat(all_summary_tables, ignore_index=True)  # Concatenate all summary tables

    return cumulative_loss / samples, cumulative_accuracy / samples * 100, combined_summary_table

def main(
    run_name,
    batch_size = 32,
    device = "cuda",
    learning_rate=0.001,
    weight_decay=0.000001,
    momentum=0.9,
    num_augmentations = 15
):
    device = torch.device(device)

    # Initialize the ResNet model
    weights = ResNet50_Weights.DEFAULT
    model = resnet50(weights=weights).to(device)

    # Initialize the inference transforms
    preprocess = weights.transforms()

    # Initialize the test dataloader
    test_loader, _ = get_data(batch_size, data_folder, preprocess)

    # Initialize the optimizer
    optimizer = get_optimizer(model, learning_rate, weight_decay, momentum)

    # Initialize the cost function
    cost_function = get_cost_function()

    test_loss, test_accuracy, combined_summary_table = test(model, test_loader, num_augmentations, cost_function, optimizer, preprocess, device)
    print(f"\tTest loss {test_loss:.5f}, Test accuracy {test_accuracy:.2f}")

    # Print the combined summary table
    print("\nCombined Augmentation Summary Table:")
    print(combined_summary_table)

    # Calculate statistics
    aug_counts = combined_summary_table['Augmentation'].value_counts().reset_index()
    aug_counts.columns = ['Augmentation', 'Count']

    aug_factors = combined_summary_table.groupby('Augmentation')['Factor'].mean().reset_index()
    aug_factors.columns = ['Augmentation', 'Average Factor']

    final_summary = pd.merge(aug_counts, aug_factors, on='Augmentation', how='left')

    print("\nFinal Summary Table:")
    print(final_summary)

main("resnet_MEMO")