Inspired by:

https://github.com/google-deepmind/sonnet/blob/v2/sonnet/src/nets/vqvae.py

Config values from paper plus:

https://github.com/google-deepmind/sonnet/blob/v2/examples/vqvae_example.ipynb

TODO:
- [ ] sample / interpolate latent space
- [X] try training without weight decay
- [ ] VQ-VAE 2 hierarchical encoding / decoding
- [X] EMA VQ Loss
- [ ] try to identify the optimal learning rate (is this more complicated due to training dynamics?)
- [X] try varifying that the codebook is initialized correctly
- [X] try to understand and log codebook usage and other important model dinamics
- [ ] try training with lr warmup
- [ ] try training with 1 cycle + warmup  

In [1]:
!pip install -q wandb pytorch_lightning

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m11.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m777.7/777.7 kB[0m [31m16.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m196.4/196.4 kB[0m [31m14.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m254.1/254.1 kB[0m [31m18.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m840.2/840.2 kB[0m [31m31.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.7/62.7 kB[0m [31m6.8 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import argparse
import math
import torch
import torch.utils.data
import torchvision
import wandb
import pytorch_lightning as pl
import torch.nn.init as init

from __future__ import print_function
from collections import OrderedDict
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, EarlyStopping
from pytorch_lightning.tuner import Tuner
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR


pl.seed_everything(42)

INFO:lightning_fabric.utilities.seed:Seed set to 42


42

# Data

In [3]:
class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

    def prepare_data(self):
        torchvision.datasets.CIFAR10(root='./data', train=True, download=True)
        torchvision.datasets.CIFAR10(root='./data', train=False, download=True)

    def setup(self, stage=None):
        if stage in ('fit', None):
            self.cifar10_train = torchvision.datasets.CIFAR10(
                root='./data', train=True, transform=self.transform)
            self.cifar10_val = torchvision.datasets.CIFAR10(
                root='./data', train=False, transform=self.transform)

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.cifar10_train,
            batch_size=self.config.batch_size,
            shuffle=True,
            num_workers=self.config.num_workers, # TODO: check if we need more workers
            pin_memory=True
        )

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.cifar10_val,
            batch_size=self.config.batch_size,
            shuffle=False,
            num_workers=self.config.num_workers, # TODO: check if we need more workers
            pin_memory=True
        )

# Model

In [None]:
class ResBlock(nn.Module):
    def __init__(
            self,
            in_channels,
            out_channels
        ):
        '''
        "[...] followed by two residual 3 × 3 blocks (implemented as ReLU, 3x3 conv,
        ReLU, 1x1 conv), all having 256 hidden units."
        '''
        super().__init__()
        if in_channels != out_channels:
            self.identity = nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=1
            )
        else:
            self.identity = nn.Identity()

        self.block = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=in_channels,
                kernel_size=3,
                padding=1
            ),
            nn.ReLU(),
            nn.BatchNorm2d(in_channels),
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=1
            ),
            nn.BatchNorm2d(out_channels)
        )
        self.res_act = nn.ReLU()

    def forward(self, x):
        out = self.block(x) + self.identity(x)
        return self.res_act(out)

In [None]:
class Encoder(nn.Module):
    def __init__(
            self,
            in_channels,
            hidden_channels,
            out_channels,
            nlayers,
            nblocks
        ):
        '''
        "For 256 × 256 images, we use a two level latent hierarchy. [...] the encoder
        network first transforms and downsamples the image by a factor of 4 to a 64 × 64 representation
        which is quantized to our bottom level latent map. Another stack of residual blocks then further
        scales down the representations by a factor of two, yielding a top-level 32 × 32 latent map after
        quantization"
        '''
        super().__init__()
        self.strided_blocks = nn.Sequential(*[
            nn.Sequential(
                nn.Conv2d(
                    in_channels=in_channels if i == 0 else hidden_channels,
                    out_channels=hidden_channels,
                    kernel_size=4,
                    stride=2,
                    padding=1
                ),
                nn.BatchNorm2d(hidden_channels),
                nn.ReLU()
            ) for i in range(nlayers)
        ])

        self.res_blocks = nn.Sequential(*[
            ResBlock(
                in_channels=hidden_channels,
                out_channels=hidden_channels if i < nblocks-1 else out_channels
            ) for i in range(nblocks)
        ])

    def forward(self, x):
        x = self.strided_blocks(x)
        x = self.res_blocks(x)
        return x

In [None]:
class Quantizer(nn.Module):
    def __init__(
            self,
            codebook_size,
            latent_channels,
            commit_loss_beta,
            track_codebook
        ):
        super().__init__()
        self.codebook_size = codebook_size
        self.latent_channels = latent_channels
        self.commit_loss_beta = commit_loss_beta
        self.track_codebook = track_codebook

        self.codebook = nn.Embedding(codebook_size, latent_channels)
        init.uniform_(self.codebook.weight, -1/codebook_size, 1/codebook_size)

        if track_codebook:
            self.register_buffer('codebook_usage', torch.zeros(codebook_size, dtype=torch.float))
            self.register_buffer('total_usage', torch.tensor(0, dtype=torch.float))

    def reset_usage_stats(self):
        self.codebook_usage.zero_()
        self.total_usage.zero_()

    def calculate_perplexity(self, enc_idxs):
        unique_indices, counts = torch.unique(enc_idxs, return_counts=True)
        self.codebook_usage.index_add_(0, unique_indices, counts.float())
        self.total_usage += torch.sum(counts)

        if self.total_usage > 0:
            probs = self.codebook_usage / self.total_usage
            perplexity = torch.exp(-torch.sum(probs * torch.log(probs + 1e-10)))
            return perplexity
        else:
            return torch.tensor([0.0])

    def forward(self, inputs):
        inputs = inputs.permute(0, 2, 3, 1).contiguous()
        flat_inputs = inputs.reshape(-1, self.latent_channels)

        # Σ(x-y)^2 = Σx^2 - 2xy + Σy^2
        dists = (
            torch.sum(flat_inputs ** 2, dim=1, keepdim=True) - # Σx^2
            2 * (flat_inputs @ self.codebook.weight.t()) +     # 2*xy
            torch.sum(self.codebook.weight ** 2, dim=1)        # Σy^2
        )

        code_idxs = torch.argmin(dists, dim=1)
        quantized_inputs = self.codebook(code_idxs).reshape(inputs.shape)

        # "The VQ objective uses the l2 error to move the embedding vectors
        # e_i towards the encoder outputs z_e(x)"
        embedding_loss = F.mse_loss(quantized_inputs, inputs.detach())

        # "since the volume of the embedding space is dimensionless, it can grow
        # arbitrarily if the embeddings e_i do not train as fast as the encoder
        # parameters. To make sure the encoder commits to an embedding and its
        # output does not grow, we add a commitment loss"
        commitment_loss = F.mse_loss(quantized_inputs.detach(), inputs)

        # parts 2 & 3 of full loss (ie. not including reconstruciton loss)
        vq_loss = commitment_loss * self.config.embed_loss_beta + embedding_loss

        # sets the output to be the input plus the residual value between the
        # quantized latents and the inputs like a resnet for Straight Through
        # Estimation (STE)
        quantized_inputs = inputs + (quantized_inputs - inputs).detach()
        quantized_inputs = quantized_inputs.permute(0, 3, 1, 2).contiguous()

        if self.track_codebook:
            perplexity = self.calculate_perplexity(code_idxs)

        return {
            'quantized_inputs': quantized_inputs,
            'vq_loss':          vq_loss,
            'embedding_loss':   embedding_loss,
            'commitment_loss':  commitment_loss,
            'perplexity':       perplexity if self.track_codebook else torch.tensor([0.0])
        }

In [None]:
class QuantizerEMA(nn.Module):
    def __init__(
            self,
            codebook_size,
            latent_channels,
            ema_gamma,
            commit_loss_beta,
            track_codebook
        ):
        super().__init__()
        self.codebook_size = codebook_size
        self.latent_channels = latent_channels
        self.ema_gamma = ema_gamma
        self.commit_loss_beta = commit_loss_beta
        self.track_codebook = track_codebook

        self.codebook = nn.Embedding(codebook_size, latent_channels)
        init.uniform_(self.codebook.weight, -1/codebook_size, 1/codebook_size)

        self.register_buffer('N', torch.zeros(codebook_size) + 1e-6)
        self.register_buffer('m', torch.zeros(codebook_size, latent_channels))
        init.uniform_(self.m, -1/codebook_size, 1/codebook_size)

        if track_codebook:
            self.register_buffer('codebook_usage', torch.zeros(codebook_size, dtype=torch.float))
            self.register_buffer('total_usage', torch.tensor(0, dtype=torch.float))

    def ema_update(self, code_idxs, flat_inputs):
        # we don't want to track grads for ops in EMA calculation
        code_idxs, flat_inputs = code_idxs.detach(), flat_inputs.detach()

        # number of vectors for each i which quantize to e_i
        n = torch.bincount(code_idxs, minlength=self.codebook_size)

        # sum of vectors for each i which quantize to code e_i
        one_hot_indices = F.one_hot(code_idxs, num_classes=self.codebook_size).type(flat_inputs.dtype)
        embed_sums = one_hot_indices.T @ flat_inputs

        # update EMA of code usage and sum of codes
        self.N = self.N * self.ema_gamma + n * (1 - self.ema_gamma)
        self.m = self.m * self.ema_gamma + embed_sums * (1 - self.ema_gamma)

        self.codebook.weight.data.copy_(self.m / self.N.unsqueeze(-1))

    def reset_usage_stats(self):
        self.codebook_usage.zero_()
        self.total_usage.zero_()

    def calculate_perplexity(self, enc_idxs):
        unique_indices, counts = torch.unique(enc_idxs, return_counts=True)
        self.codebook_usage.index_add_(0, unique_indices, counts.float())
        self.total_usage += torch.sum(counts)

        if self.total_usage > 0:
            probs = self.codebook_usage / self.total_usage
            perplexity = torch.exp(-torch.sum(probs * torch.log(probs + 1e-10)))
            return perplexity
        else:
            return torch.tensor([0.0])

    def forward(self, inputs):
        inputs = inputs.permute(0, 2, 3, 1).contiguous()
        flat_inputs = inputs.reshape(-1, self.latent_channels)

        # Σ(x-y)^2 = Σx^2 - 2xy + Σy^2
        dists = (
            torch.sum(flat_inputs ** 2, dim=1, keepdim=True) - # Σx^2
            2 * (flat_inputs @ self.codebook.weight.t()) +     # 2*xy
            torch.sum(self.codebook.weight ** 2, dim=1)        # Σy^2
        )

        code_idxs = torch.argmin(dists, dim=1)
        quantized_inputs = self.codebook(code_idxs).reshape(inputs.shape)

        if self.training:
            # perform exponential moving average update for codebook
            self.ema_update(code_idxs, flat_inputs)

        # "since the volume of the embedding space is dimensionless, it can grow
        # arbitrarily if the embeddings e_i do not train as fast as the encoder
        # parameters. To make sure the encoder commits to an embedding and its
        # output does not grow, we add a commitment loss"
        commitment_loss = F.mse_loss(quantized_inputs.detach(), inputs)

        # parts 2 & 3 of full loss (ie. not including reconstruciton loss)
        vq_loss = commitment_loss * self.commit_loss_beta

        # sets the output to be the input plus the residual value between the
        # quantized latents and the inputs like a resnet for Straight Through
        # Estimation (STE)
        quantized_inputs = inputs + (quantized_inputs - inputs).detach()
        quantized_inputs = quantized_inputs.permute(0, 3, 1, 2).contiguous()

        if self.track_codebook:
            perplexity = self.calculate_perplexity(code_idxs)

        return {
            'quantized_inputs': quantized_inputs,
            'vq_loss':          vq_loss,
            'commitment_loss':  commitment_loss,
            'perplexity':       perplexity if self.track_codebook else torch.tensor([0.0])
        }

In [None]:
class Decoder(nn.Module):
    def __init__(
            self,
            in_channels,
            hidden_channels,
            out_channels,
            nblocks,
            nlayers
        ):
        '''
        "The decoder similarly has two residual 3 × 3 blocks, followed by
        two transposed convolutions with stride 2 and window size 4 × 4"
        '''
        super().__init__()
        self.res_blocks = nn.Sequential(*[
            ResBlock(
                in_channels=in_channels if i==0 else hidden_channels,
                out_channels=hidden_channels
            ) for i in range(nblocks)
        ])

        self.transposed_blocks = nn.Sequential(*[
            nn.Sequential(
                nn.ConvTranspose2d(
                    in_channels=hidden_channels,
                    out_channels=hidden_channels,
                    kernel_size=4,
                    stride=2,
                    padding=1
                ),
                nn.BatchNorm2d(hidden_channels),
                nn.ReLU()
            ) for _ in range(nlayers-1)
        ])

        self.out_layer = nn.ConvTranspose2d(
            in_channels=hidden_channels,
            out_channels=out_channels,
            kernel_size=4,
            stride=2,
            padding=1
        )

    def forward(self, z_q):
        out = self.res_blocks(z_q)
        out = self.transposed_blocks(out)
        out = self.out_layer(out)
        return out

In [None]:
class VQVAE2(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.bottom_encoder = Encoder(
            in_channels     = config.in_channels,
            hidden_channels = config.hidden_channels,
            out_channels    = config.latent_channels, # bottom codebook latent_channels
            nlayers         = config.nlayers,
            nblocks         = config.nblocks
        )
        self.top_encoder = Encoder(
            in_channels     = config.latent_channels,
            hidden_channels = config.hidden_channels,
            out_channels    = config.latent_channels, # top codebook latent_channels
            nlayers         = config.nlayers,
            nblocks         = config.nblocks
        )
        
        self.top_decoder = Decoder(
            in_channels     = config.latent_channels, # top codebook latent_channels
            hidden_channels = config.hidden_channels,
            out_channels    = config.latent_channels,
            nlayers         = config.nlayers,
            nblocks         = config.nblocks
        )
        self.bottom_decoder = Decoder(
            in_channels     = config.latent_channels * 2, # top codebook latent_channels
            hidden_channels = config.hidden_channels,
            out_channels    = config.in_channels,
            nlayers         = config.nlayers,
            nblocks         = config.nblocks
        )
        
        if config.use_ema:
            self.top_quantizer = QuantizerEMA(
                codebook_size    = config.codebook_size,
                latent_channels  = config.latent_channels,
                ema_gamma        = config.ema_gamma,
                commit_loss_beta = config.commit_loss_beta,
                track_codebook   = config.track_codebook
            )
            self.bottom_quantizer = QuantizerEMA(
                codebook_size    = config.codebook_size,
                latent_channels  = config.latent_channels * 2, # we concat two bottom z's together
                ema_gamma        = config.ema_gamma,
                commit_loss_beta = config.commit_loss_beta,
                track_codebook   = config.track_codebook
            )
        else:
            self.top_quantizer = Quantizer(
                codebook_size    = config.codebook_size,
                latent_channels  = config.latent_channels,
                commit_loss_beta = config.commit_loss_beta,
                track_codebook   = config.track_codebook
            )
            self.bottom_quantizer = Quantizer(
                codebook_size    = config.codebook_size,
                latent_channels  = config.latent_channels * 2, # we concat two bottom z's together
                commit_loss_beta = config.commit_loss_beta,
                track_codebook   = config.track_codebook
            )
        
        self.config = config

    def loss(self, x_hat, x, quantized):
        MSE = F.mse_loss(x_hat, x)
        loss = MSE + quantized['bottom_quantized']['vq_loss'] + quantized['top_quantized']['vq_loss']

        return {
            'MSE':  MSE,
            'loss': loss,
            **quantized
        }

    def forward(self, x):
        bottom_z = self.bottom_encoder(x)
        top_z = self.top_encoder(bottom_z)
        
        top_quantized = self.top_quantizer(top_z)
        bottom_z_hat = self.top_decoder(top_quantized['quantized_inputs'])
        
        bottom_z_cat = torch.cat([bottom_z, bottom_z_hat], dim=1)
        bottom_quantized = self.bottom_quantizer(bottom_z_cat)

        x_hat = self.bottom_decoder(bottom_quantized['quantized_inputs'])
        losses = self.loss(x_hat, x, {'bottom_quantized': bottom_quantized, 'top_quantized': top_quantized})

        return {'x_hat': x_hat, **losses}

# Lightning Model

In [None]:
class LitVQVAE2(pl.LightningModule):
    def __init__(self, model, config):
        super().__init__()
        self.model = model
        self.config = config
        self.lr = config.lr

        if self.logger:
            self.logger.experiment.config.update(config)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, _ = batch
        out = self(x)

        self.log('train/loss', out['loss'], prog_bar=True)
        self.log('train/MSE',  out['MSE'],  prog_bar=True)
        
        self.log('train/top/vq_loss',    out['top_quantized']['vq_loss'],    prog_bar=True)
        self.log('train/bottom/vq_loss', out['bottom_quantized']['vq_loss'], prog_bar=True)
        
        self.log('train/top/commitment_loss',    out['top_quantized']['commitment_loss'],    prog_bar=True)
        self.log('train/bottom/commitment_loss', out['bottom_quantized']['commitment_loss'], prog_bar=True)

        if not self.config.use_ema:
            self.log('train/top/embedding_loss',    out['embedding_loss'], prog_bar=True)
            self.log('train/bottom/embedding_loss', out['embedding_loss'], prog_bar=True)

        if self.config.track_codebook:
            self.log('train/top/perplexity',    out['top_quantized']['perplexity'],    prog_bar=True)
            self.log('train/bottom/perplexity', out['bottom_quantized']['perplexity'], prog_bar=True)

        return out['loss']

    def validation_step(self, batch, batch_idx):
        x, _ = batch
        out = self(x)

        self.log('val/loss', out['loss'], prog_bar=True)
        self.log('val/MSE',  out['MSE'],  prog_bar=True)
        
        self.log('val/top/vq_loss',    out['top_quantized']['vq_loss'],    prog_bar=True)
        self.log('val/bottom/vq_loss', out['bottom_quantized']['vq_loss'], prog_bar=True)
        
        self.log('val/top/commitment_loss',    out['top_quantized']['commitment_loss'],    prog_bar=True)
        self.log('val/bottom/commitment_loss', out['bottom_quantized']['commitment_loss'], prog_bar=True)

        if not self.config.use_ema:
            self.log('val/top/embedding_loss',    out['top_quantized']['embedding_loss'],    prog_bar=True)
            self.log('val/bottom/embedding_loss', out['bottom_quantized']['embedding_loss'], prog_bar=True)

        if self.config.track_codebook:
            self.log('val/top/perplexity',    out['top_quantized']['perplexity'],    prog_bar=True)
            self.log('val/bottom/perplexity', out['bottom_quantized']['perplexity'], prog_bar=True)

        if batch_idx == 0:
            n_images = min(x.size(0), 8)
            comparison = torch.cat([x[:n_images], out['x_hat'][:n_images]])
            grid = torchvision.utils.make_grid(comparison)
            self.logger.experiment.log({"val/reconstructions": [wandb.Image(grid, caption="Top: Original, Bottom: Reconstructed")]})

        return out['loss']

    def on_epoch_end(self):
        # tracking perplexity per epoch
        self.model.top_quantizer.reset_usage_stats()
        self.model.bottom_quantizer.reset_usage_stats()

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.lr,
            betas=(self.config.beta1, self.config.beta2),
            weight_decay=self.config.weight_decay
        )

        if self.config.use_lr_schedule:
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.config.max_epochs)
            return [optimizer], [scheduler]

        return optimizer


# Config

In [None]:
def get_num_downsample_layers(img_size):
    """
    Get the number of strided Conv2D layers
    required to produce a 2x2 output volume
    """
    if img_size < 2:
        raise ValueError("Image size must be at least 2x2.")

    # Calculate the minimum number of downsample layers required for 2x2 final
    num_layers = math.ceil(math.log2(img_size / 2))
    return num_layers

def build_channel_dims(start_channels, nlayers):
    """
    Construct a list of channel counts for nlayers downsample layers
    assuming the channels double as spatial dims halve
    """
    channels = []
    for _ in range(nlayers):
        channels.append(start_channels)
        start_channels *= 2
    return channels

class CIFAR10VQVAE2Config:
    def __init__(self):
        # model checkpoints
        self.checkpoint_path = "./checkpoints"
        self.save_top_k = 1
        # training
        self.batch_size = 128
        self.max_epochs = 120
        self.training_steps = 250000
        self.num_workers = 2
        # optimizer
        self.lr = 2e-4
        self.beta1 = 0.9
        self.beta2 = 0.999
        self.weight_decay = 0.0 # 1e-2
        self.use_wd_schedule = False
        self.use_lr_schedule = False
        # input properties
        self.img_size = 32
        self.in_channels = 3
        # latents / quantization
        self.latent_channels = 10
        self.top_codebook_size = 256
        self.bottom_codebook_size = 256
        self.codebook_size = 512
        self.commit_loss_beta = 0.25
        self.track_codebook = True
        self.use_ema = True
        self.ema_gamma = 0.99
        # encoder/decoder
        self.hidden_channels = 256
        self.nblocks = 1
        self.nlayers = 1

    def update(self, updates):
        for key, value in updates.items():
            if hasattr(self, key):
                setattr(self, key, value)
        self.nlayers = get_num_downsample_layers(self.img_size)
        self.channel_dims = build_channel_dims(self.start_channels, self.nlayers)

    def to_dict(self):
        return {k: v for k, v in self.__dict__.items()}

# COMPONENT TESTING

In [None]:
dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())
dataloader = torch.utils.data.DataLoader(dataset, batch_size=3, shuffle=False, num_workers=1)

Files already downloaded and verified


In [None]:
img, label = next(iter(dataloader))
img.shape

torch.Size([3, 3, 32, 32])

In [None]:
config = CIFAR10VQVAE2Config()

In [None]:
vqvae2 = VQVAE2(config)

In [None]:
out = vqvae2(img)

In [None]:
print(out['top_quantized']['quantized_inputs'].shape)
print(out['bottom_quantized']['quantized_inputs'].shape)
print(out['x_hat'].shape)

# Train

In [None]:
config = CIFAR10VQVAE2Config()
model = VQVAE2(config)
lit_model = LitVQVAE2(model, config)
cifar10_data = CIFAR10DataModule(config)

wandb.init(project="VQ-VAE-2 CIFAR-10", config=config.to_dict())
wandb_logger = WandbLogger(project="VQ-VAE-2 CIFAR-10", log_model=False)
wandb_logger.watch(lit_model, log="all")

lr_monitor = LearningRateMonitor(logging_interval='step')

checkpoint_callback = ModelCheckpoint(
    dirpath=config.checkpoint_path,
    filename='model-{epoch:02d}-{val_loss:.2f}',
    every_n_epochs=5,
    save_top_k=config.save_top_k,
    monitor='val/loss',
    mode='min',
    save_last=True
)

# Define the EarlyStopping callback
early_stop_callback = EarlyStopping(
    monitor='val/loss',
    min_delta=0.00,
    patience=15,
    verbose=True,
    check_finite=True
)

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loggers/wandb.py:389: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


In [None]:
# wandb.finish()

In [None]:
trainer = pl.Trainer(
    max_epochs=config.max_epochs,
    devices=1,
    accelerator="gpu",
    precision="16-mixed",
    logger=wandb_logger,
    callbacks=[
        lr_monitor,
        early_stop_callback,
        # checkpoint_callback
    ],
    log_every_n_steps=1,
    # overfit_batches=1,
)

# tuner = Tuner(trainer)
# tuner.lr_find(lit_model, datamodule=cifar10_data)

INFO:pytorch_lightning.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(lit_model, cifar10_data)
wandb.finish()

# Sweeps

In [None]:
sweep_config = {
    'method': 'random',
    'metric': {
        'name': 'val/loss',
        'goal': 'minimize'
    },
    'parameters': {
        'nblocks': {
            'values': [0, 1, 2]
        },
        'latent_dim': {
            'values': [128, 256, 512]
        },
        'start_channels': {
            'values': [32, 64, 128]
        },
        'lr': {
            'min': 1e-5,
            'max': 1e-2,
            'distribution': 'uniform'
        },
        'beta1': {
            'values': [0.9, 0.95, 0.99]
        },
        'beta2': {
            'values': [0.999, 0.9999]
        },
        'kld_weight': {
            'min': 0.0000025,
            'max': 0.00025,
            'distribution': 'uniform'
        }
    }
}

sweep_id = wandb.sweep(sweep_config, project="VQ-VAE-2 CIFAR-10")

In [None]:
def train():
    with wandb.init() as run:
        config = CIFAR10VQVAE2Config()
        config.update(wandb.config)

        model = VQVAE2(config)
        lit_model = LitVQVAE2(model, config)
        cifar10_data = CIFAR10DataModule(config)

        wandb_logger = WandbLogger(project="VQ-VAE-2 CIFAR-10", log_model=False)

        early_stop_callback = EarlyStopping(
            monitor='val/loss',
            min_delta=0.00,
            patience=3,
            verbose=True,
            check_finite=True
        )

        trainer = pl.Trainer(
            max_epochs=config.max_epochs,
            devices=1,
            accelerator="gpu",
            precision="16-mixed",
            logger=wandb_logger,
            callbacks=[early_stop_callback]
        )

        trainer.fit(lit_model, cifar10_data)

In [None]:
wandb.agent(sweep_id, train, count=5)