In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Installations (uncomment if needed)
!pip install git+https://github.com/openai/CLIP.git
!pip install datasets
!pip install captum
!pip install tqdm
!pip install torchcam

# System and OS
import os
import os.path as osp
from collections import OrderedDict

# Basic Libraries
import json
import math
import random
import numpy as np
import cv2
import copy
from tqdm import tqdm

# PyTorch related
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
from torch.cuda.amp import GradScaler, autocast
import torchvision.models as models
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from torchvision.transforms.functional import to_pil_image
from torch.optim.lr_scheduler import CosineAnnealingLR

# CLIP and related libraries
from clip import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer

# Model and Data Processing
from transformers import AlignTextModel, AlignProcessor, AlignModel
from PIL import Image

# Visualization
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from captum.attr import Saliency
from captum.attr import visualization as viz

# CAM methods
from torchcam.methods import GradCAM
from torchcam.utils import overlay_mask

# Datasets
from datasets import load_dataset

Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-b8zb3lge
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-b8zb3lge
  Resolved https://github.com/openai/CLIP.git to commit a1d071733d7111c9c014f024669f959182114e33
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting ftfy (from clip==1.0)
  Downloading ftfy-6.1.3-py3-none-any.whl (53 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.4/53.4 kB[0m [31m863.7 kB/s[0m eta [36m0:00:00[0m
Building wheels for collected packages: clip
  Building wheel for clip (setup.py) ... [?25l[?25hdone
  Created wheel for clip: filename=clip-1.0-py3-none-any.whl size=1369497 sha256=3265dfb5197ba1d4fbab615c17dee4ef882b9b1c09546ee2a6cf182f633d3abf
  Stored in directory: /tmp/pip-ephem-wheel-cache-c5jovrcn/wheels/da/2b/4c/d6691fa9597aac8bb85d2ac13b112deb897d5b50f5ad9a37e4
Successfully built clip
In

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(device)

cuda


In [None]:
# LOAD DATALOADERS WITH TRANSFORMED IMAGES TODO FOR LATER


#using the ImageNet Transformation
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224), # Cropping a central square patch of the image
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  #TO-DO figure out impact & optimal values
])

access_token = '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'


def collate_fn(batch):
    images = []
    labels = []

    for item in batch:
        image = item['image']
        label = item['label']

        # Convert to PIL Image if not already (assuming image is a NumPy array or a tensor)
        if not isinstance(image, Image.Image):
            image = to_pil_image(image)

        # Ensure the image is in RGB format
        if image.mode != 'RGB':
            image = image.convert('RGB')

        # Apply transformations
        image = transform(image)

        # Append the transformed image and label to the lists
        images.append(image)
        labels.append(label)

    # Stack images into a single tensor and convert labels to tensor
    images = torch.stack(images)
    labels = torch.tensor(labels)

    return images, labels

subset_size = 10000  # Adjust this based on your needs, 10.000 is almost too much
subset_data = []
imagenet_data = load_dataset("imagenet-1k", split="train", streaming = True, token=access_token, trust_remote_code=True)

# Manually iterate through the dataset and take a subset
for i, sample in enumerate(imagenet_data):
    if i >= subset_size:
        break
    subset_data.append(sample)

# creating a DataLoader from this subset
dataloader = DataLoader(subset_data, batch_size=32, shuffle=True, collate_fn=collate_fn)


# Get validation set
subset_size = 1000  # Adjust this based on your needs, 10.000 is almost too much
subset_data = []
validation_data = load_dataset("imagenet-1k", split="validation", streaming = True, token=access_token, trust_remote_code=True)
for i, sample in enumerate(validation_data):
    if i >= subset_size:
        break
    subset_data.append(sample)
validation_loader = DataLoader(subset_data, batch_size=24, shuffle=None, collate_fn=collate_fn)


Downloading builder script:   0%|          | 0.00/4.72k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/85.4k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/46.4k [00:00<?, ?B/s]

In [None]:
def get_ImageNet_ClassNames():
    """
    Reads and returns a list of class names from the ImageNet dataset.

    This function reads a JSON file containing mappings of ImageNet class indices
    to their respective human-readable names and returns a list of these names.

    Returns:
        list: A list of strings where each string is a class name from ImageNet.
    """
    # Path to the JSON file containing ImageNet class index and names
    text_file = '/content/drive/MyDrive/FACT LICO 13/imagenet_class_index.json'

    # Open the JSON file and load its contents into a Python dictionary
    with open(text_file, 'r', encoding='utf-8') as f:
        class_index = json.load(f)

    # Initialize an empty list to hold the class names
    names = []

    # Iterate over the dictionary and extract class names
    for i in range(len(class_index)):
        # Append the last element (class name) of each list in the dictionary to 'names'
        name = class_index[str(i)].replace("_", " ")

        names.append(name)

    # Return the list of class names
    return names

In [None]:
def get_encoded_labels(labels, prompt):
    """
    Get prompts that correspond with labels of given batch.
    """
    labels = labels.to(torch.int64)
    selected_encodings = prompt[labels]
    return selected_encodings

In [None]:
class ModifiedResNet(nn.Module):
    """
    Resnet50 model structure, with modification that it also returns the image features
    before the fully connected layer. This is used for OT loss
    """
    def __init__(self, original_model):
        super(ModifiedResNet, self).__init__()

        # add layers from the original model
        for name, module in original_model.named_children():
            setattr(self, name, module)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        features = x # [Batch_size, number of filters, feature_map_height, feature_map_width]

        x = self.avgpool(x)
        x = torch.flatten(x, start_dim=1)
        logits = self.fc(x)
        return logits, features

In [None]:
class MLP(nn.Module):
    """
    MLP to transform text_features to image features dimensions to use text encoder
    with any image encoder
    the temperature parameter is trained here and used for MM loss in the training loop
    output_dim is the dimension of the image_feature output of the image encoder
    """
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.dropout = nn.Dropout(p=0.5)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

        # Temperature as trainable param for mm loss
        self.temp = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) # DAAN LETS CHECK VALUES, MAYBE TRESHOLD

    def forward(self, x):
        x = self.fc1(x)
        x = self.dropout(x)
        x = nn.functional.relu(x)
        x = self.fc2(x)
        return x

In [None]:
class TextEncoder(nn.Module):
    """
    A text encoder module that uses a transformer model from a CLIP architecture
    to encode text prompts into feature embeddings.

    Attributes:
        transformer (nn.Module): The transformer module from the CLIP model.
        positional_embedding (Tensor): The positional embeddings from the CLIP model.
        ln_final (nn.Module): Layer normalization applied after the transformer.
        text_projection (Tensor): Linear projection layer for final text features.
        dtype (torch.dtype): Data type of the model, typically torch.FloatTensor.
    """

    def __init__(self, clip_model):
        """
        Initializes the TextEncoder module using components from a given CLIP model.

        Args:
            clip_model (CLIP): A pre-trained CLIP model.
        """
        super().__init__()
        # Adjust the CLIP model to the appropriate data type (float)
        clip_model = clip_model.type(torch.FloatTensor)

        # Extract relevant parts from the CLIP model
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection
        self.dtype = clip_model.dtype

    def forward(self, prompts, tokenized_prompts):
        """
        Forward pass for encoding all text prompts.

        Args:
            prompts (Tensor): The context vectors for prompts of all classes.
            tokenized_prompts (Tensor): Tokenized representation of the prompts of all classes.

        Returns:
            Tensor: The encoded text features.
        """
        # Add positional embeddings to prompts and adjust dimensions for transformer
        x = prompts + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)  # Reorder dimensions for transformer input. # (batch, length, dimension) -> (length, batch, dimension) for transformer

        # Pass the input through the transformer
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # Reorder dimensions back to original.  # (batch, length, dimension) <- (length, batch, dimension) for transformer

        # Apply layer normalization
        x = self.ln_final(x).type(self.dtype)

        # Extract features corresponding to the end-of-token (EOT) embedding
        # and apply text projection to get final text feature embeddings
        # EOT: embeddings of entire input sequence
        # self.text_projection is a learned linear transformation
        # maps the high-dimensional transformer output to a lower-dimensional space suitable for downstream tasks
        # print(f'x{x.shape}')
        x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection

        return x

In [None]:
# COMMENT AND CHECK LATER

class PromptLearner(nn.Module):
    """
    A PyTorch module for learning prompt embeddings in the context of a CLIP model.

    This module creates and learns context vectors (prompts) for each class in a given set of class names.
    These prompts are used with a CLIP model to produce text embeddings that are aligned with image features.

    Attributes:
        ctx (nn.Parameter): Learnable context vectors for each class.
        token_prefix (Tensor): Start-of-sequence token embeddings from CLIP.
        token_suffix (Tensor): End-of-sequence and class token embeddings from CLIP.
        n_cls (int): Number of classes.
        n_ctx (int): Number of context tokens.
        class_token_position (str): Position of the class token in the prompt (options: 'middle', 'end', 'front').
    """
    def __init__(self, classnames, clip_model):
        """
        Initializes the PromptLearner module with class names and a CLIP model.

        Args:
            classnames (list): A list of class names (strings).
            clip_model (CLIP): The pre-trained CLIP model from which certain layers are used.
        """
        super().__init__()
        n_cls = len(classnames)
        n_ctx = 12
        # DAAN: ctx_init = None, so we can never provide our initialization ??
        ctx_init = None
        dtype = clip_model.dtype
        ctx_dim = clip_model.ln_final.weight.shape[0]
        self.N = 1

        if ctx_init:
            # use given words to initialize context vectors
            ctx_init = ctx_init.replace("_", " ")
            n_ctx = len(ctx_init.split(" "))
            prompt = clip.tokenize(ctx_init)
            with torch.no_grad():
                embedding = clip_model.token_embedding(prompt).type(dtype)
            ctx_vectors = embedding[0, 1 : 1 + n_ctx, :]
            prompt_prefix = ctx_init
        else:
            # random initialization, DAAN: not random, right? we initialize with X's
            if True:
                print("Initializing class-specific contexts")
                ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim, dtype=dtype)
            else:
                print("Initializing a generic context")
                ctx_vectors = torch.empty(self.N, n_ctx, ctx_dim, dtype=dtype)
            nn.init.normal_(ctx_vectors, std=0.02)   # define the prompt to be trained
            prompt_prefix = " ".join(["X"] * n_ctx)

        print(f'Initial context: "{prompt_prefix}"')
        print(f"Number of context words (tokens): {n_ctx}")

        self.ctx = nn.Parameter(ctx_vectors)  # to be optimized

        classnames = [name.replace("_", " ") for name in classnames]
        name_lens = [len(_tokenizer.encode(name)) for name in classnames]

        # '.' as end of sentence token for representation of whole sentence
        prompts = [prompt_prefix + " " + name + "." for name in classnames]

        tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]) # (10, 77)
        tokenized_prompts = tokenized_prompts.repeat(self.N,1)


        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)
        print('tokenized prompts:', embedding.shape, 'ctx: ', self.ctx.shape)

        # These token vectors will be saved when in save_model(),
        # but they should be ignored in load_model() as we want to use
        # those computed using the current class names DAAN: huh??? So we don't use it??
        self.register_buffer("token_prefix", embedding[:, :1, :])  # SOS
        self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :])  # CLS, EOS

        self.n_cls = n_cls
        self.n_ctx = n_ctx
        self.tokenized_prompts = tokenized_prompts  # torch.Tensor
        self.name_lens = name_lens
        # DAAN (and COEN): Why do we define 'middle' here? DAAN: There is a cls_loc arg in the shuffle, would be nice in forward to experiment with
        self.class_token_position = 'middle'

    # DAAN: Why do we give the prefix and suffix here if we dont use them?
    def _ctx_shuffle(self, prefix, suffix, ctx, cls_loc = 'end', shuffleCLS = False):
        """
        Shuffles the context vectors.

        Args:
            prefix (Tensor): Prefix token embeddings.
            suffix (Tensor): Suffix token embeddings.
            ctx (Tensor): Context vectors to shuffle.
            cls_loc (str): Position of the class token in the prompt.
            shuffleCLS (bool): Whether to shuffle the class token positions.

        Returns:
            Tensor: Shuffled context vectors.
        """

        # shuffle the ctx along 2nd dimension
        rand_idx = torch.randperm(ctx.shape[1])
        shuffled_ctx = ctx[:, rand_idx, :]
        return shuffled_ctx


    def forward(self):
        """
        Forward pass of the PromptLearner to create prompts for each class.

        Returns:
            Tensor: A batch of prompts, one for each class.
        """

        ctx = self.ctx
        if ctx.dim() == 3:
            ctx = ctx.unsqueeze(0)

        ctx = ctx.contiguous().view(self.N*self.n_cls,self.n_ctx,ctx.shape[3])

        prefix = self.token_prefix
        suffix = self.token_suffix

        # DAAN: do different context vectors still become different (or does this make them all the same) ?
        ctx = self._ctx_shuffle(prefix, suffix, ctx)

        if self.class_token_position == "end":
            prompts = torch.cat(
                [
                    prefix,  # (n_cls, 1, dim)
                    ctx,     # (n_cls, n_ctx, dim)
                    suffix,  # (n_cls, *, dim)
                ],
                dim=1,
            )

        elif self.class_token_position == "middle":
            half_n_ctx = self.n_ctx // 2
            prompts = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i : i + 1, :, :]
                class_i = suffix[i : i + 1, :name_len, :]
                suffix_i = suffix[i : i + 1, name_len:, :]
                ctx_i_half1 = ctx[i : i + 1, :half_n_ctx, :]
                ctx_i_half2 = ctx[i : i + 1, half_n_ctx:, :]
                prompt = torch.cat(
                    [
                        prefix_i,     # (1, 1, dim)
                        ctx_i_half1,  # (1, n_ctx//2, dim)
                        class_i,      # (1, name_len, dim)
                        ctx_i_half2,  # (1, n_ctx//2, dim)
                        suffix_i,     # (1, *, dim)
                    ],
                    dim=1,
                )
                prompts.append(prompt)
            prompts = torch.cat(prompts, dim=0)

        elif self.class_token_position == "front":
            prompts = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i : i + 1, :, :]
                class_i = suffix[i : i + 1, :name_len, :]
                suffix_i = suffix[i : i + 1, name_len:, :]
                ctx_i = ctx[i : i + 1, :, :]
                prompt = torch.cat(
                    [
                        prefix_i,  # (1, 1, dim)
                        class_i,   # (1, name_len, dim)
                        ctx_i,     # (1, n_ctx, dim)
                        suffix_i,  # (1, *, dim)
                    ],
                    dim=1,
                )
                prompts.append(prompt)
            prompts = torch.cat(prompts, dim=0)

        else:
            raise ValueError
        return prompts

