In [None]:
# standard packages
import numpy as np
import matplotlib.pyplot as plt

# pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as  data

# torchvision
import torchvision
from torchvision.datasets import MNIST
from torchvision import transforms

# Wandb
import wandb

# paths
DATASET_PATH = './data'
CECKPOINT_PATH = './checkpoints'

# seed
seed = 7
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

# ensure reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# device
device = device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
import h5py
import os
from PIL import Image
import urllib.request


class fixedMNIST(data.Dataset):
    """ Binarized MNIST dataset, proposed in
    http://proceedings.mlr.press/v15/larochelle11a/larochelle11a.pdf """

    train_file = "binarized_mnist_train.amat"
    val_file = "binarized_mnist_valid.amat"
    test_file = "binarized_mnist_test.amat"

    def __init__(self, root, train=True, transform=None, download=False):
        # we ignore transform.
        self.root = os.path.expanduser(root)
        self.train = train  # training set or test set

        if download:
            self.download()
        if not self._check_exists():
            raise RuntimeError(
                "Dataset not found." + " You can use download=True to download it"
            )

        self.data = self._get_data(train=train)

    def __getitem__(self, index):
        img = self.data[index]
        img = Image.fromarray(img)
        img = transforms.ToTensor()(img).type(torch.FloatTensor)
        return img, torch.tensor(-1)  # Meaningless tensor instead of target

    def __len__(self):
        return len(self.data)

    def _get_data(self, train=True):
        with h5py.File(os.path.join(self.root, "data.h5"), "r") as hf:
            data = hf.get("train" if train else "test")
            data = np.array(data)
        return data

    def get_mean_img(self):
        return self.data.mean(0).flatten()

    def download(self):
        if self._check_exists():
            return
        if not os.path.exists(self.root):
            os.makedirs(self.root)

        print("Downloading MNIST with fixed binarization...")
        for dataset in ["train", "valid", "test"]:
            filename = "binarized_mnist_{}.amat".format(dataset)
            url = "http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/binarized_mnist_{}.amat".format(
                dataset
            )
            print("Downloading from {}...".format(url))
            local_filename = os.path.join(self.root, filename)
            urllib.request.urlretrieve(url, local_filename)
            print("Saved to {}".format(local_filename))

        def filename_to_np(filename):
            with open(filename) as f:
                lines = f.readlines()
            return np.array([[int(i) for i in line.split()] for line in lines]).astype(
                "int8"
            )

        train_data = np.concatenate(
            [
                filename_to_np(os.path.join(self.root, self.train_file)),
                filename_to_np(os.path.join(self.root, self.val_file)),
            ]
        )
        test_data = filename_to_np(os.path.join(self.root, self.val_file))
        with h5py.File(os.path.join(self.root, "data.h5"), "w") as hf:
            hf.create_dataset("train", data=train_data.reshape(-1, 28, 28))
            hf.create_dataset("test", data=test_data.reshape(-1, 28, 28))
        print("Done!")

    def _check_exists(self):
        return os.path.exists(os.path.join(self.root, "data.h5"))

In [None]:
# transformations applied to all images
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])
# binarized mnist
loader_fn, root = fixedMNIST, DATASET_PATH + "/fixedmnist"

# load train dataset
train_dataset = loader_fn(root=root, train=True, download=True, transform=transform)
train_set, val_set = data.random_split(train_dataset, [50000, 10000])

# load test dataset
test_set = loader_fn(root=root, train=False, download=True, transform=transform)

# data loaders
train_loader = data.DataLoader(train_set, batch_size=1024, shuffle=True)
val_loader = data.DataLoader(val_set, batch_size=1024, shuffle=False)
test_loader = data.DataLoader(test_set, batch_size=1024, shuffle=False)

In [None]:
for i, (image, label) in enumerate(train_loader):
  plt.figure(1)
  plt.imshow(image[0,0])
  sample_1 = image[0,0]
  break

In [None]:
class Encoder(nn.Module):

  def __init__(
      self, 
      z_dim: int=2,
      act_fn: object=nn.ReLU
  ):
    
    super().__init__()

    self.encoder = nn.Sequential(
        nn.Linear(784, 512),
        act_fn(),
        nn.Linear(512, 256),
        act_fn(),
    )

  def forward(self, x):
    return self.encoder(x)

class Decoder(nn.Module):

  def __init__(
      self,
      z_dim: int=2,
      act_fn: object=nn.ReLU
  ):
    
    super().__init__()

    self.decoder = nn.Sequential(
        nn.Linear(256, 512),
        act_fn(),
        nn.Linear(512,784),
        # nn.Sigmoid(),
    )

  def forward(self, x):
    x = self.decoder(x)
    return x

