In [None]:
# standard packages
import numpy as np
import matplotlib.pyplot as plt

# pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as  data

# torchvision
import torchvision
from torchvision.datasets import MNIST
from torchvision import transforms

# Wandb
import wandb

# paths
DATASET_PATH = './data'
CECKPOINT_PATH = './checkpoints'

# seed
seed = 7
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

# ensure reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# device
device = device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
# transformations applied to all images
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])

# load train dataset
train_dataset = MNIST(root=DATASET_PATH, train=True, transform=transform, download=True)
train_set, val_set = data.random_split(train_dataset, [50000, 10000])

# load test dataset
test_set = MNIST(root=DATASET_PATH, train=True, transform=transform, download=True)

# data loaders
train_loader = data.DataLoader(train_set, batch_size=64, shuffle=True)
train_loader = data.DataLoader(val_set, batch_size=64, shuffle=False)
train_loader = data.DataLoader(test_set, batch_size=64, shuffle=False)

In [None]:
for i, (image, label) in enumerate(train_loader):
  plt.figure(1)
  plt.imshow(image[0,0])
  print(label[1])
  sample_1 = image[0,0]
  break

In [None]:
class Encoder(nn.Module):

  '''
    4 conv layer encoder

    Parameters:
      input_channels: number of input channels
      hidden_dim: base number of hidden units
      z_dim: latent dimension / categories of categorical distribution
      act_fn: activation function used throughout the encoder
    '''

  def __init__(
      self, 
      model_type=None, # 'uniform' 'normal'
      input_channels: int=1, 
      hid_dim: int=128,
      z_dim: int=2,
      act_fn: object=nn.ReLU
  ):
    
    super().__init__()
    self.model_type = model_type

    self.encoder = nn.Sequential(
        nn.Linear(784, 512),
        act_fn(),
        nn.Linear(512, 256),
        act_fn(),
    )

    self.z_mean = torch.nn.Linear(256,z_dim)
    self.z_log_var = torch.nn.Linear(256,z_dim)

  def forward(self, x):
    x = self.encoder(x)
    z_mean = self.z_mean(x)
    z_log_var = self.z_log_var(x)
    return z_mean, z_log_var

class Decoder(nn.Module):

  '''
  3 deconv layer, 1 conv layer decoder

  Parameters:
    input_channels: number of input channels
    hidden_dim: base number of hidden units
    z_dim: latent dimension / categories of categorical distribution
    act_fn: activation function used throughout the decoder
  '''

  def __init__(
      self,
      model_type=None, # VAE, AE, uniform, normal
      output_channels: int=1, 
      hid_dim: int=128,
      z_dim: int=2,
      act_fn: object=nn.ReLU
  ):
    
    super().__init__()

    self.model_type = model_type

    self.decoder = nn.Sequential(
        nn.Linear(256, 512),
        act_fn(),
        nn.Linear(512,784),
        # nn.Sigmoid(),
    )

    # Normal VAE
    self.linear_ = nn.Sequential(
      nn.Linear(z_dim, 256),
      act_fn()
    )

  def normal_forward(self, z_mean, z_log_var):
    x = torch.distributions.Normal(z_mean, torch.exp(z_log_var)).rsample()
    return self.linear_(x)

  def forward(self, x, z_mean=None, z_log_var=None):
    x = self.normal_forward(z_mean, z_log_var)
    x = self.decoder(x)
    return x

