# **DEPENDENCIES**

In [None]:
# Local import
# import sys
# sys.path.append('..')

# Import on colab
%pip install git+https://github.com/adityaprakash-work/DreamWalker.git

In [None]:
from google.colab import drive; drive.mount('/content/drive')

import numpy as np
import torch
from torchvision.utils import make_grid
from torch import optim
from torch.nn import functional as F
from torch.optim import lr_scheduler

import dreamwalker as dw
from dreamwalker.pytorch_generative import models, trainer

# **DATASET**

In [None]:
# Run this cell to load from an online source
dataset_url = "https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000/data"
dataset_dir = "/content/dataset"
dw.utils.datasets.download(dataset_url, dataset_dir)

In [None]:
train_dir = "/content/dataset/imagenetmini-1000/imagenet-mini/train"
valid_dir = "/content/dataset/imagenetmini-1000/imagenet-mini/val"
if valid_dir is None:
    dataset = dw.utils.datasets.ImageStream(train_dir, ext="JPEG")
    train_loader, valid_loader = dw.utils.datasets.get_loaders(
        dataset, return_valid=True, valid_size=0.2
    )

else:
    train_dataset = dw.utils.datasets.ImageStream(train_dir, ext="JPEG")
    valid_dataset = dw.utils.datasets.ImageStream(valid_dir, ext="JPEG")
    train_loader = dw.utils.datasets.get_loaders(train_dataset, batch_size=16)
    valid_loader = dw.utils.datasets.get_loaders(valid_dataset, batch_size=16)

In [None]:
for data, _ in train_loader:
    print(data.shape)
    break

# **TRAINING**

In [None]:
model = models.VectorQuantizedVAE2(
    in_channels=3,
    out_channels=3,
    hidden_channels=128,
    n_residual_blocks=2,
    residual_channels=64,
    n_embeddings=512,
    embedding_dim=64,
)

optimizer = optim.Adam(model.parameters(), lr=2e-4)
scheduler = lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lambda _: 0.999977)

def loss_fn(x, _, preds):
    preds, vq_loss = preds
    recon_loss = F.mse_loss(preds, x)
    loss = recon_loss + 0.25 * vq_loss

    return {
        "vq_loss": vq_loss,
        "reconstruction_loss": recon_loss,
        "loss": loss,
    }

model_trainer = trainer.Trainer(
    model=model, 
    loss_fn=loss_fn,
    optimizer=optimizer,
    train_loader=train_loader,
    eval_loader=valid_loader,
    lr_scheduler=scheduler,
    log_dir="/content/logs/vqvae0",
    n_gpus=0,
)

In [None]:
# Make grid of original vs reconstructed images
def make_grid_ovsr(original, reconstructions):
    num_samples=original.shape[0]
    num_rows = int(np.ceil(np.sqrt(num_samples)))
    grid_o = make_grid(original, nrow=num_rows, normalize=True)
    grid_r = make_grid(reconstructions, nrow=num_rows, normalize=True)
    grid = torch.cat([grid_o, grid_r], dim=-1)
    return grid


@torch.no_grad()
def recplt_monitor(model_trainer):
    model_trainer.model.eval()
    x, _ = next(iter(model_trainer.eval_loader))
    x = x.to(model_trainer.device)
    x_recon, _ = model_trainer.model(x)
    x = x.cpu().detach()
    x_recon = x_recon.cpu().detach()
    model_trainer._summary_writer.add_image(
        "Reconstruction Fidelity",
        make_grid_ovsr(x, x_recon),
        model_trainer._step,
    )
    model_trainer.model.train()
    

model_trainer.interleaved_train_and_eval(5, arbitrary_monitors=[recplt_monitor])