In [8]:
# !git clone https://github.com/akshitv2/VAE-latent-space-experiment.git
# %cd /content/VAE-latent-space-experiment

fatal: destination path 'VAE-latent-space-experiment' already exists and is not an empty directory.
/content/VAE-latent-space-experiment


In [3]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import os
from tqdm import tqdm

from experiments.Checkpointing import save_checkpoint
from models.VAE import VAE
from modules.Losses import vae_loss
from modules.SaveOutputs import save_reconstructions, save_samples

dataset_dir: str = "./data/raw"
out_dir: str = "./outputs/"
batch_size: int = 64
latent_dim: int = 512
checkpoint_dir = "./experiments/checkpoints"
epochs: int = 10
lr: float = 3e-4
beta: float = 0.5

torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.Resize((64, 64)),  # resize to 224x224
    transforms.ToTensor()  # convert to tensor & scale to [0,1]
])


In [23]:
dataset = datasets.ImageFolder(root="G:\Temp", transform=transform)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)

In [10]:


print("Loaded datasets, number of samples: ", len(dataset))

# Model & Optimizer
model = VAE(latent_dim=latent_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)


global_step = 0
os.makedirs(out_dir, exist_ok=True)

Loaded datasets, number of samples:  202599


In [11]:
train_mode = True
current_epoch = 0
epochs = 2
beta = 0.001
beta_anneal_factor = 1.1

In [20]:
training_loss_tracker = []

In [29]:
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F

# -------------------------------
# Perceptual loss using VGG16
# -------------------------------
class VGGPerceptualLoss(nn.Module):
    def __init__(self, layer_ids=[3, 8, 15], device='cuda'):
        super().__init__()
        self.device = device
        vgg = models.vgg16(pretrained=True).features.to(device)  # <-- move to device
        self.layers = nn.ModuleList([vgg[:i+1] for i in layer_ids])
        for param in vgg.parameters():
            param.requires_grad = False

    def forward(self, x, y):
        # normalize to ImageNet stats
        mean = torch.tensor([0.485, 0.456, 0.406], device=self.device).view(1,3,1,1)
        std  = torch.tensor([0.229, 0.224, 0.225], device=self.device).view(1,3,1,1)
        x_norm = (x - mean) / std
        y_norm = (y - mean) / std

        loss = 0
        for layer in self.layers:
            loss += F.mse_loss(layer(x_norm), layer(y_norm))
        return loss

# -------------------------------
# VAE Loss combining L1 + Perceptual + KL
# -------------------------------
class VAEVggLoss(nn.Module):
    def __init__(self, recon_weight=1.0, perc_weight=0.1, kl_weight=0.01):
        super().__init__()
        self.recon_weight = recon_weight
        self.perc_weight = perc_weight
        self.kl_weight = kl_weight
        self.perc_loss = VGGPerceptualLoss()

    def forward(self, x_recon, x, mu, logvar):
        # L1 reconstruction loss
        recon_loss = F.l1_loss(x_recon, x)

        # Perceptual loss
        perc_loss = self.perc_loss(x_recon, x)

        # KL divergence
        kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        kl_loss = kl_loss / x.size(0)  # normalize by batch size

        total_loss = self.recon_weight * recon_loss + \
                     self.perc_weight * perc_loss + \
                     self.kl_weight * kl_loss
        return total_loss, recon_loss, perc_loss, kl_loss
vgg_loss = VAEVggLoss()

