In [None]:
import zipfile
import os

# Path to the ZIP file
zip_path = '/content/vae_data_archive.zip'

# Destination directory to extract contents
extract_to = 'vae_data'

# Unzip the file
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_to)


In [None]:
!pip install wandb lpips 

In [None]:
import os
import torch
import random
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
from torchvision.utils import make_grid
from torchvision import datasets, transforms
from PIL import Image
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import lpips
import wandb
from dataclasses import dataclass, asdict
from torch.amp.grad_scaler import GradScaler


In [None]:
@dataclass
class VAEConfig:
    block_out_channel: tuple[int] = (64, 64, 128, 256)
    input_channel: int = 3
    output_channel: int = 3
    latent_channel: int = 4
    num_res_layers: int = 2
    group_channels: int = 16
    lr: float = 1e-4
    beta: float = 0.7
    epochs: int = 10
    batch: int = 16
    image_size: int = 128
    folder_path: str = "vae_data"
    train_split: float = 0.9


In [None]:
config = VAEConfig()

transform = transforms.Compose([
    transforms.Resize((config.image_size, config.image_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.5] * 3, [0.5] * 3)
])

class ImageFolderDataset(Dataset):
    def __init__(self, folder_path, transform=None):
        self.folder_path = folder_path
        self.image_files = [os.path.join(folder_path, f)
                            for f in os.listdir(folder_path)
                            if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img = Image.open(self.image_files[idx]).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img

dataset = ImageFolderDataset(config.folder_path, transform=transform)

val_size = int(len(dataset) * (1 - config.train_split))
train_size = len(dataset) - val_size
train_ds, val_ds = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_ds, batch_size=config.batch, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=config.batch, shuffle=False)

fixed_val_batch = next(iter(val_loader))
fixed_val_images = fixed_val_batch[:8].to("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
import torch
import torch.nn as nn

class ResnetBlock(nn.Module):
  def __init__(self,
               in_channel: int,
               out_channel: int = None,
               group_channel: int = 32,
               drop: float = 0,
               eps: float = 1e-5,
              #  down: bool = False,
              #  up: bool = False
               ):
    super().__init__()

    out_channel = in_channel if out_channel is None else out_channel
    self.out_channel = out_channel
    self.norm1 = nn.GroupNorm(num_groups= group_channel, num_channels=in_channel, eps = eps)
    self.norm2 = nn.GroupNorm(num_groups= group_channel, num_channels=out_channel, eps = eps)

    self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1)
    self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1)

    self.act = nn.SiLU()
    self.dropout = nn.Dropout(p=drop)

    self.residual_layer = nn.Conv2d(in_channel, out_channel, kernel_size=1) if in_channel != out_channel else nn.Identity()

  def forward(self, x: torch.Tensor):
    # x -> (b,c,h,w)
    h = self.conv1(self.act(self.norm1(x))) # (b,c',h,w)
    h = self.dropout(h)
    h = self.conv2(self.act(self.norm2(h)))

    x = self.residual_layer(x) + h
    return x

class Downsample(nn.Module):
  def __init__(self,
               in_channel: int,
               out_channel: int = None,
               use_conv:bool = False,
               kernel:int = 3,
               stride:int = 2
               ):
    super().__init__()

    out_channel = in_channel if out_channel is None else out_channel
    self.out_channel = out_channel
    self.padding = kernel//2 #smooth padding
    self.use_conv = use_conv

    if use_conv:
      self.down_layer = nn.Conv2d(in_channel, out_channel, kernel_size=kernel, stride=stride, padding=self.padding)
    else:
      self.down_layer = nn.AvgPool2d(kernel_size=stride, stride=stride)

  def forward(self, x: torch.Tensor):
    #x -> (b,c,h,w)
    if not self.use_conv:
      pad = (0, 1, 0, 1)
      x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
    return self.down_layer(x)

class Upsample(nn.Module):
  def __init__(self,
               in_channel: int,
               out_channel: int = None,
               use_conv:bool = False,
               use_conv_tranpose:bool = False,
               interpolate:bool = True
               ):
    super().__init__()

    out_channel = in_channel if out_channel is None else out_channel
    self.out_channel = out_channel
    self.use_conv = use_conv
    self.use_conv_tranpose = use_conv_tranpose
    self.interpolate = interpolate

    self.layer = None
    if use_conv:
      self.layer = nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1)
    elif use_conv_tranpose:
      self.layer = nn.ConvTranspose2d(in_channel, out_channel, kernel_size=4, stride=2, padding=1)

  def forward(self, x: torch.Tensor):
    #x -> (b,c,h,w)

   
    if self.use_conv_tranpose:
      return self.layer(x)

    if self.interpolate:
      x = nn.functional.interpolate(x, scale_factor=2, mode = 'nearest')

    if self.use_conv:
      x = self.layer(x)

    return x