In [None]:
class DD_VAE(nn.Module):

  '''

    Parameters:
      batch_size: batch_size
      lr = learning rate
      encoder: encoder module
      decoder: corresponding decoder models
      input_channels: number of input channels
      hidden_dim: base number of hidden units
      z_dim: latent dimension / categories of categorical distribution
    '''

  def __init__(
      self,
      # model type
      model_tpye: str = 'DD-VAE', # 'DD-VAE', N-VAE, U-VAE
      # Encoder, Decoder
      encoder: object = Encoder,
      decoder: object = Decoder,
      # model specifications
      input_channels: int=1, 
      hid_dim: int=128,
      z_dim: int=2,

  ):

    super().__init__()
    self.model_tpye = model_tpye

    self.encoder = encoder('vae',input_channels, hid_dim, z_dim).to(device)
    self.decoder_VAE = decoder('normal', input_channels, hid_dim, z_dim).to(device)


  def forward(self, x):
    '''
    Forward pass through the encoder and both decoders
    '''

    z_mean, z_var = self.encoder(x)
    x_rec_vae = self.decoder_VAE(None, z_mean, z_var)

    return x_rec_vae


  def compute_loss(self, x, x_rec_vae):
    ''' 
    Computer 2 rec_losses and reg_loss
    Reg_loss: KL-Divergence between encoder q_enc(z|x) and uniform prior p(z)
    '''

    rec_loss = F.mse_loss(x_rec_vae, x)
    reg_loss = torch.mean(-0.5 * torch.sum(1 + z_log_var - z_mean ** 2 - z_log_var.exp(), dim = 1), dim = 0)

    return rec_loss + reg_loss, rec_loss, reg_loss

  def _KL(P,Q):
    ''' Kl-Divergence between two distributions '''
    eps = 1e-15
    P = P + eps
    Q = Q + eps
    return torch.sum(P*torch.log(P/Q))


  def train(self, learning_rate, num_epochs, dataloader):
    '''
    trains the model
    '''

    self = self.to(device)
    optimizer = optim.Adam(self.parameters(), lr=learning_rate)

    for e in range(epochs):

      epoch_loss = 0
      epoch_rec_loss = 0
      epoch_reg_loss = 0
      instance_loss = 0

      for i, (batch_images, _) in enumerate(dataloader):

        batch_images = batch_images.to(device)
        batch_images = batch_images.reshape(batch_images.shape[0], -1)

        # pass batch through the model
        x_rec_vae  = self.forward(batch_images)

        loss, rec_loss, reg_loss = self.compute_loss(batch_images, x_rec_vae)
        
        epoch_loss += loss
        epoch_rec_loss += rec_loss
        epoch_reg_loss += reg_loss

        loss.backward()
        # if i < 10:
        #   print(self.decoder_VAE.decoder[2].weight.grad)
        optimizer.step()
        optimizer.zero_grad()

        instance_loss = loss / batch_images.shape[0]

        wandb.log({"instance_loss": instance_loss})

      wandb.log({"epoch_loss": epoch_loss})
      wandb.log({"rec_loss": epoch_rec_loss})
      wandb.log({"reg_loss": epoch_reg_loss})
      print(f'Epoch: {e} done, Loss: {epoch_loss}, Rec_Loss: {epoch_rec_loss}, Reg_Loss: {epoch_reg_loss}')

    print('Training complete')

  def generate(self, model, x):
    x_rec_vae, x_rec_ae, simplex, z_mean, z_log_var  = model.forward(x)
    x_rec_vae = x_rec_vae.cpu().detach().numpy()
    plt.figure(2)
    x = x_rec_vae[0,0]
    x = x.reshape(28,28)
    plt.imshow(x)
    return

In [None]:
config = {
  "model_type": 'N-AE',
  "learning_rate": 0.001,
  "epochs": 5,
  "batch_size": 1024,
  "z_dim": 10,
}

wandb.init(project="test-project", entity="inspired-minds", name='test2', config=config)


model_type = 'N-VAE'
learning_rate = 0.001
epochs = 20
batch_size = 1024
z_dim = 10

model = DD_VAE(model_tpye=model_type, z_dim=z_dim)
dataloader = train_loader

model.train(learning_rate, epochs, dataloader)
sample_1 = sample_1.to(device)
sample_1 = sample_1.reshape(1,-1)
model.generate(model, sample_1.unsqueeze(0).unsqueeze(0))