# Visualizing and Understanding Self-Supervised Learning

In [None]:
import torch
import numpy as np
import torch.nn as nn
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from PIL import Image 
import cv2

%matplotlib inline
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
from data_transforms import normal_transforms, no_shift_transforms, ig_transforms, modify_transforms
from utils import overlay_heatmap, viz_map, show_image, deprocess, get_ssl_model
from methods import occlusion, occlusion_context_agnositc, pairwise_occlusion, deepdream, get_difference
from methods import create_mixed_images, averaged_transforms, sailency, smooth_grad 
from methods import get_pixel_invariance_dataset, pixel_invariance, get_gradcam, get_interactioncam

### Load model and transform images

In [None]:
network = 'simclrv2'
denorm = False

ssl_model = get_ssl_model(network, '1x')

if network != 'simclrv2':
    # add ImageNet normalization to data transforms since these models expect the input to be ImageNet mean and std normalized
    normal_transforms, no_shift_transforms, ig_transforms = modify_transforms(normal_transforms, no_shift_transforms, ig_transforms)
    denorm = True

In [None]:
img_path = 'val/images/ILSVRC2012_val_00000012.jpeg'
img = Image.open(img_path).convert('RGB')
img1 = normal_transforms['pure'](img).unsqueeze(0).to(device)
img2 = normal_transforms['aug'](img).unsqueeze(0).to(device)
print("Similarity from model: ", nn.CosineSimilarity(dim=-1)(ssl_model(img1), ssl_model(img2)).item())

fig, axs = plt.subplots(1, 2, figsize=(20,5))
np.vectorize(lambda ax:ax.axis('off'))(axs)

axs[0].imshow(show_image(img1, denormalize = denorm))  
axs[1].imshow(show_image(img2, denormalize = denorm))

### Perturbation Methods
*Conditional Occlusion, Context Agnostic Occlusion, Pairwise Occlusion*

In [None]:
heatmap1, heatmap2 = occlusion(img1, img2, ssl_model, w_size = 64, stride = 8, batch_size = 32)
heatmap1_ca, heatmap2_ca = occlusion_context_agnositc(img1, img2, ssl_model, w_size = 64, stride = 8, batch_size = 32)
heatmap1_po, heatmap2_po = pairwise_occlusion(img1, img2, ssl_model, batch_size = 32, erase_scale = (0.1, 0.3), erase_ratio = (1, 1.5), num_erases = 100)

fig, axs = plt.subplots(4, 2, figsize=(20,20))
np.vectorize(lambda ax:ax.axis('off'))(axs)

added_image1 = overlay_heatmap(img1, heatmap1, denormalize = denorm)
added_image2 = overlay_heatmap(img2, heatmap2, denormalize = denorm)
added_image1_ca = overlay_heatmap(img1, heatmap1_ca, denormalize = denorm)
added_image2_ca = overlay_heatmap(img2, heatmap2_ca, denormalize = denorm)

axs[0, 0].set_title("Original and Augmented")
axs[0, 0].imshow(show_image(img1, denormalize = denorm))
axs[0, 1].imshow(show_image(img2, denormalize = denorm))
axs[1, 0].set_title("Conditional Occlusion")
axs[1, 0].imshow(added_image1)
axs[1, 1].imshow(added_image2)
axs[2, 0].set_title("Context-Agnostic Conditional Occlusion")
axs[2, 0].imshow(added_image1_ca)
axs[2, 1].imshow(added_image2_ca)
axs[3, 0].set_title("Pairwise Occlusion")
axs[3, 0].imshow((deprocess(img1, denormalize = denorm) * heatmap1_po[:,:,None]).astype('uint8'))
axs[3, 1].imshow((deprocess(img2, denormalize = denorm) * heatmap2_po[:,:,None]).astype('uint8'))

### Feature Visualization
*Maximize the cossim score with varying minmax_weight or maximize the features with varying exponentially moving average (ema)*

In [None]:
dreamed_image, detail = deepdream(img1, img2, ssl_model, optimize_score = True, up_until = 4, ema = 0.5, 
                                  reg_l2 = True, reg_l2_weight = 1e-3, use_tv = True, tv_weight = 1e-3, 
                                  minmax_weight = 0, blur = True, iterations = 20, lr = 0.01, lr_norm = True, 
                                  octave_scale = 1.4, num_octaves = 10, init_scale = 1e-2)