In [None]:
class DD_VAE(nn.Module):

  '''

    Parameters:
      batch_size: batch_size
      lr = learning rate
      encoder: encoder module
      decoder: corresponding decoder models
      z_dim: latent dimension / categories of categorical distribution
    '''

  def __init__(
      self,
      # model type
      model_tpye: str = 'DD-VAE',
      # Encoder, Decoder
      encoder: object = Encoder,
      decoder: object = Decoder,
      # model specifications
      z_dim: int=2,
      learning_rate = 1e-3,
      act_fn: object=nn.ReLU
  ):

    super().__init__()
    self.model_tpye = model_tpye
    self.binary_cross_entropy = nn.BCELoss(reduction="none")
    self.z_dim = z_dim

    self.encoder = encoder(z_dim).to(device)
    self.linear = nn.Linear(256, z_dim*10)

    # stochastic decoder (generative process)
    self.linear_vae = nn.Sequential(
      nn.Linear(z_dim*10, 256),
      act_fn()
    )
    self.decoder_VAE = decoder(z_dim).to(device)

    # deterministic decoder (approximation)
    self.linear_ae = nn.Sequential(
      nn.Linear(z_dim*10, 256),
      act_fn()
    )
    self.decoder_AE = decoder(z_dim).to(device)
    # we start with the same weights for the two decoders
    self.decoder_AE.decoder.load_state_dict(self.decoder_VAE.decoder.state_dict())

    # dirichlet prior
    concentration = torch.ones(10) * 0.1
    self.prior = torch.distributions.dirichlet.Dirichlet(concentration)

    # optimizers
    self.optimizer_rec = optim.Adam(self.parameters(), lr=learning_rate)
    self.optimizer_app = optim.Adam(self.parameters(), lr=learning_rate)

  def dirichlet_sampling(self, num_samples=10):
    '''
    sample from dirichlet prior to explore latent space
    '''
    simplex_batch = self.prior.rsample((num_samples, self.z_dim))
    simplex_batch = simplex_batch.to(device)

    with torch.no_grad():
      # generative process (stochastic decoder)
      sample = torch.distributions.Categorical(logits=simplex_batch).sample()
      z_vae = torch.nn.functional.one_hot(sample, 10).reshape(sample.shape[0],-1)
      z_vae = z_vae.type(torch.FloatTensor).to(device)
      z_vae = self.linear_vae(z_vae)
      x_rec_vae = self.decoder_VAE(z_vae)
      x_rec_vae = F.sigmoid(x_rec_vae)

    # approximation of generative process (determininstic decoder)
    z_ae = simplex_batch.reshape(simplex_batch.shape[0],-1)
    z_ae = self.linear_ae(z_ae)
    x_rec_ae = self.decoder_AE(z_ae)
    x_rec_ae = F.sigmoid(x_rec_ae)

    return x_rec_vae, x_rec_ae

  def forward(self, x):
    '''
    Forward pass through the encoder and both decoders
    '''

    x = self.encoder(x)
    simplex = self.linear(x)
    simplex = simplex.reshape(x.shape[0], -1, 10)
    # simplex = F.softmax(simplex.reshape(x.shape[0], -1, 10), dim=2) (careful: gradient bottleneck)

    # generative process (stochastic decoder)
    sample = torch.distributions.Categorical(logits=simplex).sample()
    z_vae = torch.nn.functional.one_hot(sample, 10).reshape(sample.shape[0],-1)
    z_vae = z_vae.type(torch.FloatTensor).to(device)
    z_vae = self.linear_vae(z_vae)
    x_rec_vae = self.decoder_VAE(z_vae)
    x_rec_vae = F.sigmoid(x_rec_vae)

    # approximation of generative process (determininstic decoder)
    z_ae = simplex.reshape(simplex.shape[0],-1)
    z_ae = self.linear_ae(z_ae)
    x_rec_ae = self.decoder_AE(z_ae)
    x_rec_ae = F.sigmoid(x_rec_ae)

    return x_rec_vae, x_rec_ae

  def optimize_reconstruction(self, x, x_rec_vae, x_rec_ae):
    ''' 
    Takes output of forward as input
    Computes rec_losses for each decoder
    '''

    rec_loss_vae = self.binary_cross_entropy(x_rec_vae, x)
    rec_loss_vae = rec_loss_vae.sum(1).mean() # TODO: sum everything instead of mean
    # rec_loss_vae = 0

    rec_loss_ae = self.binary_cross_entropy(x_rec_ae, x)
    rec_loss_ae = rec_loss_ae.sum(1).mean() # TODO: sum everything instead of mean
    # rec_loss_ae = 0

    reconstruction_loss = rec_loss_vae + rec_loss_ae

    reconstruction_loss.backward()
    self.optimizer_rec.step()
    self.optimizer_rec.zero_grad()

    return rec_loss_vae, rec_loss_ae

  def optimize_approximation(self, x_rec_vae, x_rec_ae):
    ''' 
    Takes output of dirchlet sampling as input
    Computer cross_decoder approximation loss and optimize 
    deterministic decoder to approximate stochastic decoder
    '''

    x_rec_vae = x_rec_vae.detach()
    cross_loss = self.binary_cross_entropy(x_rec_ae, x_rec_vae) # ANESI-Loss
    cross_loss = cross_loss.sum(1).mean() #TODO: Check whether to use sum or mean

    cross_loss.backward()
    self.optimizer_app.step()
    self.optimizer_app.zero_grad()

    return cross_loss

  def train(self, epochs, dataloader):
    '''
    trains the model
    '''

    self = self.to(device)

    for e in range(epochs):

      epoch_rec_vae_loss = 0
      epoch_rec_ae_loss = 0
      epoch_cross_loss = 0

      for i, (batch_images, _) in enumerate(dataloader):

        batch_images = batch_images.to(device)
        batch_images = batch_images.reshape(batch_images.shape[0], -1)

        # optimize reconstruction step
        x_rec_vae, x_rec_ae  = self.forward(batch_images)
        rec_loss_vae, rec_loss_ae = self.optimize_reconstruction(batch_images, x_rec_vae, x_rec_ae)

        # optimize approximation step
        x_rec_vae, x_rec_ae  = self.dirichlet_sampling(10)
        cross_loss = self.optimize_approximation(x_rec_vae, x_rec_ae)
        
        epoch_rec_vae_loss += rec_loss_vae
        epoch_rec_ae_loss += rec_loss_ae
        epoch_cross_loss += cross_loss

        wandb.log({"instance_loss": rec_loss_vae})
        wandb.log({"det_decoder_loss": rec_loss_ae})
        wandb.log({"cross_loss": cross_loss})

      wandb.log({"epoch": e})
      wandb.log({"epoch_rec_vae_loss": epoch_rec_vae_loss})
      wandb.log({"epoch_rec_ae_loss": epoch_rec_ae_loss})
      wandb.log({"epoch_cross_loss": epoch_cross_loss})

      print(f'Epoch: {e} done, stochastic decoder loss: {epoch_rec_vae_loss}, deterministic decoder loss: {epoch_rec_ae_loss}, approximation loss: {epoch_cross_loss}')

    print('Training complete')