In [None]:
class_names = get_ImageNet_ClassNames()

# Load CLIP model
model, preprocess = clip.load("ViT-B/32", device=device)


# Tokenize and encode class names
text_inputs = clip.tokenize(class_names).to(device)
with torch.no_grad():
    text_features_alt = model.encode_text(text_inputs).float()

print(text_features_alt.shape)

100%|████████████████████████████████████████| 338M/338M [00:02<00:00, 152MiB/s]


torch.Size([1000, 512])


In [None]:
# CALCULATIONS FOR OT LOSS
# DAAN can we get sinkhorn loss with a library? Would make the code much simpler


# Adapted from https://github.com/gpeyre/SinkhornAutoDiff
class SinkhornDistance(nn.Module):
    """
    Given two empirical measures each with :math:`P_1` locations
    :math:`x\in\mathbb{R}^{D_1}` and :math:`P_2` locations :math:`y\in\mathbb{R}^{D_2}`,
    outputs an approximation of the regularized OT cost for point clouds.
    Args:
        eps (float): regularization coefficient
        max_iter (int): maximum number of Sinkhorn iterations
        reduction (string, optional): Specifies the reduction to apply to the output:
            'none' | 'mean' | 'sum'. 'none': no reduction will be applied,
            'mean': the sum of the output will be divided by the number of
            elements in the output, 'sum': the output will be summed. Default: 'none'
    Shape:
        - Input: :math:`(N, P_1, D_1)`, :math:`(N, P_2, D_2)`
        - Output: :math:`(N)` or :math:`()`, depending on `reduction`
    """
    def __init__(self, eps, max_iter, reduction='none'):
        super(SinkhornDistance, self).__init__()
        self.eps = eps
        self.max_iter = max_iter
        self.reduction = reduction

    def forward(self, x, y):
        # The Sinkhorn algorithm takes as input three variables :
        C = self._cost_matrix(x, y)  # Wasserstein cost function
        # print(x.size(), y.size(), C.shape)
        x_points = x.shape[-2]
        y_points = y.shape[-2]
        # print(x.dim(), x_points, y_points)
        if x.dim() == 2:
            batch_size = 1
        else:
            batch_size = x.shape[0]

        # both marginals are fixed with equal weights
        mu = torch.empty(batch_size, x_points, dtype=torch.float,
                         requires_grad=False).fill_(1.0 / x_points).squeeze().cuda()
        nu = torch.empty(batch_size, y_points, dtype=torch.float,
                         requires_grad=False).fill_(1.0 / y_points).squeeze().cuda()

        u = torch.zeros_like(mu).cuda()
        v = torch.zeros_like(nu).cuda()
        # To check if algorithm terminates because of threshold
        # or max iterations reached
        actual_nits = 0
        # Stopping criterion
        thresh = 1e-3

        # Sinkhorn iterations
        for i in range(self.max_iter):
            u1 = u  # useful to check the update
            u = self.eps * (torch.log(mu+1e-8) - torch.logsumexp(self.M(C, u, v), dim=-1)) + u
            v = self.eps * (torch.log(nu+1e-8) - torch.logsumexp(self.M(C, u, v).transpose(-2, -1), dim=-1)) + v
            err = (u - u1).abs().sum(-1).mean()

            actual_nits += 1
            # print(i, err.item(), thresh)
            if err.item() < thresh:
                break

        U, V = u, v
        # Transport plan pi = diag(a)*K*diag(b)
        pi = torch.exp(self.M(C, U, V))

        # Sinkhorn distance
        cost = torch.sum(pi * C, dim=(-2, -1)).mean()

        # if self.reduction == 'mean':
        #     cost = cost.mean()
        # elif self.reduction == 'sum':
        #     cost = cost.sum()

        return cost

    def M(self, C, u, v):
        "Modified cost for logarithmic updates"
        "$M_{ij} = (-c_{ij} + u_i + v_j) / \epsilon$"
        return (-C + u.unsqueeze(-1) + v.unsqueeze(-2)) / self.eps

    @staticmethod
    def _cost_matrix(x, y, p=2):
        "Returns the matrix of $|x_i-y_j|^p$."
        # print(x.shape, y.shape)
        x_col = x.unsqueeze(-2)
        y_lin = y.unsqueeze(-3)
        # print(x_col.shape, y_lin.shape)
        C = torch.sum((torch.abs(x_col - y_lin)) ** p, -1)
        # C.detach()
        return C

