In [1]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np

from src.autoencoders import Autoencoder, VariationalAutoEncoder
from utils.mnist_loader import data_download, data_loader
from utils.model_trainer import autoencoder_trainer, vae_trainer
from utils.visualization import visualization

In [2]:
USE_CUDA = torch.cuda.is_available()
DEVICE = torch.device("cuda:0" if USE_CUDA else "cpu")
EPOCHS = 100
SAMPLES = 5
print(DEVICE)
train_data, test_data = data_download()
train_loader, test_loader = data_loader(train_data, test_data, batch_size=256)

cuda:0
number of training data :  60000
number of test data :  10000


# Autoencoder

In [3]:
ae=Autoencoder(n_hidden=336, z_dim=128).to(DEVICE)
criteria = nn.MSELoss()
optimizer = torch.optim.Adam(ae.parameters(), lr=0.001)

In [4]:
train_loss, test_loss = autoencoder_trainer(model=ae, 
                                            criteria=criteria, optimizer=optimizer, 
                                            train_loader=train_loader, test_loader=test_loader, 
                                            device=DEVICE, epochs = EPOCHS)

  1%|          | 1/100 [01:11<1:58:24, 71.76s/it]

epochs: 1 - Train loss: 0.05692167207598686 - Test loss: 0.06921812146902084


  2%|▏         | 2/100 [02:02<1:36:53, 59.32s/it]

epochs: 2 - Train loss: 0.04686934873461723 - Test loss: 0.05052662640810013


  3%|▎         | 3/100 [02:58<1:33:48, 58.03s/it]

epochs: 3 - Train loss: 0.03166787698864937 - Test loss: 0.026341402903199196


  4%|▍         | 4/100 [03:46<1:26:24, 54.01s/it]

epochs: 4 - Train loss: 0.026871783658862114 - Test loss: 0.026222048327326775


  5%|▌         | 5/100 [04:33<1:21:31, 51.49s/it]

epochs: 5 - Train loss: 0.022289780899882317 - Test loss: 0.023341067135334015


  6%|▌         | 6/100 [05:21<1:18:47, 50.29s/it]

epochs: 6 - Train loss: 0.020244594663381577 - Test loss: 0.01885257102549076


  7%|▋         | 7/100 [06:05<1:14:46, 48.24s/it]

epochs: 7 - Train loss: 0.016019558534026146 - Test loss: 0.013821464963257313


  8%|▊         | 8/100 [06:42<1:08:34, 44.72s/it]

epochs: 8 - Train loss: 0.014357063919305801 - Test loss: 0.016972238197922707


  9%|▉         | 9/100 [07:24<1:06:25, 43.79s/it]

epochs: 9 - Train loss: 0.01447738241404295 - Test loss: 0.010582073591649532


 10%|█         | 10/100 [08:03<1:03:28, 42.32s/it]

epochs: 10 - Train loss: 0.012674110941588879 - Test loss: 0.012179370038211346


 11%|█         | 11/100 [08:56<1:07:27, 45.48s/it]

epochs: 11 - Train loss: 0.011767672374844551 - Test loss: 0.009716476313769817


 12%|█▏        | 12/100 [09:39<1:05:33, 44.70s/it]

epochs: 12 - Train loss: 0.010552050545811653 - Test loss: 0.012292672879993916


 13%|█▎        | 13/100 [10:31<1:08:00, 46.90s/it]

epochs: 13 - Train loss: 0.010759622789919376 - Test loss: 0.00838910136371851


 14%|█▍        | 14/100 [11:23<1:09:21, 48.38s/it]

epochs: 14 - Train loss: 0.010882385075092316 - Test loss: 0.010398639366030693


 15%|█▌        | 15/100 [12:12<1:09:12, 48.85s/it]

epochs: 15 - Train loss: 0.00919578317552805 - Test loss: 0.009261674247682095


 16%|█▌        | 16/100 [12:59<1:07:22, 48.12s/it]