class EncoderBlock(nn.Module):
  def __init__(self,
               num_res_layers:int,
               in_channel:int,
               out_channel:int,
               group_channel:int,
               drop:int = 0,
               eps:float = 1e-5,
               down_layer:bool = True,
               ):
    super().__init__()

    self.down_layer = down_layer
    self.layers = nn.ModuleList([])

    for idx in range(num_res_layers):
      in_channel = in_channel if idx == 0 else out_channel
      self.layers.append(ResnetBlock(in_channel,out_channel,group_channel,drop, eps))

    if self.down_layer:
      self.down_block = Downsample(out_channel, use_conv=True)

  def forward(self, x:torch.Tensor):

    for layer in self.layers:
      x = layer(x)

    if self.down_layer:
      x = self.down_block(x)
    return x

class DecoderBlock(nn.Module):
  def __init__(self,
               num_res_layers:int,
               in_channel:int,
               out_channel:int,
               group_channel:int,
               drop:int = 0,
               eps:float = 1e-5,
               up_layer:bool = True
               ):
    super().__init__()

    self.up_layer = up_layer

    self.layers = nn.ModuleList([])

    for idx in range(num_res_layers):
      in_channel = in_channel if idx == 0 else out_channel
      self.layers.append(ResnetBlock(in_channel,out_channel,group_channel,drop, eps))

    if self.up_layer:
      self.up_block = Upsample(out_channel, use_conv=True, interpolate=True)

  def forward(self, x:torch.Tensor):

    for layer in self.layers:
      x = layer(x)

    if self.up_layer:
      x = self.up_block(x)
    return x


class BottleNeck(nn.Module):
  def __init__(self,
              num_res_layers:int,
              in_channel:int,
              out_channel:int,
              group_channel:int,
              drop:int = 0,
              eps:float = 1e-5):
    super().__init__()

    self.layers = nn.ModuleList([])

    for idx in range(num_res_layers):
      in_channel = in_channel if idx == 0 else out_channel
      self.layers.append(ResnetBlock(in_channel,out_channel,group_channel,drop, eps))

  def forward(self, x:torch.Tensor):

    for layer in self.layers:
      x = layer(x)
    return x



class Encoder(nn.Module):
    def __init__(self,
                  block_out_channels:list[int],
                  input_channel:int = 3,
                  output_channel:int = 4, #latent channels
                  num_res_layers:int = 2,
                  group_channel: int =32,
                  drop: float = 0,
                  eps:float = 1e-5
                  ):
      super().__init__()

      out_channel = block_out_channels[0]
      in_channel = input_channel
      #first layer for projection
      self.first_layer = nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1)

      #encoder part
      self.encoder_blocks = nn.ModuleList([])
      for idx,_ in enumerate(block_out_channels):
        in_channel = out_channel
       
        out_channel = block_out_channels[idx]
        last_block = idx == len(block_out_channels)-1
        self.encoder_blocks.append(EncoderBlock(num_res_layers,in_channel,out_channel,group_channel,drop,eps,not last_block))

      #bottleneck part
      final_channel = block_out_channels[-1]
      self.bottleneck = BottleNeck(num_res_layers=num_res_layers,in_channel=final_channel,out_channel=final_channel,group_channel= group_channel,drop=drop,eps=eps)

      self.latent_out = nn.Sequential(
        nn.GroupNorm(num_groups=32, num_channels=block_out_channels[-1]),
        nn.SiLU(),
        nn.Conv2d(block_out_channels[-1], 2 * output_channel, kernel_size=3, padding=1)
      )

    def forward(self, x):

      x = self.first_layer(x)
      for encoder_block in self.encoder_blocks:
        x = encoder_block(x)

      x = self.bottleneck(x)
      return self.latent_out(x)


class Decoder(nn.Module):
    def __init__(self,
                  block_out_channels:list[int],
                  input_channel:int = 4, #latent channels
                  output_channel:int = 3,
                  num_res_layers:int = 2,
                  group_channel: int = 32,
                  drop:float = 0,
                  eps:float = 1e-5
                  ):
      super().__init__()

      block_out_channels = block_out_channels[::-1]

      out_channel = block_out_channels[0]
      in_channel = input_channel
      #first layer for projection
      self.first_layer = nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1)

      #bottleneck part
      final_channel = block_out_channels[0]
      self.bottleneck = BottleNeck(num_res_layers=num_res_layers,in_channel=final_channel,out_channel=final_channel,group_channel= group_channel,drop=drop,eps=eps)

      #encoder part
      self.decoder_blocks = nn.ModuleList([])
      for idx,_ in enumerate(block_out_channels):
        in_channel = out_channel
        
        out_channel = block_out_channels[idx]
        last_block = idx == len(block_out_channels)-1
        self.decoder_blocks.append(DecoderBlock(num_res_layers,in_channel,out_channel,group_channel,drop,eps,not last_block))

      self.image_out = nn.Sequential(
        nn.GroupNorm(num_groups=32, num_channels=block_out_channels[-1]),
        nn.SiLU(),
        nn.Conv2d(block_out_channels[-1], output_channel, kernel_size=3, padding=1),
        nn.Tanh()
      )

    def forward(self, x):

      x = self.first_layer(x)
      x = self.bottleneck(x)
      for decoder_block in self.decoder_blocks:
        x = decoder_block(x)

      return self.image_out(x)


