In [8]:
!git clone https://github.com/akshitv2/VAE-latent-space-experiment.git
%cd /content/VAE-latent-space-experiment

fatal: destination path 'VAE-latent-space-experiment' already exists and is not an empty directory.
/content/VAE-latent-space-experiment


In [5]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import os

from experiments.Checkpointing import save_checkpoint
from models.VAE import VAE
from modules.Losses import vae_loss
from modules.SaveOutputs import save_reconstructions, save_samples

dataset_dir: str = "./data/raw"
out_dir: str = "./outputs/"
batch_size: int = 32
latent_dim: int = 64
checkpoint_dir="./experiments/checkpoints"
epochs: int = 10
lr: float = 3e-4
beta: float = 0.5

torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
        transforms.Resize((64, 64)),  # resize to 224x224
        transforms.ToTensor()  # convert to tensor & scale to [0,1]
    ])


In [7]:
dataset = datasets.ImageFolder(root="G:\Temp", transform=transform)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)

In [8]:
print("Loaded datasets, number of samples: ", len(dataset))

    # Model & Optimizer
model = VAE(latent_dim=latent_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

global_step = 0
os.makedirs(out_dir, exist_ok=True)

Loaded datasets, number of samples:  202599


In [9]:
train_mode = True
current_epoch = 0
epochs = 10

In [11]:
epochs = 100

In [None]:
if train_mode:
    for epoch in range(current_epoch + 1, current_epoch + epochs + 1):
        model.train()
        running_total = 0.0
        running_recon = 0.0
        running_kld = 0.0
        current_epoch+=1

        for batch_idx, (x, _) in enumerate(train_loader, start=1):
            x = x.to(device)
            optimizer.zero_grad(set_to_none=True)
            logits, mean, logvar = model(x)
            loss = vae_loss(logits, x, mean, logvar, beta=beta)
            loss.total.backward()
            optimizer.step()

            running_total += loss.total.item()
            running_recon += loss.recon.item()
            running_kld += loss.kld.item()

        save_reconstructions(model, (x.cpu(), None), out_dir, current_epoch, device)
        n_train = len(train_loader.dataset)
        print(
            f"Epoch {epoch:02d} | total: {running_total / n_train:.4f} | "
            f"recon: {running_recon / n_train:.4f} | kld: {running_kld / n_train:.4f}"
        )

        # model.eval()
        # test_total = test_recon = test_kld = 0.0
        # with torch.no_grad():
        #     for x, _ in test_loader:
        #         x = x.to(device)
        #         logits, mean, logvar = model(x)
        #         loss = vae_loss(logits, x, mean, logvar, beta=beta)
        #         test_total += loss.total.item()
        #         test_recon += loss.recon.item()
        #         test_kld += loss.kld.item()
        # n_test = len(test_loader.dataset)
        # print(f"  [val] total: {test_total / n_test:.4f} | recon: {test_recon / n_test:.4f} | kld: {test_kld / n_test:.4f}")
        if epoch % 10 == 0:
            save_checkpoint(model, optimizer, epoch, checkpoint_dir)