## Feature Visualization of Deep Neural Networks
- Prepared by Engin Deniz Erkan

In [None]:
!pip install timm

To save the images inside google drive, below code is used.

In [None]:
from google.colab import drive
import os
import sys

drive.mount('/content/drive')

output_dir = "/content/drive/MyDrive/MMI727_project_visualizations"

if not os.path.exists(output_dir):
    print(f"Directory does not exist: {output_dir}")
    sys.exit(f"Exiting program: Directory does not exist: {output_dir}")
else:
    print(f"Directory exists: {output_dir}")

In [None]:
import torch
import torchvision
import numpy as np
import PIL
import random
import matplotlib.pyplot as plt
from PIL import Image
from IPython.display import display

device = "cuda" if torch.cuda.is_available() else "cpu"

############################################################################################################
##                       Utility functions and classes                                           ##
############################################################################################################
class NetworkWrapper(torch.nn.Module):

    def __init__(self, network, preprocess_fn):
        super(NetworkWrapper, self).__init__()

        self.preprocess_fn = preprocess_fn
        self.network = network
        self.network.eval()

    def forward(self, x):
        x = self.preprocess_fn(x)
        x = self.network(x)
        return x

class Visualization(torch.nn.Module):
    def __init__(self, h, w):
        super(Visualization, self).__init__()
        self.__data = torch.nn.Parameter(torch.randn(1, 3, h, w))

    def __augment(self, x, batch_size, apply_all=True):

        x = torch.cat([x] * batch_size, dim=0)

        # Apply random augmentations to each image in the batch
        if apply_all:
            transform = torchvision.transforms.Compose([
                torchvision.transforms.RandomResizedCrop([self.out_h, self.out_w]),
                torchvision.transforms.RandomHorizontalFlip(),
                torchvision.transforms.RandomRotation(360),
                torchvision.transforms.RandomPerspective(),
                torchvision.transforms.GaussianBlur(3),
                torchvision.transforms.RandomGrayscale(p=0.1)
            ])
        else:
            transform = torchvision.transforms.RandomResizedCrop([self.out_h, self.out_w])

        augmented_images = torch.stack([transform(img) for img in x])

        return augmented_images


    def __reparameterize(self, x):

        x = torch.nn.functional.sigmoid(x)
        return x

    def set_output_shape(self, h, w):
        self.out_h = h
        self.out_w = w

    def forward(self, batch_size, apply_all=True):
        x = self.__data
        x = self.__reparameterize(x)
        x = self.__augment(x, batch_size, apply_all)
        return x

    def to_img(self):
        with torch.no_grad():
            x = self.__data
            x = self.__reparameterize(x)

        # fill here to create a PIL image from x
        # and return it

        x_np = x.detach().cpu().numpy().squeeze()

        # Rescale values from (0, 1) to (0, 255) and change data type to uint8
        x_np = (x_np * 255).astype(np.uint8)

        # Create a PIL image
        pil_img = PIL.Image.fromarray(np.transpose(x_np, (1, 2, 0)))  # Reorder dimensions for PIL

        return pil_img

############################################################################################################
##                         Initialize the model, visualization and optimizer                              ##
############################################################################################################
class_labels = [2, 76, 107, 340, 440, 479, 546, 836, 852, 947]

for class_label in class_labels:

    # RESNET with all augmentations
    net = torchvision.models.resnet18(pretrained=True)

    resnet_preprocess_fn = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    resnet_model = NetworkWrapper(net, resnet_preprocess_fn).to(device)
    resnet_vis = Visualization(256, 256).to(device)
    resnet_vis.set_output_shape(224, 224)
    resnet_optimizer = torch.optim.AdamW(params=resnet_vis.parameters(), lr=0.2)

    # RESNET with one augmentation
    net1_aug = torchvision.models.resnet18(pretrained=True)

    resnet1_aug_preprocess_fn = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    resnet1_aug_model = NetworkWrapper(net1_aug, resnet1_aug_preprocess_fn).to(device)
    resnet1_aug_vis = Visualization(256, 256).to(device)
    resnet1_aug_vis.set_output_shape(224, 224)
    resnet1_aug_optimizer = torch.optim.AdamW(params=resnet1_aug_vis.parameters(), lr=0.2)

    # VIT with all augmentations
    net2 = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)

    vit_preprocess_fn = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    vit_model = NetworkWrapper(net2, vit_preprocess_fn).to(device)
    vit_vis = Visualization(256, 256).to(device)
    vit_vis.set_output_shape(224, 224)
    vit_optimizer = torch.optim.AdamW(params=vit_vis.parameters(), lr=0.2)

    # VIT with one augmentation
    net2_aug = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)

    vit1_aug_preprocess_fn = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    vit1_aug_model = NetworkWrapper(net2_aug, vit1_aug_preprocess_fn).to(device)
    vit1_aug_vis = Visualization(256, 256).to(device)
    vit1_aug_vis.set_output_shape(224, 224)
    vit1_aug_optimizer = torch.optim.AdamW(params=vit1_aug_vis.parameters(), lr=0.2)

