In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import h5py
import math

In [None]:
import os

if os.path.exists('./3dshapes.h5') :
  os.remove('./3dshapes.h5')

!wget https://storage.googleapis.com/3d-shapes/3dshapes.h5

!ls -lh 3dshapes.h5

In [None]:
class Shapes3d(Dataset) :
  def __init__(self, path) :
    self.path = path
    print("dataset to RAM")
    with h5py.File(path, 'r') as f :
      self.images = f['images'][()]
      self.labels = f['labels'][()]

  def __len__(self) :
    return len(self.images)

  def __getitem__(self, idx) :
    img = self.images[idx]
    img = torch.from_numpy(img).float()/ 255.0
    img = img.permute(2, 0, 1)

    return img

dataset = Shapes3d('./3dshapes.h5')
dataset_loader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=2)

In [None]:
# Model Architecture

class BetaTCVAE(nn.Module) :
  def __init__(self) :
    super(BetaTCVAE, self).__init__()

    # encoder
    # (3, 64, 64)
    self.conv1 = nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1)
    # (32, 32, 32)
    self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)
    # (64, 16, 16)
    self.conv3 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
    # (128, 8, 8)
    self.conv4 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)
    # (256, 4, ,4)

    self.hl1 = nn.Linear(256 * 4 * 4, 256)
    self.hl2_mu = nn.Linear(256, 12) # mean
    self.hl2_logvar = nn.Linear(256, 12) # log(variance)

    #decoder
    self.hl3 = nn.Linear(12, 256)
    self.hl4 = nn.Linear(256, 256 * 4 * 4)

    self.convT1 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1)
    self.convT2 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)
    self.convT3 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1)
    self.convT4 = nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1)

  def encode(self, x) :
    c1 = F.leaky_relu(self.conv1(x))
    c2 = F.leaky_relu(self.conv2(c1))
    c3 = F.leaky_relu(self.conv3(c2))
    c4 = F.leaky_relu(self.conv4(c3))
    h1 = F.leaky_relu(self.hl1(c4.view(-1, 256 * 4 * 4)))
    mu = self.hl2_mu(h1)
    logvar = self.hl2_logvar(h1)
    logvar = torch.clamp(logvar, min=-6.0, max=6.0)
    return mu, logvar

  def reparameterize(self, mu, logvar) :
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std

  def decode(self, z) :
    h3 = F.leaky_relu(self.hl3(z))
    h4 = F.leaky_relu(self.hl4(h3))
    cT1 = F.leaky_relu(self.convT1(h4.view(-1, 256, 4, 4)))
    cT2 = F.leaky_relu(self.convT2(cT1))
    cT3 = F.leaky_relu(self.convT3(cT2))
    cT4 = torch.sigmoid(self.convT4(cT3))
    img = torch.clamp(cT4, min=1e-6, max=1.0)
    return img

  def forward(self, x) :
    mu, logvar = self.encode(x)
    z = self.reparameterize(mu, logvar)
    recon_x = self.decode(z)
    return recon_x, z, mu, logvar

In [None]:
# initialization of model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

BATCH_SIZE = 128
TRAINING_BATCHES = 500
LR = 5e-4
EPOCHS = 50
GAMMA = 1.0
BETA = 6.0
ANNEAL_STEPS = 5000

model = BetaTCVAE()
optimizer = optim.Adam(model.parameters(), lr=LR)

def log_density_gaussian(z, mu, logvar) :
  norm = -0.5 * (math.log(2 * math.pi) + logvar)
  log_density = norm - 0.5 * ((z - mu) ** 2 * torch.exp(-logvar))
  return log_density

def BCE_loss_function(recon_x, x) :
  BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
  return BCE

def log_q_zx(z, mu, logvar) :
  log_q_zx = log_density_gaussian(z, mu, logvar)
  return log_q_zx.sum(dim=1)

def log_p_z(z) :
  standard = torch.zeros_like(z)
  log_p_z = log_density_gaussian(z, standard, standard)
  return log_p_z.sum(dim=1)

def initialize_weights(batch_size=BATCH_SIZE, dataset_size = len(dataset)) :
  stratified_weight = (dataset_size - batch_size + 1) / (dataset_size * (batch_size - 1))
  weight_arr = torch.full((batch_size, batch_size), stratified_weight)
  weight_arr.view(-1)[::batch_size + 1] = 1.0 / dataset_size
  return weight_arr

def mat_log_q_z(batch_z, batch_mu, batch_logvar, batch_size=BATCH_SIZE, latent_dim=12) :
  mat_log_q_z = log_density_gaussian(batch_z.view(1, batch_size, latent_dim),
                                     batch_mu.view(batch_size, 1, latent_dim),
                                     batch_logvar.view(batch_size, 1, latent_dim))
  return mat_log_q_z

