In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import math

In [None]:
class Dsprite(Dataset) :
  def __init__(self,path) :
    data = np.load(path, encoding='latin1', allow_pickle=True)
    self.imgs = data['imgs']
    self.true_latent = torch.from_numpy(data['latents_values']).float()
    self.latent_info = torch.from_numpy(data['latents_classes']).float()

  def __len__(self) :
    return len(self.imgs)

  def __getitem__(self, idx) :
    img = self.imgs[idx].astype(np.float32)
    img = torch.from_numpy(img).unsqueeze(0)

    return img, self.true_latent[idx]


path = './dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz'

dataset = Dsprite(path=path)
dataset_loader = DataLoader(dataset, batch_size = 256, shuffle = True, num_workers = 2)

In [None]:
print(len(dataset_loader))
iter_data = iter(dataset_loader)
print(next(iter_data)[0].shape)

In [None]:
# Model Architecture

class BetaTCVAE(nn.Module) :
  def __init__(self) :
    super(BetaTCVAE, self).__init__()

    #encoder
    self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=4, stride=2, padding=1)
    self.conv2 = nn.Conv2d(in_channels= 32, out_channels=32, kernel_size=4, stride=2, padding=1)
    self.conv3 = nn.Conv2d(in_channels= 32, out_channels=32, kernel_size=4, stride=2, padding=1)
    self.conv4 = nn.Conv2d(in_channels= 32, out_channels=32, kernel_size=4, stride=2, padding=1)
    self.hl1 = nn.Linear(32 * 4 * 4, 256)
    self.hl2_mu = nn.Linear(256, 10)
    self.hl2_logvar = nn.Linear(256, 10)

    #decoder
    self.hl3 = nn.Linear(10, 256)
    self.hl4 = nn.Linear(256, 32 * 4 * 4)
    self.cT1 = nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=1)
    self.cT2 = nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=1)
    self.cT3 = nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=1)
    self.cT4 = nn.ConvTranspose2d(in_channels=32, out_channels=1, kernel_size=4, stride=2, padding=1)

  def encode(self, x) :
    c1 = F.relu(self.conv1(x))
    c2 = F.relu(self.conv2(c1))
    c3 = F.relu(self.conv3(c2))
    c4 = F.relu(self.conv4(c3))
    h1 = F.relu(self.hl1(c4.view(-1, 32*4*4)))
    mu = self.hl2_mu(h1)
    logvar = self.hl2_logvar(h1)
    logvar = torch.clamp(logvar, min=-6.0, max=6.0)
    return mu, logvar

  def reparameterize(self, mu, logvar) :
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std

  def decode(self, z) :
    h3 = F.relu(self.hl3(z))
    h4 = F.relu(self.hl4(h3))
    T1 = F.relu(self.cT1(h4.view(-1, 32, 4, 4)))
    T2 = F.relu(self.cT2(T1))
    T3 = F.relu(self.cT3(T2))
    T4 = torch.sigmoid(self.cT4(T3))
    T4 = torch.clamp(T4, min=1e-6, max=1.0)
    return T4

  def forward(self, x) :
    mu, logvar = self.encode(x)
    z = self.reparameterize(mu, logvar)
    recon_x = self.decode(z)
    return recon_x, z, mu, logvar

In [None]:
# initialization of model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

BATCH_SIZE = 256
TRAINING_BATCHES = 500
LR = 5e-4
EPOCHS = 100
GAMMA = 1
BETA = 6
ANNEAL_STEPS = 5000

model = BetaTCVAE()
optimizer = optim.Adam(model.parameters(), lr=LR)

def log_density_gaussian(x, mu, logvar) :
  norm = -0.5 * (math.log(2 * math.pi) + logvar)
  log_density = norm - 0.5 * ((x - mu) ** 2 * torch.exp(-logvar))
  return log_density

def BCE_loss_function(recon_x, x) :
  BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
  return BCE

# the below code describes log(q(z|x)) that is the log-PDF of "given x what is the probability that z is the latent code for x"
# output shape (128) (128 probabilities)
def log_q_zx(z, mu, logvar) :
  return log_density_gaussian(z, mu, logvar).sum(dim=1)

# the below code is log(p(z)) that is the log-PDF of a latent code being chosen
# output shape (128) (128 probabilities)
def log_p_z(z) :
  zeros = torch.zeros_like(z)
  return log_density_gaussian(z, zeros, zeros).sum(dim=1)

# the below code initializes a weighted matrix to give the loss it's importance in accordance with mat_log_q_z
def initialize_stratified_weights(batch_size=BATCH_SIZE, total_training_samples=BATCH_SIZE*TRAINING_BATCHES) :
  strat_weight = (1/batch_size)
  importance_weights = torch.full((batch_size, batch_size), strat_weight)
  return importance_weights

# the below code is 128 * 128 matrix of log(q(z)) this is the probability that latent code z_i if we assume it was generated by encoder for image mu_j and logvar_j
# this outputs a (128, 128, 10) tensor (128 probabilities against all 128 latent codes for each latent factor)
def mat_log_q_z(z, mu, logvar, batch_size=BATCH_SIZE, latent_dim=10) :
  mat_log_q_z = log_density_gaussian(z.view(batch_size, 1, latent_dim),
                                     mu.view(1, batch_size, latent_dim),
                                     logvar.view(1, batch_size, latent_dim))
  return mat_log_q_z

