In [None]:
# standard packages
import numpy as np
import matplotlib.pyplot as plt

# pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as  data

# torchvision
import torchvision
from torchvision.datasets import MNIST
from torchvision import transforms

# Wandb
import wandb

# paths
DATASET_PATH = './data'
CECKPOINT_PATH = './checkpoints'

# seed
seed = 7
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

# ensure reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# device
device = device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
# transformations applied to all images
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])

# load train dataset
train_dataset = MNIST(root=DATASET_PATH, train=True, transform=transform, download=True)
train_set, val_set = data.random_split(train_dataset, [50000, 10000])

# load test dataset
test_set = MNIST(root=DATASET_PATH, train=True, transform=transform, download=True)

# data loaders
train_loader = data.DataLoader(train_set, batch_size=1024, shuffle=True)
val_loader = data.DataLoader(val_set, batch_size=1024, shuffle=False)
test_loader = data.DataLoader(test_set, batch_size=1024, shuffle=False)

In [None]:
for i, (image, label) in enumerate(train_loader):
  plt.figure(1)
  plt.imshow(image[0,0])
  print(label[1])
  sample_1 = image[0,0]
  break

In [None]:
class Encoder(nn.Module):

  '''
    4 conv layer encoder

    Parameters:
      input_channels: number of input channels
      hidden_dim: base number of hidden units
      z_dim: latent dimension / categories of categorical distribution
      act_fn: activation function used throughout the encoder
    '''

  def __init__(
      self, 
      model_type=None, # 'uniform' 'normal'
      input_channels: int=1, 
      hid_dim: int=128,
      z_dim: int=2,
      act_fn: object=nn.ReLU
  ):
    
    super().__init__()
    self.model_type = model_type

    # self.encoder = nn.Sequential(
    #     nn.Conv2d(input_channels, hid_dim, kernel_size=(4,4), stride=(2,2), padding=(1,1)),
    #     nn.BatchNorm2d(hid_dim),
    #     act_fn(),
    #     nn.Conv2d(hid_dim, 2*hid_dim, kernel_size=(4,4), stride=(2,2), padding=(1,1)),
    #     nn.BatchNorm2d(2*hid_dim),
    #     act_fn(),
    #     nn.Conv2d(2*hid_dim, 4*hid_dim, kernel_size=(4,4), stride=(2,2), padding=(1,1)),
    #     nn.BatchNorm2d(4*hid_dim),
    #     act_fn(),
    #     nn.Conv2d(4*hid_dim, 8*hid_dim, kernel_size=(4,4), stride=(2,2), padding=(1,1)),
    #     nn.BatchNorm2d(8*hid_dim),
    #     act_fn(),
    #     nn.Flatten(),
    # )

    self.encoder = nn.Sequential(
        nn.Linear(784, 512),
        act_fn(),
        nn.Linear(512, 256),
        act_fn(),
    )



    if model_type == 'ae':
      self.linear = nn.Linear(256, z_dim*10)
    if model_type == 'vae':
      self.z_mean = torch.nn.Linear(256,z_dim)
      self.z_log_var = torch.nn.Linear(256,z_dim)

  def forward(self, x):
    if self.model_type == 'ae':
      x = self.encoder(x)
      x = self.linear(x)
      return F.softmax(x.reshape(x.shape[0], -1, 10), dim=2) # softmax?
    elif self.model_type == 'vae':
      x = self.encoder(x)
      z_mean = self.z_mean(x)
      z_log_var = self.z_log_var(x)
      return z_mean, z_log_var

