In [1]:
# from google.colab import drive
# drive.mount('/content/drive')
# !git clone https://github.com/akshitv2/VAE-latent-space-experiment.git
# %cd VAE-latent-space-experiment
# !git checkout celebA-ParamTweaking
# !cp /content/drive/MyDrive/Datasets/img_align_celeba.zip /content/
# !unzip /content/img_align_celeba.zip -d /content/dataset > /dev/null
try:
    import google.colab
    running_in_colab = True
except ImportError:
    running_in_colab = False

if running_in_colab:
    dataset_dir = "/content/dataset"
    out_dir: str = "/content/drive/MyDrive/Temp/outputs"
    checkpoint_dir = "/content/drive/MyDrive/Temp/checkpoints"
else:
    dataset_dir = "G:/Temp"
    out_dir: str = "./outputs/"
    checkpoint_dir = "./experiments/checkpoints"

In [2]:
from torch import GradScaler
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import os
from tqdm import tqdm
from experiments.Checkpointing import save_checkpoint
from models.VAE import VAE
from modules.Losses import VAEVggLoss
from modules.SaveOutputs import save_reconstructions

torch.backends.cudnn.benchmark = True
torch.cuda.empty_cache()

dataset_dir: str = "./data/raw"
out_dir: str = "./outputs/"
batch_size: int = 64
latent_dim: int = 256
checkpoint_dir = "./experiments/checkpoints"
# epochs: int = 10
lr: float = 3e-4
beta: float = 0.5

torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.Resize((64, 64)),  # resize to 224x224
    transforms.ToTensor()  # convert to tensor & scale to [0,1]
])

dataset = datasets.ImageFolder(root="G:\Temp", transform=transform)
train_test_split_var = 0.99
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [int(train_test_split_var*len(dataset)), len(dataset) - int(train_test_split_var*len(dataset))])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
n_train = len(train_loader.dataset)
print("Loaded datasets, number of samples: ", n_train)

# Model & Optimizer
scaler = GradScaler(device)
model = VAE(latent_dim=latent_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

current_epoch = 1
os.makedirs(out_dir, exist_ok=True)

Loaded datasets, number of samples:  2026


In [None]:
epochs = 10
vgg_loss = VAEVggLoss(recon_weight=0.01, perc_weight=0.1, kl_weight=0.01, recon_loss_function = "mse")
for epoch in range(1, epochs + 1):
    model.train()
    running_total = running_recon = running_kld = running_perceptual =  0.0
    progress_bar = tqdm(enumerate(train_loader, start=1), total=len(train_loader), desc="Training")

    for batch_idx, (x, _) in progress_bar:
        x = x.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)
        logits, mean, logvar = model(x)
        loss, l1_loss, perc_loss, kl_loss = vgg_loss(logits, x, mean, logvar)
        # loss, l1_loss, perc_loss, kl_loss = criterion(x_recon, x, mu, logvar)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        progress_bar.set_postfix(
            loss=f"{loss.item():.3f}",
            l1=f"{l1_loss.item():.3f}",
            kld=f"{kl_loss.item():.3f}",
            percep=f"{perc_loss.item():.3f}",

        )
        running_total += loss.item()
        running_recon += l1_loss.item()
        running_kld += kl_loss.item()
        running_perceptual += perc_loss.item()
    current_epoch += 1
    print(
        f"Epoch {current_epoch:02d} | total: {running_total / n_train:.4f} | "
        f"recon: {running_recon / n_train:.4f} | kld: {running_kld / n_train:.4f} | "
        f"perceptual: {running_perceptual / n_train:.4f}"
    )
    save_reconstructions(model=model, x=x, out_dir=out_dir, step = current_epoch, device=device, variant="")
    # save_reconstructions(model, x, out_dir, current_epoch, device)
    if current_epoch % 100 == 0:
        save_checkpoint(model, optimizer, current_epoch, checkpoint_dir)

Training: 100%|██████████| 32/32 [00:06<00:00,  5.00it/s, kld=5.750, l1=411.343, loss=421.448, percep=4.355] 


Epoch 02 | total: 10.3730 | recon: 10.1987 | kld: 0.1103 | perceptual: 0.0640


Training: 100%|██████████| 32/32 [00:04<00:00,  7.58it/s, kld=13.657, l1=346.803, loss=364.804, percep=4.343]


Epoch 03 | total: 8.4164 | recon: 8.0828 | kld: 0.2687 | perceptual: 0.0649


Training: 100%|██████████| 32/32 [00:04<00:00,  7.90it/s, kld=20.494, l1=232.459, loss=257.274, percep=4.321]


Epoch 04 | total: 6.3663 | recon: 5.9684 | kld: 0.3309 | perceptual: 0.0670


Training: 100%|██████████| 32/32 [00:04<00:00,  7.93it/s, kld=17.994, l1=178.974, loss=201.061, percep=4.092]


Epoch 05 | total: 5.0149 | recon: 4.5190 | kld: 0.4300 | perceptual: 0.0659


Training: 100%|██████████| 32/32 [00:04<00:00,  7.67it/s, kld=23.089, l1=141.278, loss=168.236, percep=3.869]


Epoch 06 | total: 4.6159 | recon: 4.0743 | kld: 0.4769 | perceptual: 0.0647


Training: 100%|██████████| 32/32 [00:04<00:00,  7.69it/s, kld=20.017, l1=155.773, loss=179.761, percep=3.972]


Epoch 07 | total: 4.3462 | recon: 3.7598 | kld: 0.5224 | perceptual: 0.0639


Training: 100%|██████████| 32/32 [00:04<00:00,  7.75it/s, kld=22.619, l1=141.541, loss=168.156, percep=3.996]


Epoch 08 | total: 4.2152 | recon: 3.6072 | kld: 0.5449 | perceptual: 0.0631


Training: 100%|██████████| 32/32 [00:04<00:00,  7.76it/s, kld=23.244, l1=133.926, loss=161.141, percep=3.970]


Epoch 09 | total: 4.0334 | recon: 3.4134 | kld: 0.5577 | perceptual: 0.0623
