In [6]:
!curl -LsSf https://astral.sh/uv/install.sh | sh
!uv sync
!source .venv/bin/activate

%cd /notebooks/vae-comparison
!git pull
%pip install einops matplotlib datasets
# %pip install -r requirements.txt


downloading uv 0.5.7 x86_64-unknown-linux-gnu
no checksums to verify
installing to /root/.local/bin
  uv
  uvx
everything's installed!
[2mResolved [1m92 packages[0m [2min 1ms[0m[0m
[2mAudited [1m89 packages[0m [2min 0.14ms[0m[0m
/notebooks/vae-comparison


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


Already up to date.
Collecting einops
  Downloading einops-0.8.0-py3-none-any.whl.metadata (12 kB)
Downloading einops-0.8.0-py3-none-any.whl (43 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.2/43.2 kB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.8.0
[0mNote: you may need to restart the kernel to use updated packages.


In [7]:
import numpy as np
import torch
from torch.utils.data import Dataset
from datasets import load_dataset

ds = load_dataset("uoft-cs/cifar10")


class CIFAR10Dataset(Dataset):
    def __init__(self, type: str = "train"):
        self.ds = ds[type]  # type: ignore

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        item = self.ds[idx]
        image = np.array(item["img"], dtype=np.float32)
        image = image / 255.0
        image = torch.tensor(image)
        label = item["label"]
        return (image, label)


In [8]:
import numpy as np
from torch import nn
import torch


class VAE(nn.Module):
    def __init__(self, input_shape, in_channels, latent_dim=16):
        super(VAE, self).__init__()

        size = 32
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, size, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(size, size * 2, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(size * 2, size * 4, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(size * 4, size * 8, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten(),
        )

        conv_out_size = self._get_conv_out_size(input_shape)

        self.mu_layer = nn.Linear(conv_out_size, latent_dim)
        self.logvar_layer = nn.Linear(conv_out_size, latent_dim)

        self.predecode = nn.Linear(latent_dim, conv_out_size)

        self.decoder = nn.Sequential(
            nn.Unflatten(1, (size * 8, input_shape[1] // 16, input_shape[2] // 16)),
            nn.ConvTranspose2d(size * 8, size * 4, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(size * 4, size * 2, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(size * 2, size, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(size, in_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh(),
        )

        self.init_weights()

    def init_weights(self):
        def init_func(m):
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                nn.init.kaiming_normal_(
                    m.weight, mode="fan_out", nonlinearity="relu"
                )
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

        self.apply(init_func)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)

        return eps.mul(std).add_(mu)

    def encode(self, x):
        h1 = self.encoder(x)
        mu, logvar = self.mu_layer(h1), self.logvar_layer(h1)

        return self.reparameterize(mu, logvar)

    def decode(self, z):
        h3 = self.predecode(z)
        return self.decoder(h3)

    def forward(self, x):
        z = self.encode(x)
        y = self.decode(z)

        return y

    def _get_conv_out_size(self, shape):
        out = self.encoder(torch.zeros(1, *shape))
        self.conv_out_shape = out.size()
        return int(np.prod(self.conv_out_shape))


In [10]:
from torch.utils.data import DataLoader
from einops import rearrange
from tqdm import tqdm
import os
from time import time

fingerprint = f"{round(time())}"
outdir = f"checkpoints/run-{fingerprint}"
os.makedirs(outdir, exist_ok=True)

print(f"Saving checkpoints to {outdir}")


ENABLE_MPS = False
device = torch.device(
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available() and ENABLE_MPS
    else "cpu"
)

vae = VAE(in_channels=3, input_shape=(3, 32, 32))
vae.to(device)
vae.train()

optimizer = torch.optim.AdamW(vae.parameters(), lr=1e-3)
β = 1


dataloader = DataLoader(CIFAR10Dataset("train"), shuffle=True, batch_size=32)

for epoch in range(10):
    total_loss = 0
    with tqdm(dataloader, desc=f"Epoch {epoch}") as pbar:
        for x, y in pbar:
            x = x.to(device)
            x = rearrange(x, "b w h c -> b c w h")

            y = y.to(device).float()
            recon = vae(x)

            # recon_loss = torch.nn.functional.mse_loss(recon, x)
            recon_loss = torch.nn.functional.binary_cross_entropy(
                recon, x, reduction="sum"
            )
            kl_loss = -0.5 * torch.sum(
                1
                + vae.logvar_layer.weight
                - vae.mu_layer.weight.pow(2)
                - vae.logvar_layer.weight.exp()
            )
            kl_loss /= x.size(0)  # Normalize by batch size

            loss = recon_loss + kl_loss * β
            total_loss += loss.item()

            vae.zero_grad()
            loss.backward()
            optimizer.step()

            pbar.set_postfix(
                {
                    "Loss": loss.item(),
                    "Recon Loss": recon_loss.item(),
                    "KL Loss": kl_loss.item(),
                }
            )

        loss = total_loss / len(dataloader)
        print(f"Epoch: {epoch}, Loss: {loss}")

        torch.save(vae.state_dict(), f"{outdir}/vae_epoch_{epoch}.pth")


Saving checkpoints to checkpoints/run-1733777568


RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
# Load the trained VAE model for inference
vae.eval()

# Prepare a DataLoader for visualization
visualization_dataloader = DataLoader(
    CIFAR10Dataset("train"), shuffle=True, batch_size=1
)

import matplotlib.pyplot as plt


# Function to visualize original and reconstructed images
def visualize_images(originals, reconstructeds):
    fig, axes = plt.subplots(2, 5, figsize=(15, 6))
    for idx, (original, reconstructed) in enumerate(zip(originals, reconstructeds)):
        axes[0, idx].imshow(original)
        axes[0, idx].axis("off")
        axes[1, idx].imshow(reconstructed)
        axes[1, idx].axis("off")

    plt.show()


# Collect the first 5 images and their reconstructions
original_images = []
reconstructed_images = []
for i, (image_tensor, _) in enumerate(visualization_dataloader):
    if i >= 5:
        break
    image = image_tensor.squeeze(0).cpu().numpy()
    original_images.append(image)

    # Encode and decode the image
    image_tensor = image_tensor.to(device)
    image_tensor = rearrange(image_tensor, "1 w h c -> 1 c w h")
    with torch.no_grad():
        reconstructed_tensor = vae(image_tensor)
        reconstructed_image = (
            rearrange(reconstructed_tensor, "1 c w h -> w h c").cpu().numpy()
        )
    reconstructed_images.append(reconstructed_image)

# Visualize the original and reconstructed images
visualize_images(original_images, reconstructed_images)