In [None]:
# MAKE FUNCTION OF IT
def calculate_adjacency_matrix(features, temperature):
    """
    calculates the adjacency matrix of the given features.
    1. Calculate the pairwise Euclidean distances of the features
    2. Apply the temperature scaling
    """
    dist_matrix = torch.cdist(features, features, p=2)
    adj_matrix = F.softmax(-dist_matrix / temperature, dim=1)
    return adj_matrix


def manifold_matching_loss(image_features, text_features, temperature):
    """
    calculate the mm loss of the lico model
    """
    A_F = calculate_adjacency_matrix(image_features, temperature)
    A_G = calculate_adjacency_matrix(text_features, temperature)
    # print(A_F.shape)
    # Calculate the KL divergence loss for manifold matching
    loss = F.kl_div(A_G.log(), A_F, reduction='batchmean')
    return loss

In [None]:
def train_model(modified_resnet, dataloader, manifold_matching_loss, sinkhorn_loss, text2img_dim_transform, num_epochs, device, all_prompt_features, validation_loader, get_encoded_labels, ablation1, ablation2):
    """
    Train the model.
    """

    # Check validation before training
    validate_model(modified_resnet, validation_loader, device)

    # initialize the optimizer
    optimizer = optim.SGD([
        {'params': modified_resnet.parameters()},
        {'params': text2img_dim_transform.parameters()},
    ], lr=0.03, momentum=0.9, weight_decay=0.0001)

    # Initialize the learning rate scheduler
    scheduler = CosineAnnealingLR(optimizer, num_epochs)

    # Training loop
    for epoch in range(num_epochs):
        modified_resnet.train()
        text2img_dim_transform.train()


        if (epoch+1) % 2 == 1: # Change this to show which epoch we are
            print(f"Epoch {epoch + 1}/{num_epochs}")

        for images, labels in tqdm(dataloader):
            images = images.to(device)
            labels = labels.to(device)

            # Forward pass through models
            encoded_labels = get_encoded_labels(labels, all_prompt_features)
            predictions, features_resnet = modified_resnet(images)

            # feature_maps for OT loss
            feature_maps = features_resnet.view(features_resnet.shape[0], features_resnet.shape[1], -1)
            feature_maps = F.normalize(feature_maps, dim = 2)

            # image_features for manifold loss
            image_features = F.adaptive_avg_pool2d(features_resnet, 1)
            image_features = image_features.view(images.shape[0], -1)
            image_features = F.normalize(image_features, dim = -1)

            # transform text_features dimension to match thos of the image encoder's output
            text_features = text2img_dim_transform(encoded_labels)
            text_features = F.normalize(text_features, dim = -1)

            # get temperature parameter
            temperature = text2img_dim_transform.temp

            # calculate losses
            CE_loss = torch.nn.functional.cross_entropy(predictions, labels)
            MM_loss = manifold_matching_loss(image_features, text_features, temperature)
            OT_loss = sinkhorn_loss(feature_maps, text_features)

            if ablation1 == 'mm' or ablation2 == 'mm':
                MM_loss = 0
            if ablation1 == 'ot' or ablation2 == 'ot':
                OT_loss = 0

            # params according to the paper
            alpha = 10
            beta = 1

            # Combine the losses or use them as needed
            total_loss = CE_loss + alpha * MM_loss + beta * OT_loss

            # Backward and optimize
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            # Clipping the parameter value to be within a min_val and max_val #CHOSEN BY OURSELVES
            with torch.no_grad():  # This makes sure the operation is not tracked by autograd
                text2img_dim_transform.temp.clamp_(min=0.1, max=3)


        print(f"temperature after last batch of epoch was:{temperature.item()}")

        # Evaluate on validation set or perform any other actions at the end of each epoch
        validate_model(modified_resnet, validation_loader, device)
        scheduler.step()

        print(f"Loss of last epoch in batch is: CE: {CE_loss}, OT: {OT_loss}, MM: {MM_loss}")

        # Save the model after training
        torch.save(modified_resnet.state_dict(), f'/content/drive/MyDrive/FACT LICO 13/Models/modified_resnet_{ablation1}{ablation2}_{epoch}.pth') # CHANGE PATH ALWAYS
        torch.save(text2img_dim_transform.state_dict(), f'/content/drive/MyDrive/FACT LICO 13/Models/text2img_dim_transform_{ablation1}{ablation2}_{epoch}.pth') # CHANGE PATH ALWAYS

