In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

# load mnist
train_ds = MNIST('data', train=True, download=True, transform=ToTensor())
test_ds = MNIST('data', train=False, download=True , transform=ToTensor())

# create data loaders
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=12)
val_loader = DataLoader(test_ds, batch_size=32, shuffle=False, num_workers=12)

sample = next(iter(train_loader))

In [29]:
class Encoder(nn.Module):
    def __init__(self, latent_dim=16):
        super(Encoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, latent_dim, 1, padding=0)
        )

    def forward(self, x):
        return self.encoder(x)


In [30]:
print(sample[0].shape)

encoder = Encoder()
out = encoder(sample[0])
print(out.shape)

torch.Size([32, 1, 28, 28])
torch.Size([32, 16, 7, 7])


In [46]:
class Decoder(nn.Module):
    def __init__(self, latent_dim=16):
        super(Decoder, self).__init__()
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 64, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, 3, padding=1)
        )

    def forward(self, x):
        return self.decoder(x)

In [47]:
randn = torch.randn(32, 16, 7, 7)
decoder = Decoder()
out = decoder(randn)
print(out.shape)

torch.Size([32, 1, 28, 28])


In [4]:
class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, beta=0.25):
        super(VectorQuantizer, self).__init__()
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.beta = beta

        w_init = torch.nn.init.uniform_
        self.embeddings = nn.Parameter(w_init(torch.empty(self.embedding_dim, self.num_embeddings)))
        
    def forward(self, x):
        input_shape = x.shape
        flattened = x.view(-1, self.embedding_dim)

        encoding_indices = self.get_code_indices(flattened)
        encodings = F.one_hot(encoding_indices, num_classes=self.num_embeddings).to(flattened.device)
        quantized = torch.matmul(encodings.float(), self.embeddings.t())

        quantized = quantized.view(input_shape)

        commitment_loss = torch.mean((quantized.detach() - x) ** 2)
        codebook_loss = torch.mean((quantized - x.detach()) ** 2)
        self.loss = self.beta * commitment_loss + codebook_loss

        quantized = x + (quantized - x).detach()
        return quantized

    def get_code_indices(self, flattened_inputs):
        similarity = torch.matmul(flattened_inputs, self.embeddings)
        distances = (torch.sum(flattened_inputs ** 2, dim=1, keepdim=True)
                     + torch.sum(self.embeddings ** 2, dim=0)
                     - 2 * similarity)

        encoding_indices = torch.argmin(distances, dim=1)
        return encoding_indices