class VAE(nn.Module):
  def __init__(self, config):
    super().__init__()
    block_out_channel = config.block_out_channel
    input_channel = config.input_channel
    output_channel = config.output_channel
    latent_channel = config.latent_channel
    num_res_layers = config.num_res_layers
    group_channels = config.group_channels

    self.encoder = Encoder(block_out_channels= block_out_channel,
                           input_channel = input_channel,
                           output_channel= latent_channel,
                           num_res_layers= num_res_layers,
                           group_channel= group_channels
                           )
    self.decoder = Decoder(
                          block_out_channels= block_out_channel,
                          input_channel = latent_channel,
                          output_channel= output_channel,
                          num_res_layers= num_res_layers,
                          group_channel= group_channels
                          )

  def encode(self, x: torch.Tensor):
    h = self.encoder(x)
    #considering the covariance as diagonal (independent variables)
    mu, log_var = torch.split(h, h.shape[1]//2, dim = 1) #splitting over the channel_dim
    var = torch.exp(log_var)
    std = torch.exp(0.5 * log_var)
    z = mu + std * torch.randn_like(mu)
    return z, mu, var

  def decode(self, z: torch.Tensor):
    x = self.decoder(z)
    return x

  def forward(self, x: torch.Tensor):

    z, mu, var= self.encode(x)
    x = self.decode(z)
    return x, mu, var

In [None]:
vae = VAE(config).to("cuda")



In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

In [None]:
from torchvision.utils import make_grid
wandb.login(key="api_key_here")  # Replace with your actual WandB API key
run = wandb.init(
    project="vae-landscapes-256-pure-pytorch",
    config=asdict(config)
)

device = "cuda" if torch.cuda.is_available() else "cpu"
# vae = VAE(config).to(device)
scaler = GradScaler(device=device)

optimizer_vae = optim.AdamW(vae.parameters(), lr=config.lr)
recon_loss_fn = nn.L1Loss()
lpips_fn = lpips.LPIPS(net='alex').to(device)

def vae_loss(x, x_const, mu, var):
    recon_loss = recon_loss_fn(x_const, x)
    kl_loss = -0.5 * torch.mean(1 + torch.log(var) - mu**2 - var)
    return recon_loss, kl_loss

global_steps = 0
for epoch in range(config.epochs):
    vae.train()
    for batch in tqdm(train_loader, desc=f"Epoch {epoch}/{config.epochs}"):
        global_steps += 1
        batch = batch.to(device)

        with torch.autocast(device_type=device, dtype=torch.bfloat16):
            x_const, mu, var = vae(batch)
            recon_loss, kl_loss = vae_loss(batch, x_const, mu, var)
            lpips_loss = lpips_fn(x_const, batch).mean()
            total_loss = recon_loss + config.beta * kl_loss + lpips_loss

        wandb.log({
            "loss/total_loss": total_loss.item(),
            "loss/reconstr_loss": recon_loss.item(),
            "loss/kl_loss": kl_loss.item(),
            "lpips_loss": lpips_loss.item()
        }, step=global_steps)

        optimizer_vae.zero_grad()
        scaler.scale(total_loss).backward()
        scaler.step(optimizer_vae)
        scaler.update()

    if (epoch + 1) % 1 == 0:
        vae.eval()
        with torch.no_grad():
           
            x_const, _, _ = vae(fixed_val_images)
            both = torch.cat([fixed_val_images, x_const], dim=0)
            grid_both = make_grid(both.cpu() * 0.5 + 0.5, nrow=8)

            
            z = torch.randn(8, config.latent_channel, 16, 16).to(device)

            sampled = vae.decode(z)
            grid_sampled = make_grid(sampled.cpu() * 0.5 + 0.5, nrow=4)

            wandb.log({
                "original_vs_reconstructed": wandb.Image(grid_both),
                "sampled_from_gaussian": wandb.Image(grid_sampled)
            }, step=global_steps)

        os.makedirs("checkpoints", exist_ok=True)
        torch.save(vae.state_dict(), f"checkpoints/model_dict{epoch+10}.pth")

    if epoch % 1 == 0:
        mem_allocated = torch.cuda.memory_allocated(device) / 1024**2
        mem_reserved = torch.cuda.memory_reserved(device) / 1024**2
        print(f"[Epoch {epoch}] GPU Memory: Allocated = {mem_allocated:.2f} MB, Reserved = {mem_reserved:.2f} MB")
