## Train VAE and AE on MNIST

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision.datasets import MNIST,CelebA
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torchvision import transforms
import os
import timeit
import numpy as np
from AE import *
import torch.nn.functional as F
from tqdm.notebook import tqdm
import operator

In [None]:
path1 = %pwd
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
transform_MNIST = transforms.Compose([
    transforms.ToTensor(),
    transforms.ToPILImage(),
    transforms.Pad(2),
    transforms.ToTensor(),
])
transform_CelebA = transforms.Compose([
    transforms.CenterCrop(140),
    transforms.Scale((64,64)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
dataset = MNIST(path1,train=True, transform=transform_MNIST, download=False)


In [None]:
def get_data_loader(dataset, batch_size, cuda=False):
    return DataLoader(
        dataset, batch_size=batch_size, shuffle=True,
        **({'num_workers': 10, 'pin_memory': True} if cuda else {})
    )

In [None]:
def train_AE(model, dataset, epochs=200, 
         batch_size=128, img_size=32,
         lr=1e-3, weight_decay=1e-3,
         device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
    
    
    model = model.to(device)    
    criterion = torch.nn.BCELoss(size_average=False)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    
    
    resume=False  # set true if you want to continue training
    if resume:
        model=AE_MNIST(image_size=32, channel_num=1,kernel_num=128, z_size=10)
        model.load_state_dict(torch.load('./_____.pth',map_location=torch.device('cuda'))) # replace modelname of to load correct model
        model.to(torch.device("cuda:0"))
        epoch_start=xxx # set epoch-counter to start on
        print('model loaded')
        resume=False
    else:
        epoch_start = 0
        
    for epoch in range(epoch_start,epochs+1):
        model.train()
        data_loader=get_data_loader(dataset, batch_size=batch_size,cuda=True)
        print("Epoch {}".format(epoch))
        for batch_index, (x, _) in tqdm(enumerate(data_loader, 1)):
            x = x.to(device)
            x_reconstructed = model(x)
            loss = criterion(x_reconstructed, x)/ x.size(0)            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print("epoch : {}/{}, recon loss = {:.8f}".format(epoch + 1, epochs, loss))            
        if epoch > 0 and epoch % 2 == 0: 
           
            torch.save(model.state_dict(), './ae_MNIST_{}.pth'.format(epoch))

In [None]:
def train_VAE(model, dataset, epochs=200, 
         batch_size=128, img_size=32,
         lr=1e-3, weight_decay=1e-3,
         device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
   
    
    model = model.to(device)    
    criterion = torch.nn.BCELoss(size_average=False)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

    for epoch in range(epochs):
        model.train()
        data_loader=get_data_loader(dataset, batch_size=batch_size,cuda=True)
        print("Epoch {}".format(epoch))
        for batch_index, (x, _) in tqdm(enumerate(data_loader, 1)):
            x = x.to(device)
            (mean, logvar), x_reconstructed = model(x)
            penalization_loss = model.kl_divergence_loss(mean, logvar)
            reconstruction_loss = criterion(x_reconstructed, x) / x.size(0)
            loss = reconstruction_loss + penalization_loss
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print("epoch : {}/{}, recon loss = {:.8f}".format(epoch + 1, epochs, loss))  
        if epoch > 0 and epoch % 2 == 0: 
           
            torch.save(model.state_dict(), './vae_MNIST_{}.pth'.format(epoch))
        

#### Train VAE

In [None]:
model = VAE_MNIST(image_size=32, channel_num=1,kernel_num=128, z_size=10)

In [None]:
epochs = 200
start = timeit.default_timer()
train_VAE(model, dataset, epochs)
end = timeit.default_timer()
print("Runtime: ",end-start)

#### Train AE

In [None]:
model = AE_MNIST(image_size=32, channel_num=1,kernel_num=128, z_size=10)

In [None]:
epochs = 200
start = timeit.default_timer()
train_AE(model, dataset, epochs)
end = timeit.default_timer()
print("Runtime: ",end-start)