In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np

from src.autoencoder import Autoencoder, VariationalAutoEncoder
from utils.mnist_loader import data_download, data_loader
from utils.model_trainer import autoencoder_trainer, vae_trainer
from utils.visualization import visualization

In [None]:
USE_CUDA = torch.cuda.is_available()
DEVICE = torch.device("cuda:0" if USE_CUDA else "cpu")
EPOCHS = 50
print(DEVICE)
train_data, test_data = data_download()
train_loader, test_loader = data_loader(train_data, test_data, batch_size=256)

# Autoencoder

In [None]:
ae=Autoencoder(n_hidden=256, z_dim=64).to(DEVICE)
criteria = nn.MSELoss()
optimizer = torch.optim.Adam(ae.parameters(), lr=0.001)

In [None]:
train_loss, test_loss = autoencoder_trainer(model=ae, 
                                            criteria=criteria, optimizer=optimizer, 
                                            train_loader=train_loader, test_loader=test_loader, 
                                            device=DEVICE, epochs = EPOCHS)

In [None]:
visualization(loader=test_loader, model=ae, device=DEVICE, num_of_samples=5)

In [None]:
plt.plot(np.array(test_loss))
plt.show()

# Variational Autoencoder

In [None]:
def vae_loss(reconstruction, x, mu, log_var):
    reconstruction_loss = nn.functional.binary_cross_entropy(reconstruction, x, reduction='sum') # bernoulli distribution assumption
    kl_loss = - 0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return reconstruction_loss, kl_loss

In [None]:
vae = VariationalAutoEncoder(n_hidden=336, z_dim=128).to(DEVICE)
optimizer = torch.optim.Adam(vae.parameters(), lr=0.001)
criteria = vae_loss

In [None]:
train_loss, test_loss = vae_trainer(model=vae, 
                                    criteria=criteria, optimizer=optimizer, 
                                    train_loader=train_loader, test_loader=test_loader, 
                                    device=DEVICE, epochs = EPOCHS)

In [None]:
visualization(loader=test_loader, model=vae, device=DEVICE, num_of_samples=5)

In [None]:
vae.cpu()
generated_samples = vae.generate(5)

for sample in generated_samples:
    plt.matshow(sample.reshape(28,28))
    plt.show()

In [None]:
plt.plot(np.array(test_loss))
plt.show()