<a href="https://colab.research.google.com/github/Rohit1217/VAE/blob/main/vae.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
import numpy as np
import matplotlib.pyplot as plt

import torch
from torchvision import datasets,transforms
from torch.utils.data import TensorDataset,DataLoader
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [92]:
def get_device():
  return torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [9]:
# Transformations
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((128,),(128,))])

# Load MNIST dataset
trainset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
testset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

# Data loaders
trainloader = DataLoader(dataset=trainset, batch_size=128, shuffle=True)
testloader = DataLoader(dataset=testset, batch_size=128,shuffle=False)

In [91]:
class Encoder(nn.Module):
  def __init__(self,in_channels,out_channels,kernel,latent_dim):
    super(Encoder,self).__init__()
    self.in_channels=in_channels
    self.out_channels=out_channels
    self.latent_dim=latent_dim
    self.conv1=nn.Conv2d(in_channels,out_channels,kernel+2)
    self.conv2=nn.Conv2d(out_channels,out_channels,kernel)
    self.conv3=nn.Conv2d(out_channels,out_channels,kernel)
    self.conv4=nn.Conv2d(out_channels,out_channels,kernel)
    self.conv5=nn.Conv2d(out_channels,2,kernel)
    self.conv6=nn.Conv2d(out_channels,2,kernel)
  def forward(self,x):
    x=self.conv1(x)
    x=self.conv2(x)
    x=self.conv3(x)
    x=self.conv4(x)
    x=self.conv5(x)
    mean,log_variance=x[:,0],x[:,1]
    epsilon=torch.randn(self.latent_dim,self.latent_dim)
    z=mean+epsilon*torch.exp(2*log_variance)
    return mean,log_variance,z

'''enc=Encoder(1,64,5,6)
x=enc(torch.ones(6,1,28,28))
x[0].shape,x[2].shape'''

(torch.Size([6, 6, 6]), torch.Size([6, 6, 6]))

In [None]:
class Decoder(nn.Module):
  def __init__(self,in_channels,out_channels,kernel,input_dim):
    super(Decoder,self).__init__()
    self.in_channels=in_channels
    self.out_channels=out_channels
    self.input_dim=input_dim
    self.conv1=nn.ConvTranspose2d(in_channels,out_channels,kernel)
    self.conv2=nn.ConvTranspose2d(out_channels,out_channels,kernel)
    self.conv3=nn.ConvTranspose2d(out_channels,out_channels,kernel)
    self.conv4=nn.ConvTranspose2d(out_channels,out_channels,kernel)
    self.conv5=nn.ConvTranspose2d(out_channels,2,kernel+2)
    self.conv6=nn.ConvTranspose2d(out_channels,out_channels,kernel)
    self.register_buffer('epsilon',torch.randn(self.input_dim,self.input_dim))

  def forward(self,x):
    x=F.relu(self.conv1(x))
    x=F.relu(self.conv2(x))
    x=F.relu(self.conv3(x))
    x=F.relu(self.conv4(x))
    x=F.relu(self.conv5(x))
    mean,variance=x[:,0],x[:,1]
    return mean+self.epsilon*variance


'''enc=Decoder(1,64,5,28)
x=enc(torch.ones(2,1,6,6))
x.shape,x'''

In [None]:
class VAE(nn.Module):
  def __init__(self,input_dim,latent_dim,in_channels,out_channels,kernel):
    super(VAE,self).__init__()
    self.input_dim=input_dim
    self.latent_dim=latent_dim
    self.encoder=Encoder(in_channels,out_channels,kernel,latent_dim)
    self.decoder=Decoder(in_channels,out_channels,kernel,input_dim)
    self.e=1e-8

  def forward(self,x):
    z,mean,log_variance=self.encoder(x)
    b,w,h=z.shape
    z=z.view(b,1,w,h)
    x=x.view(b,self.input_dim,self.input_dim)
    x_re=self.decoder(z)
    return x_re,x,mean,log_variance


vae=VAE(28,6,1,64,5)
x=vae(torch.randn(6,1,28,28))
x[0].shape,x[1].shape,x[2].shape,x[3].shape,x

In [86]:
def Loss(x_re,x_or,mean,log_variance):
  re_loss=F.mse_loss(x_re,x_or,reduction='mean')
  d=mean.shape[1]
  kl=0.5*torch.sum(1+log_variance-torch.exp(log_variance)-torch.pow(mean,2))
  return re_loss+kl

In [87]:
'''Loss(x[0],x[1],x[2],x[3])'''

tensor(-82.6075, grad_fn=<AddBackward0>)

In [80]:
model=VAE(28,6,1,64,5)
model=model.to(get_device())
epochs=50
optimizer=optim.Adam(model.parameters(),lr=0.001)

tensor(43.0873)

In [None]:
for epoch in epochs():
  total_loss=0
  count=0
  for x,_ in trainloader:
    x.to(get_device())
    x_re,x,mean,log_variance=model(x)
    loss=Loss(x_re,x,mean,log_variance)

    total_loss+=loss.item()
    count+=1

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

  print(f'total_loss{total_loss},mse_loss{total_loss/count},epoch{epoch}')