# Prepare notebook

In [None]:
!pip install gdown -q

In [None]:
!gdown 18KIvMBWD031oDvg0DVebI06SMwKuTp4l -O sh3_sc6_y32_x32_imgs.npz

In [None]:
!pip install -q pytorch_fid

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm.auto import tqdm

import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms

import gc

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

### Constants-Globs

In [None]:
DATA_PATH = './sh3_sc6_y32_x32_imgs.npz'
TRAIN_RATIO = 0.8

LATENT_DIM = 256
Z_DIM = 10

BETA_1 = 0.9
BETA_2 = 0.99
LEARNING_RATE = 1e-4

BATCH_SIZE = 128
TRAIN_EPOCHS = 60

W_KL = 1
W_SCALE = 1

## Data

In [None]:
class DspritesDataset(Dataset):

    def __init__(self, npz_path, transform=None):
        data = np.load(npz_path, allow_pickle=True, encoding='latin1')
        self.transform = transform
        self.images = data['imgs']

        if self.transform is None:
          self.transform = transforms.ToTensor()

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        image = self.transform(image)
        return image


In [None]:
raw_dataset = DspritesDataset(DATA_PATH)
print(len(raw_dataset))

In [None]:
train_size = int(TRAIN_RATIO * len(raw_dataset))
test_size = len(raw_dataset) - train_size
lengths = [train_size, test_size]

train_dataset, test_dataset = random_split(raw_dataset, lengths)

In [None]:
print(f'Train size: {len(train_dataset)}')
print(f'Test size: {len(test_dataset)}')

In [None]:
import random

plt.figure(figsize=(10, 10))
for i in range(25):
    plt.subplot(5, 5, i+1)
    idx = random.randint(0, len(train_dataset)-1)
    plt.imshow(train_dataset[idx].squeeze(), cmap='gray')
    plt.axis('off')

## VariationalAutoEncoder Model

In [None]:
class VAE(nn.Module):
    def __init__(self, latent_dim, z_dim):
        super(VAE, self).__init__()
        self.latent_dim = latent_dim
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=latent_dim, kernel_size=4, stride=1),
            nn.ReLU(),
        )
        self.fc_encoder = nn.Linear(latent_dim, latent_dim)
        self.fc_mu = nn.Linear(latent_dim, z_dim)
        self.fc_sigma = nn.Linear(latent_dim, z_dim)

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(in_channels=latent_dim, out_channels=64, kernel_size=4),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=32, out_channels=1, kernel_size=4, stride=2, padding=1),
        )
        self.fc_decoder = nn.Linear(z_dim, latent_dim)

    def reparameterize(self, mu, sigma):
        std = torch.exp(sigma / 2.)
        eps = torch.randn_like(mu).to(mu.get_device())
        return mu + std * eps

    def encode(self, x):
        h = self.encoder(x)
        h = h.view(-1, self.latent_dim)
        h = self.fc_encoder(h)
        mu = self.fc_mu(h)
        sigma = self.fc_sigma(h)
        return h, mu, sigma

    def decode(self, z):
        h = self.fc_decoder(z)
        h = h.view(-1, 256, 1, 1)
        x = self.decoder(h)
        x = F.sigmoid(x)
        return x

    def forward(self, x):
        _, mu, sigma = self.encode(x)
        z = self.reparameterize(mu, sigma)
        x = self.decode(z)
        return x, mu, sigma


## Train utils

In [None]:
def kl_divergence(mu, sigma):
    return W_SCALE * (-0.5 * (1 + sigma - mu.pow(2) - sigma.exp())).sum(dim=1).mean()

In [None]:
def reconstruction_loss(x, x_hat):
    batch_size = x.size(0)
    return W_SCALE * F.mse_loss(x_hat, x, reduction='sum')
    # return W_SCALE * F.binary_cross_entropy_with_logits(x_hat, x, reduction='sum') / batch_size

In [None]:
def vae_loss(x, x_hat, mu, sigma):
    return  (reconstruction_loss(x, x_hat) + W_KL * kl_divergence(mu, sigma))