class Decoder(nn.Module):

  '''
  3 deconv layer, 1 conv layer decoder

  Parameters:
    input_channels: number of input channels
    hidden_dim: base number of hidden units
    z_dim: latent dimension / categories of categorical distribution
    act_fn: activation function used throughout the decoder
  '''

  def __init__(
      self,
      model_type=None, # VAE, AE, uniform, normal
      output_channels: int=1, 
      hid_dim: int=128,
      z_dim: int=2,
      act_fn: object=nn.ReLU
  ):
    
    super().__init__()

    self.model_type = model_type

    # self.decoder = nn.Sequential(
    #     nn.ConvTranspose2d(8*hid_dim, 4*hid_dim, kernel_size=(4,4), stride=(2,2), padding=(1,1)),
    #     nn.BatchNorm2d(4*hid_dim),
    #     act_fn(),
    #     nn.ConvTranspose2d(4*hid_dim, 2*hid_dim, kernel_size=(4,4), stride=(2,2), padding=(1,1)),
    #     nn.BatchNorm2d(2*hid_dim),
    #     act_fn(),
    #     nn.ConvTranspose2d(2*hid_dim, 1, kernel_size=(4,4), stride=(1,1)),
    # )

    self.decoder = nn.Sequential(
        nn.Linear(256, 512),
        act_fn(),
        nn.Linear(512,784),
        # nn.Sigmoid(),
    )

    # Gumbel-Softmax and VAE of DD-VAE
    if self.model_type == 'VAE' or self.model_type == 'uniform':
      self.linear_vae = nn.Sequential(
        nn.Linear(z_dim*10, 256),
        act_fn()
      )
    # Normal AE and AE of DD-VAE
    elif self.model_type == 'AE':
      self.linear_ae = nn.Sequential(
        nn.Linear(z_dim*10, 256),
        act_fn()
      )
    # Normal VAE
    elif self.model_type == 'normal':
      self.linear_ = nn.Sequential(
        nn.Linear(z_dim, 256),
        act_fn()
      )
    else:
      print('Something went wrong')


  def variational_forward(self, simplex):
    sample = torch.distributions.Categorical(simplex).sample()
    x = torch.nn.functional.one_hot(sample, 10).reshape(sample.shape[0],-1)
    x = x.type(torch.FloatTensor).to(device)
    return self.linear_vae(x)

  def simplex_forward(self, x):
    x = x.reshape(x.shape[0],-1)
    return self.linear_ae(x)

  def normal_forward(self, z_mean, z_log_var):
    x = self.reparameterize(z_mean, z_log_var.exp())
    # x = torch.distributions.Normal(z_mean, torch.exp(z_log_var)).rsample()
    return self.linear_(x)

  def reparameterize(self, z_mean, z_var):

    assert not (z_var < 0).any().item(), "The reparameterization trick got a negative std as input. "

    noise = torch.randn(z_mean.size()).to(device)
    z = z_mean + noise * z_var
    return z

  def forward(self, x, z_mean=None, z_log_var=None):

    if self.model_type == 'VAE' or self.model_type == 'uniform':
      # Gumbel-Softmax and VAE of DD-VAE
      x = self.variational_forward(x)

    elif self.model_type == 'AE':
      # Normal AE and AE of DD-VAE
      x = self.simplex_forward(x)

    elif self.model_type == 'normal':
      # Normal VAE
      x = self.normal_forward(z_mean, z_log_var)

    else:
      print(f'Wrong decoder type: {self.model_type}')

    # x = x.reshape(x.shape[0], -1, 8, 8)
    x = self.decoder(x)
    return x

