In [1]:
from src.utils.visualization import *
from src.pipeline.trainer import *
from src.data.cifar10 import *

from src.modules.gan_module import GANModule
from src.modules.vae_module import VAEModule
from src.modules.ae_module import AEModule
from src.modules.rnvp_module import RNVPModule

In [2]:
cifar10_data_module = CIFAR10DataModule(data_dir='./data', batch_size=64, val_split=0.2)
cifar10_data_module.setup()

print(f"Training set size: {len(cifar10_data_module.train_dataloader().dataset)}")
print(f"Validation set size: {len(cifar10_data_module.val_dataloader().dataset)}")
print(f"Test set size: {len(cifar10_data_module.test_dataloader().dataset)}")

Files already downloaded and verified
Files already downloaded and verified
Training set size: 40000
Validation set size: 10000
Test set size: 10000


In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


# GAN

In [7]:
gan_model = GANModule(latent_dim=100)

gan_trainer = Trainer(
    model=gan_model,
    data_module=cifar10_data_module,
    log_dir='./tensorboard/gan_log_basic',
    checkpoint_dir='./tensorboard/gan_checkpoints_basic',
    device=device)

gan_trainer.train(num_epochs=100)
gan_trainer.save_checkpoint(prefix='last_')

Epoch 0:
	Train Metrics = {'g_loss': 139.6159892578125, 'd_loss': 69.41216985778809, 'loss': 209.02815911560057, 'ssim': 0.8050960312843323}
	Val Metrics = {'loss': 38.208515033600435, 'ssim': 0.7971386863927173}
	Best model saved at epoch 0 with loss=38.208515
Epoch 1:
	Train Metrics = {'g_loss': 97.57289946289063, 'd_loss': 75.98701738891602, 'loss': 173.55991685180663, 'ssim': 0.7870937895774841}
	Val Metrics = {'loss': 39.13666084921284, 'ssim': 0.7875900412820707}
Epoch 2:
	Train Metrics = {'g_loss': 78.22355698242187, 'd_loss': 79.82592861938477, 'loss': 158.04948560180665, 'ssim': 0.7810884092330933}
	Val Metrics = {'loss': 34.788787234361, 'ssim': 0.8004127429549102}
	Best model saved at epoch 2 with loss=34.788787
Epoch 3:
	Train Metrics = {'g_loss': 76.75967722167968, 'd_loss': 79.8086055419922, 'loss': 156.56828276367187, 'ssim': 0.7789904623985291}
	Val Metrics = {'loss': 45.89941875190492, 'ssim': 0.7725792937218003}