In [None]:
@torch.no_grad()
def model_eval(model, loader, criterion):
  """Returns total_loss, reconstruction loss, kld_loss"""
  total_loss = 0.0
  rec_loss = 0.0
  kl_loss = 0.0
  total_samples = 0

  model.eval()
  itr = tqdm(loader, total=len(loader), leave=False)

  for batch in itr:
      total_samples += len(batch)
      batch = batch.to(DEVICE)

      _x, _mu, _sigma = model(batch)

      kl = kl_divergence(_mu, _sigma).item()
      kl_loss += kl
      rec = reconstruction_loss(batch, _x).item()
      rec_loss += rec

      loss = criterion(batch, _x, _mu, _sigma)
      total_loss += loss.item()


      itr.set_description("(Eval)")
      itr.set_postfix(
          total_loss=round(total_loss/total_samples/W_SCALE, 4),
          kd_div=round(kl_loss/total_samples/W_SCALE, 4),
          rec_loss=round(rec_loss/total_loss/W_SCALE, 4),
      )

  total_loss = total_loss / total_samples
  rec_loss = rec_loss / total_samples
  kl_loss = kl_loss / total_samples

  return total_loss, rec_loss, kl_loss


In [None]:
def train_model(
        model,
        batch_size,
        epochs,
        criterion,
        train_set,
        test_set,
        lr=2e-5,
        use_PI=False,
        pi_kw_args={},
):

  global W_KL

  train_total_losses = []
  test_total_losses = []

  train_kl_losses = []
  test_kl_losses = []

  train_rec_losses = []
  test_rec_losses = []

  if use_PI:
      assert 'des_kl' in pi_kw_args, "des_kl is required for PI"
      assert 'kp' in pi_kw_args, "kp is required for PI"
      assert 'ki' in pi_kw_args, "ki is required for PI"
      PI = PIDControl()

  train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
  test_loader  = DataLoader(test_set, batch_size=batch_size, shuffle=False)

  optimizer = optim.Adam(model.parameters(), lr=lr, betas=(BETA_1, BETA_2))

  itr = tqdm(train_loader, total=len(train_loader), leave=False)

  for epoch in range(epochs):
      model.train()
      epoch_total_loss = 0
      epoch_kl_loss = 0
      epoch_rec_loss = 0
      epoch_samples = 0
      for batch in itr:
          epoch_samples += len(batch)
          batch = batch.to(DEVICE)

          _x, _mu, _sigma = model(batch)

          kl_loss = kl_divergence(_mu, _sigma)
          rec_loss = reconstruction_loss(batch, _x)

          loss = criterion(batch, _x, _mu, _sigma)
          epoch_total_loss += loss.item()/W_SCALE

          optimizer.zero_grad()
          loss.backward()
          optimizer.step()

          epoch_kl_loss += kl_loss.item()/W_SCALE
          epoch_rec_loss += rec_loss.item()/W_SCALE

          itr.set_description(f"(Training) Epoch [{epoch + 1}/{epochs}]")
          itr.set_postfix(
            loss=round(epoch_total_loss/epoch_samples, 5),
          )
      if use_PI:
        wkl = PI.pid(kl_div=epoch_kl_loss/len(train_loader), **pi_kw_args)
        W_KL = wkl

      train_total_losses.append(epoch_total_loss/epoch_samples)
      train_kl_losses.append(epoch_kl_loss/epoch_samples)
      train_rec_losses.append(epoch_rec_loss/epoch_samples)


      model.eval()
      test_total_loss, test_rec_loss, test_kl_loss = model_eval(
          model=model,
          loader=test_loader,
          criterion=criterion,
      )
      test_total_losses.append(test_total_loss)
      test_kl_losses.append(test_kl_loss)
      test_rec_losses.append(test_rec_loss)


  history = {
    "Train_Total_Loss": train_total_losses,
    "Train_Rec_Loss": train_rec_losses,
    "Train_KL_Loss": train_kl_losses,

    "Test_Total_Loss": test_total_losses,
    "Test_Rec_Loss": test_rec_losses,
    "Test_KL_Loss": test_kl_losses,
  }
  return history

In [None]:
def plot_reconstructions(model, dataset, n=5):
    test_dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)
    model.eval()
    with torch.no_grad():
        batch = iter(test_dataloader).__next__().to(DEVICE)
        x_hat, *_ = model(batch)

    plt.figure(figsize=(4*n, 6))
    for i in range(n):
        plt.subplot(2, n, i+1)
        plt.imshow(batch[i][0].detach().cpu(), cmap='gray')
        plt.title("Original")
        plt.axis('off')
        plt.subplot(2, n, i+1+n)
        plt.imshow(x_hat[i][0].detach().cpu(), cmap='gray')
        plt.title("Reconstructed")
        plt.axis('off')
    plt.show()

In [None]:
def trend_plot_helper(pobj):
    plt.figure(figsize=(8*len(pobj), 6))
    for idx, (titler, plots) in enumerate(pobj.items(), start=1):
        plt.subplot(1, len(pobj), idx)
        for label, trend in plots:
            plt.plot(range(1, len(trend)+1), trend, label=label)
        yt, xt = titler.split(' - ')
        plt.xlabel(xt)
        plt.ylabel(yt)
        plt.legend()