In [None]:
class DD_VAE(nn.Module):

  '''

    Parameters:
      batch_size: batch_size
      lr = learning rate
      encoder: encoder module
      decoder: corresponding decoder models
      input_channels: number of input channels
      hidden_dim: base number of hidden units
      z_dim: latent dimension / categories of categorical distribution
    '''

  def __init__(
      self,
      # model type
      model_tpye: str = 'DD-VAE', # 'DD-VAE', N-VAE, U-VAE
      # Encoder, Decoder
      encoder: object = Encoder,
      decoder: object = Decoder,
      # model specifications
      input_channels: int=1, 
      hid_dim: int=128,
      z_dim: int=2,

  ):

    super().__init__()
    self.model_tpye = model_tpye

    if self.model_tpye == 'DD-VAE':

      self.encoder = encoder('ae',input_channels, hid_dim, z_dim).to(device)
      self.decoder_VAE = decoder('VAE', input_channels, hid_dim, z_dim).to(device)

      # we start with the same weights for the two decoders
      self.decoder_AE = decoder('AE',input_channels, hid_dim, z_dim).to(device)
      self.decoder_AE.decoder.load_state_dict(self.decoder_VAE.decoder.state_dict())

    elif self.model_tpye == 'U-VAE':

      self.encoder = encoder('ae',input_channels, hid_dim, z_dim).to(device)
      self.decoder_VAE = decoder('uniform', input_channels, hid_dim, z_dim).to(device)

    elif self.model_tpye == 'N-VAE':

      self.encoder = encoder('vae',input_channels, hid_dim, z_dim).to(device)
      self.decoder_VAE = decoder('normal', input_channels, hid_dim, z_dim).to(device)

    elif self.model_tpye == 'N-AE':

      self.encoder = encoder('ae',input_channels, hid_dim, z_dim).to(device)
      self.decoder_VAE = decoder('AE', input_channels, hid_dim, z_dim).to(device)


  def forward(self, x):
    '''
    Forward pass through the encoder and both decoders
    '''

    if self.model_tpye == 'DD-VAE':

      simplex = self.encoder(x)
      x_rec_vae = self.decoder_VAE(simplex)
      x_rec_ae = self.decoder_AE(simplex)
      
      return x_rec_vae, x_rec_ae, simplex, None, None

    elif self.model_tpye == 'U-VAE' or self.model_tpye == 'N-AE':

      simplex = self.encoder(x)
      x_rec_vae = self.decoder_VAE(simplex)

      return x_rec_vae, None, simplex, None, None

    elif self.model_tpye == 'N-VAE':

      z_mean, z_var = self.encoder(x)
      x_rec_vae = self.decoder_VAE(None, z_mean, z_var)

      return x_rec_vae, None, None, z_mean, z_var


  def compute_loss(self, x, x_rec_vae, x_rec_ae, simplex, z_mean, z_log_var):
    ''' 
    Computer 2 rec_losses and reg_loss
    Reg_loss: KL-Divergence between encoder q_enc(z|x) and uniform prior p(z)
    '''
    if self.model_tpye == 'DD-VAE':

      rec_loss_vae = F.mse_loss(x, x_rec_vae, reduction="none")
      rec_loss_vae = rec_loss_vae.sum(dim=[1]).mean(dim=[0])
      rec_loss_ae = F.mse_loss(x, x_rec_ae, reduction="none")
      rec_loss_ae = rec_loss_ae.sum(dim=[1]).mean(dim=[0])
      p_z = torch.full((10,1),0.1).squeeze(1).to(device)
      reg_loss = self._KL(simplex, p_z)

      return rec_loss_vae + rec_loss_ae + reg_loss, rec_loss_vae, reg_loss, rec_loss_ae

    elif self.model_tpye == 'U-VAE':

      rec_loss = F.mse_loss(x, x_rec_vae, reduction="none")
      rec_loss = rec_loss.sum(dim=[1]).mean(dim=[0])
      p_z = torch.full((10,1),0.1).squeeze(1)
      reg_loss = self._KL(simplex, p_z)

      return rec_loss + reg_loss, rec_loss, reg_loss, None

    elif self.model_tpye == 'N-VAE':

      rec_loss = F.mse_loss(x, x_rec_vae, reduction="none")
      rec_loss = rec_loss.sum(dim=[1]).mean(dim=[0])
      reg_loss = torch.mean(-0.5 * torch.sum(1 + z_log_var - z_mean ** 2 - z_log_var.exp(), dim = 1), dim = 0)

      return rec_loss + reg_loss, rec_loss, reg_loss, data.non_deterministic

    elif self.model_tpye == 'N-AE':

      rec_loss = F.mse_loss(x, x_rec_vae, reduction="none")
      rec_loss = rec_loss.sum(dim=[1]).mean(dim=[0])
      reg_loss = 0

      return rec_loss + reg_loss, rec_loss, reg_loss, None

  def par_loss(self):
    ''' Par_loss: Mean squared loss between q_vae(x|z) and q_ae(x|P) parameters '''

    return F.mse_loss(self.decoder_VAE.parameters(), self.decoder_AE.parameters())

  def _KL(self,P,Q):
    ''' Kl-Divergence between two distributions '''
    eps = 1e-15
    P = P + eps
    Q = Q + eps
    return torch.sum(P*torch.log(P/Q))


  def train(self, learning_rate, epochs, dataloader):
    '''
    trains the model
    '''

    self = self.to(device)
    optimizer = optim.Adam(self.parameters(), lr=learning_rate)

    for e in range(epochs):

      epoch_loss = 0
      epoch_rec_loss = 0
      epoch_reg_loss = 0
      instance_loss = 0
      epoch_rec_loss_ae = 0

      for i, (batch_images, _) in enumerate(dataloader):

        batch_images = batch_images.to(device)
        batch_images = batch_images.reshape(batch_images.shape[0], -1)

        # pass batch through the model
        x_rec_vae, x_rec_ae, simplex, z_mean, z_log_var  = self.forward(batch_images)

        loss, rec_loss, reg_loss, rec_loss_ae = self.compute_loss(batch_images, x_rec_vae, x_rec_ae, simplex, z_mean, z_log_var)
        
        epoch_loss += loss
        epoch_rec_loss += rec_loss
        epoch_reg_loss += reg_loss
        if rec_loss_ae != None:
          epoch_rec_loss_ae += rec_loss_ae

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        instance_loss = loss / batch_images.shape[0]

        wandb.log({"instance_loss": instance_loss})

      wandb.log({"epoch": e})
      wandb.log({"epoch_loss": epoch_loss})
      wandb.log({"rec_loss": epoch_rec_loss})
      wandb.log({"reg_loss": epoch_reg_loss})
      wandb.log({"rec_loss_ae": epoch_rec_loss_ae})
      print(f'Epoch: {e} done, Loss: {epoch_loss}, Rec_Loss: {epoch_rec_loss}, Reg_Loss: {epoch_reg_loss}')

    print('Training complete')

  def generate(self, model, x):
    x_rec_vae, x_rec_ae, simplex, z_mean, z_log_var  = model.forward(x)
    x_rec_vae = x_rec_vae.cpu().detach().numpy()
    plt.figure(2)
    x = x_rec_vae[0,0]
    x = x.reshape(28,28)
    plt.imshow(x)
    return

