In [None]:
# data importing
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class Dsprite(Dataset) :
    def __init__(self, path) :
        data = np.load(path, encoding='latin1', allow_pickle=True)
        self.imgs = data['imgs']
        # the above are the actual images, approximately 700k images
        # it is a tuple of (1, image, 64, 64)
        self.intended_latent = torch.from_numpy(data['latents_values']).float()
        # the ground truth factors of the dataset [color, shape, scale, orientation, posX, posY]
        self.classified_latent = torch.from_numpy(data['latents_classes']).long()
        # the above is the classified version of latent variables
        # there are certain decided values of shape, scale, orientation, posX and posY and all images
        # are different combinations of them, latent_classes tell which image has which class of
        # these properties

    def __len__(self) :
        return len(self.imgs)

    def __getitem__(self, idx) :
        img = self.imgs[idx].astype(np.float32)
        img = torch.from_numpy(img).unsqueeze(0)

        return img, self.intended_latent[idx]

path = "./data/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz"

dataset = Dsprite(path=path)

In [None]:
#data loading
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

dataset_loader = DataLoader(dataset,  batch_size=128, shuffle=True, num_workers=2)

In [None]:
print(len(dataset_loader))
# prints 5760, that is the number of batches
iter_data = iter(dataset_loader)
print(next(iter_data)[0].shape)
# prints 128, 1, 64, 64
# 128 images in a batch
# 1 depth layer
# 64 is the width of the image
# 64 is the breadth of the image

In [None]:
# Modle architecture
class GammaVAE(nn.Module) :
  def __init__(self) :
        super(GammaVAE, self).__init__()

        # encoder
        # (1, 64, 64)
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=4, stride=2, padding=1)
        # (32, 32, 32)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=1)
        # (32, 16, 16)
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=1)
        # (32, 8, 8)
        self.conv4 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=1)
        # (32, 4, 4)
        self.hl1 = nn.Linear(32*4*4, 256)
        self.hl2 = nn.Linear(256, 64)
        self.hl3_mu = nn.Linear(64, 10) # mu
        self.hl3_logvar = nn.Linear(64, 10) # logvar

        # decoder
        self.hl4 = nn.Linear(10, 64)
        self.hl5 = nn.Linear(64, 256)
        self.hl6 = nn.Linear(256, 32*4*4)
        # (32, 4, 4)
        self.convT1 = nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=1)
        # (32, 8, 8)
        self.convT2 = nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=1)
        # (32, 16, 16)
        self.convT3 = nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=1)
        # (32, 32, 32)
        self.convT4 = nn.ConvTranspose2d(in_channels=32, out_channels=1, kernel_size=4, stride=2, padding=1)
        # (1, 64, 64)

  def encode(self, x) :
    c1 = F.relu(self.conv1(x))
    c2 = F.relu(self.conv2(c1))
    c3 = F.relu(self.conv3(c2))
    c4 = F.relu(self.conv4(c3))
    h1 = F.relu(self.hl1(c4.view(-1, 32*4*4)))
    h2 = F.relu(self.hl2(h1))
    mu = self.hl3_mu(h2)
    logvar = self.hl3_logvar(h2)
    
    return mu, logvar

  def reparameterize(self, mu, logvar) :
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std

  def decode(self, z) :
    h4 = F.relu(self.hl4(z))
    h5 = F.relu(self.hl5(h4))
    h6 = F.relu(self.hl6(h5))
    cT1 = F.relu(self.convT1(h6.view(-1, 32, 4, 4)))
    cT2 = F.relu(self.convT2(cT1))
    cT3 = F.relu(self.convT3(cT2))
    img = torch.sigmoid(self.convT4(cT3))
    return img

  def forward(self, x) :
    mu, logvar = self.encode(x)
    z = self.reparameterize(mu, logvar)
    img = self.decode(z)
    return img, mu, logvar

In [None]:
# initialization and training

