<a href="https://colab.research.google.com/github/Ryan0v0/ACM-codes/blob/master/vq_vae.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#!pip3 install -U -r requirements.txt

# Step1: Splitting up neural net params into chunks

In [2]:
import torch
import torch.nn as nn

# Define the neural network architecture
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.fc1 = nn.Linear(10, 100)
        self.fc2 = nn.Linear(100, 100)
        self.fc3 = nn.Linear(100, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Create an instance of the neural network
net = NeuralNetwork()

# Split up the neural network parameters into chunks
chunk_size = 1000
param_chunks = []
for param in net.parameters():
    flattened_param = param.view(-1)
    chunks = torch.split(flattened_param, chunk_size)
    param_chunks.extend(chunks)

# Print the number of parameter chunks
print("Number of parameter chunks:", len(param_chunks))
print("Parameter chunks:", param_chunks)


Number of parameter chunks: 15
Parameter chunks: [tensor([-6.4396e-02,  2.1130e-01, -1.1838e-01,  2.9958e-01,  2.1870e-01,
         2.6336e-01,  2.5198e-01,  2.9975e-01, -3.1238e-01,  1.7613e-01,
         1.0182e-01, -1.2976e-01,  2.8008e-01,  2.3157e-01, -2.1085e-02,
         2.4409e-01, -1.8932e-01,  3.0637e-01,  2.3122e-01,  8.3699e-02,
        -2.1217e-02, -1.2121e-01,  2.6251e-01,  2.1006e-01,  2.7667e-01,
         4.6984e-02,  1.7459e-01,  8.8939e-03, -5.0475e-02,  2.3915e-01,
         3.1617e-01, -2.3763e-01, -2.1570e-01, -2.0415e-01,  2.7576e-01,
        -2.1171e-01,  4.5023e-03,  1.8539e-01, -1.5277e-01,  1.6450e-01,
        -2.0056e-01,  1.9189e-01,  9.8684e-02, -4.5814e-02, -1.4340e-01,
        -4.6220e-02,  7.9708e-02, -1.1465e-01,  1.1937e-01,  2.0092e-01,
         3.2817e-02, -1.5336e-01,  1.3120e-01, -4.4501e-03,  5.4510e-02,
         3.0196e-01,  1.1010e-01, -6.1521e-02, -1.1210e-02, -2.6767e-01,
        -2.0423e-01,  9.6125e-02,  2.4598e-01, -1.5151e-01,  5.7496e-02,
 

In [3]:
import numpy as np

# Convert param_chunks to a numpy array
param_chunks_np = np.concatenate([chunk.detach().numpy() for chunk in param_chunks])

print(type(param_chunks_np))
print("size:", param_chunks_np.shape)

<class 'numpy.ndarray'>
size: (11301,)


# Step2: learning a mapping from each chunk to an integer via VQ-VAE

In [4]:
from __future__ import print_function


import matplotlib.pyplot as plt
import numpy as np
from scipy.signal import savgol_filter


from six.moves import xrange

# import umap

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim

import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchvision.utils import make_grid

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Load Data

To verify the correctness of the VQVAE itself, I am currently using CIFAR-10 as input data.

Next step: using the weights of the above neural network as input.

In [6]:
training_data = datasets.CIFAR10(root="data", train=True, download=True,
                                  transform=transforms.Compose([
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.5,0.5,0.5), (1.0,1.0,1.0))
                                  ]))

validation_data = datasets.CIFAR10(root="data", train=False, download=True,
                                  transform=transforms.Compose([
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.5,0.5,0.5), (1.0,1.0,1.0))
                                  ]))

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:07<00:00, 23619885.03it/s]


Extracting data/cifar-10-python.tar.gz to data
Files already downloaded and verified


In [7]:
print(training_data.data[0])
print("size:", training_data.data.shape)
print(validation_data.data[0])
print("size:", validation_data.data.shape)

