In [None]:
!pip install -q wandb pytorch_lightning

In [None]:
import argparse
import math
import torch
import torch.utils.data
import torchvision
import wandb
import pytorch_lightning as pl

from __future__ import print_function
from collections import OrderedDict
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.tuner import Tuner
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR


pl.seed_everything(42)

In [None]:
class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(self, config):
        super().__init__()
        self.batch_size = config.batch_size
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

    def prepare_data(self):
        torchvision.datasets.CIFAR10(root='./data', train=True, download=True)
        torchvision.datasets.CIFAR10(root='./data', train=False, download=True)

    def setup(self, stage=None):
        if stage in ('fit', None):
            self.cifar10_train = torchvision.datasets.CIFAR10(
                root='./data', train=True, transform=self.transform)
            self.cifar10_val = torchvision.datasets.CIFAR10(
                root='./data', train=False, transform=self.transform)

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.cifar10_train,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=2, # TODO: check if we need more workers
            pin_memory=True
        )

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.cifar10_val,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=2, # TODO: check if we need more workers
            pin_memory=True
        )

In [None]:
dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)

In [None]:
img, label = next(iter(dataloader))
img.shape

In [None]:
K, D = 100, 10
codes = nn.Embedding(K, D)
latents = torch.randn(1, D, 2, 2)
print(latents.shape)

In [None]:
latents = latents.permute(0, 2, 3, 1).contiguous()
print(latents.shape)
flat_latents = latents.view(-1, D)
print(flat_latents.shape)

# Σ(x-y)^2 = Σx^2 - 2xy + Σy^2

In [None]:
latent_dists = torch.sum(flat_latents ** 2, dim=1, keepdim=True)
print(f'||latents||^2: {latent_dists.shape}')

code_dists = torch.sum(codes.weight ** 2, dim=1)
print(f'||codes||^2: {code_dists.shape}\n')

lat_code_dists = 2 * (flat_latents @ codes.weight.t())
print(f'2*lats@codes: {lat_code_dists.shape}\n')

dist = latent_dists + code_dists - lat_code_dists  # [BHW x K]
print(f'dist.shape: {dist.shape} = {latent_dists.shape} + {code_dists.shape} - {lat_code_dists.shape}\n')

# Get the encoding that has the min distance
encoding_inds = torch.argmin(dist, dim=1).unsqueeze(1)  # [BHW, 1]
print(f'encoding_inds.shape: {encoding_inds.shape}')

In [None]:
class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        if in_channels != out_channels:
            self.identity = nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=1
            )
        else:
            self.identity = nn.Identity()

        hidden_channels = out_channels // 4

        self.block = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=hidden_channels,
                kernel_size=3,
                padding=1
            ),
            nn.BatchNorm2d(hidden_channels),
            nn.LeakyReLU(),
            nn.Conv2d(
                in_channels=hidden_channels,
                out_channels=hidden_channels,
                kernel_size=3,
                padding=1
            ),
            nn.BatchNorm2d(hidden_channels),
            nn.LeakyReLU(),
            nn.Conv2d(
                in_channels=hidden_channels,
                out_channels=hidden_channels,
                kernel_size=3,
                padding=1
            ),
            nn.BatchNorm2d(hidden_channels),
            nn.LeakyReLU(),
            nn.Conv2d(
                in_channels=hidden_channels,
                out_channels=out_channels,
                kernel_size=1
            )
        )

    def forward(self, x):
        out = self.block(x) + self.identity(x)
        return F.leaky_relu(out)

In [None]:
class Encoder(nn.Module):
    def __init__(self, config):
        super().__init__()

        blocks = []
        curr_channel = config.in_channels
        for out_channel in config.channel_dims:
            blocks.append(
                nn.Sequential(
                    # reduce spatial dims increase channels
                    nn.Conv2d(
                        in_channels=curr_channel,
                        out_channels=out_channel,
                        kernel_size=3,
                        stride=2,
                        padding=1
                    ),
                    nn.BatchNorm2d(out_channel),
                    nn.LeakyReLU(),
                    # continue at current dim and channels
                    nn.Sequential(*[
                        ResBlock(
                            in_channels=out_channel,
                            out_channels=out_channel
                        ) for _ in range(config.nblocks)
                    ])
                ),
            )
            curr_channel = out_channel

        self.blocks    = nn.Sequential(*blocks)
        self.fc_mu     = nn.Linear(config.channel_dims[-1]*4, config.latent_dim)
        self.fc_logvar = nn.Linear(config.channel_dims[-1]*4, config.latent_dim)

    def forward(self, x):
        x = self.blocks(x)
        x = torch.flatten(x, start_dim=1)
        return self.fc_mu(x), self.fc_logvar(x)