Epoch 4:
	Train Metrics = {'g_loss': 103.68159293212891,

In [8]:
gan_model_2 = GANModule(latent_dim=10)

gan_trainer_2 = Trainer(
    model=gan_model_2,
    data_module=cifar10_data_module,
    log_dir='./tensorboard/gan_log_small_latent',
    checkpoint_dir='./tensorboard/gan_checkpoints_small_latent',
    device=device)

gan_trainer_2.train(num_epochs=100)
gan_trainer_2.save_checkpoint(prefix='last_')

Epoch 0:
	Train Metrics = {'g_loss': 147.97749813232423, 'd_loss': 66.97502859420776, 'loss': 214.95252672653197, 'ssim': 0.8022816811561585}
	Val Metrics = {'loss': 41.699823039352516, 'ssim': 0.7979382276535034}
	Best model saved at epoch 0 with loss=41.699823
Epoch 1:
	Train Metrics = {'g_loss': 102.75185173950196, 'd_loss': 72.57394107666016, 'loss': 175.3257928161621, 'ssim': 0.7847237666130066}
	Val Metrics = {'loss': 36.337819737233936, 'ssim': 0.7878051996231079}
	Best model saved at epoch 1 with loss=36.337820
Epoch 2:
	Train Metrics = {'g_loss': 110.41684400634766, 'd_loss': 71.16767210083007, 'loss': 181.58451610717773, 'ssim': 0.7838585991859436}
	Val Metrics = {'loss': 36.41892650628545, 'ssim': 0.7816830888675277}
Epoch 3:
	Train Metrics = {'g_loss': 110.34383098754883, 'd_loss': 70.88035786132812, 'loss': 181.22418884887696, 'ssim': 0.781248490524292}
	Val Metrics = {'loss': 38.20579816733196, 'ssim': 0.7910138121835745}
Epoch 4:
	Train Metrics = {'g_loss': 123.090480551

In [9]:
gan_model_3 = GANModule(latent_dim=2000)

gan_trainer_3 = Trainer(
    model=gan_model_3,
    data_module=cifar10_data_module,
    log_dir='./tensorboard/gan_log_huge_latent',
    checkpoint_dir='./tensorboard/gan_checkpoints_huge_latent',
    device=device)

gan_trainer_3.train(num_epochs=100)
gan_trainer_3.save_checkpoint(prefix='last_')

Epoch 0:
	Train Metrics = {'g_loss': 155.78549383544922, 'd_loss': 63.57507931709289, 'loss': 219.36057315254212, 'ssim': 0.8083353238105774}
	Val Metrics = {'loss': 46.679218923969636, 'ssim': 0.8029564281178129}
	Best model saved at epoch 0 with loss=46.679219
Epoch 1:
	Train Metrics = {'g_loss': 118.98868752441406, 'd_loss': 67.09509100952148, 'loss': 186.08377853393554, 'ssim': 0.7911824394226075}
	Val Metrics = {'loss': 43.433646232459196, 'ssim': 0.7902024673048857}
	Best model saved at epoch 1 with loss=43.433646
Epoch 2:
	Train Metrics = {'g_loss': 114.7934759338379, 'd_loss': 67.24471710815429, 'loss': 182.0381930419922, 'ssim': 0.7837855472564698}
	Val Metrics = {'loss': 36.50995565523767, 'ssim': 0.7792053272010414}
	Best model saved at epoch 2 with loss=36.509956
Epoch 3:
	Train Metrics = {'g_loss': 116.56244364013672, 'd_loss': 67.35256701965332, 'loss': 183.91501065979003, 'ssim': 0.7793386483192444}
	Val Metrics = {'loss': 39.090981374121014, 'ssim': 0.7779532503929867}


# VAE

In [6]:
vae_model = VAEModule(latent_dim=128)

vae_trainer = Trainer(
    model=vae_model,
    data_module=cifar10_data_module,
    log_dir='./tensorboard/vae_log_basic',
    checkpoint_dir='./tensorboard/vae_checkpoints_basic',
    device=device)

vae_trainer.train(num_epochs=100)
vae_trainer.save_checkpoint(prefix='last_')

Epoch 0:
	Train Metrics = {'loss': 118703.7883625, 'bce': 118669.31875, 'kld': 344.6952050079346, 'ssim': 0.4732855936050415}
	Val Metrics = {'loss': 113673.0668789809, 'bce': 113621.87191480891, 'kld': 511.949738715105, 'ssim': 0.3763685621273745}
	Best model saved at epoch 0 with loss=113673.066879
Epoch 1:
	Train Metrics = {'loss': 114017.9762125, 'bce': 113966.237875, 'kld': 517.3831455078125, 'ssim': 0.3482531197547913}
	Val Metrics = {'loss': 112350.26787669188, 'bce': 112297.26375895701, 'kld': 530.039898137378, 'ssim': 0.3199806042537568}
	Best model saved at epoch 1 with loss=112350.267877
Epoch 2:
	Train Metrics = {'loss': 113134.681725, 'bce': 113083.49915, 'kld': 511.8260671875, 'ssim': 0.3071547957420349}
	Val Metrics = {'loss': 111780.80297074045, 'bce': 111728.98629080414, 'kld': 518.1662447984052, 'ssim': 0.28815733892902445}
	Best model saved at epoch 2 with loss=111780.802971
Epoch 3:
	Train Metrics = {'loss': 112642.3921125, 'bce': 112592.643525, 'kld': 497.485811914

In [6]:
vae_model_2 = VAEModule(latent_dim=10)

vae_trainer_2 = Trainer(
    model=vae_model_2,
    data_module=cifar10_data_module,
    log_dir='./tensorboard/vae_log_small_latent',
    checkpoint_dir='./tensorboard/vae_checkpoints_small_latent',
    device=device)

vae_trainer_2.train(num_epochs=100)
vae_trainer_2.save_checkpoint(prefix='last_')

Epoch 0:
	Train Metrics = {'loss': 120337.5838, 'bce': 120332.778325, 'kld': 48.05561430244446, 'ssim': 0.5378468700408936}
	Val Metrics = {'loss': 117343.99017217357, 'bce': 117337.63146894904, 'kld': 63.58372660351407, 'ssim': 0.5019091181694322}
	Best model saved at epoch 0 with loss=117343.990172
Epoch 1:
	Train Metrics = {'loss': 118217.2187, 'bce': 118210.76605, 'kld': 64.52699364624023, 'ssim': 0.5029779838562012}
	Val Metrics = {'loss': 117098.96327627389, 'bce': 117092.06827229299, 'kld': 68.94767885147387, 'ssim': 0.49791210510168865}
	Best model saved at epoch 1 with loss=117098.963276
Epoch 2:
	Train Metrics = {'loss': 118097.6263, 'bce': 118090.9581125, 'kld': 66.6818806640625, 'ssim': 0.4997215682029724}
	Val Metrics = {'loss': 116996.44449144109, 'bce': 116989.7347482086, 'kld': 67.09622753653557, 'ssim': 0.49648571925558105}
	Best model saved at epoch 2 with loss=116996.444491
Epoch 3:
	Train Metrics = {'loss': 117984.1016, 'bce': 117977.4138625, 'kld': 66.8792368041992

In [7]:
vae_model_3 = VAEModule(latent_dim=512)

vae_trainer_3 = Trainer(
    model=vae_model_3,
    data_module=cifar10_data_module,
    log_dir='./tensorboard/vae_log_huge_latent',
    checkpoint_dir='./tensorboard/vae_checkpoints_huge_latent',
    device=device)

vae_trainer_3.train(num_epochs=100)
vae_trainer_3.save_checkpoint(prefix='last_')

Epoch 0:
	Train Metrics = {'loss': 118780.613725, 'bce': 118706.9776375, 'kld': 736.3600170043945, 'ssim': 0.47853396272659304}
	Val Metrics = {'loss': 114043.71466958598, 'bce': 113926.70081110668, 'kld': 1170.139168320188, 'ssim': 0.38665815827193534}
	Best model saved at epoch 0 with loss=114043.714670
Epoch 1:
	Train Metrics = {'loss': 114245.189575, 'bce': 114134.8969625, 'kld': 1102.924287890625, 'ssim': 0.35692215671539307}
	Val Metrics = {'loss': 112624.62506220143, 'bce': 112508.59259305334, 'kld': 1160.323857981688, 'ssim': 0.3284646200526292}
	Best model saved at epoch 1 with loss=112624.625062
Epoch 2:
	Train Metrics = {'loss': 113281.37975, 'bce': 113168.7258625, 'kld': 1126.541613671875, 'ssim': 0.3144909646034241}
	Val Metrics = {'loss': 111945.26137042197, 'bce': 111831.44491441082, 'kld': 1138.165201563744, 'ssim': 0.29585559277018164}
	Best model saved at epoch 2 with loss=111945.261370
Epoch 3:
	Train Metrics = {'loss': 112652.2814875, 'bce': 112536.2331875, 'kld': 1

In [5]:
vae_model_4 = VAEModule(latent_dim=2048)

vae_trainer_4 = Trainer(
    model=vae_model_4,
    data_module=cifar10_data_module,
    log_dir='./tensorboard/vae_log_really_huge_latent',
    checkpoint_dir='./tensorboard/vae_checkpoints_really_huge_latent',
    device=device)

vae_trainer_4.train(num_epochs=100)
vae_trainer_4.save_checkpoint(prefix='last_')

Epoch 0:
	Train Metrics = {'loss': 119557.8518875, 'bce': 119448.7796375, 'kld': 1090.7241618652345, 'ssim': 0.4901458318710327}
	Val Metrics = {'loss': 114357.79968650478, 'bce': 114204.03508160828, 'kld': 1537.6429972071558, 'ssim': 0.39656165147283273}
	Best model saved at epoch 0 with loss=114357.799687
Epoch 1:
	Train Metrics = {'loss': 114650.3156, 'bce': 114484.685525, 'kld': 1656.3000234375, 'ssim': 0.37050615310668944}
	Val Metrics = {'loss': 113068.43708947054, 'bce': 112879.70839968153, 'kld': 1887.289088158091, 'ssim': 0.3426227683474304}
	Best model saved at epoch 1 with loss=113068.437089
Epoch 2:
	Train Metrics = {'loss': 113669.2433625, 'bce': 113490.9826, 'kld': 1782.6076212890625, 'ssim': 0.32906164026260376}
	Val Metrics = {'loss': 112314.8880249801, 'bce': 112120.03809215764, 'kld': 1948.4998289460589, 'ssim': 0.30971912289880643}
	Best model saved at epoch 2 with loss=112314.888025
Epoch 3:
	Train Metrics = {'loss': 113121.5462, 'bce': 112936.1541625, 'kld': 1853.9

# AE

In [4]:
ae_model = AEModule(lr=5e-4)

ae_trainer = Trainer(
    model=ae_model,
    data_module=cifar10_data_module,
    log_dir='./tensorboard/ae_logs_basic',
    checkpoint_dir='./tensorboard/ae_checkpoints_basic',
    device=device)

ae_trainer.train(num_epochs=50)
ae_trainer.save_checkpoint(prefix='last_')

Epoch 0:
	Train Metrics = {'loss': 0.024891943296790123, 'ssim': 0.4784761930465698}
	Val Metrics = {'loss': 0.013246978330574219, 'ssim': 0.36298553427313546}
	Best model saved at epoch 0 with loss=0.013247
Epoch 1:
	Train Metrics = {'loss': 0.011396158915758133, 'ssim': 0.3179770886421204}
	Val Metrics = {'loss': 0.009880427603319192, 'ssim': 0.2870136396900104}
	Best model saved at epoch 1 with loss=0.009880
Epoch 2:
	Train Metrics = {'loss': 0.009109029860049486, 'ssim': 0.2653541066169739}
	Val Metrics = {'loss': 0.008371602773524014, 'ssim': 0.24915698189644297}
	Best model saved at epoch 2 with loss=0.008372
Epoch 3:
	Train Metrics = {'loss': 0.007951960050314665, 'ssim': 0.2361314368247986}
	Val Metrics = {'loss': 0.007509948353573775, 'ssim': 0.2253497918699957}
	Best model saved at epoch 3 with loss=0.007510
Epoch 4:
	Train Metrics = {'loss': 0.007177755065262317, 'ssim': 0.21593969898223878}
	Val Metrics = {'loss': 0.006892165344113567, 'ssim': 0.20836858954399254}
	Best mod

AttributeError: 'AEModule' object has no attribute 'on_epoch_end'

In [4]:
ae_model_2 = AEModule(lr=1e-2)

ae_trainer_2 = Trainer(
    model=ae_model_2,
    data_module=cifar10_data_module,
    log_dir='./tensorboard/ae_logs_big_lr',
    checkpoint_dir='./tensorboard/ae_checkpoints_big_lr',
    device=device)

ae_trainer_2.train(num_epochs=50)
ae_trainer_2.save_checkpoint(prefix='last_')

Epoch 0:
	Train Metrics = {'loss': 0.019632446879148484, 'ssim': 0.4024359302520752}
	Val Metrics = {'loss': 0.010427090514356353, 'ssim': 0.278573648185487}
	Best model saved at epoch 0 with loss=0.010427
Epoch 1:
	Train Metrics = {'loss': 0.008251360622048377, 'ssim': 0.23319719724655152}
	Val Metrics = {'loss': 0.006743645948019757, 'ssim': 0.19921577584211994}
	Best model saved at epoch 1 with loss=0.006744
Epoch 2:
	Train Metrics = {'loss': 0.006218027845025062, 'ssim': 0.18556841764450074}
	Val Metrics = {'loss': 0.005859147448828266, 'ssim': 0.17603389510683193}
	Best model saved at epoch 2 with loss=0.005859
Epoch 3:
	Train Metrics = {'loss': 0.005809548465162516, 'ssim': 0.17265866222381593}
	Val Metrics = {'loss': 0.005387157464257566, 'ssim': 0.1638464510061179}
	Best model saved at epoch 3 with loss=0.005387
Epoch 4:
	Train Metrics = {'loss': 0.005272466291487217, 'ssim': 0.15838783226013184}
	Val Metrics = {'loss': 0.00529921293899322, 'ssim': 0.15503779224529388}
	Best mo

# RealNVP

In [4]:
rnvp_model = RNVPModule(lr=1e-3)

rnvp_trainer = Trainer(
    model=rnvp_model,
    data_module=cifar10_data_module,
    log_dir='./tensorboard/rnvp_logs_basic',
    checkpoint_dir='./tensorboard/rnvp_checkpoints_basic',
    device=device)

rnvp_trainer.train(num_epochs=10)
rnvp_trainer.save_checkpoint(prefix='last_')

Epoch 0:
	Train Metrics = {'loss': -4394.908418505859, 'ssim': 0.9570071745872497}


AttributeError: 'RealNVP' object has no attribute 'decode'