[[[ 59  62  63]
  [ 43  46  45]
  [ 50  48  43]
  ...
  [158 132 108]
  [152 125 102]
  [148 124 103]]

 [[ 16  20  20]
  [  0   0   0]
  [ 18   8   0]
  ...
  [123  88  55]
  [119  83  50]
  [122  87  57]]

 [[ 25  24  21]
  [ 16   7   0]
  [ 49  27   8]
  ...
  [118  84  50]
  [120  84  50]
  [109  73  42]]

 ...

 [[208 170  96]
  [201 153  34]
  [198 161  26]
  ...
  [160 133  70]
  [ 56  31   7]
  [ 53  34  20]]

 [[180 139  96]
  [173 123  42]
  [186 144  30]
  ...
  [184 148  94]
  [ 97  62  34]
  [ 83  53  34]]

 [[177 144 116]
  [168 129  94]
  [179 142  87]
  ...
  [216 184 140]
  [151 118  84]
  [123  92  72]]]
size: (50000, 32, 32, 3)
[[[158 112  49]
  [159 111  47]
  [165 116  51]
  ...
  [137  95  36]
  [126  91  36]
  [116  85  33]]

 [[152 112  51]
  [151 110  40]
  [159 114  45]
  ...
  [136  95  31]
  [125  91  32]
  [119  88  34]]

 [[151 110  47]
  [151 109  33]
  [158 111  36]
  ...
  [139  98  34]
  [130  95  34]
  [120  89  33]]

 ...

 [[ 68 124 177]
  [ 42 100 

In [8]:
data_variance = np.var(training_data.data / 255.0)

In [9]:
data_variance = np.var(param_chunks_np / 255.0)

print(data_variance)

9.486475e-08


## Vector Quantizer Layer

This layer takes a tensor to be quantized. The channel dimension will be used as the space in which to quantize. All other dimensions will be flattened and will be seen as different examples to quantize.

The output tensor will have the same shape as the input.

As an example for a `BCHW` tensor of shape `[16, 64, 32, 32]`, we will first convert it to an `BHWC` tensor of shape `[16, 32, 32, 64]` and then reshape it into `[16384, 64]` and all `16384` vectors of size `64`  will be quantized independently. In otherwords, the channels are used as the space in which to quantize. All other dimensions will be flattened and be seen as different examples to quantize, `16384` in this case.

In [10]:
class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost):
        super(VectorQuantizer, self).__init__()

        self._embedding_dim = embedding_dim
        self._num_embeddings = num_embeddings

        self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
        self._embedding.weight.data.uniform_(-1/self._num_embeddings, 1/self._num_embeddings)
        self._commitment_cost = commitment_cost

    def forward(self, inputs):
        # convert inputs from BCHW -> BHWC
        inputs = inputs.permute(0, 2, 3, 1).contiguous()
        input_shape = inputs.shape

        # Flatten input
        flat_input = inputs.view(-1, self._embedding_dim)

        # Calculate distances
        distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
                    + torch.sum(self._embedding.weight**2, dim=1)
                    - 2 * torch.matmul(flat_input, self._embedding.weight.t()))

        # Encoding
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
        encodings.scatter_(1, encoding_indices, 1)

        # Quantize and unflatten
        quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)

        # Loss
        e_latent_loss = F.mse_loss(quantized.detach(), inputs)
        q_latent_loss = F.mse_loss(quantized, inputs.detach())
        loss = q_latent_loss + self._commitment_cost * e_latent_loss

        quantized = inputs + (quantized - inputs).detach()
        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))

        # convert quantized from BHWC -> BCHW
        return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings

We will also implement a slightly modified version  which will use exponential moving averages to update the embedding vectors instead of an auxillary loss. This has the advantage that the embedding updates are independent of the choice of optimizer for the encoder, decoder and other parts of the architecture. For most experiments the EMA version trains faster than the non-EMA version.