In [None]:
class Decoder(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.latent_map = nn.Linear(config.latent_dim, config.channel_dims[-1]*4)

        blocks = []
        reverse_channels = config.channel_dims
        reverse_channels = list(config.channel_dims)[::-1]
        for i in range(len(reverse_channels)-1):
            blocks.append(
                nn.Sequential(
                    # continue at current dim and channels
                    nn.Sequential(*[
                        ResBlock(
                            in_channels=reverse_channels[i],
                            out_channels=reverse_channels[i]
                        ) for _ in range(config.nblocks)
                    ]),
                    # reduce channels increase spatial dims
                    nn.ConvTranspose2d(
                        in_channels=reverse_channels[i],
                        out_channels=reverse_channels[i+1],
                        kernel_size=3,
                        stride = 2,
                        padding=1,
                        output_padding=1
                    ),
                    nn.BatchNorm2d(reverse_channels[i+1]),
                    nn.LeakyReLU()
                )
            )
        self.blocks = nn.Sequential(*blocks)
        self.final_layer = nn.Sequential(
            nn.ConvTranspose2d(
                in_channels=reverse_channels[-1],
                out_channels=config.in_channels,
                kernel_size=3,
                stride = 2,
                padding=1,
                output_padding=1
            ),
            nn.LeakyReLU(),
            nn.Tanh()
        )

        self.config = config

    def forward(self, z):
        z = self.latent_map(z)
        z = z.view(-1, self.config.channel_dims[-1], 2, 2)
        z = self.blocks(z)
        return self.final_layer(z)


In [None]:
class Quantizer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.codebook = nn.Embedding(config.codebook_size, config.latent_dim)

In [None]:
class VQVAE(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.encoder = Encoder(config)
        self.decoder = Decoder(config)
        self.config = config

    def encode(self, x):
        return self.encoder(x)

    def decode(self, z):
        return self.decoder(z)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def loss(self, x_hat, x, mu, logvar):
        MSE = F.mse_loss(x_hat, x)
        MKLD = torch.mean(-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1), dim=0)
        loss = MSE + self.config.kld_weight * MKLD
        return loss, MSE, MKLD

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_hat = self.decode(z)
        return x_hat, mu, logvar

In [None]:
class LitVQVAE(pl.LightningModule):
    def __init__(self, model, config=None):
        super().__init__()
        self.model = model
        self.config = config
        self.lr = config.lr

        if self.logger:
            self.logger.experiment.config.update(config)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, _ = batch
        x_hat, mu, logvar = self(x)

        loss, MSE, MKLD = self.model.loss(x_hat, x, mu, logvar)
        self.log('train/loss', loss, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):
        x, _ = batch
        x_hat, mu, logvar = self(x)

        loss, MSE, MKLD = self.model.loss(x_hat, x, mu, logvar)
        self.log('val/loss', loss, prog_bar=True)

        if batch_idx == 0:
            n_images = min(x.size(0), 8)
            comparison = torch.cat([x[:n_images], x_hat[:n_images]])
            grid = torchvision.utils.make_grid(comparison)
            self.logger.experiment.log({"val/reconstructions": [wandb.Image(grid, caption="Top: Original, Bottom: Reconstructed")]})

        return loss

    def configure_optimizers(self):
        self.optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, betas=(self.config.beta1, self.config.beta2))
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=self.config.max_epochs)
        return [self.optimizer], [self.scheduler]


In [None]:
def get_num_downsample_layers(img_size):
    """
    Get the number of strided Conv2D layers
    required to produce a 2x2 output volume
    """
    if img_size < 2:
        raise ValueError("Image size must be at least 2x2.")

    # Calculate the minimum number of downsample layers required for 2x2 final
    num_layers = math.ceil(math.log2(img_size / 2))
    return num_layers

def build_channel_dims(start_channels, nlayers):
    """
    Construct a list of channel counts for nlayers downsample layers
    assuming the channels double as spatial dims halve
    """
    channels = []
    for _ in range(nlayers):
        channels.append(start_channels)
        start_channels *= 2
    return channels

