In [3]:
#Cell to create the VAE functions

#Import necessary packages
import torch
import torch.nn.functional as F
from torch import nn

#intialize the VAE class

class VAE(nn.Module):
    def __init__(self, input_dim, hidden_dim_fin = 200, z_dim = 20):
        super().__init__()
        
        self.bn1 = nn.BatchNorm1d(400)
        self.bn2 = nn.BatchNorm1d(200)
        
        #encoder
        self.img_to_hid1 = nn.Linear(input_dim, 400)
        self.hid1_to_hid2 = nn.Linear(400, hidden_dim_fin)
        self.hid2_to_mu = nn.Linear(hidden_dim_fin, z_dim)
        self.hid2_to_sigma = nn.Linear(hidden_dim_fin, z_dim)
        
        
        #decoder
        self.z_to_hid2 = nn.Linear(z_dim, hidden_dim_fin)
        self.hid2_to_hid1 = nn.Linear(hidden_dim_fin, 400)
        self.hid1_to_img = nn.Linear(400, input_dim)
        
        self.relu = nn.ReLU()
        
 #Steps: Input -> Hidden dim -> Mean & Std
        # Mean & Std -> reparameterization trick applied 
        # from there that is passed through the decoder -> output
        
    def encode(self, x):
        # q_phi(z|x)
        
        #apply batch normalization after first dimension reduction
        h_1 = self.bn1(self.img_to_hid1(x))
        #Pass the output through next dimension reduction function to get to the final hidden state then apply the relu activation function
        h_fin = self.relu(self.hid1_to_hid2(h_1))
        mu, sigma = self.hid2_to_mu(h_fin), self.hid2_to_sigma(h_fin)
        
        return mu, sigma
    
    
    def decode(self, z):
        #p_theta(x|z)
        
        #apply batch normalization after the first linear transformation
        h_a = self.bn2(self.z_to_hid2(z))
        #Apply the second linear transformation before relu activation function
        h_b = self.relu(self.hid2_to_hid1(h_a))
        
        return torch.sigmoid(self.hid1_to_img(h_b)) #the values for the pixels have to be between 0 and 1 might be diff for other purposes
    
    
    
    def forward(self, x):
        mu, sigma = self.encode(x)
        #reparameterization trick applied here
        epsilon = torch.randn_like(sigma)
        z_reparametrized = mu + sigma*epsilon
        
        x_reconstructed = self.decode(z_reparametrized)
        
        return x_reconstructed, mu, sigma
        
    
    
if __name__ == "__main__":
    #To test that we wrote the function well we use a random x: Follow code
    x = torch.randn(4, 28*28) # 28*28 input_dim = 784
    vae = VAE(input_dim=784)
    x_reconstructed, mu, sigma = vae(x)
    print(x_reconstructed.shape)
    print(mu.shape)
    print(sigma.shape)

torch.Size([4, 784])
torch.Size([4, 20])
torch.Size([4, 20])


'/Users/michael/Documents/Bioinformatics/593/Homework5'

In [10]:
#import necessaru packages for data and training 
import torch
import torchvision.datasets as datasets # standard datasets
from tqdm import tqdm
from torch import nn, optim
from torchvision import transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader # this gives easier dataset management by creating mini batches etc

#configuration

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

INPUT_DIM = 784
HIDDEN_DIM = 200
Z_DIM = 20
NUM_EPOCHS = 10 
BATCH_SIZE = 32 
LR_RATE = 0.001


#Dataset Loading
dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True)

model = VAE(INPUT_DIM, HIDDEN_DIM, Z_DIM).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LR_RATE)
loss_fn = nn.BCELoss(reduction="sum")


#Initiate training

for epoch in range(NUM_EPOCHS):
    loop = tqdm(enumerate(train_loader))
    for i, (x, label) in loop:
        #fowardpass
        x = x.to(DEVICE).view(x.shape[0], INPUT_DIM)
        x_reconstructed, mu, sigma = model(x)
        
        
        #compute loss
        reconstruction_loss = loss_fn(x_reconstructed, x)
        kl_div = -torch.sum(1 + torch.log(sigma.pow(2)) - mu.pow(2) - sigma.pow(2)) # this will push the latent space towards standard gaussian
        
        #Backprop
        loss = reconstruction_loss + kl_div
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loop.set_postfix(loss=loss.item())



1875it [00:16, 113.12it/s, loss=8.08e+3]
1875it [00:18, 100.75it/s, loss=9.03e+3]
1875it [00:17, 106.61it/s, loss=7.84e+3]
1875it [00:18, 101.20it/s, loss=8.32e+3]
1875it [00:17, 106.08it/s, loss=7.38e+3]
1875it [00:17, 104.63it/s, loss=7.97e+3]
1875it [00:18, 102.32it/s, loss=6.83e+3]
1875it [00:19, 96.46it/s, loss=6.57e+3] 
1875it [00:19, 97.05it/s, loss=6.54e+3] 
1875it [00:18, 102.43it/s, loss=7.33e+3]


In [6]:
pwd

'/Users/michael/Documents/Bioinformatics/Personal_Projects/VAE/Basic_VAE_on_MNIST_Data'