In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import h5py
import matplotlib.pyplot as plt
import math
import random

In [None]:
import os

if os.path.exists('./3dshapes.h5') :
  print("file already exists")
else :
    !wget https://storage.googleapis.com/3d-shapes/3dshapes.h5

!ls -lh 3dshapes.h5

In [None]:
class Shapes3d(Dataset) :
  def __init__(self, path) :
    self.path = path
    print("dataset to RAM")
    with h5py.File(path, 'r') as f :
      self.images = f['images'][()]
      self.labels = f['labels'][()]

  def __len__(self) :
    return len(self.images)

  def __getitem__(self, idx) :
    img = self.images[idx]
    label = self.labels[idx]
    img = torch.from_numpy(img).float()/ 255.0
    img = img.permute(2, 0, 1)

    return img, label

dataset = Shapes3d('./3dshapes.h5')
dataset_loader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=2)

In [None]:
# Model Architecture

class BetaTCVAE(nn.Module) :
  def __init__(self) :
    super(BetaTCVAE, self).__init__()

    # encoder
    # (3, 64, 64)
    self.conv1 = nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1)
    # (32, 32, 32)
    self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)
    # (64, 16, 16)
    self.conv3 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
    # (128, 8, 8)
    self.conv4 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)
    # (256, 4, ,4)

    self.hl1 = nn.Linear(256 * 4 * 4, 256)
    self.hl2_mu = nn.Linear(256, 12) # mean
    self.hl2_logvar = nn.Linear(256, 12) # log(variance)

    #decoder
    self.hl3 = nn.Linear(12, 256)
    self.hl4 = nn.Linear(256, 256 * 4 * 4)

    self.convT1 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1)
    self.convT2 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)
    self.convT3 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1)
    self.convT4 = nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1)

  def encode(self, x) :
    c1 = F.leaky_relu(self.conv1(x))
    c2 = F.leaky_relu(self.conv2(c1))
    c3 = F.leaky_relu(self.conv3(c2))
    c4 = F.leaky_relu(self.conv4(c3))
    h1 = F.leaky_relu(self.hl1(c4.view(-1, 256 * 4 * 4)))
    mu = self.hl2_mu(h1)
    logvar = self.hl2_logvar(h1)
    logvar = torch.clamp(logvar, min=-6.0, max=6.0)
    return mu, logvar

  def reparameterize(self, mu, logvar) :
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std

  def decode(self, z) :
    h3 = F.leaky_relu(self.hl3(z))
    h4 = F.leaky_relu(self.hl4(h3))
    cT1 = F.leaky_relu(self.convT1(h4.view(-1, 256, 4, 4)))
    cT2 = F.leaky_relu(self.convT2(cT1))
    cT3 = F.leaky_relu(self.convT3(cT2))
    cT4 = torch.sigmoid(self.convT4(cT3))
    img = torch.clamp(cT4, min=1e-6, max=1.0)
    return img

  def forward(self, x) :
    mu, logvar = self.encode(x)
    z = self.reparameterize(mu, logvar)
    recon_x = self.decode(z)
    return recon_x, z, mu, logvar

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BetaTCVAE()
checkpoint = torch.load("/kaggle/input/betatcvae-shapes3d/final_model.pth", map_location=torch.device(device))
model.load_state_dict(checkpoint)
model.eval()
model.to(device)

In [None]:
# batch curation

def discretize(dataset) :
    all_floats = np.array([data[1] for data in dataset])
    all_uids = np.zeros_like(all_floats, dtype=int)

    for i in range(6) :
        _, all_uids[:, i] = np.unique(all_floats[:, i], return_inverse=True)

    return all_uids

class CuratedSampler :
    def __init__(self, dataset, normal_criteria, outlier_criteria = {}) :
        self.discrete = discretize(dataset)
        self.dataset = dataset
        self.normal_idx = []
        self.outlier_idx = []

        for idx in range(len(dataset)) :            
            is_normal = True
            for label_idx, normal_value in normal_criteria.items() :
                if self.discrete[idx][label_idx] != normal_value :
                    is_normal = False
            if is_normal : self.normal_idx.append(idx)
            else :
                is_outlier = True
                for label_idx, outlier_value in outlier_criteria.items() :
                    if self.discrete[idx][label_idx] != outlier_value :
                        is_outlier = False
                if is_outlier : self.outlier_idx.append(idx)
    
    def get_batch(self, num_normal, device, num_outlier=0) :
        normal_idxes = random.choices(self.normal_idx, k = num_normal)
        outlier_idxes = random.choices(self.outlier_idx, k = num_outlier)
        batch = []
        for idx in normal_idxes :
            batch.append(self.dataset[idx][0])
        for idx in outlier_idxes :
            batch.append(self.dataset[idx][0])
            
        return torch.stack(batch).to(device)

In [None]:
import cv2

class VAEGradCAM :
    def __init__(self, target_layer) :
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None

        target_layer.register_forward_hook(self.save_activation)
        target_layer.register_full_backward_hook(self.save_gradient)

    def save_activation(self, module, input, output) :
        self.activations = output

    def save_gradient(self, module, grad_input, grad_output) :
        self.gradients = grad_output[0]

    def __call__(self, x, latent_idx, visualize_negative=False) :
        mu, _ = model.encode(x)
        target_value = mu[0, latent_idx]
        model.zero_grad()

        if visualize_negative :
            (target_value * -1).backward()
        else :
            target_value.backward()

        gradients = self.gradients[0]
        activations = self.activations[0]

        weights = torch.mean(gradients, dim=(1, 2))
        cam = torch.zeros(activations.shape[1:], dtype=torch.float32).to(activations.device)
        for i, w in enumerate(weights) :
            cam += w * activations[i]
        cam = F.relu(cam)

        cam = cam - torch.min(cam)
        cam = cam / (torch.max(cam) + 1e-8)
        cam = cam.detach().cpu().numpy()

        cam = cv2.resize(cam, (x.shape[3], x.shape[2]))

        return cam, mu[0, latent_idx].item()

In [None]:
def plot_latent_cam(original, cam_map, latent_idx, latent_val, visualize_negative) :
    img = original.squeeze().permute(1, 2, 0).cpu().numpy()
    heatmap = cv2.applyColorMap(np.uint8(255 * cam_map), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    heatmap = heatmap[..., ::-1]

    overlay = 0.5 * heatmap + 0.5 * img
    overlay = overlay / np.max(overlay)

    fig, axes = plt.subplots(1, 3, figsize=(10, 4))
    
    axes[0].imshow(img)
    axes[0].set_title("Original Image")
    axes[0].axis('off')

    axes[1].imshow(heatmap)
    axes[1].set_title(f"latent index {latent_idx}")
    axes[1].axis('off')

    axes[2].imshow(overlay)
    axes[2].set_title(f"value = {latent_val:.3f}")
    axes[2].axis('off')

    if visualize_negative :
        fig.suptitle("Heatmap of negative activations")
    else :
        fig.suptitle("Heatmap of positive activations")

    plt.show()

In [None]:
curator = CuratedSampler(dataset, {5 : 14})
batch = curator.get_batch(100, device)
cam = VAEGradCAM(model.conv4)


In [None]:
image = batch[56].unsqueeze(0).to(device)
latent_idx = 4
heatmap1, val1 = cam(image, latent_idx=latent_idx, visualize_negative=True)
plot_latent_cam(image, heatmap1, latent_idx, latent_val = val1, visualize_negative = True)
heatmap2, val2 = cam(image, latent_idx=latent_idx)
plot_latent_cam(image, heatmap2, latent_idx, latent_val = val2, visualize_negative = False)