In [11]:
class VectorQuantizerEMA(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay, epsilon=1e-5):
        super(VectorQuantizerEMA, self).__init__()

        self._embedding_dim = embedding_dim
        self._num_embeddings = num_embeddings

        self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
        self._embedding.weight.data.normal_()
        self._commitment_cost = commitment_cost

        self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings))
        self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim))
        self._ema_w.data.normal_()

        self._decay = decay
        self._epsilon = epsilon

    def forward(self, inputs):
        # convert inputs from BCHW -> BHWC
        inputs = inputs.permute(0, 2, 3, 1).contiguous()
        input_shape = inputs.shape

        # Flatten input
        flat_input = inputs.view(-1, self._embedding_dim)

        # Calculate distances
        distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
                    + torch.sum(self._embedding.weight**2, dim=1)
                    - 2 * torch.matmul(flat_input, self._embedding.weight.t()))

        # Encoding
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
        encodings.scatter_(1, encoding_indices, 1)

        # Quantize and unflatten
        quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)

        # Use EMA to update the embedding vectors
        if self.training:
            self._ema_cluster_size = self._ema_cluster_size * self._decay + \
                                     (1 - self._decay) * torch.sum(encodings, 0)

            # Laplace smoothing of the cluster size
            n = torch.sum(self._ema_cluster_size.data)
            self._ema_cluster_size = (
                (self._ema_cluster_size + self._epsilon)
                / (n + self._num_embeddings * self._epsilon) * n)

            dw = torch.matmul(encodings.t(), flat_input)
            self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw)

            self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1))

        # Loss
        e_latent_loss = F.mse_loss(quantized.detach(), inputs)
        loss = self._commitment_cost * e_latent_loss

        # Straight Through Estimator
        quantized = inputs + (quantized - inputs).detach()
        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))

        # convert quantized from BHWC -> BCHW
        return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings

## Encoder & Decoder Architecture

The encoder and decoder architecture is based on a ResNet and is implemented below:

In [12]:
class Residual(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_hiddens):
        super(Residual, self).__init__()
        self._block = nn.Sequential(
            nn.ReLU(True),
            nn.Conv2d(in_channels=in_channels,
                      out_channels=num_residual_hiddens,
                      kernel_size=3, stride=1, padding=1, bias=False),
            nn.ReLU(True),
            nn.Conv2d(in_channels=num_residual_hiddens,
                      out_channels=num_hiddens,
                      kernel_size=1, stride=1, bias=False)
        )

    def forward(self, x):
        return x + self._block(x)


