## VQGAN From Scratch

#### Install Libraries

In [2]:
import torch
import torch.nn as nn 
import torch.nn.functional as F 
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import wandb
# from config import Config
import os
from PIL import Image
from torchvision.utils import save_image
from tqdm import tqdm
import wandb

In [2]:
config = Config()

In [7]:
# import torch
class Config:
    def __init__(self):
        self.in_channels = 3
        self.hidden_channels = 128
        self.latent_dim = 256
        self.out_channels = 3
        self.img_size = 64 
        self.num_embeddings = 1024 #Total Number of Embeddings in the CodeBook
        self.embedding_dim = 256   #Dimensionality of Each CodeBook Vector
        self.beta = 0.25           #Commitment Loss Weight
        self.batch_size = 32
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
        self.epochs = 20
        self.lr = 2e-4
        self.project_name = "VQ_VAE"
        self.save_dir = "./saves"
        self.log_interval = 100
        self.save_interval = 2
        self.use_wandb = True
        

In [8]:
config = Config()

## Network Architecture

### Helper Class for Residual Connection and Attention Mechanism

In [9]:
class ResBlock(nn.Module):
    ''' 
        Residual Block with GroupNormalization.
        This will help in building Deeper Networks without Diminishing Gradient
    '''
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        
        self.block = nn.Sequential(
            nn.GroupNorm(num_groups=32, num_channels=in_channels),
            nn.SiLU(),
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.GroupNorm(num_groups=32, num_channels=out_channels),
            nn.SiLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        )
        
        if in_channels != out_channels:
            self.residual_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        else:
            self.residual_connection = nn.Identity()
    
    def forward(self, x):
        return self.residual_connection(x) + self.block(x)
    
    
class AttentionBlock(nn.Module):
    ''' 
        A Self-Attention Block,
        It will help the model focus on relevant parts of the image
    '''
    def __init__(self, channels):
        super().__init__()
        self.norm = nn.GroupNorm(num_groups=32, num_channels=channels)
        self.qkv = nn.Conv2d(channels, out_channels=channels*3, kernel_size=1) #Create Q,K,V matrices
        self.out_proj = nn.Conv2d(channels, channels, kernel_size=1)
        
    def forward(self, x):
        B, C, H, W = x.shape #Get the Batch_size, Channels, Height, Width
        h = self.norm(x)
        q, k, v = self.qkv(h).chunk(3, dim = 1)
        
        #Reshape the Q,K,V for Attention Computation
        q = q.view(B, C, H * W).permute(0, 2, 1)  #Shape: [B, H*W, C]
        k = k.view(B, C, H * W)                   #Shape: [B, C, H*W]
        v = v.view(B, C, H * W).permute(0, 2, 1)  #Shape: [B, H*W, C]
        
        #Compute Attention Weights
        weight = torch.einsum('bic, bcj->bij', q, k)*(C ** -0.5)
        logits = F.softmax(weight, dim = -1)
        
        #Attend to Values
        h = torch.einsum('bij, bjc->bic', logits, v)
        h = h.permute(0, 2, 1).view(B, C, H, W)  #Reshape back to iamge format
        
        return x + self.out_proj(h)
        

### Encoder of the VAE

