In [None]:
import torch, itertools
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, utils
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Data loader
transform = transforms.Compose([
    transforms.Resize((32,32)),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)

# VQ-VAE Components
class Encoder(nn.Module):
    def __init__(self, input_channels):
        super(Encoder, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(input_channels, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(),   
        )

    def forward(self, x):
        return self.conv(x)

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.conv_trans = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid(),  # Output between 0 and 1
        )

    def forward(self, x):
        return self.conv_trans(x)

class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim):
        super(VectorQuantizer, self).__init__()
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.embedding.weight.data.uniform_(-1, 1)

    def forward(self, x):
        # Flatten input
        flat_x = x.view(-1, 512)  # Assuming the feature size from the encoder is 256
        # Calculate distances between input and embedding vectors
        distances = torch.cdist(flat_x, self.embedding.weight)
        # Find the nearest embeddings
        min_distances = distances.min(1, keepdim=True)[1]
        # Quantize the input
        quantized = self.embedding(min_distances).view_as(x)
        return quantized

# Select a fixed batch of images for visualization (5 images)
fixed_images, _ = next(iter(DataLoader(dataset=train_dataset, batch_size=5, shuffle=True)))
fixed_images = fixed_images.to(device)

# Define the Encoder, Decoder, and VectorQuantizer classes (omitted for brevity)
input_channels = 3
# Model instantiation
encoder = Encoder(input_channels).to(device)
decoder = Decoder().to(device)
vector_quantizer = VectorQuantizer(num_embeddings=100, embedding_dim=512).to(device)

# Optimizer
parameters = list(encoder.parameters()) + list(decoder.parameters()) + list(vector_quantizer.parameters())
optimizer = optim.Adam(parameters, lr=0.001)

# Function to plot images
#def plot_images(images, title):
#    images = images.cpu().detach().numpy()
#    fig, axes = plt.subplots(1, 5, figsize=(15, 3))
#    for idx, ax in enumerate(axes):
#        ax.imshow(images[idx].reshape(32, 32), cmap='gray')
#        ax.axis('off')
#    plt.suptitle(title)
#    plt.show()
def plot_images(images, title):
    images = images.cpu().detach().permute(0, 2, 3, 1).numpy()
    fig, axes = plt.subplots(1, 5, figsize=(15, 3))
    for idx, ax in enumerate(axes):
        ax.imshow(images[idx], interpolation='nearest')
        ax.axis('off')
    plt.suptitle(title)
    plt.show()
# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    for images, _ in train_loader:
        images = images.to(device)
        optimizer.zero_grad()

        encoded = encoder(images)
        quantized = vector_quantizer(encoded)
        decoded = decoder(quantized)

        loss = nn.functional.mse_loss(decoded, images)
        loss.backward()
        optimizer.step()
    
    # After each epoch, visualize the reconstruction of the fixed images
    with torch.no_grad():
        encoded = encoder(fixed_images)
        quantized = vector_quantizer(encoded)
        reconstructed_images = decoder(quantized)
        plot_images(reconstructed_images, f'Reconstructed Images, Epoch {epoch+1}')

print("Training complete")