In [None]:
config = {
  "model_type": 'DD-VAE',
  "learning_rate": 0.001,
  "epochs": 5,
  "batch_size": 1024,
  "z_dim": 2,
}

wandb.init(project="test-project", entity="inspired-minds", name='dev', config=config)


model_type = config['model_type']
learning_rate = config['learning_rate']
epochs = config['epochs']
batch_size = config['batch_size']
z_dim = config['z_dim']

model = DD_VAE(model_tpye=model_type, z_dim=z_dim)
dataloader = train_loader

model.train(learning_rate, epochs, dataloader)
# sample_1 = sample_1.to(device)
# sample_1 = sample_1.reshape(1,-1)
# model.generate(model, sample_1.unsqueeze(0).unsqueeze(0))

In [None]:
def sample(model, num_samples=16):
  z_mean = torch.Tensor([[-0.5,-0.2]]*16).unsqueeze(0)
  z_log_var = torch.Tensor([[0,0]]*16).unsqueeze(0)
  sampled_img = model.decoder_VAE(model, z_mean, z_log_var).cpu().detach().numpy()
  sampled_img = sampled_img.reshape(num_samples, 28, 28)
  fig, ax = plt.subplots(4,4)
  for i in range(num_samples):
    ax[i//4,i%4].imshow(sampled_img[i])

In [None]:
def visualize_manifold(grid_size=10):
  '''
  For latent space z_dim=2 this returns a visualisation of the learned manifold
  '''

  values = torch.arange(0.5/grid_size, 1, 1/grid_size)

  percentils = torch.distributions.Normal(0, 1) 
  zs = percentils.icdf(values)

  mesh_grid_x, mesh_grid_y = torch.meshgrid(zs, zs, indexing='ij')
  mesh_grid = torch.stack([mesh_grid_x.flatten(), mesh_grid_y.flatten()], dim=1)
  print(mesh_grid.shape)

  z_mean = mesh_grid.to(device)
  z_log_var = torch.zeros(mesh_grid.shape).to(device) - 1e15

  sampled_img = model.decoder_VAE(model, z_mean, z_log_var)
  sampled_img = F.softmax(sampled_img, dim=1).unsqueeze(1)

  sampled_img = sampled_img.reshape(-1,28,28).cpu().detach().numpy()

  fig, ax = plt.subplots(grid_size,grid_size,figsize=(20,20))
  for i in range(grid_size**2):
    ax[i//grid_size,i%grid_size].imshow(sampled_img[i])

  return sampled_img