In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import h5py
import matplotlib.pyplot as plt
import math
import random

In [None]:
import os

if os.path.exists('./3dshapes.h5') :
  print("file already exists")
else :
    !wget https://storage.googleapis.com/3d-shapes/3dshapes.h5

!ls -lh 3dshapes.h5

In [None]:
class Shapes3d(Dataset) :
  def __init__(self, path) :
    self.path = path
    print("dataset to RAM")
    with h5py.File(path, 'r') as f :
      self.images = f['images'][()]
      self.labels = f['labels'][()]

  def __len__(self) :
    return len(self.images)

  def __getitem__(self, idx) :
    img = self.images[idx]
    label = self.labels[idx]
    img = torch.from_numpy(img).float()/ 255.0
    img = img.permute(2, 0, 1)

    return img, label

dataset = Shapes3d('./3dshapes.h5')
dataset_loader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=2)

In [None]:
# Model Architecture

class BetaTCVAE(nn.Module) :
  def __init__(self) :
    super(BetaTCVAE, self).__init__()

    # encoder
    # (3, 64, 64)
    self.conv1 = nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1)
    # (32, 32, 32)
    self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)
    # (64, 16, 16)
    self.conv3 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
    # (128, 8, 8)
    self.conv4 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)
    # (256, 4, ,4)

    self.hl1 = nn.Linear(256 * 4 * 4, 256)
    self.hl2_mu = nn.Linear(256, 12) # mean
    self.hl2_logvar = nn.Linear(256, 12) # log(variance)

    #decoder
    self.hl3 = nn.Linear(12, 256)
    self.hl4 = nn.Linear(256, 256 * 4 * 4)

    self.convT1 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1)
    self.convT2 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)
    self.convT3 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1)
    self.convT4 = nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1)

  def encode(self, x) :
    c1 = F.leaky_relu(self.conv1(x))
    c2 = F.leaky_relu(self.conv2(c1))
    c3 = F.leaky_relu(self.conv3(c2))
    c4 = F.leaky_relu(self.conv4(c3))
    h1 = F.leaky_relu(self.hl1(c4.view(-1, 256 * 4 * 4)))
    mu = self.hl2_mu(h1)
    logvar = self.hl2_logvar(h1)
    logvar = torch.clamp(logvar, min=-6.0, max=6.0)
    return mu, logvar

  def reparameterize(self, mu, logvar) :
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std

  def decode(self, z) :
    h3 = F.leaky_relu(self.hl3(z))
    h4 = F.leaky_relu(self.hl4(h3))
    cT1 = F.leaky_relu(self.convT1(h4.view(-1, 256, 4, 4)))
    cT2 = F.leaky_relu(self.convT2(cT1))
    cT3 = F.leaky_relu(self.convT3(cT2))
    cT4 = torch.sigmoid(self.convT4(cT3))
    img = torch.clamp(cT4, min=1e-6, max=1.0)
    return img

  def forward(self, x) :
    mu, logvar = self.encode(x)
    z = self.reparameterize(mu, logvar)
    recon_x = self.decode(z)
    return recon_x, z, mu, logvar

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BetaTCVAE()
checkpoint = torch.load("/kaggle/input/betatcvae-shapes3d/final_model.pth", map_location=torch.device(device))
model.load_state_dict(checkpoint)
model.eval()
model.to(device)

In [None]:
# batch curation

def discretize(dataset) :
    all_floats = np.array([data[1] for data in dataset])
    all_uids = np.zeros_like(all_floats, dtype=int)

    for i in range(6) :
        _, all_uids[:, i] = np.unique(all_floats[:, i], return_inverse=True)

    return all_uids

class CuratedSampler :
    def __init__(self, dataset, normal_criteria, outlier_criteria = {}) :
        self.discrete = discretize(dataset)
        self.dataset = dataset
        self.normal_idx = []
        self.outlier_idx = []

        for idx in range(len(dataset)) :            
            is_normal = True
            for label_idx, normal_value in normal_criteria.items() :
                if self.discrete[idx][label_idx] != normal_value :
                    is_normal = False
            if is_normal : self.normal_idx.append(idx)
            else :
                is_outlier = True
                for label_idx, outlier_value in outlier_criteria.items() :
                    if self.discrete[idx][label_idx] != outlier_value :
                        is_outlier = False
                if is_outlier : self.outlier_idx.append(idx)
    
    def get_batch(self, num_normal, device, num_outlier=0) :
        normal_idxes = random.choices(self.normal_idx, k = num_normal)
        outlier_idxes = random.choices(self.outlier_idx, k = num_outlier)
        batch = []
        for idx in normal_idxes :
            batch.append(self.dataset[idx][0])
        for idx in outlier_idxes :
            batch.append(self.dataset[idx][0])
            
        return torch.stack(batch).to(device)

In [None]:
curator = CuratedSampler(dataset, {})
batch = curator.get_batch(100, device)

In [None]:
def shape_traversal(batch, dim1, dim2, rows = 21, cols = 21) :
    
    with torch.no_grad() :
        _, _, mu, _ = model(batch)
        base = mu[25].clone().unsqueeze(0)

    path_range = torch.linspace(-2, 2, steps= rows).to(device)

    fig, axes = plt.subplots(rows, cols, figsize=(rows, cols))
    plt.subplots_adjust(wspace=0.05, hspace=0.05)

    for row in range(rows) :
        z_ = base.clone()
        z_[0, dim1] = path_range[row]
        for col in range(cols) :
            z_[0, dim2] = path_range[col]

            with torch.no_grad() :
                img = model.decode(z_)

            ax = axes[row, col]
            ax.imshow(img[0].cpu().permute(1, 2, 0).squeeze().numpy())
            ax.axis('off')

    fig.suptitle(f"2D matrix of changes in latent code index {dim1} and {dim2} \nX - axis is latent code {dim1} and Y-axis is latent code {dim2}")
    plt.show()
    
shape_traversal(batch, 5, 10)