In [None]:
def validate_model(modified_resnet, dataloader, device):
    """
    Validate the model.
    """

    modified_resnet.eval()  # Set the model to evaluation mode

    # Initialize variables to track metrics
    total_accuracy = 0.0
    num_batches = 0

    with torch.no_grad():  # No need to track gradients during validation
        for images, labels in tqdm(dataloader):
            images = images.to(device)
            labels = labels.to(device)

            # Forward pass through models
            predictions, _ = modified_resnet(images)

            # Calculate accuracy
            _, predicted = torch.max(predictions.data, 1)
            total_accuracy += (predicted == labels).sum().item()
            num_batches += 1

    # Compute average losses and accuracy
    avg_accuracy = total_accuracy / (num_batches * dataloader.batch_size)

    print(f'Validation results: Accuracy: {avg_accuracy}')

    # Return to training mode
    modified_resnet.train()
    return avg_accuracy

In [None]:
clip_model, _ = clip.load("ViT-B/32", device)
classnames = get_ImageNet_ClassNames()
_tokenizer = _Tokenizer()

# 1. For pronmptleaner
text_encoder = TextEncoder(clip_model).to(device)
prompt_learner = PromptLearner(classnames, clip_model).to(device)
total_prompt_from_labels = prompt_learner()
tokenized_total_prompt = prompt_learner.tokenized_prompts.to(device)
with torch.no_grad():
    all_prompt_features = text_encoder(total_prompt_from_labels, tokenized_total_prompt)

