In [2]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

In [None]:
class VQVAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(in_channels=16, out_channels=4, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(4),
            nn.ReLU()
        )

        self.pre_quant_conv = nn.Conv2d(in_channels=4, out_channels=2, kernel_size=1)  # so, our codebook vectors are made of 2 variables.
        self.embedding = nn.Embedding(num_embeddings=3, embedding_dim=2) # only 3 embeddings. Now, we can visualize them easily
        self.post_quant_conv = nn.Conv2d(in_channels=2, out_channels=4, kernel_size=1)

        # commitment loss beta
        self.commitment_beta = 0.2

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(in_channels=4, out_channels=16, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=16, out_channels=1, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
            
        )

    def forward(self, x):
        # B, C, H, W
        encoded_output = self.encoder(x)
        quant_input = self.pre_quant_conv(encoded_output)

        # quantization
        B, C, H, W = quant_input.shape
        quant_input = quant_input.permute(0, 2, 3, 1) # the latent dimension is the last now which is 2
        quant_input = quant_input.reshape((quant_input.size(0), -1, quant_input.size(-1)))

        # pairwise distance with codebook vectors
        dist = torch.cdist(quant_input, self.embedding.weight[None, :].repeat((quant_input.size(0), 1, 1)))

        # index of nearest embedding
        min_encoding_indices = torch.argmin(dist, dim=-1)

        # select the embedding weights
        quant_out = torch.index_select(self.embedding.weight, dim=0, min_encoding_indices.view(-1))
        quant_input = quant_input.reshape((-1, quant_input.size(-1)))

        committment_loss = torch.mean((quant_out.detach() - quant_input)**2)
        codebook_loss = torch.mnean((quant_out - quant_input.detach())**2)
        quantize_losses = codebook_loss + self.commitment_beta * committment_loss

        quant_out = quant_input + (quant_out - quant_input).detach() # straight through loss
        quant_out = quant_out.reshape((B,H, W, C)).permute(0, 3, 1, 2) # ??
        min_encoding_indices = min_encoding_indices.reshape((-1, quant_out.size(-2), quant_out.size(-1))
        
        
        
        
        

        
    