# VQ-VAE + Transformer Prior on MNIST

This notebook demonstrates:
- Training a Vector Quantized VAE (VQ-VAE) on MNIST
- Extracting quantized codebook indices as discrete image representations
- Training a causal Transformer on the discrete token sequences (as a prior)
- Autoregressively sampling new images using the Transformer + VQ-VAE decoder

In [1]:
# ----- Imports and setup -----
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from dataclasses import dataclass

# Use GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## VQ-VAE Model
- Discrete bottleneck enabled by a codebook (vector quantizer)
- Shallow model for MNIST, but you can go deeper for complex data

In [2]:
@dataclass
class VAEOutput:
    loss: torch.Tensor
    recon_loss: torch.Tensor
    vq_loss: torch.Tensor

class VectorQuantizer(nn.Module):
    """
    Discretizes latents from encoder, selects nearest codebook vector per spatial location.
    Implements the VQ-VAE "commitment" and "embedding" loss.
    """
    def __init__(self, num_embeddings: int, embedding_dim: int, beta: float = 0.25):
        super().__init__()
        self.K = num_embeddings
        self.D = embedding_dim
        self.beta = beta
        self.embedding = nn.Embedding(self.K, self.D)
        self.embedding.weight.data.uniform_(-1 / self.K, 1 / self.K)
    def forward(self, latents: torch.Tensor):
        # Reorder to (B, H, W, D)
        latents = latents.permute(0,2,3,1).contiguous()
        shape = latents.shape
        # Flatten batch and spatial dims, shape (BHW, D)
        flat_latents = latents.view(-1, self.D)
        # Compute squared distances to codebook
        dist = torch.sum(flat_latents**2,dim=1,keepdim=True) \
               + torch.sum(self.embedding.weight**2,dim=1) \
               - 2*torch.matmul(flat_latents, self.embedding.weight.t())
        # For each code, get index of nearest codebook vector
        encoding_inds = torch.argmin(dist,dim=1,keepdim=True)
        one_hot = torch.zeros(encoding_inds.size(0), self.K, device=latents.device)
        one_hot.scatter_(1, encoding_inds, 1)
        quantized = torch.matmul(one_hot, self.embedding.weight).view(shape)
        # Losses
        commitment_loss = F.mse_loss(quantized.detach(), latents)
        embedding_loss = F.mse_loss(quantized, latents.detach())
        vq_loss = self.beta * commitment_loss + embedding_loss
        # Pass gradients through quantized using 'straight-through estimator'
        quantized = latents + (quantized - latents).detach()
        return quantized.permute(0,3,1,2).contiguous(), vq_loss

class ResidualLayer(nn.Module):
    """Standard residual block as used in VQ-VAE."""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.resblock = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 1),
        )
    def forward(self, x): return x + self.resblock(x)

class VQVAE(nn.Module):
    """VQ-VAE Encoder + Quantizer + Decoder."""
    def __init__(self, in_channels, embedding_dim, num_embeddings, hidden_dims=None, beta=0.25):
        super().__init__()
        modules = []
        if hidden_dims is None: hidden_dims=[64,128]
        # Encoder: image to "latent image"
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, h_dim, 4, 2, 1),
                    nn.LeakyReLU()
                )
            )
            in_channels = h_dim
        for _ in range(2):
            modules.append(ResidualLayer(in_channels, in_channels))
        modules.append(nn.Conv2d(in_channels, embedding_dim, 1))
        self.encoder = nn.Sequential(*modules)
        self.vq = VectorQuantizer(num_embeddings, embedding_dim, beta)
        dec_mod = []
        dec_mod.append(nn.Conv2d(embedding_dim, hidden_dims[-1], 3, 1, 1))
        for _ in range(2):
            dec_mod.append(ResidualLayer(hidden_dims[-1], hidden_dims[-1]))
        dec_mod.append(nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[0], 4, 2, 1))
        dec_mod.append(nn.LeakyReLU())
        dec_mod.append(nn.ConvTranspose2d(hidden_dims[0], 1, 4, 2, 1))
        dec_mod.append(nn.Sigmoid())
        self.decoder = nn.Sequential(*dec_mod)
    def forward(self, x):
        z = self.encoder(x)
        q, vq_loss = self.vq(z)
        out = self.decoder(q)
        return out, x, vq_loss
    def loss_function(self, *args):
        recons, input, vq_loss = args
        recons_loss = F.mse_loss(recons, input)
        loss = recons_loss + vq_loss
        return VAEOutput(loss, recons_loss.detach(), vq_loss.detach())
    def get_codebook_indices(self, x):
        z = self.encoder(x)
        latents = z.permute(0,2,3,1).contiguous()
        flat_latents = latents.view(-1, self.vq.D)
        dist = (
            torch.sum(flat_latents**2, dim=1,keepdim=True)
            + torch.sum(self.vq.embedding.weight**2, dim=1)
            - 2*torch.matmul(flat_latents, self.vq.embedding.weight.t())
        )
        inds = torch.argmin(dist,dim=1)
        N, H, W, D = latents.shape
        return inds.view(N, H, W)