# For without Promptlearner
# all_prompt_features = text_features_alt

all_prompt_features = all_prompt_features.to(device)

# create the resnet model
resnet = models.resnet50(pretrained=False)
resnet = resnet.to(device)
modified_resnet = ModifiedResNet(resnet)
modified_resnet = modified_resnet.to(device)


# SAVE UNTRAINED MODEL FOR LATER COMPARISON
pre_training_weights = copy.deepcopy(modified_resnet.state_dict())

# create mlp
input_dim = 512 # text encoder CLIP
output_dim = 49 # to match 7x7 dimension of the feature maps
hidden_dim = 512 # COEN: chat said the notation is hidden dim, output dim, so hidden dim = 512
text2img_dim_transform = MLP(input_dim, hidden_dim, output_dim)
text2img_dim_transform = text2img_dim_transform.to(device)


# Initialize SinkhornDistance module
sinkhorn_loss = SinkhornDistance(eps=0.1, max_iter=100, reduction='mean').to(device)

# train model
num_epochs = 90 # CHANGE THIS
ablation1 = "none"
ablation2 = "none"



In [None]:
train_model(modified_resnet, dataloader, manifold_matching_loss, sinkhorn_loss, text2img_dim_transform, num_epochs, device, all_prompt_features, validation_loader, get_encoded_labels, ablation1, ablation2)

  0%|          | 0/32 [00:00<?, ?it/s]