In [None]:
epochs = 200
if train_mode:
    for epoch in range(current_epoch + 1, current_epoch + epochs + 1):
        model.train()
        running_total = running_recon = running_kld = 0.0

        for batch_idx, (x, _) in enumerate(tqdm(train_loader), start=1):
            x = x.to(device)
            optimizer.zero_grad(set_to_none=True)
            logits, mean, logvar = model(x)
            loss, l1_loss, perc_loss, kl_loss = vgg_loss(logits, x, mean, logvar)
            # loss, l1_loss, perc_loss, kl_loss = criterion(x_recon, x, mu, logvar)
            loss.backward()
            optimizer.step()
            if batch_idx % 3000 == 0:
                print(f"Epoch [{epoch+1}/{epochs}], Step [{epoch}/{len(train_loader)}], "
                  f"Total Loss: {loss.item():.4f}, "
                  f"L1: {l1_loss.item():.4f}, "
                  f"Perceptual: {perc_loss.item():.4f}, "
                  f"KL: {kl_loss.item():.4f}")

            # running_total += loss.total.item()
            # running_recon += loss.recon.item()
            # running_kld += loss.kld.item()
        current_epoch += 1
        # beta = min(0.8, beta * beta_anneal_factor)

        save_reconstructions(model, x, out_dir, current_epoch, device)
        n_train = len(train_loader.dataset)

        # print(
        #      f"Epoch {epoch:02d} | loss: {loss.item}
        # print(
        #     f"Epoch {epoch:02d} | total: {running_total / n_train:.4f} | "
        #     f"recon: {running_recon / n_train:.4f} | kld: {running_kld / n_train:.4f} | "
        #     f"beta: {beta:.4f}"
        # )
        # if(epoch % 10) == 0:
        #     training_loss_tracker.append([running_total / n_train, running_recon / n_train, running_kld / n_train, beta])

        if epoch % 100 == 0:
            save_checkpoint(model, optimizer, epoch, checkpoint_dir)

 95%|█████████▍| 3002/3166 [06:25<00:20,  8.11it/s]

Epoch [48/200], Step [47/3166], Total Loss: 3.7064, L1: 0.0987, Perceptual: 27.6040, KL: 84.7289


100%|██████████| 3166/3166 [06:47<00:00,  7.76it/s]
 95%|█████████▍| 3002/3166 [06:39<00:20,  8.09it/s]

Epoch [49/200], Step [48/3166], Total Loss: 3.5058, L1: 0.1069, Perceptual: 27.5072, KL: 64.8140


100%|██████████| 3166/3166 [07:01<00:00,  7.52it/s]
 95%|█████████▍| 3002/3166 [06:41<00:20,  8.07it/s]

Epoch [50/200], Step [49/3166], Total Loss: 3.5882, L1: 0.1137, Perceptual: 29.1536, KL: 55.9137


100%|██████████| 3166/3166 [07:03<00:00,  7.48it/s]
 95%|█████████▍| 3002/3166 [06:43<00:20,  8.03it/s]

Epoch [51/200], Step [50/3166], Total Loss: 3.3575, L1: 0.1116, Perceptual: 27.9086, KL: 45.4949


100%|██████████| 3166/3166 [07:05<00:00,  7.44it/s]
 95%|█████████▍| 3002/3166 [06:42<00:20,  8.11it/s]

Epoch [52/200], Step [51/3166], Total Loss: 2.9288, L1: 0.1041, Perceptual: 24.1249, KL: 41.2171


100%|██████████| 3166/3166 [07:04<00:00,  7.45it/s]
 95%|█████████▍| 3002/3166 [06:48<00:20,  7.84it/s]

Epoch [53/200], Step [52/3166], Total Loss: 3.1003, L1: 0.1134, Perceptual: 26.1807, KL: 36.8820


100%|██████████| 3166/3166 [07:10<00:00,  7.35it/s]
 95%|█████████▍| 3002/3166 [06:43<00:20,  8.04it/s]

Epoch [54/200], Step [53/3166], Total Loss: 3.1097, L1: 0.1161, Perceptual: 26.3963, KL: 35.3922


100%|██████████| 3166/3166 [07:05<00:00,  7.44it/s]
 95%|█████████▍| 3002/3166 [06:44<00:20,  7.98it/s]

Epoch [55/200], Step [54/3166], Total Loss: 3.0031, L1: 0.1137, Perceptual: 25.6187, KL: 32.7496


100%|██████████| 3166/3166 [07:06<00:00,  7.42it/s]
 95%|█████████▍| 3002/3166 [06:39<00:20,  8.09it/s]

Epoch [56/200], Step [55/3166], Total Loss: 3.1241, L1: 0.1146, Perceptual: 26.5846, KL: 35.1018


