<a href="https://colab.research.google.com/github/A-b-h-a-y-0-2/Computer-Vision/blob/main/Variational_autoencoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#         Pipeline
##Input image-->Hidden dim-->mean,std -->Parameterization trick-->Decoder-->Output Img


#Imports


In [47]:
import torch
import torch.nn as nn
import torch.nn.functional as F

#Encoding class

In [48]:
class Variational_Autoencoder(nn.Module):
  def __init__(self,input_dim,hidden_dim=200,z_dim=20):
    super().__init__()
    #encoder
    self.img_2hid = nn.Linear(input_dim,hidden_dim)
    self.hid_2mu = nn.Linear(hidden_dim,z_dim)
    self.hid_2std = nn.Linear(hidden_dim,z_dim)
    #decoder
    self.z_2hid = nn.Linear(z_dim,hidden_dim)
    self.hid_2img = nn.Linear(hidden_dim,input_dim)


  def encoder(self,x):
    h = F.relu(self.img_2hid(x))
    mu ,sigma = self.hid_2mu(h),self.hid_2std(h)
    return mu,sigma

  def decoder(self,z):
    h = F.relu(self.z_2hid(z))
    return torch.sigmoid(self.hid_2img(h))

  def forward(self,x):
    mu ,sigma = self.encoder(x)
    epsilon = torch.rand_like(sigma)
    reparametarized_z = mu + sigma*epsilon
    reconstructed_x = self.decoder(reparametarized_z)
    return reconstructed_x,mu,sigma

In [49]:
x = torch.randn(4,28*28)
VAE = Variational_Autoencoder(input_dim=784)
reconstructed_x,mu,sigma = VAE(x)
print(reconstructed_x.shape)
print(mu.shape)
print(sigma.shape)



torch.Size([4, 784])
torch.Size([4, 20])
torch.Size([4, 20])


In [50]:
import torch
import torchvision.datasets as datasets
from tqdm import tqdm
from torch import nn, optim
from torchvision import transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader


In [55]:
Device =torch.device('cuda' if torch.cuda.is_available else 'cpu')
batch_size = 32
input_dim = 784
hid_dim = 200
z_dim =20
num_epochs = 10
lr = 3e-4

In [56]:
train_dataset = datasets.MNIST('/content/datasets',train=True,transform = transforms.ToTensor(),download = True)
train_loader = DataLoader(train_dataset,batch_size = batch_size,shuffle=True)

In [57]:
model = Variational_Autoencoder(input_dim,hid_dim,z_dim)
optimizer = optim.Adam(model.parameters(),lr = lr)
loss_fn = nn.BCELoss(reduction = 'sum')

In [58]:
for epoch in range(num_epochs):
  for i,(x,_) in enumerate(tqdm(train_loader)):
    x = x.view(x.shape[0],-1)
    x_reconstructed , mu , sigma = model(x)
    x_reconstructed = torch.sigmoid(x_reconstructed)

    reconstruction_loss = loss_fn(x_reconstructed, x)


    kl_div = -0.5 * torch.sum(1 + torch.log(sigma.pow(2)) - mu.pow(2) - sigma.pow(2))


    loss = reconstruction_loss + kl_div
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


100%|██████████| 1875/1875 [00:24<00:00, 75.19it/s]
100%|██████████| 1875/1875 [00:25<00:00, 73.89it/s]
100%|██████████| 1875/1875 [00:25<00:00, 74.54it/s]
100%|██████████| 1875/1875 [00:24<00:00, 75.39it/s]
100%|██████████| 1875/1875 [00:25<00:00, 74.38it/s]
100%|██████████| 1875/1875 [00:25<00:00, 73.69it/s]
100%|██████████| 1875/1875 [00:25<00:00, 72.44it/s]
100%|██████████| 1875/1875 [00:25<00:00, 72.20it/s]
100%|██████████| 1875/1875 [00:27<00:00, 68.51it/s]
100%|██████████| 1875/1875 [00:26<00:00, 71.23it/s]


In [59]:
def inference(digit,num_examples=1):
  images = []
  idx = 0
  for x, y in train_dataset:
    if y==idx:
      images.append(x)
      idx +=1
    if idx == 10:
      break

  encoding_digit = []
  for d in range(10):
    with torch.no_grad():
      mu, sigma = model.encoder(images[d].view(1,784))
    encoding_digit.append((mu,sigma))
    for example in range(num_examples):
      epsilon = torch.rand_like(sigma)
      z = mu + sigma*epsilon
      out = model.decoder(z)
      out = out.view(-1,1,28,28)
      save_image(out,f'generated_{digit}_ex{example}.png')
for idx in range(10):
  inference(idx, num_examples=1)