## Get MNIST data (binarized 28x28 images, batches of size 64) 

In [3]:
trainset = torchvision.datasets.MNIST('./data', train=True, transform=torchvision.transforms.ToTensor(), download=True)
loader = DataLoader(trainset, batch_size=64, shuffle=True)

## Train VQ-VAE on MNIST
- Downsamples images to a latent grid
- Optimizes for MSE + VQ loss

In [4]:
vqvae = VQVAE(1, 8, 64, [64,128], beta=0.25).to(device)
vqvae.train()
opt = torch.optim.Adam(vqvae.parameters(), lr=2e-3)

for epoch in range(4):
    for img, _ in loader:
        img = img.to(device)
        out = vqvae(img)
        vo = vqvae.loss_function(*out)
        opt.zero_grad()
        vo.loss.backward()
        opt.step()
    print(f'VQ-VAE Epoch {epoch+1} Loss: {vo.loss.item():.4f}')

## Extract codebook indices for each image
- Each image is now represented as a `(H, W)` grid of discrete integer codes.
- We'll use these as tokens for the Transformer.

In [5]:
vqvae.eval()
all_codes = []
with torch.no_grad():
    for img, _ in loader:
        img = img.to(device)
        code = vqvae.get_codebook_indices(img).cpu()
        all_codes.append(code)
all_codes = torch.cat(all_codes, dim=0) # (N, H, W)
flat_codes = all_codes.view(all_codes.size(0), -1) # (N, seq_len)
seq_len = flat_codes.size(1)
print("Flattened code sequence shape:", flat_codes.shape)

## Minimal causal Transformer prior for discrete codes
- Each code sequence is used as a 1D sequence of tokens.
- Standard causal (autoregressive) mask is used so position t only attends to 0..t.

In [6]:
class TokenTransformer(nn.Module):
    def __init__(self, num_tokens, seq_len, d_model=128, nhead=4, num_layers=4):
        super().__init__()
        self.num_tokens = num_tokens
        self.seq_len = seq_len
        self.token_emb = nn.Embedding(num_tokens, d_model)
        self.pos_emb = nn.Parameter(torch.randn(1, seq_len, d_model))
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model, nhead), num_layers)
        self.ln = nn.LayerNorm(d_model)
        self.fc = nn.Linear(d_model, num_tokens)
    def forward(self, x):  # x: (batch, seq_len)
        emb = self.token_emb(x) + self.pos_emb
        emb = emb.permute(1,0,2)
        mask = nn.Transformer.generate_square_subsequent_mask(self.seq_len).to(x.device)
        h = self.transformer(emb, mask)
        h = h.permute(1,0,2)
        return self.fc(self.ln(h))  # (batch, seq_len, num_tokens)

## Train the Transformer prior
- At each step, predicts the next code index given previous indices.
- Loss is categorical cross-entropy.
- Simple teacher-forcing: input is [0, ..., L-2], target is [1, ..., L-1] (shifted sequence).

In [7]:
num_tokens = 64
transformer = TokenTransformer(num_tokens, seq_len).to(device)
optim_t = torch.optim.Adam(transformer.parameters(), lr=2e-4)
batch_size = 16

print("Training Transformer prior...")
for epoch in range(4):
    perm = torch.randperm(flat_codes.size(0))
    for i in range(0, flat_codes.size(0), batch_size):
        idx = perm[i:i+batch_size]
        batch = flat_codes[idx].to(device)
        inp = batch[:, :-1]
        tgt = batch[:, 1:]
        logits = transformer(inp)
        logits = logits[:, :-1, :]
        loss = F.cross_entropy(logits.reshape(-1, num_tokens), tgt.reshape(-1))
        optim_t.zero_grad()
        loss.backward()
        optim_t.step()
    print(f'Transformer Epoch {epoch+1} Loss: {loss.item():.4f}')

## Sample a new code sequence from the Transformer and decode with VQ-VAE

The Transformer autoregressively predicts code indices one by one. The generated sequence is then mapped to embeddings and decoded into an image.

In [8]:
print("Generating from Transformer prior...")
transformer.eval()
with torch.no_grad():
    # Start with all zeros
    seq = torch.zeros(1, seq_len, dtype=torch.long, device=device)
    for t in range(seq_len-1):
        logits = transformer(seq[:, :t+1])
        next_token = torch.multinomial(F.softmax(logits[0, t], -1), 1)
        seq[0, t+1] = next_token

    H = W = int(np.sqrt(seq_len))
    sampled_grid = seq.view(1, H, W)

# Convert indices back to embeddings and decode
emb = vqvae.vq.embedding(sampled_grid.view(-1)).view(1, H, W, vqvae.vq.D)
emb = emb.permute(0, 3, 1, 2).contiguous()
img_gen = vqvae.decoder(emb).cpu().detach().squeeze().numpy()

plt.title("VQ-VAE + Transformer Prior: Sampled Image")
plt.imshow(img_gen, cmap='gray')
plt.axis('off')
plt.show()