Validation results: Accuracy: 0.0009765625
Epoch 1/90


  0%|          | 0/32 [00:00<?, ?it/s]

temperature after last batch of epoch was:2.7385759353637695


  0%|          | 0/32 [00:00<?, ?it/s]

Validation results: Accuracy: 0.001953125
Loss of last epoch in batch is: CE: 8.309207916259766, OT: 0.20374803245067596, MM: 0.003949718549847603


  0%|          | 0/32 [00:00<?, ?it/s]

KeyboardInterrupt: 

QUICK CHECK TO SEE IF MODEL WEIGHTS TRANSFER AND IF RESULTS CAN BE VISUALIZED with original pretrained weights

In [None]:
# from torchvision.models import resnet50
# from torchcam.methods import GradCAM
# from torchcam.utils import overlay_mask
# import torch
# from torchvision.transforms.functional import to_pil_image
# import torchvision.models as models

# Load weights of modified resnet into resnet50
resnet50_for_gradcam = models.resnet50(pretrained=False).to(device)
resnet50_for_gradcam.load_state_dict(pre_training_weights)

# Load the trained weights back into the model
resnet50_for_gradcam.load_state_dict(torch.load('/content/drive/MyDrive/FACT LICO 13/Models/modified_resnet_90_no_mm.pth')) # DAAN I changed this, it was 'modified_resnet.pth'