100%|██████████| 3166/3166 [07:01<00:00,  7.51it/s]
 95%|█████████▍| 3002/3166 [06:42<00:20,  8.02it/s]

Epoch [57/200], Step [56/3166], Total Loss: 2.9251, L1: 0.1122, Perceptual: 24.7491, KL: 33.7983


100%|██████████| 3166/3166 [07:04<00:00,  7.46it/s]
 95%|█████████▍| 3002/3166 [06:39<00:20,  8.08it/s]

Epoch [58/200], Step [57/3166], Total Loss: 3.0550, L1: 0.1122, Perceptual: 26.0212, KL: 34.0651


100%|██████████| 3166/3166 [07:01<00:00,  7.50it/s]
 95%|█████████▍| 3002/3166 [06:40<00:20,  8.05it/s]

Epoch [59/200], Step [58/3166], Total Loss: 2.8185, L1: 0.1065, Perceptual: 23.4940, KL: 36.2581


100%|██████████| 3166/3166 [07:02<00:00,  7.50it/s]
 95%|█████████▍| 3002/3166 [06:41<00:19,  8.21it/s]

Epoch [60/200], Step [59/3166], Total Loss: 3.0553, L1: 0.1170, Perceptual: 26.0064, KL: 33.7709


100%|██████████| 3166/3166 [07:03<00:00,  7.48it/s]
 95%|█████████▍| 3002/3166 [06:40<00:20,  8.07it/s]

Epoch [61/200], Step [60/3166], Total Loss: 2.8850, L1: 0.1095, Perceptual: 24.3369, KL: 34.1719


100%|██████████| 3166/3166 [07:02<00:00,  7.49it/s]
 95%|█████████▍| 3002/3166 [06:39<00:20,  8.12it/s]

Epoch [62/200], Step [61/3166], Total Loss: 2.8143, L1: 0.1077, Perceptual: 23.7443, KL: 33.2176


100%|██████████| 3166/3166 [07:02<00:00,  7.50it/s]
 95%|█████████▍| 3002/3166 [06:41<00:20,  8.08it/s]

Epoch [63/200], Step [62/3166], Total Loss: 2.8559, L1: 0.1086, Perceptual: 23.8554, KL: 36.1757


100%|██████████| 3166/3166 [07:03<00:00,  7.47it/s]
 95%|█████████▍| 3002/3166 [06:39<00:20,  8.14it/s]

Epoch [64/200], Step [63/3166], Total Loss: 2.8495, L1: 0.1103, Perceptual: 23.9267, KL: 34.6469


100%|██████████| 3166/3166 [07:01<00:00,  7.52it/s]
 95%|█████████▍| 3002/3166 [06:41<00:20,  8.05it/s]

Epoch [65/200], Step [64/3166], Total Loss: 2.9684, L1: 0.1102, Perceptual: 25.0890, KL: 34.9287


100%|██████████| 3166/3166 [07:04<00:00,  7.47it/s]
 95%|█████████▍| 3002/3166 [06:39<00:20,  8.10it/s]

Epoch [66/200], Step [65/3166], Total Loss: 2.6627, L1: 0.0987, Perceptual: 22.2269, KL: 34.1268


100%|██████████| 3166/3166 [07:01<00:00,  7.50it/s]
 95%|█████████▍| 3002/3166 [06:42<00:20,  8.05it/s]

Epoch [67/200], Step [66/3166], Total Loss: 2.7655, L1: 0.1056, Perceptual: 22.9924, KL: 36.0633


100%|██████████| 3166/3166 [07:04<00:00,  7.46it/s]
 95%|█████████▍| 3002/3166 [06:39<00:20,  8.10it/s]

Epoch [68/200], Step [67/3166], Total Loss: 2.9364, L1: 0.1119, Perceptual: 24.6887, KL: 35.5623


100%|██████████| 3166/3166 [07:01<00:00,  7.52it/s]
 95%|█████████▍| 3002/3166 [06:41<00:20,  8.03it/s]

Epoch [69/200], Step [68/3166], Total Loss: 2.7881, L1: 0.1030, Perceptual: 23.4107, KL: 34.4035