BATCH_SIZE = 128
LR = 5e-4
EPOCHS = 120
CUTOFF_EPOCH = 60
GAMMA_MAX = 10
GAMMA_MIN = 5
PEAK_CAP = 25*BATCH_SIZE

model = GammaVAE()
optimizer = optim.Adam(model.parameters(), lr=LR)

def BCE_loss_function(recon_x, x) :
  BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
  return BCE

def KLD_loss_function(mu, logvar) :
  KLD = -0.5 * torch.sum(1 + logvar - mu**2 - logvar.exp())
  return KLD

def loss_function(recon_x, x, mu, logvar, epoch, cutoff_epoch) :
  BCE = BCE_loss_function(recon_x, x)
  KLD = KLD_loss_function(mu, logvar)
  cap_ratio = min(epoch/cutoff_epoch, 1.0)
  if epoch < cutoff_epoch :
    GAMMA = GAMMA_MAX
  else :
    progress = (epoch - cutoff_epoch) / (EPOCHS - cutoff_epoch)
    GAMMA = GAMMA_MAX - progress * (GAMMA_MAX - GAMMA_MIN)
  return BCE + GAMMA * abs(KLD - PEAK_CAP * cap_ratio)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using Device: {device}")

model.to(device)

model.train()
train_BCE_losses = []
train_KLD_losses = []
train_losses = []
print("Training started -")

for epoch in range(EPOCHS) :

  overall_BCE_loss = 0
  overall_KLD_loss = 0
  overall_loss = 0

  for batch_idx, (data, _) in enumerate(dataset_loader) :
    data = data.to(device)
    optimizer.zero_grad()
    recon_batch, mu, logvar = model(data)
    loss = loss_function(recon_batch, data, mu, logvar, epoch, CUTOFF_EPOCH)
    loss.backward()
    optimizer.step()
    overall_BCE_loss += BCE_loss_function(recon_batch, data).item()
    overall_KLD_loss += KLD_loss_function(mu, logvar).item()
    overall_loss += loss.item()
    if batch_idx > 500 : break

  average_BCE_loss = overall_BCE_loss / 500
  average_KLD_loss = overall_KLD_loss / 500
  average_loss = overall_loss / 500
  train_BCE_losses.append(average_BCE_loss)
  train_KLD_losses.append(average_KLD_loss)
  train_losses.append(average_loss)
  print(f"epoch = {epoch + 1}\nAverage reconstruction loss = {average_BCE_loss:.2f}\tAverage KLD = {average_KLD_loss:.2f}")

In [None]:
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
axes[0].plot(train_BCE_losses)
axes[0].set_xlabel("Epochs")
axes[0].set_ylabel("BCE Loss")
axes[0].set_title("Reconstruction Loss")
axes[0].grid(True, alpha=0.3)
axes[1].plot(train_KLD_losses)
axes[1].set_xlabel("Epochs")
axes[1].set_ylabel("KLD Loss")
axes[1].set_title("KL Divergence")
axes[1].grid(True, alpha=0.3)

plt.tight_layout() 
plt.show()

In [None]:
model.eval()

iterable = iter(dataset_loader)

with torch.no_grad() :
  x, x_labels = next(iterable)
  x = x.to(device)
  x_labels = x_labels.to(device)
  recon_batch, mu, logvar = model(x)
  recon_labels = model.encode(x)

fig, axes = plt.subplots(2, 10, figsize=(10, 2))

for i in range(10) :
  axes[0, i].imshow(x[i].cpu().numpy().reshape(64, 64), cmap='gray')
  axes[0, i].axis('off')
  if i == 5 : axes[0, i].set_title("original images")
  axes[1, i].imshow(recon_batch[i].cpu().numpy().reshape(64, 64), cmap='gray')
  axes[1, i].axis('off')
  if i == 5 : axes[1, i].set_title("reconstructed images")

plt.show()