resnet50_for_gradcam = resnet50_for_gradcam.eval()


# # Alternative if layers are not all the same anymore
# # Instantiate a new standard ResNet50 model
# standard_resnet = models.resnet50(pretrained=False).to(device)

# # Get the names of the layers in the standard ResNet50
# standard_resnet_layer_names = [name for name, _ in standard_resnet.named_children()]

# # Transfer the weights from ModifiedResNet to the standard ResNet50
# for name, module in modified_resnet.named_children():
#     if name in standard_resnet_layer_names:
#         # print("YES")
#         # Transfer the state dictionary of each corresponding layer
#         getattr(standard_resnet, name).load_state_dict(module.state_dict())

# # Initializing separate instance of the model fro Grad-Cam since it overlaps with the other one and causes errors with Quantitative Test
# resnet50_for_gradcam = standard_resnet.eval()
# resnet50_for_gradcam = resnet50_for_gradcam.to(device)


# # Initializing the CAM extractor
# for param in resnet50_for_gradcam.layer4.parameters():
#     param.requires_grad = True

# #COEN HZ LAYER 4 HIER DAAN Thats the last convlution layer (so where the last features are extracted)
# cam_extractor = GradCAM(resnet50_for_gradcam, 'layer4')

# gradcam_maps = []
# original_images = []
# outputs = []