## FID utils

In [None]:
import shutil
from torch.utils.data import SubsetRandomSampler

total_samples = 5000
num_test_samples = min(3500, len(test_dataset))
num_train_samples = total_samples - num_test_samples

train_indices = np.random.choice(len(train_dataset), num_train_samples, replace=False)
test_indices = np.random.choice(len(test_dataset), num_test_samples, replace=False)

train_sampler = SubsetRandomSampler(train_indices)
test_sampler = SubsetRandomSampler(test_indices)

fid_train_dataloader = DataLoader(train_dataset, batch_size=1, sampler=train_sampler)
fid_test_dataloader = DataLoader(test_dataset, batch_size=1, sampler=test_sampler)
fid_loaders = {'train': fid_train_dataloader, 'test': fid_test_dataloader}

to_pil = transforms.ToPILImage()

ORIGINAL_DIR = 'original_samples'
GENERATED_DIR = 'generated_samples'

if os.path.exists(ORIGINAL_DIR):
    shutil.rmtree(ORIGINAL_DIR)
os.mkdir(ORIGINAL_DIR)

for pre, loader in fid_loaders.items():
    for i, batch in enumerate(loader):
        path = f'{pre}_{i}.png'
        image = batch[0][0] * 255
        image = to_pil(image.byte())
        path = os.path.join(ORIGINAL_DIR, path)
        image.save(path)

In [None]:
@torch.no_grad()
def save_fid_samples(model):
    if os.path.exists(GENERATED_DIR):
        shutil.rmtree(GENERATED_DIR)
    os.mkdir(GENERATED_DIR)
    model.eval()
    for pre, loader in fid_loaders.items():
        for i, batch in enumerate(loader):
            path = f'{pre}_{i}.png'
            batch = batch.to(DEVICE)
            x_hat, *_ = model(batch)
            x_hat = torch.round(x_hat)
            image = x_hat[0][0] * 255
            image = to_pil(image.byte())
            path = os.path.join(GENERATED_DIR, path)
            image.save(path)

In [None]:
from pytorch_fid import fid_score as FID

def calculate_fid(model):
    save_fid_samples(model)
    fid_score = FID.calculate_fid_given_paths([ORIGINAL_DIR, GENERATED_DIR], batch_size=BATCH_SIZE, device=DEVICE, dims=2048)
    return fid_score

## Ordinary VAE

In [None]:
model_ovae = VAE(LATENT_DIM, Z_DIM).to(DEVICE)

In [None]:
history = train_model(
        model=model_ovae,
        batch_size=BATCH_SIZE,
        epochs=TRAIN_EPOCHS,
        criterion=vae_loss,
        train_set=train_dataset,
        test_set=test_dataset,
        lr=LEARNING_RATE,
)

In [None]:
trend_plot_helper(
    {
        "Train Loss - Epoch": [
            ("Total", history["Train_Total_Loss"]),
            ("KL Dive", history["Train_KL_Loss"]),
            ("Reconstruction", history["Train_Rec_Loss"]),
        ],
        "Test Loss - Epoch": [
            ("Total", history["Test_Total_Loss"]),
            ("KL Dive", history["Test_KL_Loss"]),
            ("Reconstruction", history["Test_Rec_Loss"]),
        ]
    }
)

### Generate samples

In [None]:
plot_reconstructions(model_ovae, test_dataset)

### FID score

In [None]:
fid_score_ovae = calculate_fid(model_ovae)
print(f"FID score: {fid_score_ovae}")

In [None]:
del model_ovae
if torch.cuda.is_available():
    torch.cuda.empty_cache()
gc.collect()

## ControlVAE

### PID Controller

In [None]:
class PIDControl:
    def __init__(self):
        self.i_k1 = 0.0
        self.w_k1 = 0.0

    def pid(self, des_kl, kl_div, kp=0.001, ki=-0.001):
        kl_div = kl_div / W_SCALE
        print("Round>>", des_kl, kl_div)

        error_k = des_kl - kl_div
        # print("\t err", error_k)

        pk = kp * (1.0 / (1.0 + np.exp(error_k)))
        # print("\t Pk", pk)

        ik = self.i_k1
        if 0 <= self.w_k1 <= 1:
            ik = self.i_k1 - ki * error_k
        # print("\t Ik", ik)

        wk = pk + ik

        self.w_k1 = wk
        self.i_k1 = ik

        wk = np.clip(wk, 0.0, 1.0)
        print("---- Wk", wk)

        return wk


### train controlVAE

In [None]:
base_conf = {
    'kp': 1e-2,
    'ki': 1e-3,
}

