In [None]:
# install from git
!pip install pytorch-lightning python-dotenv wandb==0.15.0 protobuf==3.20.3 boto3 --quiet
!if [ -e ./side_project_utils ]; then rm -rf ./side_project_utils; fi
!git clone https://github.com/LongDangHoang/side_project_utils ./side_project_utils --quiet  > /dev/null

In [None]:
%load_ext autoreload
%autoreload 2

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys

sys.path.append(".")

import math
import numpy as np
import matplotlib.pyplot as plt

from tqdm.auto import tqdm
from typing import List, Dict, Optional, Union, Tuple
from pathlib import Path

from torchvision import datasets, transforms
from torchvision.datasets import ImageFolder
from torchvision.models import resnet50, ResNet50_Weights
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau, OneCycleLR

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import EarlyStopping, StochasticWeightAveraging, Callback, ModelCheckpoint, LearningRateMonitor

from side_project_utils.callbacks import *
from side_project_utils.training_setup import *

torch.manual_seed(314)
torch.cuda.manual_seed_all(314)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device}")

In [None]:
# Define hyperparameters
## Define hyperparameters
env = set_os_env_from_notebook_secrets()
project_name = "pokemon-resnet-vae"

config = dict(
    batch_size = 256,
    use_constant_lr = False,
    lr = 2e-4,
    num_epoch = -1,
    overfit_batch = 0,
    weight_decay = 0,
    dropout = 0.1,
    log_wandb = True,
    init_new_wandb_run = True,
    use_existing_run = None,
    use_augmentation = True,
    accumulate_grad_batches = 1,
)


# start a new wandb run to track this script
wandb.login()

if "run" not in globals() and config["log_wandb"]:
    run = wandb.init(
        project=project_name,
        id=config["use_existing_run"] if (config["use_existing_run"] and not config["init_new_wandb_run"]) else None,
        resume="must" if (config["use_existing_run"] and not config["init_new_wandb_run"]) else None,
        config=config,
    )
    assert run is not None

# Get data

## Pokemon dataset


In [None]:
from pathlib import Path
import json

if env == "KAGGLE":
    dataset_path = Path("/kaggle/input/pokemon-image-dataset/images")

if env == "COLAB":
    dataset_path = Path('./kaggle/input/pokemon-image-dataset/images')
    Path("/root/.kaggle/").mkdir(parents=True, exist_ok=True)
    with open("/root/.kaggle/kaggle.json", "w") as f:
        json.dump({"username": "danghoanglong", "key": os.getenv("KAGGLE_JSON_KEY")}, f)

    if not dataset_path.exists():
        with open("./kaggle_pokemon_script.sh", "w") as f:
            f.write("""
                mkdir ./kaggle/input/pokemon-image-dataset -p
                kaggle datasets download -d hlrhegemony/pokemon-image-dataset -p ./kaggle/input/pokemon-image-dataset
                unzip -q ./kaggle/input/pokemon-image-dataset/pokemon-image-dataset.zip
                mv images ./kaggle/input/pokemon-image-dataset/
            """)
        os.system("chmod +x ./kaggle_pokemon_script.sh")
        os.system("./kaggle_pokemon_script.sh")

In [None]:
# data augmentation transforms
data_aug_transforms = [
    transforms.RandomHorizontalFlip(p=0.2),
    transforms.RandomVerticalFlip(p=0.2),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
]

# utility transforms
normaliser = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
inv_normaliser = transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], std=[1/0.229, 1/0.224, 1/0.225])
to_pil = transforms.ToPILImage()

# load dataset
pokemon_dataset = ImageFolder(
    dataset_path,
    transform=transforms.Compose([
        *(data_aug_transforms if config["use_augmentation"] else []),
        ResNet50_Weights.DEFAULT.transforms(),
    ])
)

# visualise
to_pil(inv_normaliser(pokemon_dataset[0][0]))

In [None]:
# split train valid
pokemon_train_dataset, pokemon_valid_dataset = torch.utils.data.random_split(
    pokemon_dataset,
    [0.8, 0.2],
    generator=torch.Generator().manual_seed(42)
)

print(f"Train: {len(pokemon_train_dataset)}, Valid: {len(pokemon_valid_dataset)}")

pokemon_train_loader = torch.utils.data.DataLoader(
    pokemon_train_dataset,
    batch_size=config["batch_size"],
    shuffle=config["overfit_batch"] is None,
    num_workers=os.cpu_count(),
    pin_memory=True,
)

pokemon_valid_loader = torch.utils.data.DataLoader(
    pokemon_valid_dataset,
    batch_size=config["batch_size"],
    shuffle=False,
    num_workers=os.cpu_count(),
    pin_memory=True,
)

pokemon_img_size = pokemon_train_dataset[0][0].shape[2]
print(f"Pokemon image size: {pokemon_img_size}")

# Define models

