In [2]:
import numpy as np #importing packages
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt

# mnist dataset
from torchvision import datasets
from torchvision.transforms import ToTensor

from dataset_tools import MNIST
from plotting import loss_plots

from collections import defaultdict


cuda = torch.cuda.is_available()
print(cuda)

False


# Data

In [4]:
# notice that we dont flatten the images now

data_train = MNIST(train=True, transform=ToTensor())
data_test = MNIST(train=False, transform=ToTensor())

train_loader = DataLoader(data_train, batch_size=64, shuffle=True, pin_memory=cuda)
test_loader = DataLoader(data_test, batch_size=64, shuffle=False, pin_memory=cuda)


In [5]:
# test to see if the shape is correct

x, y = next(iter(train_loader))
print(x.shape, y.shape)

torch.Size([64, 1, 28, 28]) torch.Size([64])


# VAE

In [10]:

# input img -> hidden -> mu, logvar -> reparameterization trick (sample point from distribution made from mu, logvar) -> decoder -> output img
from models import VAE_CNN



In [11]:
model = VAE_CNN(input_dim=28*28, hidden_dim=256, latent_dim=2)
print(model)
if cuda:
    model = model.cuda()

optimizer = optim.Adam(model.parameters(), lr=1e-3)

def loss_function(x, x_hat, mu, sigma):
    # reconstruction loss mse
    MSE = F.mse_loss(x_hat, x, reduction='sum')
    # kl divergence
    KLD = -0.5 * torch.sum(1 + torch.log(sigma.pow(2)) - mu.pow(2) - sigma.pow(2))
    return MSE + KLD