def loss_function(num_iters, recon_batch, batch, batch_z, batch_mu, batch_logvar) :
  mat_log_q_z_val = mat_log_q_z(batch_z, batch_mu, batch_logvar)
  log_p_z_val = log_p_z(batch_z)
  log_q_zx_val = log_q_zx(batch_z, batch_mu, batch_logvar)
  BCE = BCE_loss_function(recon_batch, batch)

  weights = initialize_weights()
  weights = weights.to(device)
  log_weights = torch.log(weights + 1e-10)

  # mat_log_q_z_val shape : [128, 128, 12]

  log_joint_prob = mat_log_q_z_val.sum(2) + log_weights
  # log_joint_prob shape : [128, 128]
  # [i, j] is the probability that x_i generated latent code z_j
  log_q_z = torch.logsumexp(log_joint_prob, dim=1, keepdim=False)
  # log_q_z shape : [128]
  # it is a density map of the latent codees of the batch

  log_marg_prob = mat_log_q_z_val + log_weights.view(BATCH_SIZE, BATCH_SIZE, 1)
  # log_marg_prob shape : [128, 128, 12]
  # probability of each component of each latent code occuring individually in each picture's distibution
  log_prod_q_z = torch.logsumexp(log_marg_prob, dim=1, keepdim=False).sum(1)
  # log_prod_q_z shape : [128]
  # i'th entry is product of marginal probability of occurence each component of z_i

  mi_loss = (log_q_zx_val - log_q_z).sum()
  tc_loss = (log_q_z - log_prod_q_z).sum()
  dwkl_loss = (log_prod_q_z - log_p_z_val).sum()

  gamma_ratio = min(1.0 * num_iters/ANNEAL_STEPS, 1.0)

  loss = BCE + mi_loss + BETA * tc_loss + gamma_ratio * GAMMA * dwkl_loss
  return loss, BCE, mi_loss, tc_loss, dwkl_loss# initialization of model

In [None]:
model.to(device)
model.train()

train_bce_losses = []
train_mi_losses = []
train_tc_losses = []
train_dwkl_losses = []

num_iters = 0

print(f"using device {device}")
print("Training started -")

for epoch in range(EPOCHS) :
  overall_BCE_loss = 0
  overall_tc_loss = 0
  overall_mi_loss = 0
  overall_dwkl_loss = 0

  for batch_idx, data in enumerate(dataset_loader) :
    num_iters += 1
    data = data.to(device)
    optimizer.zero_grad()
    recon_batch, batch_z, batch_mu, batch_logvar = model(data)
    loss, bce_loss, mi_loss, tc_loss, dwkl_loss = loss_function(num_iters, recon_batch, data, batch_z, batch_mu, batch_logvar)
    loss.backward()
    optimizer.step()
    overall_BCE_loss += bce_loss.item()
    overall_tc_loss += tc_loss.item()
    overall_mi_loss += mi_loss.item()
    overall_dwkl_loss += dwkl_loss.item()

    if batch_idx > TRAINING_BATCHES :
      break

  average_bce_loss = overall_BCE_loss / TRAINING_BATCHES
  average_tc_loss = overall_tc_loss / TRAINING_BATCHES
  average_mi_loss = overall_mi_loss / TRAINING_BATCHES
  average_dwkl_loss = overall_dwkl_loss / TRAINING_BATCHES
  train_bce_losses.append(average_bce_loss)
  train_tc_losses.append(average_tc_loss)
  train_mi_losses.append(average_mi_loss)
  train_dwkl_losses.append(average_dwkl_loss)
  print(f"epoch : {epoch}")
  print(f"BCE loss : {average_bce_loss:.2f}\tMI loss : {average_mi_loss:.2f}")
  print(f"TC loss (per image) : {average_tc_loss/BATCH_SIZE:.2f}\tDWKL loss (per image) : {average_dwkl_loss/BATCH_SIZE:.2f}")
  print("------------------------------------------------------------------")

print("Training finished")
save_path = "./model_weights_tcvae.pth"
torch.save(model.state_dict(), save_path)
print(f"Model weights saved to {save_path}")

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
axes[0].plot(train_dwkl_losses, label='DWKL loss')
axes[0].set_xlabel("Epochs")
axes[0].set_ylabel("DWKL loss")
axes[0].set_title("Dimension wise KL loss")
axes[0].grid(True, alpha=0.3)
axes[1].plot(train_tc_losses, label='TC loss')
axes[1].plot(train_mi_losses, label='MI loss')
axes[1].set_xlabel("Epochs")
axes[1].set_ylabel("loss")
axes[1].set_title("TC and MI losses")
axes[1].legend()
axes[1].grid(True, alpha=0.3)
plt.show()

plt.plot(train_bce_losses)
plt.xlabel("Epochs")
plt.ylabel("BCE loss")
plt.title("BCE loss")
plt.grid(True, alpha=0.3)
plt.show()

In [None]:
model.eval()

iterable = iter(dataset_loader)

with torch.no_grad() :
  x = next(iterable)
  x = x.to(device)
  recon_batch, _, mu, logvar = model(x)

fig, axes = plt.subplots(2, 10, figsize=(15, 4))

for i in range(10) :
  original = x[i].cpu().permute(1, 2, 0).squeeze().numpy()
  axes[0, i].imshow(original)
  axes[0, i].axis('off')
  if i == 5 : axes[0, i].set_title("original images")
  recon = recon_batch[i].cpu().permute(1, 2, 0).squeeze().numpy()
  axes[1, i].imshow(recon)
  axes[1, i].axis('off')
  if i == 5 : axes[1, i].set_title("reconstructed images")

plt.show()

In [None]:
def visualize_traversal(model, iterable, dims=12, steps=12) :
  model.eval()
  data = next(iterable)
  data = data.to(device)

  with torch.no_grad() :
    _, _, mu, _ = model(data)
    base = mu[0].clone().unsqueeze(0)

  path_range = torch.linspace(-3, 3, steps=steps).to(device)

  fig, axes = plt.subplots(dims, steps, figsize=(steps, dims))
  plt.subplots_adjust(wspace = 0.05, hspace = 0.05)

  for dim in range(dims) :
    for step in range(steps) :

      z_ = base.clone()
      z_[0, dim] = path_range[step]

      with torch.no_grad() :
        recon_img = model.decode(z_)

      ax = axes[dim, step]
      img = recon_img[0].cpu().permute(1, 2, 0).squeeze().numpy()
      ax.imshow(img)
      ax.axis('off')

  plt.suptitle(f"Latent Traversals", fontsize=16)
  plt.show()

visualize_traversal(model, iterable)