#### Set point = 8

In [None]:
model_pid_8 = VAE(LATENT_DIM, Z_DIM).to(DEVICE)

config_pid_8 = {
    'des_kl': 8,
    **base_conf,
}

In [None]:
W_KL = 1
history_pid_8 = train_model(
        model=model_pid_8,
        batch_size=BATCH_SIZE,
        epochs=TRAIN_EPOCHS,
        criterion=vae_loss,
        train_set=train_dataset,
        test_set=test_dataset,
        lr=LEARNING_RATE,
        use_PI=True,
        pi_kw_args=config_pid_8,
)
W_KL = 1

In [None]:
trend_plot_helper(
    {
        "Train Loss - Epoch": [
            ("Total", history_pid_8["Train_Total_Loss"]),
            ("KL Dive", history_pid_8["Train_KL_Loss"]),
            ("Reconstruction", history_pid_8["Train_Rec_Loss"]),
        ],
        "Test Loss - Epoch": [
            ("Total", history_pid_8["Test_Total_Loss"]),
            ("KL Dive", history_pid_8["Test_KL_Loss"]),
            ("Reconstruction", history_pid_8["Test_Rec_Loss"]),
        ]
    }
)

In [None]:
plot_reconstructions(model_pid_8, test_dataset)

In [None]:
fid_score_pid_8 = calculate_fid(model_pid_8)
print(f"FID score: {fid_score_pid_8}")

In [None]:
del model_pid_8
if torch.cuda.is_available():
    torch.cuda.empty_cache()
gc.collect()

#### Set point = 14

In [None]:
model_pid_14 = VAE(LATENT_DIM, Z_DIM).to(DEVICE)

config_pid_14 = {
    'des_kl': 14,
    **base_conf,
}

In [None]:
W_KL = 1
history_pid_14 = train_model(
        model=model_pid_14,
        batch_size=BATCH_SIZE,
        epochs=TRAIN_EPOCHS,
        criterion=vae_loss,
        train_set=train_dataset,
        test_set=test_dataset,
        lr=LEARNING_RATE,
        use_PI=True,
        pi_kw_args=config_pid_14,
)
W_KL = 1

In [None]:
trend_plot_helper(
    {
        "Train Loss - Epoch": [
            ("Total", history_pid_14["Train_Total_Loss"]),
            ("KL Dive", history_pid_14["Train_KL_Loss"]),
            ("Reconstruction", history_pid_14["Train_Rec_Loss"]),
        ],
        "Test Loss - Epoch": [
            ("Total", history_pid_14["Test_Total_Loss"]),
            ("KL Dive", history_pid_14["Test_KL_Loss"]),
            ("Reconstruction", history_pid_14["Test_Rec_Loss"]),
        ]
    }
)

In [None]:
plot_reconstructions(model_pid_14, test_dataset)

In [None]:
fid_score_pid_14 = calculate_fid(model_pid_14)
print(f"FID score: {fid_score_pid_14}")

In [None]:
del model_pid_14
if torch.cuda.is_available():
    torch.cuda.empty_cache()
gc.collect()

#### Conclusion plot

we plot different loss from training on the same plot for both set points

In [None]:
trend_plot_helper(
    {
        "Train KL Loss - Epoch": [
            ("Set Point = 8", history_pid_8["Train_KL_Loss"]),
            ("Set Point = 14", history_pid_14["Train_KL_Loss"]),
        ],
        "Test KL Loss - Epoch": [
            ("Set Point = 8", history_pid_8["Test_KL_Loss"]),
            ("Set Point = 14", history_pid_14["Test_KL_Loss"]),
        ]
    }
)

In [None]:
trend_plot_helper(
    {
        "Train Rec Loss - Epoch": [
            ("Set Point = 8", history_pid_8["Train_Rec_Loss"]),
            ("Set Point = 14", history_pid_14["Train_Rec_Loss"]),
        ],
        "Test Rec Loss - Epoch": [
            ("Set Point = 8", history_pid_8["Test_Rec_Loss"]),
            ("Set Point = 14", history_pid_14["Test_Rec_Loss"]),
        ]
    }
)

In [None]:
trend_plot_helper(
    {
        "Train Total Loss - Epoch": [
            ("Set Point = 8", history_pid_8["Train_Total_Loss"]),
            ("Set Point = 14", history_pid_14["Train_Total_Loss"]),
        ],
        "Test Total Loss - Epoch": [
            ("Set Point = 8", history_pid_8["Test_Total_Loss"]),
            ("Set Point = 14", history_pid_14["Test_Total_Loss"]),
        ]
    }
)