############################################################################################################
##                                            Training loop                                               ##
############################################################################################################

    for i in range(1000):

        # RESNET with all augmentations Training
        resnet_vis.train()
        resnet_optimizer.zero_grad()

        output_resnet = resnet_model(resnet_vis(8))  # Using 8 as batch size
        loss_resnet = -output_resnet[:, class_label].mean()  # Maximizing the score of the class

        loss_resnet.backward()
        resnet_optimizer.step()

        # RESNET with one augmentation Training
        resnet1_aug_vis.train()
        resnet1_aug_optimizer.zero_grad()

        output_resnet1_aug = resnet1_aug_model(resnet1_aug_vis(8, apply_all=False))  # Using 8 as batch size, apply only the first augmentation
        loss_resnet1_aug = -output_resnet1_aug[:, class_label].mean()  # Maximizing the score of the class

        loss_resnet1_aug.backward()
        resnet1_aug_optimizer.step()

        # VIT with all augmentations Training
        vit_vis.train()
        vit_optimizer.zero_grad()

        output_vit = vit_model(vit_vis(8))  # Using 8 as batch size
        loss_vit = -output_vit[:, class_label].mean()  # Maximizing the score of the class

        loss_vit.backward()
        vit_optimizer.step()

        # VIT with one augmentation Training
        vit1_aug_vis.train()
        vit1_aug_optimizer.zero_grad()

        output_vit1_aug = vit1_aug_model(vit1_aug_vis(8, apply_all=False))  # Using 8 as batch size, apply only the first augmentation
        loss_vit1_aug = -output_vit1_aug[:, class_label].mean()  # Maximizing the score of the class

        loss_vit1_aug.backward()
        vit1_aug_optimizer.step()

        # Show visualizations every 100 iterations
        if (i + 1) % 100 == 0:
            resnet_vis.eval()
            resnet1_aug_vis.eval()
            vit_vis.eval()
            vit1_aug_vis.eval()

            with torch.no_grad():
                generated_img_resnet = resnet_vis.to_img()
                generated_img_resnet_one = resnet1_aug_vis.to_img()
                generated_img_vit = vit_vis.to_img()
                generated_img_vit_one = vit1_aug_vis.to_img()

                # Display visualizations for RESNET and VIT
                plt.figure(figsize=(10, 3))

                # RESNET with all Augmentations
                plt.subplot(1, 4, 1)
                plt.imshow(generated_img_resnet)
                plt.title("ResNet with all Augmentations", fontsize=7)
                plt.axis('off')

                # VIT with all Augmentations
                plt.subplot(1, 4, 2)
                plt.imshow(generated_img_resnet_one)
                plt.title("ResNet with one Augmentation", fontsize=7)
                plt.axis('off')

                # RESNET with one Augmentation
                plt.subplot(1, 4, 3)
                plt.imshow(generated_img_vit)
                plt.title("VIT with all Augmentations", fontsize=7)
                plt.axis('off')

                # VIT with one Augmentation
                plt.subplot(1, 4, 4)
                plt.imshow(generated_img_vit_one)
                plt.title("VIT with one Augmentation", fontsize=7)
                plt.axis('off')

                plt.suptitle(f"Visualization {i+1} - Target class: {class_label}", fontsize=10)
                plt.show()


    image_name = f"ResNet_all_augmentations_target_class_{class_label}.png"
    image_path = os.path.join(output_dir, image_name)
    generated_img_resnet.save(image_path)

    image_name = f"ResNet_one_augmentation_target_class_{class_label}.png"
    image_path = os.path.join(output_dir, image_name)
    generated_img_resnet_one.save(image_path)

    image_name = f"Vit_all_augmentations_target_class_{class_label}.png"
    image_path = os.path.join(output_dir, image_name)
    generated_img_vit.save(image_path)

    image_name = f"Vit_one_augmentation_target_class_{class_label}.png"
    image_path = os.path.join(output_dir, image_name)
    generated_img_vit_one.save(image_path)
