In [1]:
# data importing
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [2]:

class Dsprite(Dataset) :
    def __init__(self, path) :
        data = np.load(path, encoding='latin1', allow_pickle=True)
        self.imgs = data['imgs']
        # the above are the actual images, approximately 700k images
        # it is a tuple of (1, image, 64, 64)
        self.intended_latent = torch.from_numpy(data['latents_values']).float()
        # the ground truth factors of the dataset [color, shape, scale, orientation, posX, posY]
        self.classified_latent = torch.from_numpy(data['latents_classes']).long()
        # the above is the classified version of latent variables
        # there are certain decided values of shape, scale, orientation, posX and posY and all images
        # are different combinations of them, latent_classes tell which image has which class of
        # these properties

    def __len__(self) :
        return len(self.imgs)

    def __getitem__(self, idx) :
        img = self.imgs[idx].astype(np.float32)
        img = torch.from_numpy(img).unsqueeze(0)

        return img, self.intended_latent[idx]

path = "./data/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz"

dataset = Dsprite(path=path)

In [3]:
#data loading
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

dataset_loader = DataLoader(dataset,  batch_size=128, shuffle=True)

print(len(dataset_loader))
# prints 5760, that is the number of batches
iter_data = iter(dataset_loader)
print(next(iter_data)[0].shape)
# prints 128, 1, 64, 64
# 128 images in a batch
# 1 depth layer
# 64 is the width of the image
# 64 is the breadth of the image

5760
torch.Size([128, 1, 64, 64])


In [4]:
# Modle architecture
class BetaVAE(nn.Module) :
  def __init__(self,
               input_dim=4096,
               hidden_dim1 = 1024,
               hidden_dim2 = 256,
               hidden_dim3 = 64,
               latent_dim = 10
        ) :
        super(BetaVAE, self).__init__()

        # encoder
        self.hl1 = nn.Linear(input_dim, hidden_dim1)
        self.hl2 = nn.Linear(hidden_dim1, hidden_dim2)
        self.hl3 = nn.Linear(hidden_dim2, hidden_dim3)
        self.hl41_mu = nn.Linear(hidden_dim3, latent_dim) # mean
        self.hl42_logvar = nn.Linear(hidden_dim3, latent_dim) # log(variance)

        # decoder
        self.hl5 = nn.Linear(latent_dim, hidden_dim3)
        self.hl6 = nn.Linear(hidden_dim3, hidden_dim2)
        self.hl7 = nn.Linear(hidden_dim2, hidden_dim1)
        self.hl8 = nn.Linear(hidden_dim1, input_dim) # decoded image

  def encode(self, x) :
    h1 = F.relu(self.hl1(x))
    h2 = F.relu(self.hl2(h1))
    h3 = F.relu(self.hl3(h2))
    mu = self.hl41_mu(h3)
    logvar = self.hl42_logvar(h3)
    return mu, logvar

  def reparameterize(self, mu, logvar) :
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std

  def decode(self, z) :
    h5 = F.relu(self.hl5(z))
    h6 = F.relu(self.hl6(h5))
    h7 = F.relu(self.hl7(h6))
    img = torch.sigmoid(self.hl8(h7))
    return img

  def forward(self, x) :
    x = x.view(-1, 4096)
    mu, logvar = self.encode(x)
    z = self.reparameterize(mu, logvar)
    img = self.decode(z)
    return img, mu, logvar

In [5]:
# initialization and training

BATCH_SIZE = 128
LR = 1e-3
EPOCHS = 80
BETA = 2.5

model = BetaVAE()
optimizer = optim.Adam(model.parameters(), lr=LR)

def BCE_loss_function(recon_x, x) :
  BCE = F.binary_cross_entropy(recon_x, x.view(-1, 4096), reduction='sum')
  return BCE

def KLD_loss_function(mu, logvar) :
  KLD = -0.5 * torch.sum(1 + logvar - mu**2 - logvar.exp())
  return KLD

def loss_function(recon_x, x, mu, logvar) :
  BCE = BCE_loss_function(recon_x, x)
  KLD = KLD_loss_function(mu, logvar)
  return BCE + BETA * KLD

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using Device: {device}")

model.to(device)

model.train()
train_BCE_losses = []
train_KLD_losses = []
print("Training started -")

for epoch in range(EPOCHS) :

  overall_BCE_loss = 0
  overall_KLD_loss = 0

  for batch_idx, (data, _) in enumerate(dataset_loader) :
    data = data.to(device)
    optimizer.zero_grad()
    recon_batch, mu, logvar = model(data)
    loss = loss_function(recon_batch, data, mu, logvar)
    loss.backward()
    optimizer.step()
    overall_BCE_loss += BCE_loss_function(recon_batch, data).item()
    overall_KLD_loss += KLD_loss_function(mu, logvar).item()
    if batch_idx > 500 : break

  average_BCE_loss = overall_BCE_loss / 500
  average_KLD_loss = overall_KLD_loss / 500
  train_BCE_losses.append(average_BCE_loss)
  train_KLD_losses.append(average_KLD_loss)
  print(f"epoch = {epoch + 1}\nAverage reconstruction loss = {average_BCE_loss:.2f}\tAverage KLD = {average_KLD_loss:.2f}")

In [None]:
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
axes[0].plot(train_BCE_losses)
axes[0].set_xlabel("Epochs")
axes[0].set_ylabel("BCE Loss")
axes[0].set_title("Reconstruction Loss")
axes[0].grid(True, alpha=0.3)  
axes[1].plot(train_KLD_losses)
axes[1].set_xlabel("Epochs")
axes[1].set_ylabel("KLD Loss")
axes[1].set_title("KL Divergence")
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
model.eval()

iterable = iter(dataset_loader)

with torch.no_grad() :
  x, x_labels = next(iterable)
  x = x.to(device)
  x_labels = x_labels.to(device)
  recon_batch, mu, logvar = model(x)
  recon_labels = model.encode(x.view(-1, 4096))

fig, axes = plt.subplots(2, 10, figsize=(10, 2))

for i in range(10) :
  axes[0, i].imshow(x[i].cpu().numpy().reshape(64, 64), cmap='gray')
  axes[0, i].axis('off')
  if i == 5 : axes[0, i].set_title("original images")
  axes[1, i].imshow(recon_batch[i].cpu().numpy().reshape(64, 64), cmap='gray')
  axes[1, i].axis('off')
  if i == 5 : axes[1, i].set_title("reconstructed images")

plt.show()

In [None]:
def visualize_traversal(model, iterable, dims=10, steps=10) :
  model.eval()
  data, _ = next(iterable)
  data = data.to(device)

  with torch.no_grad() :
    _, mu, _ = model(data)
    base = mu[0].clone().unsqueeze(0)

  path_range = torch.linspace(-3, 3, steps=steps).to(device)

  fig, axes = plt.subplots(dims, steps, figsize=(10, 10))
  plt.subplots_adjust(wspace = 0.05, hspace = 0.05)

  for dim in range(dims) :
    for step in range(steps) :

      z_ = base.clone()
      z_[0, dim] = path_range[step]

      with torch.no_grad() :
        recon_img = model.decode(z_)

      ax = axes[dim, step]
      ax.imshow(recon_img[0].cpu().numpy().reshape(64, 64), cmap='gray')
      ax.axis('off')

  plt.suptitle(f"Latent Traversals", fontsize=16)
  plt.show()

visualize_traversal(model, iterable)