In [11]:
class Encoder(nn.Module):
    ''' 
        Main VQ-VAE Encoder class with Residual & Attention Blocks
    '''
    def __init__(self, config):
        super().__init__()
        
        #Initial Convolution to Increase Channel Dimension
        self.conv_in = nn.Conv2d(config.in_channels, config.hidden_channels, kernel_size=3, stride=1, padding=1)
        
        #DownSampling Path
        self.down = nn.Sequential(
            ResBlock(config.hidden_channels, config.hidden_channels),
            nn.Conv2d(config.hidden_channels, config.hidden_channels, kernel_size=4, stride=2, padding=1),  #Shape: [64--->32]
            ResBlock(config.hidden_channels, config.hidden_channels),
            nn.Conv2d(config.hidden_channels, config.hidden_channels, kernel_size=4, stride = 2, padding=1),
        )
        
        #Attention Is All It Needs
        self.mid = nn.Sequential(
            ResBlock(config.hidden_channels, config.hidden_channels),
            AttentionBlock(config.hidden_channels),
            ResBlock(config.hidden_channels, config.hidden_channels)
        )
        
        #Final Projection to the Latent Dimension
        self.conv_out = nn.Sequential(
            nn.GroupNorm(num_groups=32, num_channels=config.hidden_channels),
            nn.SiLU(),
            nn.Conv2d(config.hidden_channels, config.latent_dim, kernel_size=1)
        )
    
    def forward(self, x):
        #Initial Convolution
        x = self.conv_in(x)
        #DownSampling 
        x = self.down(x)
        #Give IT ATTENTION
        x = self.mid(x)
        #Final Projection
        output = self.conv_out(x)
        
        return output
        
    

- Test a Dummy Forward Pass

In [12]:
x = torch.randn(1, 3, 64, 64)
encoder = Encoder(config)
output = encoder(x)
print("="*30, f"Input Shape: [{x.shape}]", "="*30)
print("="*30, f"Output Shape: [{output.shape}]", "="*30)




### Decoder of the VAE

