In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import h5py
import matplotlib.pyplot as plt
import math

In [None]:
import os

if os.path.exists('./3dshapes.h5') :
  os.remove('./3dshapes.h5')

!wget https://storage.googleapis.com/3d-shapes/3dshapes.h5

!ls -lh 3dshapes.h5

In [None]:
class Shapes3d(Dataset) :
  def __init__(self, path) :
    self.path = path
    print("dataset to RAM")
    with h5py.File(path, 'r') as f :
      self.images = f['images'][()]
      self.labels = f['labels'][()]

  def __len__(self) :
    return len(self.images)

  def __getitem__(self, idx) :
    img = self.images[idx]
    label = self.labels[idx]
    img = torch.from_numpy(img).float()/ 255.0
    img = img.permute(2, 0, 1)

    return img, label

dataset = Shapes3d('./3dshapes.h5')
dataset_loader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=2)

In [6]:
# Model Architecture

class BetaTCVAE(nn.Module) :
  def __init__(self) :
    super(BetaTCVAE, self).__init__()

    # encoder
    # (3, 64, 64)
    self.conv1 = nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1)
    # (32, 32, 32)
    self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)
    # (64, 16, 16)
    self.conv3 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
    # (128, 8, 8)
    self.conv4 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)
    # (256, 4, ,4)

    self.hl1 = nn.Linear(256 * 4 * 4, 256)
    self.hl2_mu = nn.Linear(256, 12) # mean
    self.hl2_logvar = nn.Linear(256, 12) # log(variance)

    #decoder
    self.hl3 = nn.Linear(12, 256)
    self.hl4 = nn.Linear(256, 256 * 4 * 4)

    self.convT1 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1)
    self.convT2 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)
    self.convT3 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1)
    self.convT4 = nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1)

  def encode(self, x) :
    c1 = F.leaky_relu(self.conv1(x))
    c2 = F.leaky_relu(self.conv2(c1))
    c3 = F.leaky_relu(self.conv3(c2))
    c4 = F.leaky_relu(self.conv4(c3))
    h1 = F.leaky_relu(self.hl1(c4.view(-1, 256 * 4 * 4)))
    mu = self.hl2_mu(h1)
    logvar = self.hl2_logvar(h1)
    logvar = torch.clamp(logvar, min=-6.0, max=6.0)
    return mu, logvar

  def reparameterize(self, mu, logvar) :
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std

  def decode(self, z) :
    h3 = F.leaky_relu(self.hl3(z))
    h4 = F.leaky_relu(self.hl4(h3))
    cT1 = F.leaky_relu(self.convT1(h4.view(-1, 256, 4, 4)))
    cT2 = F.leaky_relu(self.convT2(cT1))
    cT3 = F.leaky_relu(self.convT3(cT2))
    cT4 = torch.sigmoid(self.convT4(cT3))
    img = torch.clamp(cT4, min=1e-6, max=1.0)
    return img

  def forward(self, x) :
    mu, logvar = self.encode(x)
    z = self.reparameterize(mu, logvar)
    recon_x = self.decode(z)
    return recon_x, z, mu, logvar

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BetaTCVAE()
checkpoint = torch.load("final_model.pth", map_location=torch.device(device))
model.load_state_dict(checkpoint)
model.eval()
torch.no_grad()

In [None]:
def calculate_entropy(v_k, bins) :
    counts = torch.bincount(v_k, minlength=bins)
    prob_k = counts/counts.sum()
    entropy = -torch.sum(prob_k[prob_k > 0] * torch.log(prob_k[prob_k > 0]))
    
    return entropy
    
def calculate_mutual_information(curr_z_discrete, curr_v, num_bins_z, curr_v_bins) :
    N = curr_z_discrete.shape[0]
    joint_uid = curr_z_discrete * curr_v_bins + curr_v
    joint_counts = torch.bincount(joint_uid, minlength=num_bins_z * curr_v_bins)
    joint_counts = joint_counts.reshape(num_bins_z, curr_v_bins)
    P_zv = joint_counts / N
    P_z = P_zv.sum(dim=1)
    P_v = P_zv.sum(dim=0)
    return P_zv, P_z, P_v

In [None]:
# collecting data for MIG

ANALYSIS_BATCHES = 100

latent_code = []
ground_truth = []

with torch.no_grad() :
    
    for batch_idx, (data, label) in enumerate(dataset_loader) :
        if batch_idx >= ANALYSIS_BATCHES :
            break
        
        batch = data.to(device)
        batch_z, _ = model.encode(batch)
        latent_code.append(batch_z.cpu())
        ground_truth.append(label)
        
z = torch.cat(latent_code, dim = 0)
v = torch.cat(ground_truth, dim = 0)