class ResidualStack(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
        super(ResidualStack, self).__init__()
        self._num_residual_layers = num_residual_layers
        self._layers = nn.ModuleList([Residual(in_channels, num_hiddens, num_residual_hiddens)
                             for _ in range(self._num_residual_layers)])

    def forward(self, x):
        for i in range(self._num_residual_layers):
            x = self._layers[i](x)
        return F.relu(x)

In [65]:
class Encoder(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
        super(Encoder, self).__init__()

        '''
        self._conv_1 = nn.Conv2d(in_channels=in_channels,
                                 out_channels=num_hiddens//2,
                                 kernel_size=4,
                                 stride=2, padding=1)
        self._conv_2 = nn.Conv2d(in_channels=num_hiddens//2,
                                 out_channels=num_hiddens,
                                 kernel_size=4,
                                 stride=2, padding=1)
        self._conv_3 = nn.Conv2d(in_channels=num_hiddens,
                                 out_channels=num_hiddens,
                                 kernel_size=3,
                                 stride=1, padding=1)
        '''
        self._linear_1 = nn.Linear(in_channels, num_hiddens//2)
        self._linear_2 = nn.Linear(num_hiddens//2, num_hiddens)

        '''
        self._residual_stack = ResidualStack(in_channels=num_hiddens,
                                             num_hiddens=num_hiddens,
                                             num_residual_layers=num_residual_layers,
                                             num_residual_hiddens=num_residual_hiddens)
        '''


    def forward(self, inputs):
        # x = self._conv_1(inputs)
        x = self._linear_1(inputs.view(inputs.size(0), -1)) # x = self._linear_1(inputs)
        x = F.relu(x)

        # x = self._conv_2(x)
        x = self._linear_2(x)
        x = F.relu(x)

        # x = self._conv_3(x)
        return x # self._residual_stack(x)

In [66]:
class Decoder(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
        super(Decoder, self).__init__()

        '''
        self._conv_1 = nn.Conv2d(in_channels=in_channels,
                                 out_channels=num_hiddens,
                                 kernel_size=3,
                                 stride=1, padding=1)
        '''

        self._linear_1 = nn.Linear(in_channels, num_hiddens)

        '''
        self._residual_stack = ResidualStack(in_channels=num_hiddens,
                                             num_hiddens=num_hiddens,
                                             num_residual_layers=num_residual_layers,
                                             num_residual_hiddens=num_residual_hiddens)

        self._conv_trans_1 = nn.ConvTranspose2d(in_channels=num_hiddens,
                                                out_channels=num_hiddens//2,
                                                kernel_size=4,
                                                stride=2, padding=1)

        self._conv_trans_2 = nn.ConvTranspose2d(in_channels=num_hiddens//2,
                                                out_channels=3,
                                                kernel_size=4,
                                                stride=2, padding=1)
        '''
        self._linear_2 = nn.Linear(num_hiddens, num_hiddens//2)

    def forward(self, inputs):
        # x = self._conv_1(inputs)
        x = self._linear_1(inputs.view(inputs.size(0), -1)) # x = self._linear_1(inputs)
        x = F.relu(x)

        # x = self._residual_stack(x)

        # x = self._conv_trans_1(x)
        x = self._linear_2(x)

        x = F.relu(x)

        return x # self._conv_trans_2(x)

## Train

We use the hyperparameters from the author's code:

In [67]:
batch_size = 256
num_training_updates = 15000

num_hiddens = 128
num_residual_hiddens = 32
num_residual_layers = 2

embedding_dim = 64
num_embeddings = 512

commitment_cost = 0.25

decay = 0.99

learning_rate = 1e-3

In [68]:
import torch
from torch.utils.data import DataLoader, TensorDataset

# Create a TensorDataset from param_chunks_np
dataset = TensorDataset(torch.from_numpy(param_chunks_np))

# Set the batch size and other DataLoader parameters
batch_size = 64
shuffle = True
pin_memory = True

# Create the DataLoader using the custom dataset
training_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, pin_memory=pin_memory)

print(training_loader)

print(iter(training_loader))

data = next(iter(training_loader))

print("ori_data=", data)

print(type(next(iter(training_loader))))
# print(data[0])

data = next(iter(training_loader))
# for i in range(len(data)):
#    data[i] = data[i].to(device)
data = torch.stack(data).to(device)
print("data=", data)

# There's no label in the NN weight dataset
'''
for batch_idx, data in enumerate(training_loader):
    print("Batch Index:", batch_idx)
    print("Data:", data)
    print()
'''

<torch.utils.data.dataloader.DataLoader object at 0x7f3c79ee9f30>
<torch.utils.data.dataloader._SingleProcessDataLoaderIter object at 0x7f3c79ee9210>
ori_data= [tensor([ 0.0161,  0.0678, -0.0877,  0.0406,  0.0065, -0.0521, -0.0885, -0.2883,
        -0.0341,  0.0646,  0.0905,  0.0710,  0.0094,  0.0188, -0.0050,  0.0798,
         0.0407, -0.0976, -0.0636,  0.0140,  0.0546, -0.0873, -0.0953,  0.0580,
         0.0760, -0.0899,  0.0215, -0.0032,  0.0129, -0.0101, -0.0324, -0.0518,
         0.0512,  0.0070, -0.0245,  0.0450, -0.0626,  0.0195,  0.0714, -0.0559,
        -0.0102,  0.0126,  0.0284, -0.0664, -0.0254,  0.0215,  0.0745, -0.0946,
         0.0533, -0.0298,  0.0704, -0.0590,  0.0668, -0.0313, -0.0286, -0.0742,
        -0.0551,  0.0166,  0.0057,  0.0725, -0.0769, -0.0500, -0.0220, -0.0744])]
<class 'list'>
data= tensor([[-9.2900e-02,  1.0742e-02, -2.9678e-02, -5.3320e-02,  5.0359e-02,
          2.1486e-02, -4.9762e-02, -6.1279e-02, -3.9924e-02, -7.6475e-02,
         -8.3216e-02,  4.807

'\nfor batch_idx, data in enumerate(training_loader):\n    print("Batch Index:", batch_idx)\n    print("Data:", data)\n    print()\n'

In [69]:
cifar_training_data = datasets.CIFAR10(root="data", train=True, download=True,
                                  transform=transforms.Compose([
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.5,0.5,0.5), (1.0,1.0,1.0))
                                  ]))

cifar_training_loader = DataLoader(cifar_training_data,
                             batch_size=batch_size,
                             shuffle=True,
                             pin_memory=True)

print(cifar_training_loader)

print(iter(cifar_training_loader))

(data, _) = next(iter(cifar_training_loader))
print(type(next(iter(cifar_training_loader))))
print(type(data))
print(type(_))
print("ori_data=", data)

data = data.to(device)

print("data=", data)

'''
for batch_idx, (data, _) in enumerate(cifar_training_loader):
    print("Batch Index:", batch_idx)
    # print("Data:", data)
    print("Label:", _)
    print()
'''

Files already downloaded and verified
<torch.utils.data.dataloader.DataLoader object at 0x7f3d541b3fa0>
<torch.utils.data.dataloader._SingleProcessDataLoaderIter object at 0x7f3d541b2ef0>
<class 'list'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
ori_data= tensor([[[[ 0.3196,  0.4216,  0.4059,  ...,  0.3824,  0.4098,  0.3824],
          [ 0.2137,  0.2608,  0.3588,  ...,  0.3510,  0.2843,  0.2255],
          [-0.0961, -0.0647,  0.0529,  ...,  0.2882,  0.1157, -0.0804],
          ...,
          [ 0.2569,  0.2608,  0.2804,  ...,  0.2843,  0.2686,  0.2608],
          [ 0.2922,  0.3157,  0.3157,  ...,  0.2608,  0.2765,  0.2647],
          [ 0.2882,  0.3039,  0.2922,  ...,  0.2647,  0.2647,  0.2569]],

         [[ 0.3039,  0.4020,  0.3902,  ...,  0.3706,  0.3941,  0.3706],
          [ 0.2059,  0.2451,  0.3471,  ...,  0.3314,  0.2569,  0.2059],
          [-0.1039, -0.0765,  0.0412,  ...,  0.2686,  0.0882, -0.1078],
          ...,
          [ 0.1510,  0.1392,  0.1510,  ...,  0.1627,  0.1549,

'\nfor batch_idx, (data, _) in enumerate(cifar_training_loader):\n    print("Batch Index:", batch_idx)\n    # print("Data:", data)\n    print("Label:", _)\n    print()\n'

In [70]:
validation_loader = DataLoader(validation_data,
                               batch_size=32,
                               shuffle=True,
                               pin_memory=True)

In [71]:
class Model(nn.Module):
    def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens,
                 num_embeddings, embedding_dim, commitment_cost, decay=0):
        super(Model, self).__init__()

        # Updated input size to 64
        self._encoder = Encoder(1, num_hiddens,
                                num_residual_layers,
                                num_residual_hiddens)

        '''Replaced nn.Conv2d with nn.Linear
        self._pre_vq_conv = nn.Conv2d(in_channels=num_hiddens,
                                      out_channels=embedding_dim,
                                      kernel_size=1,
                                      stride=1)
        '''
        # Replaced with nn.Linear
        self._pre_vq_linear = nn.Linear(num_hiddens, embedding_dim)

        if decay > 0.0:
            self._vq_vae = VectorQuantizerEMA(num_embeddings, embedding_dim,
                                              commitment_cost, decay)
        else:
            self._vq_vae = VectorQuantizer(num_embeddings, embedding_dim,
                                           commitment_cost)
        self._decoder = Decoder(embedding_dim,
                                num_hiddens,
                                num_residual_layers,
                                num_residual_hiddens)

    def forward(self, x):
        x = x.unsqueeze(0)  # Add an extra dimension for batch size
        z = self._encoder(x)
        z = z.view(z.size(0), -1)  # Flatten the tensor
        # Replaced self._pre_vq_conv with self._pre_vq_linear
        z = self._pre_vq_linear(z)
        loss, quantized, perplexity, _ = self._vq_vae(z)
        # Reshape quantized before passing it to the decoder
        quantized = quantized.unsqueeze(2).unsqueeze(3)
        x_recon = self._decoder(quantized)

        return loss, x_recon, perplexity

In [72]:
model = Model(num_hiddens, num_residual_layers, num_residual_hiddens,
              num_embeddings, embedding_dim,
              commitment_cost, decay).to(device)

In [73]:
optimizer = optim.Adam(model.parameters(), lr=learning_rate, amsgrad=False)

In [74]:
model.train()
train_res_recon_error = []
train_res_perplexity = []

for i in xrange(num_training_updates):
    # (data, _) = next(iter(training_loader))
    # data = data.to(device)
    data = next(iter(training_loader))
    data = torch.stack(data).to(device)
    #for i in range(len(data)):
    #  data[i] = data[i].to(device)
    optimizer.zero_grad()

    vq_loss, data_recon, perplexity = model(data)
    recon_error = F.mse_loss(data_recon, data) / data_variance
    loss = recon_error + vq_loss
    loss.backward()

    optimizer.step()

    train_res_recon_error.append(recon_error.item())
    train_res_perplexity.append(perplexity.item())

    if (i+1) % 100 == 0:
        print('%d iterations' % (i+1))
        print('recon_error: %.3f' % np.mean(train_res_recon_error[-100:]))
        print('perplexity: %.3f' % np.mean(train_res_perplexity[-100:]))
        print()

RuntimeError: ignored

## Plot Loss

In [None]:
train_res_recon_error_smooth = savgol_filter(train_res_recon_error, 201, 7)
train_res_perplexity_smooth = savgol_filter(train_res_perplexity, 201, 7)

In [None]:
f = plt.figure(figsize=(16,8))
ax = f.add_subplot(1,2,1)
ax.plot(train_res_recon_error_smooth)
ax.set_yscale('log')
ax.set_title('Smoothed NMSE.')
ax.set_xlabel('iteration')

ax = f.add_subplot(1,2,2)
ax.plot(train_res_perplexity_smooth)
ax.set_title('Smoothed Average codebook usage (perplexity).')
ax.set_xlabel('iteration')

## View Reconstructions

In [None]:
model.eval()

(valid_originals, _) = next(iter(validation_loader))
valid_originals = valid_originals.to(device)

vq_output_eval = model._pre_vq_conv(model._encoder(valid_originals))
_, valid_quantize, _, _ = model._vq_vae(vq_output_eval)
valid_reconstructions = model._decoder(valid_quantize)

In [None]:
(train_originals, _) = next(iter(training_loader))
train_originals = train_originals.to(device)
_, train_reconstructions, _, _ = model._vq_vae(train_originals)

In [None]:
def show(img):
    npimg = img.numpy()
    fig = plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')
    fig.axes.get_xaxis().set_visible(False)
    fig.axes.get_yaxis().set_visible(False)

In [None]:
print(valid_reconstructions)

In [None]:
show(make_grid(valid_reconstructions.cpu().data)+0.5, )

In [None]:
show(make_grid(valid_originals.cpu()+0.5))

## View Embedding

In [None]:
proj = umap.UMAP(n_neighbors=3,
                 min_dist=0.1,
                 metric='cosine').fit_transform(model._vq_vae._embedding.weight.data.cpu())

In [None]:
plt.scatter(proj[:,0], proj[:,1], alpha=0.3)