class CIFAR10VAEConfig:
    def __init__(self):
        self.checkpoint_path = "./checkpoints"
        self.save_top_k = 1
        self.batch_size = 2048 # TODO: maxout for max throughput
        self.max_epochs = 60
        self.lr = 3e-4
        self.beta1 = 0.9
        self.beta2 = 0.999
        self.img_size = 32
        self.in_channels = 3
        self.latent_dim = 128
        self.kld_weight = 0.000025
        self.start_channels = 32
        self.nblocks = 0
        self.nlayers = get_num_downsample_layers(self.img_size)
        self.channel_dims = build_channel_dims(self.start_channels, self.nlayers)

    def update(self, updates):
        for key, value in updates.items():
            if hasattr(self, key):
                setattr(self, key, value)
        self.nlayers = get_num_downsample_layers(self.img_size)
        self.channel_dims = build_channel_dims(self.start_channels, self.nlayers)

    def to_dict(self):
        return {k: v for k, v in self.__dict__.items()}

In [None]:
config = CIFAR10VAEConfig()
model = VAE(config)
lit_model = LitVAE(model, config)
cifar10_data = CIFAR10DataModule(config)

wandb.init(project="VAE CIFAR-10", config=config.to_dict())
wandb_logger = WandbLogger(project="VAE CIFAR-10", log_model=False)
wandb_logger.watch(lit_model, log="all")

lr_monitor = LearningRateMonitor(logging_interval='step')

checkpoint_callback = ModelCheckpoint(
    dirpath=config.checkpoint_path,
    filename='model-{epoch:02d}-{val_loss:.2f}',
    every_n_epochs=5,
    save_top_k=config.save_top_k,
    monitor='val/loss',
    mode='min',
    save_last=True
)

# Define the EarlyStopping callback
early_stop_callback = EarlyStopping(
    monitor='val/loss',
    min_delta=0.00,
    patience=3,
    verbose=True,
    check_finite=True
)

In [None]:
#wandb.finish()

In [None]:
trainer = pl.Trainer(
    max_epochs=config.max_epochs,
    devices=1,
    accelerator="gpu",
    precision="16-mixed",
    logger=wandb_logger,
    callbacks=[
        lr_monitor,
        early_stop_callback,
        # checkpoint_callback
    ],
    log_every_n_steps=1,
)

# tuner = Tuner(trainer)
# tuner.lr_find(lit_model, datamodule=cifar10_data)

In [None]:
trainer.fit(lit_model, cifar10_data)
wandb.finish()

In [None]:
sweep_config = {
    'method': 'random',
    'metric': {
        'name': 'val/loss',
        'goal': 'minimize'
    },
    'parameters': {
        'nblocks': {
            'values': [0, 1, 2]
        },
        'latent_dim': {
            'values': [128, 256, 512]
        },
        'start_channels': {
            'values': [32, 64, 128]
        },
        'lr': {
            'min': 1e-5,
            'max': 1e-2,
            'distribution': 'uniform'
        },
        'beta1': {
            'values': [0.9, 0.95, 0.99]
        },
        'beta2': {
            'values': [0.999, 0.9999]
        },
        'kld_weight': {
            'min': 0.0000025,
            'max': 0.00025,
            'distribution': 'uniform'
        }
    }
}

sweep_id = wandb.sweep(sweep_config, project="VAE CIFAR-10")

In [None]:
def train():
    with wandb.init() as run:
        config = CIFAR10VAEConfig()
        config.update(wandb.config)

        model = VAE(config)
        lit_model = LitVAE(model, config)
        cifar10_data = CIFAR10DataModule(config)

        wandb_logger = WandbLogger(project="VAE CIFAR-10", log_model=False)

        early_stop_callback = EarlyStopping(
            monitor='val/loss',
            min_delta=0.00,
            patience=3,
            verbose=True,
            check_finite=True
        )

        trainer = pl.Trainer(
            max_epochs=config.max_epochs,
            devices=1,
            accelerator="gpu",
            precision="16-mixed",
            logger=wandb_logger,
            callbacks=[early_stop_callback]
        )

        trainer.fit(lit_model, cifar10_data)

In [None]:
wandb.agent(sweep_id, train, count=5)