epochs: 16 - Train loss: 0.009178899228572845 - Test loss: 0.008618634194135666


 17%|█▋        | 17/100 [13:47<1:06:42, 48.22s/it]

epochs: 17 - Train loss: 0.009064820595085621 - Test loss: 0.00868864543735981


 18%|█▊        | 18/100 [14:38<1:07:02, 49.06s/it]

epochs: 18 - Train loss: 0.0086137093603611 - Test loss: 0.008065783418715


 19%|█▉        | 19/100 [15:30<1:07:28, 49.98s/it]

epochs: 19 - Train loss: 0.009300037287175655 - Test loss: 0.006804963573813438


 20%|██        | 20/100 [16:15<1:04:25, 48.31s/it]

epochs: 20 - Train loss: 0.008688893169164658 - Test loss: 0.010581860318779945


 21%|██        | 21/100 [17:03<1:03:27, 48.19s/it]

epochs: 21 - Train loss: 0.00768634956330061 - Test loss: 0.0067733703181147575


 22%|██▏       | 22/100 [17:50<1:02:08, 47.80s/it]

epochs: 22 - Train loss: 0.007566662039607763 - Test loss: 0.007619213312864304


 23%|██▎       | 23/100 [18:46<1:04:43, 50.43s/it]

epochs: 23 - Train loss: 0.007941557094454765 - Test loss: 0.006349925417453051


 24%|██▍       | 24/100 [19:41<1:05:42, 51.87s/it]

epochs: 24 - Train loss: 0.007659104187041521 - Test loss: 0.00927924644201994


 25%|██▌       | 25/100 [20:32<1:04:28, 51.58s/it]

epochs: 25 - Train loss: 0.007254360243678093 - Test loss: 0.007969583384692669


In [None]:
visualization(loader=test_loader, model=ae, device=DEVICE, num_of_samples=SAMPLES)

In [None]:
plt.plot(np.array(test_loss))
plt.show()

# Variational Autoencoder

In [None]:
def vae_loss(reconstruction, x, mu, log_var):
    reconstruction_loss = nn.functional.binary_cross_entropy(reconstruction, x, reduction='sum') # bernoulli distribution assumption
    kl_loss = - 0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return reconstruction_loss, kl_loss

In [None]:
vae = VariationalAutoEncoder(n_hidden=336, z_dim=128).to(DEVICE)
optimizer = torch.optim.Adam(vae.parameters(), lr=0.001)
criteria = vae_loss

In [None]:
train_loss, test_loss = vae_trainer(model=vae, beta=1,
                                    criteria=criteria, optimizer=optimizer, 
                                    train_loader=train_loader, test_loader=test_loader, 
                                    device=DEVICE, epochs = EPOCHS)

In [None]:
visualization(loader=test_loader, model=vae, device=DEVICE, num_of_samples=SAMPLES)

In [None]:
vae.cpu()
generated_samples = vae.generate(SAMPLES)

for sample in generated_samples:
    plt.matshow(sample.reshape(28,28))
    plt.show()

In [None]:
plt.plot(np.array(test_loss))
plt.show()

# Beta-Variational Autoencoder

In [None]:
b_vae = VariationalAutoEncoder(n_hidden=336, z_dim=128).to(DEVICE)
optimizer = torch.optim.Adam(b_vae.parameters(), lr=0.001)
criteria = vae_loss

In [None]:
train_loss, test_loss = vae_trainer(model=b_vae, beta=4,
                                    criteria=criteria, optimizer=optimizer, 
                                    train_loader=train_loader, test_loader=test_loader, 
                                    device=DEVICE, epochs = EPOCHS)

In [None]:
visualization(loader=test_loader, model=b_vae, device=DEVICE, num_of_samples=SAMPLES)

In [None]:
b_vae.cpu()
generated_samples = b_vae.generate(SAMPLES)

for sample in generated_samples:
    plt.matshow(sample.reshape(28,28))
    plt.show()

In [None]:
plt.plot(np.array(test_loss))
plt.show()