In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import os
from tqdm import tqdm

from experiments.Checkpointing import save_checkpoint
from models.VAE import VAE
from modules.Losses import VGGPerceptualLoss
from modules.SaveOutputs import save_reconstructions, save_samples
from torch.amp import autocast, GradScaler

torch.backends.cudnn.benchmark = True
torch.cuda.empty_cache()

dataset_dir: str = "./data/raw"
out_dir: str = "./outputs/"
batch_size: int = 64
latent_dim: int = 128
checkpoint_dir = "./experiments/checkpoints"
epochs: int = 10
lr: float = 3e-4

torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.Resize((64, 64)),  # resize to 224x224
    transforms.ToTensor()  # convert to tensor & scale to [0,1]
])

dataset = datasets.ImageFolder(root="G:\Temp", transform=transform)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
n_train = len(train_loader.dataset)
print("Loaded datasets, number of samples: ", len(dataset))
scaler = GradScaler(device)
model = VAE(latent_dim=latent_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
vgg_loss = VGGPerceptualLoss()
epoch = 1
os.makedirs(out_dir, exist_ok=True)
from modules.Losses import vae_loss

for epoch_iterator in range(1, epochs + 1):
    model.train()
    running_total = running_recon = running_l1 = running_kld = running_perceptual = 0.0
    progress_bar = tqdm(enumerate(train_loader, start=1), total=len(train_loader), desc="Training")

    for batch_idx, (x, _) in progress_bar:
        x = x.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)
        with autocast(device_type='cuda', dtype=torch.float16):
            logits, mean, logvar = model(x)
            loss = vae_loss(logits, x, mean, logvar, vgg_loss_fn=vgg_loss)
        scaler.scale(loss.total).backward()
        scaler.step(optimizer)
        scaler.update()
        running_total += loss.total.item()
        running_recon += loss.recon.item()
        running_perceptual += loss.percep.item()
        running_kld += loss.kld.item()
        running_l1 += loss.l1.item()

        progress_bar.set_postfix(
            loss=f"{loss.total.item():.3f}",
            recon=f"{loss.recon.item():.3f}",
            l1=f"{loss.l1.item():.3f}",
            kld=f"{loss.kld.item():.3f}",
            percep=f"{loss.percep.item():.3f}",

        )

    print(
        f"Epoch {epoch:02d} | total: {running_total / n_train:.4f} | "
        f"recon: {running_recon / n_train:.4f} | kld: {running_kld / n_train:.4f} | "
        f"perceptual: {running_perceptual / n_train:.4f}"
    )
    # running_total += loss.total.item()
    # running_recon += loss.recon.item()
    # running_kld += loss.kld.item()
    epoch += 1
    # beta = min(0.8, beta * beta_anneal_factor)

    save_reconstructions(model, x, out_dir, epoch, device)
    n_train = len(train_loader.dataset)

    # print(
    #      f"Epoch {epoch:02d} | loss: {loss.item}
    # print(
    #     f"Epoch {epoch:02d} | total: {running_total / n_train:.4f} | "
    #     f"recon: {running_recon / n_train:.4f} | kld: {running_kld / n_train:.4f} | "
    #     f"beta: {beta:.4f}"
    # )
    # if(epoch % 10) == 0:
    #     training_loss_tracker.append([running_total / n_train, running_recon / n_train, running_kld / n_train, beta])

    if epoch % 10 == 0:
        save_checkpoint(model, optimizer, epoch, checkpoint_dir)