fig, axs = plt.subplots(1, 1, figsize=(20,5))
axs.imshow(deprocess(detail, to_numpy = False))
axs.axis("off")

### Model Difference Vizualization
*Visualize the difference between the self-supervised model and another baseline model (e.g image classification)*

In [None]:
imagenet_images, ssl_images = get_difference(ssl_model = ssl_model, baseline = 'imagenet', image = img2, lr = 1e4, 
                                             l2_weight = 0.1, alpha_weight = 1e-7, alpha_power = 6, tv_weight = 1e-8, 
                                             init_scale = 0.1, network = network)

fig, axs = plt.subplots(3, 3, figsize=(20,10))
np.vectorize(lambda ax:ax.axis('off'))(axs)

for aa, (in_img, ssl_img) in enumerate(zip(imagenet_images, ssl_images)):
    axs[aa,0].imshow(deprocess(img2, denormalize = denorm))
    axs[aa,1].imshow(deprocess(in_img))
    axs[aa,2].imshow(deprocess(ssl_img))
    
axs[0,0].set_title("Original Image")
axs[0,1].set_title("Classification Image")
axs[0,2].set_title("Self-Supervised Image")

### Averaged Transforms

In [None]:
# 'color_jitter', 'blur', 'grayscale', 'solarize', 'combine'
mixed_images = create_mixed_images(transform_type = 'combine', 
                                   ig_transforms = ig_transforms, 
                                   step = 0.1, 
                                   img_path = img_path, 
                                   add_noise = True)

In [None]:
fig, axs = plt.subplots(1, len(mixed_images), figsize=(20,10))
np.vectorize(lambda ax:ax.axis('off'))(axs)
for m in range(len(mixed_images)):
    axs[m].imshow(show_image(mixed_images[m], denormalize = denorm))

In [None]:
# vanilla gradients (for comparison purposes)
sailency1_van, sailency2_van = sailency(guided = True, ssl_model = ssl_model, 
                                        img1 = mixed_images[0], img2 = mixed_images[-1], 
                                        blur_output = True)

# smooth gradients (for comparison purposes)
sailency1_s, sailency2_s = smooth_grad(guided = True, ssl_model = ssl_model, 
                                       img1 = mixed_images[0], img2 = mixed_images[-1], 
                                       blur_output = True, steps = 50)

# integrated transform
sailency1, sailency2 = averaged_transforms(guided = True, ssl_model = ssl_model, 
                                           mixed_images = mixed_images, 
                                           blur_output = True)

In [None]:
fig, axs = plt.subplots(1, 4, figsize=(20,20))
np.vectorize(lambda ax:ax.axis('off'))(axs)

axs[0].imshow(show_image(mixed_images[0], denormalize = denorm))
axs[0].set_title("Vanilla Gradients")
axs[1].imshow(show_image(sailency1_van.detach(), squeeze = False), cmap = plt.cm.jet)
axs[1].imshow(show_image(mixed_images[0], denormalize = denorm), alpha=0.5)
axs[2].imshow(show_image(mixed_images[-1], denormalize = denorm))
axs[3].imshow(show_image(sailency2_van.detach(), squeeze = False), cmap = plt.cm.jet)
axs[3].imshow(show_image(mixed_images[-1], denormalize = denorm), alpha=0.5)

fig, axs = plt.subplots(1, 4, figsize=(20,20))
np.vectorize(lambda ax:ax.axis('off'))(axs)

axs[0].imshow(show_image(mixed_images[0], denormalize = denorm))
axs[0].set_title("Smooth Gradients")
axs[1].imshow(show_image(sailency1_s.detach(), squeeze = False), cmap = plt.cm.jet)
axs[1].imshow(show_image(mixed_images[0], denormalize = denorm), alpha=0.5)
axs[2].imshow(show_image(mixed_images[-1], denormalize = denorm))
axs[3].imshow(show_image(sailency2_s.detach(), squeeze = False), cmap = plt.cm.jet)
axs[3].imshow(show_image(mixed_images[-1], denormalize = denorm), alpha=0.5)

