In [1]:
import torch
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import os
from tqdm import tqdm

from experiments.Checkpointing import save_checkpoint
from models.VAE import VAE
from modules.SaveOutputs import save_reconstructions, save_samples
from modules.Losses import VGGLoss, vae_loss

dataset_dir: str = "./data/raw"
out_dir: str = "./outputs/"
batch_size: int = 64
latent_dim: int = 128
checkpoint_dir = "./experiments/checkpoints"


torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.Resize((64, 64)),  # resize to 224x224
    transforms.ToTensor()  # convert to tensor & scale to [0,1]
])

dataset = datasets.ImageFolder(root="G:\Temp", transform=transform)
train_size = int(0.5 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size], generator=torch.Generator().manual_seed(42))

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

In [2]:
train_mode = True
current_epoch = 1
epochs: int = 10
lr: float = 3e-4

print("Loaded datasets, number of samples: ", len(dataset))

# Model & Optimizer
model = VAE(latent_dim=latent_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
vgg_loss_fn = VGGLoss(device=device)

global_step = 0
os.makedirs(out_dir, exist_ok=True)

Loaded datasets, number of samples:  202599




In [3]:
epochs = 50
n_train = len(train_loader.dataset)
n_test = len(test_loader.dataset)
beta = 0.9
gamma = 0.0001

for epoch in range(current_epoch, current_epoch + epochs):
    model.train()
    running_total = running_recon = running_kld = running_perceptual = 0.0
    beta = min(5, beta * 1.1)
    gamma = min(5, beta * 1.1)

    for batch_idx, (x, _) in enumerate(tqdm(train_loader), start=1):
        x = x.to(device)
        optimizer.zero_grad(set_to_none=True)
        logits, mean, logvar = model(x)
        loss = vae_loss(logits, x, mean, logvar, beta, gamma, vgg_loss_fn)
        loss.total.backward()
        optimizer.step()
        running_total += loss.total.item()
        running_recon += loss.recon.item()
        running_kld += loss.kld.item()
        running_perceptual += loss.perceptual.item()
    print(
            f"Epoch {epoch:02d} | total: {running_total / n_train:.4f} | "
            f"recon: {running_recon / n_train:.4f} | kld: {running_kld / n_train:.4f} | "
            f"perceptual: {running_perceptual:.4f}"
        )

    current_epoch += 1
    save_reconstructions(model, x, out_dir, current_epoch, device)

    model.eval()
    test_total = test_recon = test_kld = test_perceptual = 0.0
    with torch.no_grad():
        for x, _ in test_loader:
            x = x.to(device)
            logits, mean, logvar = model(x)
            loss = vae_loss(logits, x, mean, logvar, beta, gamma, vgg_loss_fn)
            test_total += loss.total.item()
            test_recon += loss.recon.item()
            test_kld += loss.kld.item()
            test_perceptual += loss.perceptual.item()
        print(
            f"  [val] total: {test_total / n_test:.4f} | recon: {test_recon / n_test:.4f} | kld: {test_kld / n_test:.4f}"
        )

    if epoch % 10 == 0:
        save_checkpoint(model, optimizer, epoch, checkpoint_dir)

  0%|          | 6/1583 [00:47<3:29:58,  7.99s/it] 


KeyboardInterrupt: 