# # Defining the inverse transform for visualization of original image
# inverse_transform = transforms.Compose([
#     transforms.Normalize(mean=[0., 0., 0.], std=[1/0.229, 1/0.224, 1/0.225]),
#     transforms.Normalize(mean=[-0.485, -0.456, -0.406], std=[1., 1., 1.]),
#     transforms.ToPILImage(),
# ])

# for images, labels_batch in validation_loader:
#     images.requires_grad_()   #finding the gradient w.r.t the input image
#     images = images.to(device)
#     for i in range(images.size(0)):
#         # Generate Grad-CAM map
#         with torch.enable_grad():
#             gc_outputs = resnet50_for_gradcam(images[i].unsqueeze(0))
#             activation_map = cam_extractor(gc_outputs.squeeze(0).argmax().item(), gc_outputs)

#             # Overlay the CAM on the image
#             gradcam_overlay = overlay_mask(to_pil_image(images[i]), to_pil_image(activation_map[0].squeeze(0), mode='F'), alpha=0.5)
#             gradcam_overlay_np = np.array(gradcam_overlay)

#             original_image = inverse_transform(images[i].cpu().detach())
#             original_image_np = np.array(original_image.convert('RGB'))

#             # print('original_image_np shape:', original_image_np.shape)
#             # print('gradcam_overlay_np shape:', gradcam_overlay_np.shape)

#             pred_class = gc_outputs.argmax(dim=1)  # Get the index of the max logit


#             original_images.append(original_image)
#             gradcam_maps.append(gradcam_overlay_np)
#             img_labels = labels_batch
#             outputs.append(pred_class)


#         # Reset the gradients for the next image
#         resnet50_for_gradcam.zero_grad()
#     break

In [None]:
# for i in range(10): #currently 'subset_size' is 10 which is just used to quickly viusalize/inspect it
#     plt.figure(figsize=(18, 6))



#     # Plot original image
#     plt.subplot(1, 3, 1)
#     plt.imshow(original_images[i])
#     plt.title(f'Original Image {i}, {classnames[img_labels[i]]}')
#     plt.axis('off')

#     # Plot Grad-CAM overlay
#     plt.subplot(1, 3, 2)
#     plt.imshow(gradcam_maps[i])
#     plt.title(f'Grad-CAM {i}, prediction: {classnames[outputs[i]]}')
#     plt.axis('off')

#     plt.show()

In [None]:
!pip install torchray
from torchray.attribution.grad_cam import grad_cam
from torchray.benchmark import get_example_data, plot_example

# Obtain example data.
_, x, category_id, _ = get_example_data()

# Grad-CAM backprop.
saliency = grad_cam(resnet50_for_gradcam, x, category_id, saliency_layer='features.29')

# Plots.
plot_example(x, saliency, 'grad-cam backprop', category_id)

Collecting torchray
  Downloading torchray-1.0.0.2.tar.gz (376 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m376.2/376.2 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting pymongo (from torchray)
  Downloading pymongo-4.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (677 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m677.1/677.1 kB[0m [31m12.2 MB/s[0m eta [36m0:00:00[0m
Collecting dnspython<3.0.0,>=1.16.0 (from pymongo->torchray)
  Downloading dnspython-2.5.0-py3-none-any.whl (305 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m305.4/305.4 kB[0m [31m28.1 MB/s[0m eta [36m0:00:00[0m
Building wheels for collected packages: torchray
  Building wheel for torchray (setup.py) ... [?25l[?25hdone
  Created wheel for torchray: filename=torchray-1.0.0.2-py3-none-any.whl size=444010 sha256=2f6359fb94743a15691d1c3b1278f032aec5d540729992ee1823814

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:05<00:00, 99.8MB/s]


UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x78c70264e3e0>