100%|██████████| 3166/3166 [07:03<00:00,  7.48it/s]
 95%|█████████▍| 3002/3166 [06:43<00:20,  8.04it/s]

Epoch [70/200], Step [69/3166], Total Loss: 2.7553, L1: 0.1080, Perceptual: 22.8435, KL: 36.2978


100%|██████████| 3166/3166 [07:05<00:00,  7.45it/s]
 95%|█████████▍| 3002/3166 [06:40<00:20,  8.02it/s]

Epoch [71/200], Step [70/3166], Total Loss: 2.8206, L1: 0.1076, Perceptual: 23.6765, KL: 34.5400


100%|██████████| 3166/3166 [07:02<00:00,  7.49it/s]
 95%|█████████▍| 3002/3166 [06:42<00:20,  8.09it/s]

Epoch [72/200], Step [71/3166], Total Loss: 2.6848, L1: 0.1075, Perceptual: 22.3029, KL: 34.6996


100%|██████████| 3166/3166 [07:04<00:00,  7.47it/s]
 95%|█████████▍| 3002/3166 [06:41<00:20,  8.07it/s]

Epoch [73/200], Step [72/3166], Total Loss: 2.8178, L1: 0.1025, Perceptual: 23.6206, KL: 35.3233


100%|██████████| 3166/3166 [07:03<00:00,  7.48it/s]
 95%|█████████▍| 3002/3166 [06:43<00:21,  7.74it/s]

Epoch [74/200], Step [73/3166], Total Loss: 2.9026, L1: 0.1098, Perceptual: 24.4550, KL: 34.7233


100%|██████████| 3166/3166 [07:06<00:00,  7.43it/s]
 95%|█████████▍| 3002/3166 [06:41<00:20,  7.97it/s]

Epoch [75/200], Step [74/3166], Total Loss: 2.8174, L1: 0.1067, Perceptual: 23.2899, KL: 38.1613


100%|██████████| 3166/3166 [07:03<00:00,  7.47it/s]
 95%|█████████▍| 3000/3166 [08:45<00:37,  4.43it/s]

Epoch [76/200], Step [75/3166], Total Loss: 2.7291, L1: 0.1039, Perceptual: 22.6685, KL: 35.8428


100%|██████████| 3166/3166 [09:16<00:00,  5.68it/s]
 95%|█████████▍| 3002/3166 [07:00<00:20,  7.85it/s]

Epoch [77/200], Step [76/3166], Total Loss: 2.7833, L1: 0.1073, Perceptual: 22.9867, KL: 37.7396


100%|██████████| 3166/3166 [07:23<00:00,  7.14it/s]
 95%|█████████▍| 3002/3166 [06:42<00:20,  8.12it/s]

Epoch [78/200], Step [77/3166], Total Loss: 2.7703, L1: 0.1067, Perceptual: 23.1654, KL: 34.7050


100%|██████████| 3166/3166 [07:05<00:00,  7.44it/s]
 95%|█████████▍| 3002/3166 [06:42<00:20,  8.07it/s]

Epoch [79/200], Step [78/3166], Total Loss: 2.6907, L1: 0.1027, Perceptual: 22.1834, KL: 36.9696


100%|██████████| 3166/3166 [07:04<00:00,  7.45it/s]
 95%|█████████▍| 3002/3166 [06:39<00:20,  8.15it/s]

Epoch [80/200], Step [79/3166], Total Loss: 2.7129, L1: 0.1055, Perceptual: 22.6130, KL: 34.6086


100%|██████████| 3166/3166 [07:01<00:00,  7.51it/s]
 95%|█████████▍| 3002/3166 [06:39<00:19,  8.20it/s]

Epoch [81/200], Step [80/3166], Total Loss: 2.6871, L1: 0.1038, Perceptual: 22.2024, KL: 36.3099


100%|██████████| 3166/3166 [07:01<00:00,  7.51it/s]
 95%|█████████▍| 3002/3166 [06:39<00:20,  8.14it/s]

Epoch [82/200], Step [81/3166], Total Loss: 2.7861, L1: 0.1119, Perceptual: 23.2960, KL: 34.4657