In [13]:
class Decoder(nn.Module):
    ''' 
        VQ-VAE Decoder class with Residual & Attention Blocks
    '''
    def __init__(self, config):
        super().__init__()
        
        #Initial Projection from Latent_Dim to Hidden_Channel
        self.conv_in = nn.Conv2d(config.latent_dim , config.hidden_channels, kernel_size= 3, stride= 1, padding=1)
        
        #Give it ATTENTION
        self.mid = nn.Sequential(
            ResBlock(config.hidden_channels, config.hidden_channels),
            AttentionBlock(config.hidden_channels),
            ResBlock(config.hidden_channels, config.hidden_channels),
        )
        
        #UpSampling 
        self.up = nn.Sequential(
            ResBlock(config.hidden_channels, config.hidden_channels),
            nn.Upsample(scale_factor=2, mode="nearest"), #Shape; [16---->32]
            nn.Conv2d(config.hidden_channels, config.hidden_channels, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            
            ResBlock(config.hidden_channels, config.hidden_channels),
            nn.Upsample(scale_factor=2, mode="nearest"), #Shape; [32---->64]
            nn.Conv2d(config.hidden_channels, config.hidden_channels, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        #Final Projection to Original Image Shape
        self.conv_out = nn.Sequential(
            nn.GroupNorm(num_groups=32, num_channels=config.hidden_channels),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(config.hidden_channels, config.in_channels, kernel_size=3, stride = 1, padding = 1)
        )
        
    def forward(self, x):
        #Initial Projection
        x = self.conv_in(x)
        
        #ATTENTION IS ALL IT WANTs
        x = self.mid(x)
        
        #UPSAMPLE THE Input
        x = self.up(x)
        
        #Final Projection
        out = self.conv_out(x)
        
        return out

- Test a Dummy Forward Pass 

In [14]:
dummy_input = torch.randn(4, config.latent_dim, 16, 16)
decoder = Decoder(config)

output = decoder(dummy_input)
print("="*30, f"Input Shape: [{dummy_input.shape}]", "="*30)
print("="*30, f"Output Shape: [{output.shape}]", "="*30)



### Vector Quantization

- Instead of passing continuous latent vectors directly to the decoder, we map each latent vector to the nearest codebook embedding.

**This Helps in:-**

- Compressing Information
- Introducing Discretization(which GANs Like a LOT)
- Encouraging Disentangled Represations{When a model learns to separate the high-level, abstract concepts in data into distinct and independent factors within its latent space}

In [15]:
class VectorQuantizer(nn.Module):
    ''' 
        This will take latent_inputs from the Encoder,
        Flatten them to [B*H*W, C]
        Find Nearest CodeBook vector for each spatial location,
        Replaces each vector with the closes embedding
    '''
    
    def __init__(self, config):
        super().__init__()
        
        self.embedding_dim = config.embedding_dim 
        self.num_embeddings = config.num_embeddings
        self.beta = config.beta
        
        #CodeBook Embeddings
        self.embedding = nn.Embedding(config.num_embeddings, config.embedding_dim)
        #Initialize the CodeBook vectors with UNIFORM DISTRIBUTION
        self.embedding.weight.data.uniform_(-1 / config.num_embeddings, 1 / config.num_embeddings)
        
    def forward(self, z):
        B, C, H, W = z.shape
        
        #Reshape the Input Latent Vector
        z_orig = z
        z_premuted = z.permute(0, 2, 3, 1).contiguous()  #Shape: [B, C, H, W]----->[B, H, W, C]
        
        #Flatten the vector
        z_flat = z_premuted.view(-1, self.embedding_dim)  #Shape: [B, H, W, C]----->[B*H*W, C]
        
        #Compute the Distances to Embeddings: [B*H*W, C]  v/s [1, num_embeddings, C]
        #It is computing Squared Euclidean Distance {||a - b||^2 = ||a||^2 + ||b||^2 - 2a@b}
        dist = (torch.sum(z_flat**2, dim = 1, keepdim=True) + torch.sum(self.embedding.weight ** 2, dim = 1) - 2 * torch.matmul(z_flat, self.embedding.weight.t()))
        
        #Get the Nearest Embeddings
        encoding_indices = torch.argmin(dist, dim = 1).unsqueeze(1) #Finds the index of the closest codebook vector for each of the input vectors
        encodings = torch.zeros(encoding_indices.shape[0], self.num_embeddings, device = z.device)
        #Get the ONE-HOT ENCODDDINGSSSS
        encodings.scatter_(1, encoding_indices, 1) 
        
        #Quantized = codebook[encoding_index] {It means fetch the Quantized for our latent_vector from the CodeBook}
        quantized = torch.matmul(encodings, self.embedding.weight)
        quantized = quantized.view(B, H, W, C).permute(0, 3, 1, 2).contiguous()
        
        #Compute the Loss
        # print("Quantized Shape:", quantized.detach().shape)
        # print("z.permute Shape:", z_orig.shape)
        commitment_loss = F.mse_loss(z_orig.detach(), quantized )
        embedding_loss = F.mse_loss(z_orig, quantized.detach())
        loss = commitment_loss*self.beta + embedding_loss
        
        #Straight-through estimator
        #argmin() operation has no gradient, which means the gradient from DECODER would stop here and never reach the ENCODER
        quantized_for_decoder = z_orig + (quantized.clone() - z_orig).detach()
        
        return quantized_for_decoder, loss, encoding_indices.view(B, H, W)
        

- Test a Dummy Forward Pass 

In [16]:
vq = VectorQuantizer(config)
z_e =  torch.randn(4, config.embedding_dim, 16, 16)
z_q, vq_loss, indices = vq(z_e)
z_q.shape
vq_loss

tensor(1.2584, grad_fn=<AddBackward0>)

### VQVAE Architecture

In [17]:
class VQVAE(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.encoder = Encoder(config)
        self.decoder = Decoder(config)
        self.quantizer = VectorQuantizer(config)
        self.loss_fn = nn.MSELoss()
        
    def forward(self, x):
        z_e = self.encoder(x)
        quantized , vq_loss, _ = self.quantizer(z_e)
        x_reconstructed = self.decoder(quantized)
        
        reconstruction_loss = self.loss_fn(x_reconstructed, x)
        total_loss = reconstruction_loss + vq_loss
        
        return {
            "x_reconstruction": x_reconstructed,
            "z_e": z_e,
            "quantized": quantized,
            "vq_loss":vq_loss,
            "reconstruction_loss":reconstruction_loss,
            "total_loss":total_loss
        }

- Test a Dummy Forward Pass

In [19]:
model = VQVAE(config)
x = torch.randn(2, 3, config.img_size, config.img_size)
output = model(x)

print(f"Reconstructed Image Shape: {output['x_reconstruction'].shape}")
print(f"Total Loss: {output['total_loss'].item():.4f}")

Reconstructed Image Shape: torch.Size([2, 3, 64, 64])
Total Loss: 1.3732


## DataLoading

In [20]:
from torchvision.datasets import CIFAR10

transform = T.Compose([
    T.Resize((config.img_size, config.img_size)),  #Resize to 64 x 64 for VQGAN
    T.ToTensor(),
    T.Normalize([0.5] * 3, [0.5] * 3)
])

dataset = CIFAR10(root="./data", train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True,num_workers=1)


100%|██████████| 170M/170M [00:03<00:00, 48.5MB/s] 


### Training Loop for Basic VQVAE

In [23]:
wandb.login(key = "674eaa4b76549d6faa73e5c97b2106b8bbb2d8e4")
# config = Config()

[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33matharv3105[0m ([33matharv3105-dr-a-p-j-abdul-kalam-technical-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [24]:
config.epochs

20

In [27]:
import torch.optim as optim
def train_vq(model,dataloader ,config):
    
    device = config.device
    model.to(device)
    model.train()
    
    optimizer = optim.Adam(model.parameters(), lr= config.lr)
    
    #Initialize WandB
    if config.use_wandb:
        wandb.init(project = config.project_name, config = config.__dict__)
    
    os.makedirs(config.save_dir, exist_ok=True)
    
    for epoch in range(config.epochs):
        epoch_loss = 0.0
        for batch_idx, (images, _) in enumerate(tqdm(dataloader)):
            images = images.to(device)
            
            outputs = model(images)
            loss = outputs["total_loss"]
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            
            #Log Metrics Every N steps
            if batch_idx % config.log_interval == 0:
                wandb.log({
                    "Total-Loss": outputs["total_loss"].item(),
                    "Reconstruction-Loss":outputs["reconstruction_loss"].item(),
                    "VQ_Loss": outputs["vq_loss"].item(),
                    "Epoch":epoch
                })
        
        #Save Reconstructions      
        if epoch % config.save_interval == 0:
            with torch.no_grad():
                recon = outputs["x_reconstruction"]
                #Save Images
                save_image(images[:8], os.path.join(config.save_dir, f"input_{epoch}.png"), nrow = 4, normalize = True)
                save_image(recon[:8], os.path.join(config.save_dir, f"recon_{epoch}.png"), nrow = 4, normalize = True)

                if config.use_wandb:
                    
                    wandb.log({
                         "Input": [wandb.Image(images[0])], 
                         "Reconstruction": [wandb.Image(recon[0])]
                        })
            checkpoint = {
                "model_state_dict":model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "epoch":epoch,
                "loss": epoch_loss / len(dataloader)
            }
            ckpt_path = os.path.join(config.save_dir, f"VQVAE_epoch_{epoch}.pth")
            torch.save(checkpoint, ckpt_path)
            print(f"Saved to Path: {ckpt_path}")
            
                
        print(f"Epoch[{epoch + 1}/ {config.epochs}]|| Loss: {epoch_loss / len(dataloader) :.4f}")
    
    wandb.finish()
    
    
        

In [70]:
# checkpoint = {
#     "model_state_dict": model.state_dict(),
# }

# torch.save(checkpoint, "./saves/checkpoint.pth")

In [65]:
# config.use_wandb = True

In [28]:
model = VQVAE(config)
train_vq(model, dataloader, config)

0,1
Epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Reconstruction-Loss,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Total-Loss,█▆▃▂▂▂▂▂▁▂▁▁▁▁▁▁
VQ_Loss,▄█▄▃▃▃▂▂▂▂▂▂▁▁▁▁

0,1
Epoch,0.0
Reconstruction-Loss,0.01707
Total-Loss,0.03456
VQ_Loss,0.01749


100%|██████████| 1563/1563 [04:41<00:00,  5.55it/s]


Saved to Path: ./saves/VQVAE_epoch_0.pth
Epoch[1/ 20]|| Loss: 0.1333


100%|██████████| 1563/1563 [04:38<00:00,  5.61it/s]


Epoch[2/ 20]|| Loss: 0.0287


100%|██████████| 1563/1563 [04:40<00:00,  5.57it/s]


Saved to Path: ./saves/VQVAE_epoch_2.pth
Epoch[3/ 20]|| Loss: 0.0180


100%|██████████| 1563/1563 [04:44<00:00,  5.49it/s]


Epoch[4/ 20]|| Loss: 0.0148


100%|██████████| 1563/1563 [04:44<00:00,  5.49it/s]


Saved to Path: ./saves/VQVAE_epoch_4.pth
Epoch[5/ 20]|| Loss: 0.0129


100%|██████████| 1563/1563 [04:44<00:00,  5.50it/s]


Epoch[6/ 20]|| Loss: 0.0116


100%|██████████| 1563/1563 [04:44<00:00,  5.49it/s]


Saved to Path: ./saves/VQVAE_epoch_6.pth
Epoch[7/ 20]|| Loss: 0.0107


100%|██████████| 1563/1563 [04:44<00:00,  5.50it/s]


Epoch[8/ 20]|| Loss: 0.0101


100%|██████████| 1563/1563 [04:44<00:00,  5.49it/s]


Saved to Path: ./saves/VQVAE_epoch_8.pth
Epoch[9/ 20]|| Loss: 0.0095


100%|██████████| 1563/1563 [04:44<00:00,  5.50it/s]


Epoch[10/ 20]|| Loss: 0.0089


100%|██████████| 1563/1563 [04:44<00:00,  5.50it/s]


Saved to Path: ./saves/VQVAE_epoch_10.pth
Epoch[11/ 20]|| Loss: 0.0082


100%|██████████| 1563/1563 [04:44<00:00,  5.50it/s]


Epoch[12/ 20]|| Loss: 0.0080


100%|██████████| 1563/1563 [04:44<00:00,  5.50it/s]


Saved to Path: ./saves/VQVAE_epoch_12.pth
Epoch[13/ 20]|| Loss: 0.0075


100%|██████████| 1563/1563 [04:43<00:00,  5.51it/s]


Epoch[14/ 20]|| Loss: 0.0073


100%|██████████| 1563/1563 [04:43<00:00,  5.52it/s]


Saved to Path: ./saves/VQVAE_epoch_14.pth
Epoch[15/ 20]|| Loss: 0.0072


100%|██████████| 1563/1563 [04:42<00:00,  5.53it/s]


Epoch[16/ 20]|| Loss: 0.0069


100%|██████████| 1563/1563 [04:43<00:00,  5.51it/s]


Saved to Path: ./saves/VQVAE_epoch_16.pth
Epoch[17/ 20]|| Loss: 0.0067


100%|██████████| 1563/1563 [04:43<00:00,  5.50it/s]


Epoch[18/ 20]|| Loss: 0.0066


100%|██████████| 1563/1563 [04:43<00:00,  5.51it/s]


Saved to Path: ./saves/VQVAE_epoch_18.pth
Epoch[19/ 20]|| Loss: 0.0064


100%|██████████| 1563/1563 [04:43<00:00,  5.50it/s]

Epoch[20/ 20]|| Loss: 0.0061





0,1
Epoch,▁▁▁▁▂▂▂▂▂▃▃▃▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇▇█████
Reconstruction-Loss,█▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Total-Loss,█▅▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
VQ_Loss,█▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
Epoch,19.0
Reconstruction-Loss,0.00301
Total-Loss,0.00574
VQ_Loss,0.00273