In [None]:
class Resnet50Decoder(nn.Module):
    def __init__(self, in_latent_dim: int=64, dropout_rate: float=0.1):
        super().__init__()

        self.act = nn.SiLU()
        self.dropout = nn.Dropout(dropout_rate)
        self.intermediate_channels = [128, 64, 32, 16]

        self.upsamples = nn.ModuleList([nn.Upsample(scale_factor=2, mode="nearest") for _ in range(5)]) # 7 -> 14 -> 28 -> 56 -> 112 -> 224
        self.bns = nn.ModuleList([nn.BatchNorm2d(c) for c in self.intermediate_channels])
        self.down_channel_convs = nn.ModuleList([
            nn.Conv2d(in_c, out_c, kernel_size=3, padding=1)
            for in_c, out_c in
            zip([in_latent_dim] + self.intermediate_channels, self.intermediate_channels + [3])
        ])
        self.blocks = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_c),
                nn.SiLU(),
                nn.Conv2d(out_c, out_c, kernel_size=1, padding=0),
                nn.BatchNorm2d(out_c),
                nn.SiLU(),
            )
            for out_c in self.intermediate_channels
        ])

    def forward(self, x):
        for i in range(len(self.intermediate_channels)):
            x = self.upsamples[i](x)
            z = self.down_channel_convs[i](x)
            z_update = self.blocks[i](self.bns[i](z))
            z = z + z_update
            x = self.dropout(z)

        x = self.upsamples[-1](x)
        x = self.down_channel_convs[-1](x)
        return x

# Example usage
latent = torch.randn((16, 64, 7, 7))  # batch of 16 samples
decoder = Resnet50Decoder()
output = decoder(latent)
assert tuple(output.shape) == (16, 3, 224, 224)

In [None]:
class ResnetVAE(LightningModule):

    ENCODER_LATENT_SPACE_LOOKUP = {
        "resnet50": (2048, 7, 7)
    }

    def __init__(self, latent_space_dim: int=64, resnet_ver: str="resnet50", decoder_kwargs: dict={}):
        super().__init__()
        self.encoder_latent_img_shape = self.ENCODER_LATENT_SPACE_LOOKUP[resnet_ver]
        self.decoder_latent_img_shape = (latent_space_dim, *self.encoder_latent_img_shape[1:])
        self.resnet_ver = resnet_ver

        self.encoder = self.prepare_frozen_encoder()
        self.decoder = self.prepare_decoder(**decoder_kwargs)

        encoder_latent__channels = self.encoder_latent_img_shape[0]
        self.mu = nn.Conv2d(in_channels=encoder_latent__channels, out_channels=latent_space_dim, kernel_size=1, stride=1, padding=0)
        self.log_var = nn.Conv2d(in_channels=encoder_latent__channels, out_channels=latent_space_dim, kernel_size=1, stride=1, padding=0)

    def prepare_decoder(self, **kwargs):
        return Resnet50Decoder(in_latent_dim=self.decoder_latent_img_shape[0], **kwargs)

    def prepare_frozen_encoder(self):
        if self.resnet_ver == "resnet50":
            encoder = resnet50(weights=ResNet50_Weights.DEFAULT)
            encoder.avgpool = nn.Identity()
            encoder.fc = nn.Identity()
        else:
            raise ValueError(f"{self.resnet_ver} is not a recognised resnet model")

        for param in encoder.parameters():
            param.requires_grad = False

        return encoder

    def forward(self, x):
        x = self.encoder(x).reshape((-1, *self.encoder_latent_img_shape))
        mu = self.mu(x)
        log_var = self.log_var(x)
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        z = mu + eps * std
        reconstructed = self.decoder(z)
        return reconstructed, mu, log_var

    def vae_loss_function(self, reconstructed, original, mu, log_var, beta=1.0):
        batch_size = original.size(0)
        reconstruction_loss = F.mse_loss(reconstructed, original, reduction='sum') / batch_size
        kl_divergence = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) / batch_size
        total_loss = reconstruction_loss + beta * kl_divergence
        return total_loss, reconstruction_loss, kl_divergence

    def training_step(self, batch, batch_index):
        x, _ = batch
        reconstructed, mu, log_var = self(x)
        loss, reconstruction_loss, kl_divergence = self.vae_loss_function(reconstructed, x, mu, log_var)
        self.log('train_step__loss', loss)
        self.log('train_step__kl_loss', kl_divergence)
        self.log('train_step__reconstruction_loss', reconstruction_loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, _ = batch
        reconstructed, mu, log_var = self(x)
        loss, reconstruction_loss, kl_divergence = self.vae_loss_function(reconstructed, x, mu, log_var)
        self.log('valid_epoch__loss', loss, on_step=False, on_epoch=True)
        self.log('valid_epoch__kl_loss', kl_divergence, on_step=False, on_epoch=True)
        self.log('valid_epoch__reconstruction_loss', reconstruction_loss, on_step=False, on_epoch=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=config["lr"], weight_decay=config["weight_decay"])
        if config["use_constant_lr"]:
            return optimizer

        scheduler = ReduceLROnPlateau(optimizer)
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                "scheduler": scheduler,
                "frequency": 100,
                "interval": "step",
                "monitor": "train_step__loss",
            }
        }