100%|██████████| 3166/3166 [07:00<00:00,  7.52it/s]
 95%|█████████▍| 3002/3166 [06:40<00:20,  8.11it/s]

Epoch [83/200], Step [82/3166], Total Loss: 2.7661, L1: 0.1041, Perceptual: 22.9383, KL: 36.8173


100%|██████████| 3166/3166 [07:02<00:00,  7.50it/s]
 95%|█████████▍| 3002/3166 [06:40<00:20,  7.99it/s]

Epoch [84/200], Step [83/3166], Total Loss: 2.8251, L1: 0.1078, Perceptual: 23.5839, KL: 35.8950


100%|██████████| 3166/3166 [07:02<00:00,  7.49it/s]
 95%|█████████▍| 3002/3166 [06:41<00:20,  8.03it/s]

Epoch [85/200], Step [84/3166], Total Loss: 2.7018, L1: 0.1048, Perceptual: 22.2717, KL: 36.9833


100%|██████████| 3166/3166 [07:03<00:00,  7.47it/s]
 95%|█████████▍| 3002/3166 [06:40<00:20,  8.10it/s]

Epoch [86/200], Step [85/3166], Total Loss: 2.7193, L1: 0.1040, Perceptual: 22.5674, KL: 35.8514


100%|██████████| 3166/3166 [07:02<00:00,  7.49it/s]
 95%|█████████▍| 3002/3166 [06:42<00:20,  8.08it/s]

Epoch [87/200], Step [86/3166], Total Loss: 2.6970, L1: 0.1041, Perceptual: 22.4155, KL: 35.1286


100%|██████████| 3166/3166 [07:04<00:00,  7.46it/s]
 95%|█████████▍| 3002/3166 [06:40<00:20,  8.07it/s]

Epoch [88/200], Step [87/3166], Total Loss: 2.7732, L1: 0.1047, Perceptual: 22.8894, KL: 37.9510


100%|██████████| 3166/3166 [07:02<00:00,  7.50it/s]
 95%|█████████▍| 3002/3166 [06:42<00:20,  8.16it/s]

Epoch [89/200], Step [88/3166], Total Loss: 2.7877, L1: 0.1043, Perceptual: 23.1175, KL: 37.1617


100%|██████████| 3166/3166 [07:04<00:00,  7.46it/s]
 95%|█████████▍| 3002/3166 [06:40<00:20,  8.05it/s]

Epoch [90/200], Step [89/3166], Total Loss: 2.6414, L1: 0.1029, Perceptual: 21.8096, KL: 35.7543


100%|██████████| 3166/3166 [07:02<00:00,  7.49it/s]
 95%|█████████▍| 3002/3166 [06:40<00:20,  8.14it/s]

Epoch [91/200], Step [90/3166], Total Loss: 2.9924, L1: 0.1140, Perceptual: 25.1511, KL: 36.3248


100%|██████████| 3166/3166 [07:02<00:00,  7.49it/s]
 95%|█████████▍| 3002/3166 [06:39<00:20,  8.10it/s]

Epoch [92/200], Step [91/3166], Total Loss: 2.7191, L1: 0.1037, Perceptual: 22.4742, KL: 36.8014


100%|██████████| 3166/3166 [07:01<00:00,  7.52it/s]
 95%|█████████▍| 3002/3166 [06:40<00:20,  8.09it/s]

Epoch [93/200], Step [92/3166], Total Loss: 2.6199, L1: 0.1018, Perceptual: 21.7163, KL: 34.6554


100%|██████████| 3166/3166 [07:02<00:00,  7.50it/s]
 95%|█████████▍| 3002/3166 [06:40<00:21,  7.56it/s]

Epoch [94/200], Step [93/3166], Total Loss: 2.9319, L1: 0.1100, Perceptual: 24.6081, KL: 36.1064


100%|██████████| 3166/3166 [07:02<00:00,  7.49it/s]
 95%|█████████▍| 3002/3166 [06:41<00:20,  7.98it/s]

Epoch [95/200], Step [94/3166], Total Loss: 2.7314, L1: 0.1049, Perceptual: 22.7773, KL: 34.8796


