In [1]:
import sys
from pathlib import Path

project_root = Path().absolute().parent.parent
results_dir = project_root / 'results'

sys.path.insert(0, str(project_root))

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from tqdm import tqdm

from utils.data_loaders import load_mnist
from utils.visualization import plot_training_history, visualize_predictions, visualize_batch

from vanilla_vae import VanillaVAE, VanillaVAEConfig

In [3]:
class TrainingCFG:
    epochs = 100
    batch_size = 64
    val_split = 0.1

    device = VanillaVAEConfig.device
    print(device)

cuda


In [4]:
vanilla_vae = VanillaVAE(VanillaVAEConfig())

In [5]:
train, val, test = load_mnist(TrainingCFG.batch_size, TrainingCFG.val_split)

In [None]:
for epoch in range(TrainingCFG.epochs):
    pbar = tqdm(train)
    for data, _ in pbar:
        data = data.to(TrainingCFG.device)
        
        loss = vanilla_vae.train_step(data)
        pbar.set_description(f"Epoch {epoch+1}/{TrainingCFG.epochs}: loss {loss:.4f}")

    if (epoch+1) % 10 == 0:
        generated_image = vanilla_vae.inference()

        show = True
        visualize_batch(generated_image, nrow=4, save_path=results_dir / f"vae_{epoch+1}.png", show=show)

Epoch 1/100: loss 8778.8848: 100%|██████████| 844/844 [00:11<00:00, 74.64it/s] 
Epoch 2/100: loss 8028.6484: 100%|██████████| 844/844 [00:12<00:00, 66.73it/s] 
Epoch 3/100: loss 6989.2183: 100%|██████████| 844/844 [00:12<00:00, 66.09it/s] 
Epoch 4/100: loss 7234.0034: 100%|██████████| 844/844 [00:10<00:00, 77.54it/s] 
Epoch 5/100: loss 7446.3926: 100%|██████████| 844/844 [00:10<00:00, 77.70it/s] 
Epoch 6/100: loss 6477.6318: 100%|██████████| 844/844 [00:10<00:00, 76.85it/s] 
Epoch 7/100: loss 6826.9253: 100%|██████████| 844/844 [00:10<00:00, 77.18it/s] 
Epoch 8/100: loss 7011.7676: 100%|██████████| 844/844 [00:11<00:00, 76.15it/s] 
Epoch 9/100: loss 6534.4214: 100%|██████████| 844/844 [00:11<00:00, 75.95it/s]
Epoch 10/100: loss 6571.4834: 100%|██████████| 844/844 [00:13<00:00, 61.01it/s]
Epoch 11/100: loss 6522.9238: 100%|██████████| 844/844 [00:10<00:00, 77.02it/s]
Epoch 12/100: loss 6876.7510: 100%|██████████| 844/844 [00:10<00:00, 76.85it/s]
Epoch 13/100: loss 6488.6382: 100%|██████