In [None]:
# pre processing the tensor

num_bins_z = 20
z_min = z.min(dim = 0, keepdim=True)[0]
z_max = z.max(dim = 0, keepdim=True)[0]
z_range = z_max - z_min
z_range[z_range == 0] = 1.0
z_norm = (z - z_min) / (z_range)
z_discrete = ((num_bins_z - 1e-5) * z_norm).long()

In [None]:
# calculating the MIG matrix

num_latent = z.shape[1]
num_ground = v.shape[1]

mat_mi = torch.zeros(num_latent, num_ground)
ground_entropies = torch.zeros(num_ground)

for j in range(num_ground) :
    curr_v_float = v[:, j]
    curr_v = torch.searchsorted(torch.unique(curr_v_float), curr_v_float)
    curr_v_bins = int(curr_v.max().item()) + 1
    ground_entropies[j] = calculate_entropy(curr_v, curr_v_bins)
    
    for i in range(num_latent) :
        
        curr_z = z_discrete[:, i]
        p_zv, p_z, p_v = calculate_mutual_information(curr_z, curr_v, num_bins_z, curr_v_bins)
        denominator = p_z.view(-1, 1) * p_v.view(1, -1)
        arg = p_zv / denominator
        mat_mi[i, j] = torch.sum(p_zv[p_zv > 0] * torch.log(arg[p_zv > 0]))

In [None]:
# Calculating the gaps

sorted_mi = mat_mi
sorted_mi, _ = torch.sort(sorted_mi, dim=0, descending=True)
best_mi_row = sorted_mi[0, :]
second_best_mi_row = sorted_mi[1, :]

gap_vec = best_mi_row - second_best_mi_row
normalized = gap_vec / (ground_entropies +  1e-10)
final_mig = torch.mean(normalized).item()
print(f"final MIG for the model : {final_mig:.4f}")
print(f"individual MIG for each latent code\n {normalized}")

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
mi_data = mat_mi.cpu().numpy()

factor_labels = ['Floor Hue', 'Wall Hue', 'Object Hue', 'Scale', 'Shape', 'Orientation']
latent_labels = [f"$z_{{{i}}}$" for i in range(mi_data.shape[0])]

plt.figure(figsize=(10, 8))
plt.title("Mutual Information Matrix (Latents vs Factors)", fontsize=14)
sns.heatmap(mi_data, 
            xticklabels=factor_labels, 
            yticklabels=latent_labels, 
            annot=True, 
            fmt=".2f", 
            cmap="viridis")

plt.xlabel("Ground Truth Factors")
plt.ylabel("Latent Codes")
plt.show()

In [None]:
import torch
import torch.nn.functional as F

def init_kernel(size, sigma) :
    x = torch.arange(size).float()
    center = size // 2
    gaussian_1d = torch.exp(- 0.5 * ((x - center) ** 2 / sigma ** 2))
    gaussian_1d = gaussian_1d / gaussian_1d.sum()
    gaussian_2d = gaussian_1d.unsqueeze(1) @ gaussian_1d.unsqueeze(0)
    return gaussian_2d.view(1, 1, size, size)

def ssim(imgs1, imgs2) :
    
    c1 = 0.01**2
    c2 = 0.03**2
    
    channel = imgs1[0].size(0)
    window = init_kernel(11, 1.5).to(imgs1[0].device)
    size = window.shape[2]
    window = window.expand(channel, 1, size, size).contiguous()
    
    mu1 = F.conv2d(imgs1, weight=window, padding=size//2, groups=channel)
    mu2 = F.conv2d(imgs2, weight=window, padding=size//2, groups=channel)
    
    mu1_sq = mu1 * mu1
    mu2_sq = mu2 * mu2
    mu1_mu2 = mu1 * mu2
    
    sigma1_sq = F.conv2d(imgs1*imgs1, weight=window, padding=size//2, groups=channel) - mu1_sq
    sigma2_sq = F.conv2d(imgs2*imgs2, weight=window, padding=size//2, groups=channel) - mu2_sq
    sigma12 = F.conv2d(imgs1 * imgs2, weight=window, padding=size//2, groups=channel) - mu1_mu2
    
    
    numerator = (2 * mu1_mu2 + c1) * (2 * sigma12 + c2)
    denominator = (mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2)
    
    ssim_map = numerator / denominator
    
    return torch.mean(ssim_map)

In [None]:
with torch.no_grad() :
    iter_data = iter(dataset_loader)
    batch = next(iter_data)
    original_imgs = batch[0].to(device)
    recon_batch, _, _, _ = model(original_imgs)
    
ssim_score = ssim(original_imgs, recon_batch)
print(f"The SSIM score of the model is : {ssim_score:.4f}")