# Init model

In [None]:
s3_sync_callback = setup_s3_model_checkpointing(
    project_name=project_name,
    wandb_run=run if config["log_wandb"] else None,
    load_from_run=config["use_existing_run"],
    every_n_epochs=100,
)

In [None]:
model = ResnetVAE(
    latent_space_dim=64,
    resnet_ver="resnet50",
    decoder_kwargs={"dropout_rate": config["dropout"]}
).to(device)
print("Number of trainable params: ", sum(p.numel() for p in model.parameters() if p.requires_grad))

# Define callbacks

In [None]:
class SampleReconstruction(Callback):
    def __init__(self, logger: WandbLogger, sample_input: torch.Tensor, every_n_epochs: int=100):
        super().__init__()
        self.sample_input = sample_input
        assert len(self.sample_input.shape) == 4, "Please ensure to keep the batch dimension"

        self.logger = logger
        self.inv_normaliser = transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], std=[1/0.229, 1/0.224, 1/0.225])
        self.to_pil = transforms.ToPILImage()

        self.every_n_epochs = every_n_epochs

    def on_validation_epoch_end(self, trainer, pl_module):
        if (trainer.current_epoch % self.every_n_epochs == 0) or (trainer.current_epoch == trainer.max_epochs - 1):
            with torch.no_grad():
                x = self.sample_input.to(device)
                reconstructed, mu, log_var = pl_module(x)
                reconstructed = reconstructed.cpu()

            n_images = x.shape[0]
            fig, axs = plt.subplots(nrows=n_images, ncols=2, figsize=(2, n_images))
            for i in range(x.shape[0]):
                orig_img = self.to_pil(self.inv_normaliser(self.sample_input[i]))
                reconstructed_img = self.to_pil(self.inv_normaliser(reconstructed[i]))
                axs[i, 0].imshow(orig_img)
                axs[i, 1].imshow(reconstructed_img)
            axs[0, 0].set_title("Original")
            axs[0, 1].set_title("Reconstructed")
            for ax in axs.ravel():
                ax.axis(False)

            self.logger.log_image(
                key="sample_reconstruction",
                images=[wandb.Image(fig).image]
            )

            plt.close()

# Train

In [None]:
callbacks = []

if config["log_wandb"]:
    callbacks.append(s3_sync_callback)
    wandb_logger = WandbLogger(project=project_name)

    try:
        wandb_logger.watch(model)
    except ValueError as e:
        if "You can only call `wandb.watch` once per model." not in str(e):
            raise e

    callbacks.append(LearningRateMonitor(logging_interval='step'))
    seed_input, _ = next(iter(pokemon_valid_loader))
    sample_callback = SampleReconstruction(wandb_logger, seed_input, every_n_epochs=100)
    callbacks.append(sample_callback)


checkpoint_callback = ModelCheckpoint(
    dirpath=s3_sync_callback.save_local_dir,
    filename="{epoch}-{step}",
    save_last=True,
    every_n_epochs=s3_sync_callback.every_n_epochs,
    save_on_train_epoch_end=True,
)
callbacks.append(checkpoint_callback)


trainer = Trainer(
    accelerator="gpu" if device == "cuda" else "cpu",
    devices=1,
    max_epochs=config["num_epoch"],
    log_every_n_steps=5,
    precision="32",
    logger=wandb_logger if config["log_wandb"] else None,
    callbacks=callbacks,
    accumulate_grad_batches=config["accumulate_grad_batches"] if config["accumulate_grad_batches"] else 1,
    overfit_batches=config["overfit_batch"],
)
trainer.fit(
    model,
    train_dataloaders=pokemon_train_loader,
    val_dataloaders=pokemon_valid_loader,
    ckpt_path=s3_sync_callback.load_local_dir/"last.ckpt" if config["use_existing_run"] else None,
)

In [None]:
# save to s3
s3_sync_callback.upload_files_to_s3()

# Generate images

In [None]:
model = model.to(device)
model.eval()

with torch.no_grad():
    sample_image = pokemon_dataset[100][0]
    x = sample_image.unsqueeze(0).to(device)
    reconstructed, mu, log_var = model(x)
    reconstructed = reconstructed.squeeze(0).cpu()
    reconstructed_pil = to_pil(inv_normaliser(reconstructed))

    fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(2, 1))

    for ax in axs.ravel():
        ax.axis(False)

    axs[0].imshow(to_pil(inv_normaliser(sample_image)))
    axs[1].imshow(reconstructed_pil)
    plt.show()

In [None]:
wandb.finish()