def loss_function(num_iters, recon_x, x, z, mu, logvar) :
  mat_log_q_z_val = mat_log_q_z(z, mu, logvar)
  log_p_z_val = log_p_z(z)
  log_q_zx_val = log_q_zx(z, mu, logvar)

  weights = initialize_stratified_weights()
  weights = weights.to(device)
  log_weights = torch.log(weights + 1e-10)

  log_joint_prob = mat_log_q_z_val.sum(2) + log_weights
  log_q_z = torch.logsumexp(log_joint_prob, dim=1, keepdim=False)

  log_marg_prob = mat_log_q_z_val + log_weights.view(BATCH_SIZE, BATCH_SIZE, 1)
  log_prod_q_z = torch.logsumexp(log_marg_prob, dim=1, keepdim=False).sum(1)

  mi_loss = (log_q_zx_val - log_q_z).sum()
  tc_loss = (log_q_z - log_prod_q_z).sum()
  kld_loss = (log_prod_q_z - log_p_z_val).sum()
  BCE = BCE_loss_function(recon_x, x)

  gamma_ratio = min(1.0 * num_iters/ANNEAL_STEPS, 1.0)
  loss = BCE + mi_loss + BETA * tc_loss + gamma_ratio * GAMMA * kld_loss
  return loss, BCE, mi_loss, tc_loss, kld_loss

In [None]:
model.to(device)

model.train()

train_BCE_losses = []
train_mi_losses_frac = []
train_tc_losses_frac = []
train_kld_losses_frac = []
train_mi_losses = []
train_tc_losses = []
train_kld_losses = []
num_iters = 0

print("Training started -" )

for epoch in range(EPOCHS) :
  num_iters += 1
  overall_BCE_loss = 0
  overall_mi_loss = 0
  overall_tc_loss = 0
  overall_kld_loss = 0

  for batch_idx, (data, _) in enumerate(dataset_loader) :
    data = data.to(device)
    optimizer.zero_grad()
    recon_batch, z, mu, logvar = model(data.view(-1, 1, 64, 64))
    loss, BCE_loss, mi_loss, tc_loss, kld_loss = loss_function(num_iters, recon_batch, data, z, mu, logvar)
    loss.backward()
    optimizer.step()
    overall_BCE_loss += BCE_loss.item()
    overall_mi_loss += mi_loss.item()
    overall_tc_loss += tc_loss.item()
    overall_kld_loss += kld_loss.item()

    if batch_idx > TRAINING_BATCHES :
      break

  average_BCE_loss = overall_BCE_loss / TRAINING_BATCHES
  average_mi_loss = overall_mi_loss / TRAINING_BATCHES
  average_tc_loss = overall_tc_loss / TRAINING_BATCHES
  average_kld_loss = overall_kld_loss / TRAINING_BATCHES
  average_total_ELBO = average_mi_loss + average_tc_loss + average_kld_loss
  train_BCE_losses.append(average_BCE_loss)
  train_mi_losses_frac.append(average_mi_loss/average_total_ELBO)
  train_tc_losses_frac.append(average_tc_loss/average_total_ELBO)
  train_kld_losses_frac.append(average_kld_loss/average_total_ELBO)
  train_kld_losses.append(average_kld_loss)
  train_tc_losses.append(average_tc_loss)
  train_mi_losses.append(average_mi_loss)
  print(f"epoch : {epoch + 1}")
  print(f"BCE loss : {average_BCE_loss:.2f}\tMI loss: {average_mi_loss:.2f}\tTC loss: {average_tc_loss:.2f}\tKLD loss: {average_kld_loss:.2f}")

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
axes[0].plot(train_BCE_losses, label='BCE loss')
axes[0].set_xlabel("Epochs")
axes[0].set_ylabel("BCE loss")
axes[0].set_title("Reconstruction loss")
axes[0].grid(True, alpha=0.3)
axes[1].plot(train_mi_losses, label='MI loss')
axes[1].plot(train_tc_losses, label='TC loss')
axes[1].plot(train_kld_losses, label='KLD loss')
axes[1].set_xlabel("Epochs")
axes[1].set_ylabel("loss")
axes[1].set_title("KLD losses")
axes[1].legend()
axes[1].grid(True, alpha=0.3)
plt.show()


In [None]:
model.eval()

iterable = iter(dataset_loader)

with torch.no_grad() :
  x, x_labels = next(iterable)
  x = x.to(device)
  x_labels = x_labels.to(device)
  recon_batch, _, mu, logvar = model(x)
  recon_labels = model.encode(x)

fig, axes = plt.subplots(2, 10, figsize=(10, 2))

for i in range(10) :
  axes[0, i].imshow(x[i].cpu().numpy().reshape(64, 64), cmap='gray')
  axes[0, i].axis('off')
  if i == 5 : axes[0, i].set_title("original images")
  axes[1, i].imshow(recon_batch[i].cpu().numpy().reshape(64, 64), cmap='gray')
  axes[1, i].axis('off')
  if i == 5 : axes[1, i].set_title("reconstructed images")

plt.show()

In [None]:
def visualize_traversal(model, iterable, dims=10, steps=12) :
  model.eval()
  data, _ = next(iterable)
  data = data.to(device)

  with torch.no_grad() :
    _, _, mu, _ = model(data)
    base = mu[0].clone().unsqueeze(0)

  path_range = torch.linspace(-3, 3, steps=steps).to(device)

  fig, axes = plt.subplots(dims, steps, figsize=(10, 10))
  plt.subplots_adjust(wspace = 0.05, hspace = 0.05)

  for dim in range(dims) :
    for step in range(steps) :

      z_ = base.clone()
      z_[0, dim] = path_range[step]

      with torch.no_grad() :
        recon_img = model.decode(z_)

      ax = axes[dim, step]
      ax.imshow(recon_img[0].cpu().numpy().reshape(64, 64), cmap='gray')
      ax.axis('off')

  plt.suptitle(f"Latent Traversals", fontsize=16)
  plt.show()

visualize_traversal(model, iterable)