In [None]:
config = {
  "model_type": 'DD-VAE',
  "learning_rate": 0.001,
  "epochs": 100,
  "batch_size": 1024,
  "z_dim": 10,
  "ds": 10,
}

wandb.init(project="test-project", entity="inspired-minds", name=f'z_dim = {config["z_dim"]}, ds={config["ds"]}', config=config)


model_type = config['model_type']
learning_rate = config['learning_rate']
epochs = config['epochs']
batch_size = config['batch_size']
z_dim = config['z_dim']

model = DD_VAE(model_tpye=model_type, z_dim=z_dim, learning_rate=learning_rate)
dataloader = train_loader

model.train(epochs, dataloader)

In [None]:
model

In [None]:
def visualize_reconstructions():
  '''
  For latent space z_dim=2 this returns a visualisation of the learned manifold
  '''

  z_1 = torch.arange(0, 10, 1)

  mesh_grid_x, mesh_grid_y = torch.meshgrid(z_1, z_1, indexing='ij')
  mesh_grid = torch.stack([mesh_grid_x.flatten(), mesh_grid_y.flatten()], dim=1)

  z = mesh_grid

  z = torch.nn.functional.one_hot(z, 10).reshape(z.shape[0],-1)
  z = z.type(torch.FloatTensor).to(device)

  sampled_img = model.linear_vae(z)
  sampled_img = model.decoder_VAE(sampled_img)
  sampled_img = F.softmax(sampled_img, dim=1).unsqueeze(1)

  sampled_img = sampled_img.reshape(-1,28,28).cpu().detach().numpy()

  fig, ax = plt.subplots(10,10,figsize=(20,20))
  for i in range(10**2):
    ax[i//10,i%10].imshow(sampled_img[i])

  return sampled_img

In [None]:
_ = visualize_reconstructions()