fig, axs = plt.subplots(1, 4, figsize=(20,20))
np.vectorize(lambda ax:ax.axis('off'))(axs)
axs[0].imshow(show_image(mixed_images[0], denormalize = denorm))
axs[0].set_title("Integrated Transform")
axs[1].imshow(show_image(sailency1.detach(), squeeze = False), cmap = plt.cm.jet)
axs[1].imshow(show_image(mixed_images[0], denormalize = denorm), alpha=0.5)
axs[2].imshow(show_image(mixed_images[-1], denormalize = denorm))
axs[3].imshow(show_image(sailency2.detach(), squeeze = False), cmap = plt.cm.jet)
axs[3].imshow(show_image(mixed_images[-1], denormalize = denorm), alpha=0.5)

### Pixel Invariance

In [None]:
data_samples1, data_samples2, data_labels = get_pixel_invariance_dataset(img_path = img_path, num_augments = 1000, 
                                                                         batch_size =  32, 
                                                                         no_shift_transforms = no_shift_transforms, 
                                                                         ssl_model = ssl_model)

In [None]:
fig, axs = plt.subplots(1, 10, figsize=(20,10))
np.vectorize(lambda ax:ax.axis('off'))(axs)
for m in range(10):
    axs[m].imshow(show_image(data_samples1[m], squeeze = False, denormalize = denorm))
    
fig, axs = plt.subplots(1, 10, figsize=(20,10))
np.vectorize(lambda ax:ax.axis('off'))(axs)
for m in range(10):
    axs[m].imshow(show_image(data_samples2[m], squeeze = False, denormalize = denorm))

In [None]:
inv_heatmap = pixel_invariance(data_samples1 = data_samples1, data_samples2 = data_samples2, data_labels = data_labels,
                               resize_transform = transforms.Resize, size = 64, epochs = 1000, learning_rate = 0.1, 
                               l1_weight = 0.2, zero_small_values = True, blur_output = True)

plt.imshow(viz_map(img_path, inv_heatmap))
plt.axis("off")

### Interaction-CAM

In [None]:
gradcam1, gradcam2 = get_gradcam(ssl_model, img1, img2)
intcam1_mean, intcam2_mean = get_interactioncam(ssl_model, img1, img2, reduction = 'mean')
intcam1_max, intcam2_max = get_interactioncam(ssl_model, img1, img2, reduction = 'max')
intcam1_attn, intcam2_attn = get_interactioncam(ssl_model, img1, img2, reduction = 'attn')

In [None]:
fig, axs = plt.subplots(1, 4, figsize=(20,20))
np.vectorize(lambda ax:ax.axis('off'))(axs)

axs[0].imshow(show_image(img1[0], squeeze = False, denormalize = denorm))
axs[0].set_title("Grad-CAM")
axs[1].imshow(overlay_heatmap(img1, gradcam1, denormalize = denorm))
axs[2].imshow(show_image(img2[0], squeeze = False, denormalize = denorm))
axs[3].imshow(overlay_heatmap(img2, gradcam2, denormalize = denorm))

fig, axs = plt.subplots(1, 4, figsize=(20,20))
np.vectorize(lambda ax:ax.axis('off'))(axs)

axs[0].imshow(show_image(img1[0], squeeze = False, denormalize = denorm))
axs[0].set_title("Interaction-CAM Mean")
axs[1].imshow(overlay_heatmap(img1, intcam1_mean, denormalize = denorm))
axs[2].imshow(show_image(img2[0], squeeze = False, denormalize = denorm))
axs[3].imshow(overlay_heatmap(img2, intcam2_mean, denormalize = denorm))

fig, axs = plt.subplots(1, 4, figsize=(20,20))
np.vectorize(lambda ax:ax.axis('off'))(axs)

axs[0].imshow(show_image(img1[0], squeeze = False, denormalize = denorm))
axs[0].set_title("Interaction-CAM Max")
axs[1].imshow(overlay_heatmap(img1, intcam1_max, denormalize = denorm))
axs[2].imshow(show_image(img2[0], squeeze = False, denormalize = denorm))
axs[3].imshow(overlay_heatmap(img2, intcam2_max, denormalize = denorm))

fig, axs = plt.subplots(1, 4, figsize=(20,20))
np.vectorize(lambda ax:ax.axis('off'))(axs)

axs[0].imshow(show_image(img1[0], squeeze = False, denormalize = denorm))
axs[0].set_title("Interaction-CAM X-Attention")
axs[1].imshow(overlay_heatmap(img1, intcam1_attn, denormalize = denorm))
axs[2].imshow(show_image(img2[0], squeeze = False, denormalize = denorm))
axs[3].imshow(overlay_heatmap(img2, intcam2_attn, denormalize = denorm))