100%|██████████| 3166/3166 [07:03<00:00,  7.47it/s]
 95%|█████████▍| 3002/3166 [06:42<00:20,  8.06it/s]

Epoch [96/200], Step [95/3166], Total Loss: 2.7943, L1: 0.1066, Perceptual: 23.2951, KL: 35.8264


100%|██████████| 3166/3166 [07:04<00:00,  7.46it/s]
 95%|█████████▍| 3002/3166 [06:41<00:20,  8.08it/s]

Epoch [97/200], Step [96/3166], Total Loss: 2.8015, L1: 0.1030, Perceptual: 23.4017, KL: 35.8266


100%|██████████| 3166/3166 [07:03<00:00,  7.47it/s]
 95%|█████████▍| 3002/3166 [06:40<00:20,  8.11it/s]

Epoch [98/200], Step [97/3166], Total Loss: 2.6827, L1: 0.1023, Perceptual: 22.0305, KL: 37.7412


100%|██████████| 3166/3166 [07:02<00:00,  7.49it/s]
 95%|█████████▍| 3002/3166 [06:42<00:20,  8.08it/s]

Epoch [99/200], Step [98/3166], Total Loss: 2.7748, L1: 0.1063, Perceptual: 23.0315, KL: 36.5349


100%|██████████| 3166/3166 [07:04<00:00,  7.46it/s]
 95%|█████████▍| 3002/3166 [06:44<00:20,  8.04it/s]

Epoch [100/200], Step [99/3166], Total Loss: 2.7238, L1: 0.1070, Perceptual: 22.5224, KL: 36.4531


100%|██████████| 3166/3166 [07:07<00:00,  7.41it/s]
 95%|█████████▍| 3002/3166 [06:40<00:20,  8.03it/s]

Epoch [101/200], Step [100/3166], Total Loss: 2.5393, L1: 0.0980, Perceptual: 20.8441, KL: 35.6866


100%|██████████| 3166/3166 [07:02<00:00,  7.50it/s]


Saved checkpoint: ./experiments/checkpoints\vae_checkpoint_epoch_100.pt


 95%|█████████▍| 3002/3166 [06:53<00:20,  7.86it/s]

Epoch [102/200], Step [101/3166], Total Loss: 2.6884, L1: 0.1082, Perceptual: 22.2029, KL: 35.9883


100%|██████████| 3166/3166 [07:17<00:00,  7.24it/s]
 95%|█████████▍| 3002/3166 [06:50<00:20,  7.89it/s]

Epoch [103/200], Step [102/3166], Total Loss: 2.7936, L1: 0.1051, Perceptual: 23.1849, KL: 37.0026


100%|██████████| 3166/3166 [07:13<00:00,  7.30it/s]
 95%|█████████▍| 3002/3166 [06:52<00:20,  7.85it/s]

Epoch [104/200], Step [103/3166], Total Loss: 2.8089, L1: 0.1086, Perceptual: 23.3697, KL: 36.3362


100%|██████████| 3166/3166 [07:15<00:00,  7.27it/s]
 95%|█████████▍| 3002/3166 [06:54<00:21,  7.79it/s]

Epoch [105/200], Step [104/3166], Total Loss: 2.8108, L1: 0.1066, Perceptual: 23.3867, KL: 36.5520


100%|██████████| 3166/3166 [07:17<00:00,  7.24it/s]
 95%|█████████▍| 3002/3166 [07:08<00:21,  7.50it/s]

Epoch [106/200], Step [105/3166], Total Loss: 2.6828, L1: 0.1035, Perceptual: 22.0212, KL: 37.7165


100%|██████████| 3166/3166 [07:32<00:00,  7.00it/s]
 95%|█████████▍| 3002/3166 [07:06<00:21,  7.55it/s]

Epoch [107/200], Step [106/3166], Total Loss: 2.7350, L1: 0.1080, Perceptual: 22.5878, KL: 36.8230


100%|██████████| 3166/3166 [07:30<00:00,  7.03it/s]
 42%|████▏     | 1320/3166 [03:12<04:30,  6.82it/s]