In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F

import torchvision
import torchvision.datasets as datasets
from torchvision.transforms import ToTensor
from torchvision.utils import save_image

In [52]:
# set the device
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else "cpu")
device

device(type='mps')

In [9]:
## Load Data

train_data = datasets.MNIST(root='datasets/',train= True,download=True, transform= ToTensor())


test_data = datasets.MNIST(root='datasets/',train= False,download=True, transform= ToTensor())

        

In [34]:
# Hyperparameters
batch_size = 128
epochs = 2

image_sz = 28*28
hidden_sz = 400
latent_sz = 20



In [35]:
train_loader =DataLoader(train_data,batch_size,True,num_workers=4)
test_loader =DataLoader(test_data,batch_size,True,num_workers=4)

In [36]:
# Create directory to save the reconstructed and sampled images
import os
out_directory = 'result'
if not os.path.exists(out_directory):
    os.makedirs(out_directory)

In [45]:
# VAE Model
class VAE(nn.Module):
    def __init__(self):

        super(VAE,self).__init__()

        self.fc1 = nn.Linear(image_size,hidden_sz)
        self.fc2_mean = nn.Linear(hidden_sz,latent_sz)
        self.fc2_logvar = nn.Linear(hidden_sz,latent_sz)

        self.fc3 = nn.Linear(latent_sz,hidden_sz)
        self.fc4 = nn.Linear(hidden_sz,image_size)

    def encode(self,x):
        h = F.relu(self.fc1(x))
        mu = self.fc2_mean(h)
        log_var = self.fc2_logvar(h)

        return mu, log_var

    def parameterize(self,mu,log_var):
        std = torch.exp(log_var/2)
        eps = torch.randn_like(std)
        return mu + std * eps

    def decode(self,z):
        h = F.relu(self.fc3(z))
        return  F.sigmoid(self.fc4(h))


    def forward(self,x):
        x = x.view(-1,image_size)
        mu,log_var = self.encode(x)
        z = self.parameterize(mu,log_var)
        reconstructed = self.decode(z)

        return reconstructed,mu,log_var
        
        
model = VAE().to(device)


In [51]:
## Check output size
x = torch.rand(128,1,28,28,device=device)
model(x)[0].shape

torch.Size([128, 784])

In [None]:
# Define loss 
