In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import numpy as np
import random
import torchvision

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"using {device} device")

In [None]:
import os
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)

batch_size = 256
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=os.cpu_count(), pin_memory=False)

# Encoder and Decoder

In [None]:
class Encoder(nn.Module):
    def __init__(self, input_dim, output_dim=512):
        super(Encoder, self).__init__()
        # Conv layers
        self.conv1 = nn.Conv2d(input_dim, 64, kernel_size=4, stride=2, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.conv4 = nn.Conv2d(256, output_dim, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(output_dim)
        # Bottleneck conv
        self.residual_conv1 = nn.Conv2d(128, 256, kernel_size=1)
        self.residual_conv2 = nn.Conv2d(256, output_dim, kernel_size=1)
        # Activation
        self.act = nn.ReLU()

    def forward(self, x):
        # First block
        x = self.act(self.bn1(self.conv1(x)))
        x = self.act(self.bn2(self.conv2(x)))
        # Second block
        res = self.residual_conv1(x)
        x = self.act(self.bn3(self.conv3(x)))
        x = res + x
        res = self.residual_conv2(x)
        x = self.act(self.bn4(self.conv4(x)))
        y = res + x
        return y

    
class Decoder(nn.Module):
    def __init__(self, input_dim, output_dim=1):
        super(Decoder, self).__init__()
        # Transposed Conv layers
        self.deconv4 = nn.ConvTranspose2d(input_dim, 256, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(256)
        self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.deconv2 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.deconv1 = nn.ConvTranspose2d(64, output_dim, kernel_size=4, stride=2, padding=1)
        # Activation
        self.act = nn.ReLU()

    def forward(self, x):
        # Third block
        x = self.act(self.bn4(self.deconv4(x)))
        # Second block
        x = self.act(self.bn3(self.deconv3(x)))
        # First block
        x = self.act(self.bn2(self.deconv2(x)))
        y = self.deconv1(x)
        return y

# Test 
enc = Encoder(1)
dec = Decoder(512)
for x, _ in train_loader:
    z = enc(x)
    y = dec(z)
    print(y.shape)
    print(f"Encoded data shape: {z.shape}")
    break

In [None]:
class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim):
        super(VectorQuantizer, self).__init__()
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        nn.init.uniform_(self.embedding.weight, -1.0, 1.0)
        
    def forward(self, x):
        # Flatten input tensor
        x = x.permute(0, 2, 3, 1).contiguous()
        x_flat = x.view(-1, self.embedding_dim)
        
        # Find nearest embeddings
        distances = torch.sum(x_flat ** 2, dim=1, keepdim=True) + \
            torch.sum(self.embedding.weight**2, dim=1) - 2 * \
            torch.matmul(x_flat, self.embedding.weight.t())
        indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(indices.shape[0], self.num_embeddings).to(x.device)
        encodings.scatter_(1, indices, 1)

        # Quantize input
        quantized = torch.matmul(encodings, self.embedding.weight).view(x.shape)

        # Losses
        codebook_loss = F.mse_loss(x.detach(), quantized)
        commitment_loss = F.mse_loss(x, quantized.detach())
        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
        
        quantized = quantized.permute(0, 3, 1, 2).contiguous()
        return quantized, commitment_loss, codebook_loss, perplexity

In [None]:
class VQ_VAE(nn.Module):
    def __init__(self, Encoder, Codebook, Decoder):
        super(VQ_VAE, self).__init__()
        self.encoder = Encoder
        self.codebook = Codebook
        self.decoder = Decoder
                
    def forward(self, x):
        z = self.encoder(x)
        z_quantized, commitment_loss, codebook_loss, perplexity = self.codebook(z)
        z_grad_flow = z + (z_quantized - z).detach()
        x_hat = self.decoder(z_grad_flow)
        
        return x_hat, commitment_loss, codebook_loss, perplexity

In [None]:
num_embeddings = 3
embedding_dim = 2

encoder = Encoder(input_dim=1, output_dim=embedding_dim)
codebook = VectorQuantizer(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
decoder = Decoder(input_dim=embedding_dim, output_dim=1)

model = VQ_VAE(Encoder=encoder, Codebook=codebook, Decoder=decoder).to(device)
recon = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=5e-4)

beta = 0.25

model

In [None]:
from tqdm import tqdm

num_epochs = 30
print_idx = 2

model.train()
train_losses = []
train_perplexities = []
codebook_vectors_list = []

for epoch in range(num_epochs):
    overall_loss = 0
    overall_perplexity = 0

    for (x, _) in tqdm(train_loader, desc=f'Training epoch {epoch+1}/{num_epochs}'):
        x = x.to(device)

        optimizer.zero_grad()

        x_hat, commitment_loss, codebook_loss, perplexity = model(x)
        recon_loss = recon(x_hat, x)
        
        loss = recon_loss + codebook_loss + beta * commitment_loss
        
        overall_loss += loss.item()
        overall_perplexity += perplexity.item()

        loss.backward()
        optimizer.step()

    # Calculate mean loss and mean perplexity for the epoch
    mean_loss = overall_loss / len(train_loader)
    mean_perplexity = overall_perplexity / len(train_loader)
    codebook_vectors = model.codebook.embedding.weight.data.cpu().numpy()
    
    # Print metrics once every two epochs
    if (epoch + 1) % print_idx == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Mean Loss: {mean_loss:.4f}, Mean Perplexity: {mean_perplexity:.4f}')

    # Save metrics for plotting
    train_losses.append(mean_loss)
    train_perplexities.append(mean_perplexity)
    codebook_vectors_list.append(codebook_vectors)


In [None]:
import matplotlib.pyplot as plt

# Plotting
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.title('Training Loss Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(train_perplexities, label='Training Perplexity')
plt.title('Training Perplexity Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Perplexity')
plt.legend()

plt.tight_layout()
plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np

model.eval()

num_examples_to_plot = 5

test_examples = next(iter(train_loader))[:num_examples_to_plot]
inputs, _ = test_examples
inputs = inputs.to(device)

# Generate outputs
with torch.no_grad():
    encoded = encoder(inputs)
    codebook_output, _, _, _ = codebook(encoded) 
    outputs = decoder(codebook_output)

inputs = inputs.cpu().numpy()
outputs = outputs.cpu().numpy()
# codebook_output = codebook_output
pad = ((codebook_output[:, 0:1, :, :]) + (codebook_output[:, 1:2, :, :])) / 2
# codebook_output = torch.cat([pad, codebook_output], dim=1)
codebook_output = pad
codebook_output = codebook_output.cpu().numpy()
embedding_vectors = model.codebook.embedding.weight.data.cpu().numpy()

# Plot examples
plt.figure(figsize=(3 * num_examples_to_plot, 6))

for example_idx in range(num_examples_to_plot):
    plt.subplot(3, num_examples_to_plot, 1 + example_idx)
    plt.imshow(inputs[example_idx].transpose(1, 2, 0), cmap='Greys')
    plt.title(f'Example {example_idx + 1}\nInput')

    plt.subplot(3, num_examples_to_plot, 1 + example_idx + num_examples_to_plot)
    plt.imshow(outputs[example_idx].transpose(1, 2, 0), cmap='Greys')
    plt.title(f'Output')
    
    plt.subplot(3, num_examples_to_plot, 1 + example_idx + 2 * num_examples_to_plot)
    plt.imshow(codebook_output[example_idx].transpose(1, 2, 0),cmap='viridis')
    plt.title(f'Codebook Output')
# Plot codebook vectors separately
plt.tight_layout()
plt.show()

# Plot codebook vectors
colors = plt.cm.viridis(np.linspace(0, 1, len(embedding_vectors)))
plt.figure(figsize=(8, 8))

origin = np.array([np.zeros_like(embedding_vectors[:, 0]), np.zeros_like(embedding_vectors[:, 1])]) # origin point
plt.quiver(*origin, embedding_vectors[:, 0], embedding_vectors[:, 1], color=colors, scale_units='xy', angles='xy',
           scale=1)
plt.xlim(-1.5, 1.5)
plt.ylim(-1.5, 1.5)
plt.title('Codebook Vectors')
plt.tight_layout()
plt.show()

In [None]:
from matplotlib.animation import FuncAnimation

# Create a GIF
fig, ax = plt.subplots(figsize=(8, 8))
plt.title('Codebook Vectors Over Training')
colors = plt.cm.viridis(np.linspace(0, 1, len(embedding_vectors)))

def update(frame):
    ax.clear()
    ax.quiver(np.zeros_like(codebook_vectors_list[frame][:, 0]),
              np.zeros_like(codebook_vectors_list[frame][:, 1]),
              codebook_vectors_list[frame][:, 0], codebook_vectors_list[frame][:, 1],
              angles='xy', scale_units='xy', scale=1, color=colors)
    ax.set_xlim(-1.5, 1.5)
    ax.set_ylim(-1.5, 1.5)
    ax.set_title(f'Epoch {frame + 1}')

ani = FuncAnimation(fig, update, frames=len(codebook_vectors_list), repeat=False)
ani.save('codebook_vectors_animation.gif', writer='imagemagick', fps=10)
plt.show()

# Testing different sizes

## Coloured dataset

In [None]:
class ColouredMNIST(Dataset):
    def __init__(self, root, train=True, transform=None, download=True, random_seed=42):
        self.mnist_dataset = datasets.MNIST(root=root, train=train, transform=transform, download=download)
        self.random_seed = random_seed
        self.random_state = torch.manual_seed(self.random_seed)

    def __len__(self):
        return len(self.mnist_dataset)

    def __getitem__(self, idx):
        # Set the random seed for reproducibility
        torch.manual_seed(self.random_seed * idx) # to make idx affect the colour
        original_data, target = self.mnist_dataset[idx]
        random_multipliers = torch.rand(3)
        three_channel_data = torch.cat([
            original_data * random_multipliers[0],
            original_data * random_multipliers[1],
            original_data * random_multipliers[2]],
            dim=0,
        )
        return three_channel_data, target

# Example
transform = transforms.ToTensor()
coloured_dataset = ColouredMNIST(root='./data', train=True, transform=transform, download=True)
batch_size = 256
coloured_loader = DataLoader(coloured_dataset, batch_size=batch_size, shuffle=True, num_workers=os.cpu_count(), pin_memory=False)

# Access an example from the custom dataset
example_data, _ = coloured_dataset[10]
print("Example Data Shape:", example_data.shape)
plt.imshow(example_data.cpu().numpy().transpose(1, 2, 0))

In [None]:
def VQ_VAE_trained_model(data_loader, input_dim, num_embeddings, embedding_dim,
                         device='cpu', num_epochs=30, print_idx=10, print_metric=False):
    encoder = Encoder(input_dim=input_dim, output_dim=embedding_dim)
    codebook = VectorQuantizer(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
    decoder = Decoder(input_dim=embedding_dim, output_dim=input_dim)

    model = VQ_VAE(Encoder=encoder, Codebook=codebook, Decoder=decoder).to(device)
    recon = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=5e-4)
    beta = 0.25

    model.train()
    train_losses = []
    train_perplexities = []

    for epoch in tqdm(range(num_epochs), 
                      desc=f'Training model with {num_embeddings}, {embedding_dim} dimentional embeddings'):
        overall_loss = 0
        overall_perplexity = 0

        for (x, _) in (data_loader):
            x = x.to(device)
            optimizer.zero_grad()

            x_hat, commitment_loss, codebook_loss, perplexity = model(x)
            recon_loss = recon(x_hat, x)

            loss = recon_loss + codebook_loss + beta * commitment_loss

            overall_loss += loss.item()
            overall_perplexity += perplexity.item()

            loss.backward()
            optimizer.step()

        # Calculate mean loss and mean perplexity for the epoch
        mean_loss = overall_loss / len(train_loader)
        mean_perplexity = overall_perplexity / len(train_loader)

        # Print metrics once every two epochs
        if print_metric:
            if (epoch + 1) % print_idx == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Mean Loss: {mean_loss:.4f}, Mean Perplexity: {mean_perplexity:.4f}')

        # Save metrics for plotting
        train_losses.append(mean_loss)
        train_perplexities.append(mean_perplexity)
    
    metrics = {
        'loss': train_losses,
        'perplexity': train_perplexities,
    }
    
    return model, metrics

In [None]:
def sample_model_output(model, data_loader, ax=None):
    import matplotlib.pyplot as plt
    import numpy as np

    model.eval()

    num_examples_to_plot = 5

    test_examples = next(iter(data_loader))[:num_examples_to_plot]
    inputs, _ = test_examples
    inputs = inputs.to(device)

    # Generate outputs
    with torch.no_grad():
        outputs, _, _, _ = model(inputs)

    inputs = inputs.cpu().numpy()
    outputs = outputs.cpu().numpy()

    embedding_vectors = model.codebook.embedding.weight.data.cpu().numpy()

    if ax is None:
        fig, ax = plt.subplots(2, num_examples_to_plot, figsize=(15, 6))

    for example_idx in range(num_examples_to_plot):
        ax[0, example_idx].imshow(inputs[example_idx].transpose(1, 2, 0))
        ax[0, example_idx].set_title(f'Example {example_idx + 1}\nInput')

        ax[1, example_idx].imshow(outputs[example_idx].transpose(1, 2, 0))
        ax[1, example_idx].set_title(f'Output')

    return ax


## Effect of number of embeddings

In [None]:
NumEmbeddings = [3, 5, 10, 20, 30, 40, 50]
modelsNum = [
    VQ_VAE_trained_model(coloured_loader, 3, num, 3, device) for num in NumEmbeddings
]

In [None]:
for idx, (model, metrics) in enumerate(modelsNum):
    print(f'Model with {NumEmbeddings[idx]} embedding vectors')
    final_loss = metrics['loss'][-1]
    final_perplexity = metrics['perplexity'][-1]
    print(f'final loss: {final_loss}')
    print(f'final perplexity: {final_perplexity}')
    sample_model_output(model, coloured_loader)
    plt.tight_layout()
    plt.show()
    
from matplotlib.animation import FuncAnimation

# Create a GIF
fig, ax = plt.subplots(2, 5, figsize=(15, 6))
plt.title('Model Output Over Models')
colors = plt.cm.viridis(np.linspace(0, 1, len(modelsNum)))

def update(frame):
    model, metrics = modelsNum[frame]
    final_loss = metrics['loss'][-1]
    final_perplexity = metrics['perplexity'][-1]
    
    sample_model_output(model, coloured_loader, ax)
    fig.suptitle(f'Number of embeddings: {NumEmbeddings[frame]}')

ani = FuncAnimation(fig, update, frames=len(modelsNum), repeat=False)
ani.save('model_output_num_embedding_animation.gif', writer='imagemagick', fps=2)
plt.show()


In [None]:
DimEmbeddings = [2, 5, 10, 20, 30, 40, 50]
modelsDim = [
    VQ_VAE_trained_model(coloured_loader, 3, 3, dim, device) for dim in DimEmbeddings
]

In [None]:
for idx, (model, metrics) in enumerate(modelsDim):
    print(f'Model with {DimEmbeddings[idx]} embedding dimension')
    final_loss = metrics['loss'][-1]
    final_perplexity = metrics['perplexity'][-1]
    print(f'final loss: {final_loss}')
    print(f'final perplexity: {final_perplexity}')
    sample_model_output(model, coloured_loader)
    plt.tight_layout()
    plt.show()
    
from matplotlib.animation import FuncAnimation

# Create a GIF
fig, ax = plt.subplots(2, 5, figsize=(15, 6))
plt.title('Model Output Over Models')
colors = plt.cm.viridis(np.linspace(0, 1, len(modelsDim)))

def update(frame):
    model, metrics = modelsDim[frame]
    final_loss = metrics['loss'][-1]
    final_perplexity = metrics['perplexity'][-1]
    
    sample_model_output(model, coloured_loader, ax)
    fig.suptitle(f'embedding dimension: {DimEmbeddings[frame]}')

ani = FuncAnimation(fig, update, frames=len(modelsNum), repeat=False)
ani.save('model_output_dim_embedding_animation.gif', writer='imagemagick', fps=2)
plt.show()
