<a href="https://colab.research.google.com/github/Shiva1906/DeepLearning/blob/VAE/VAE/VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import cv2
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
from torch.utils.data import Dataset,DataLoader
from torchsummary import summary
from tqdm import tqdm

Loading Dataset

In [None]:
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=torchvision.transforms.ToTensor())
test_data = datasets.MNIST(root='./data', train=False, download=True, transform=torchvision.transforms.ToTensor())

In [None]:
train_dataloader = DataLoader(train_data,batch_size=64,shuffle=True)
test_dataloader = DataLoader(test_data,batch_size=64,shuffle=False)

train_dataloader_iter = iter(train_dataloader)
test_dataloader_iter = iter(test_dataloader)

UNet Model Definition

In [None]:
class UNet(nn.Module):
  def __init__(self):
    super(UNet,self).__init__()
    self.conv_block_enc1 = nn.Sequential(
        nn.Conv2d(in_channels=1,out_channels=64,kernel_size=3,stride=1,padding=1),
        nn.ReLU(),
        nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,stride=1,padding=1),
        nn.ReLU()
    )
    self.max_pool1 = nn.MaxPool2d(kernel_size=2,stride=2)
    self.conv_block_enc2 = nn.Sequential(
        nn.Conv2d(in_channels=64,out_channels=128,kernel_size=3,stride=1,padding=1),
        nn.ReLU(),
        nn.Conv2d(in_channels=128,out_channels=128,kernel_size=3,stride=1,padding=1),
        nn.ReLU()
    )
    self.max_pool2 = nn.MaxPool2d(kernel_size=2,stride=2)
    self.conv_block_enc3 = nn.Sequential(
        nn.Conv2d(in_channels=128,out_channels=256,kernel_size=3,stride=1,padding=1),
        nn.ReLU(),
        nn.Conv2d(in_channels=256,out_channels=256,kernel_size=3,stride=1,padding=1),
        nn.ReLU()
    )
    self.max_pool3 = nn.MaxPool2d(kernel_size=2,stride=2)
    self.conv_block_enc4 = nn.Sequential(
        nn.Conv2d(in_channels=256,out_channels=512,kernel_size=3,stride=1,padding=1),
        nn.ReLU(),
        nn.Conv2d(in_channels=512,out_channels=512,kernel_size=3,stride=1,padding=1),
        nn.ReLU()
    )
    self.max_pool4 = nn.MaxPool2d(kernel_size=2,stride=2)
    self.conv_block_enc5 = nn.Sequential(
        nn.Conv2d(in_channels=512,out_channels=1024,kernel_size=3,stride=1,padding=1),
        nn.ReLU(),
        nn.Conv2d(in_channels=1024,out_channels=1024,kernel_size=3,stride=1,padding=1),
        nn.ReLU()
        )
    self.conv_transpose4 = nn.ConvTranspose2d(in_channels=1024,out_channels=512,kernel_size=2,stride=2)
    self.conv_block_dec4 = nn.Sequential(
        nn.Conv2d(in_channels=1024,out_channels=512,kernel_size=3,stride=1,padding=1),
        nn.ReLU(),
        nn.Conv2d(in_channels=512,out_channels=512,kernel_size=3,stride=1,padding=1),
        nn.ReLU()
    )
    self.conv_transpose3 = nn.ConvTranspose2d(in_channels=512,out_channels=256,kernel_size=2,stride=2)
    self.conv_block_dec3 = nn.Sequential(
        nn.Conv2d(in_channels=512,out_channels=256,kernel_size=3,stride=1,padding=1),
        nn.ReLU(),
        nn.Conv2d(in_channels=256,out_channels=256,kernel_size=3,stride=1,padding=1),
        nn.ReLU()
    )
    self.conv_transpose2 = nn.ConvTranspose2d(in_channels=256,out_channels=128,kernel_size=2,stride=2)
    self.conv_block_dec2 = nn.Sequential(
        nn.Conv2d(in_channels=256,out_channels=128,kernel_size=3,stride=1,padding=1),
        nn.ReLU(),
        nn.Conv2d(in_channels=128,out_channels=128,kernel_size=3,stride=1,padding=1),
        nn.ReLU()
    )
    self.conv_transpose1 = nn.ConvTranspose2d(in_channels=128,out_channels=64,kernel_size=2,stride=2)
    self.conv_block_dec1 = nn.Sequential(
        nn.Conv2d(in_channels=128,out_channels=64,kernel_size=3,stride=1,padding=1),
        nn.ReLU(),
        nn.Conv2d(in_channels=64,out_channels=1,kernel_size=3,stride=1,padding=1),
        nn.ReLU()
    )

  def forward(self,x):
    x1 = self.conv_block_enc1(x)

    x2 = self.max_pool1(x1)
    x2 = self.conv_block_enc2(x2)

    x3 = self.max_pool2(x2)
    x3 = self.conv_block_enc3(x3)

    x4 = self.max_pool3(x3)
    x4 = self.conv_block_enc4(x4)

    x5 = self.max_pool4(x4)
    x5 = self.conv_block_enc5(x5)

    y4 = self.conv_transpose4(x5)
    y4 = self.conv_block_dec4(torch.cat((y4,x4),1))

    y3 = self.conv_transpose3(y4)
    y3 = self.conv_block_dec3(torch.cat((y3,x3),1))

    y2 = self.conv_transpose2(y3)
    y2 = self.conv_block_dec2(torch.cat((y2,x2),1))

    y1 = self.conv_transpose1(y2)
    y1 = self.conv_block_dec1(torch.cat((y1,x1),1))

    return y1


In [None]:
class Reparameterize(nn.Module):
  def __init__(self,):
    super(Reparameterize,self).__init__()
  def forward(self,mu,logvar):
    std = torch.exp(0.5*logvar)
    eps = torch.randn_like(std)
    return mu + eps*std

class UNet_VAE(nn.Module):
  def __init__(self):
    super(UNet_VAE,self).__init__()
    encoder_features = [32,64,128,256,512]
    decoder_features = [512,256,128,64,32]
    self.conv_block_enc1 = nn.Sequential(
        nn.Conv2d(in_channels=1,out_channels=encoder_features[0],kernel_size=3,stride=1,padding=1),
        nn.ReLU(),
        nn.Conv2d(in_channels=encoder_features[0],out_channels=encoder_features[0],kernel_size=3,stride=1,padding=1),
        nn.ReLU()
    )
    self.max_pool1 = nn.MaxPool2d(kernel_size=2,stride=2)
    self.conv_block_enc2 = nn.Sequential(
        nn.Conv2d(in_channels=encoder_features[0],out_channels=encoder_features[1],kernel_size=3,stride=1,padding=1),
        nn.ReLU(),
        nn.Conv2d(in_channels=encoder_features[1],out_channels=encoder_features[1],kernel_size=3,stride=1,padding=1),
        nn.ReLU()
    )
    self.max_pool2 = nn.MaxPool2d(kernel_size=2,stride=2)
    self.conv_block_enc3 = nn.Sequential(
        nn.Conv2d(in_channels=encoder_features[1],out_channels=encoder_features[2],kernel_size=3,stride=1,padding=1),
        nn.ReLU(),
        nn.Conv2d(in_channels=encoder_features[2],out_channels=encoder_features[2],kernel_size=3,stride=1,padding=1),
        nn.ReLU()
    )
    self.max_pool3 = nn.MaxPool2d(kernel_size=2,stride=2)
    self.conv_block_enc4 = nn.Sequential(
        nn.Conv2d(in_channels=encoder_features[2],out_channels=encoder_features[3],kernel_size=3,stride=1,padding=1),
        nn.ReLU(),
        nn.Conv2d(in_channels=encoder_features[3],out_channels=encoder_features[3],kernel_size=3,stride=1,padding=1),
        nn.ReLU()
    )
    self.max_pool4 = nn.MaxPool2d(kernel_size=2,stride=2)
    self.conv_block_enc5 = nn.Sequential(
        nn.Conv2d(in_channels=encoder_features[3],out_channels=encoder_features[4],kernel_size=3,stride=1,padding=1),
        nn.ReLU(),
        nn.Conv2d(in_channels=encoder_features[4],out_channels=encoder_features[4],kernel_size=3,stride=1,padding=1),
        nn.ReLU()
        )

    self.flatten = nn.Flatten()

    self.mu = nn.Linear(in_features=encoder_features[4]*8*8,out_features=1024)
    self.logvar = nn.Linear(in_features=encoder_features[4]*8*8,out_features=1024)
    self.fc1 = nn.Linear(in_features=1024,out_features=encoder_features[4]*8*8)
    self.reparameterize = Reparameterize()

    self.conv_transpose4 = nn.ConvTranspose2d(in_channels=decoder_features[0],out_channels=decoder_features[1],kernel_size=2,stride=2)
    self.conv_block_dec4 = nn.Sequential(
        nn.Conv2d(in_channels=decoder_features[0],out_channels=decoder_features[1],kernel_size=3,stride=1,padding=1),
        nn.ReLU(),
        nn.Conv2d(in_channels=decoder_features[1],out_channels=decoder_features[1],kernel_size=3,stride=1,padding=1),
        nn.ReLU()
    )
    self.conv_transpose3 = nn.ConvTranspose2d(in_channels=decoder_features[1],out_channels=decoder_features[2],kernel_size=2,stride=2)
    self.conv_block_dec3 = nn.Sequential(
        nn.Conv2d(in_channels=decoder_features[1],out_channels=decoder_features[2],kernel_size=3,stride=1,padding=1),
        nn.ReLU(),
        nn.Conv2d(in_channels=decoder_features[2],out_channels=decoder_features[2],kernel_size=3,stride=1,padding=1),
        nn.ReLU()
    )
    self.conv_transpose2 = nn.ConvTranspose2d(in_channels=decoder_features[2],out_channels=decoder_features[3],kernel_size=2,stride=2)
    self.conv_block_dec2 = nn.Sequential(
        nn.Conv2d(in_channels=decoder_features[2],out_channels=decoder_features[3],kernel_size=3,stride=1,padding=1),
        nn.ReLU(),
        nn.Conv2d(in_channels=decoder_features[3],out_channels=decoder_features[3],kernel_size=3,stride=1,padding=1),
        nn.ReLU()
    )
    self.conv_transpose1 = nn.ConvTranspose2d(in_channels=decoder_features[3],out_channels=decoder_features[4],kernel_size=2,stride=2)
    self.conv_block_dec1 = nn.Sequential(
        nn.Conv2d(in_channels=decoder_features[3],out_channels=decoder_features[4],kernel_size=3,stride=1,padding=1),
        nn.ReLU(),
        nn.Conv2d(in_channels=decoder_features[4],out_channels=1,kernel_size=3,stride=1,padding=1),
        nn.ReLU()
    )

  def forward(self,x):
    x1 = self.conv_block_enc1(x)

    x2 = self.max_pool1(x1)
    x2 = self.conv_block_enc2(x2)

    x3 = self.max_pool2(x2)
    x3 = self.conv_block_enc3(x3)

    x4 = self.max_pool3(x3)
    x4 = self.conv_block_enc4(x4)

    x5 = self.max_pool4(x4)
    x5 = self.conv_block_enc5(x5)

    mu = self.mu(self.flatten(x5))
    logvar = self.logvar(self.flatten(x5))
    x5 = self.reparameterize(mu,logvar)
    x5 = self.fc1(x5)
    x5 = x5.view(-1,512,8,8)

    y4 = self.conv_transpose4(x5)
    y4 = self.conv_block_dec4(torch.cat((y4,x4),1))

    y3 = self.conv_transpose3(y4)
    y3 = self.conv_block_dec3(torch.cat((y3,x3),1))

    y2 = self.conv_transpose2(y3)
    y2 = self.conv_block_dec2(torch.cat((y2,x2),1))

    y1 = self.conv_transpose1(y2)
    y1 = self.conv_block_dec1(torch.cat((y1,x1),1))

    return y1,mu,logvar


In [None]:
# model = UNet()
model = UNet_VAE()
model = model.cuda()
summary(model,(1,128,128))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 128, 128]             320
              ReLU-2         [-1, 32, 128, 128]               0
            Conv2d-3         [-1, 32, 128, 128]           9,248
              ReLU-4         [-1, 32, 128, 128]               0
         MaxPool2d-5           [-1, 32, 64, 64]               0
            Conv2d-6           [-1, 64, 64, 64]          18,496
              ReLU-7           [-1, 64, 64, 64]               0
            Conv2d-8           [-1, 64, 64, 64]          36,928
              ReLU-9           [-1, 64, 64, 64]               0
        MaxPool2d-10           [-1, 64, 32, 32]               0
           Conv2d-11          [-1, 128, 32, 32]          73,856
             ReLU-12          [-1, 128, 32, 32]               0
           Conv2d-13          [-1, 128, 32, 32]         147,584
             ReLU-14          [-1, 128,

Training

In [None]:
loss_criterion = nn.L1Loss()

def kl_loss(mu, logvar):
    return -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).mean()

def validation(model,test_dataloader):
  model.eval()
  val_loss = 0
  for idx,(x,y) in tqdm(enumerate(test_dataloader)):
    x = x.cuda()
    x = F.interpolate(x,(128,128))
    output,_,_ = model(x)
    loss = loss_criterion(output,x)
    val_loss += loss.item()
  model.train()
  return val_loss/len(test_dataloader)

In [None]:
loss_criterion = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(),lr=0.0001)
print(len(train_dataloader))
for epoch in range(0,6):
  print("Epoch :",epoch)
  for idx,(x,y) in tqdm(enumerate(train_dataloader)):
    x = x.cuda()
    x = F.interpolate(x,(128,128))
    model.train()
    output,mu,log_var = model(x)
    loss = loss_criterion(output,x) + kl_loss(mu,log_var)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
  accuarcy = validation(model,test_dataloader)
  print("Validation Loss :",accuarcy)
  torch.save(model,'model.pth')

938
Epoch : 0


938it [04:30,  3.47it/s]
157it [00:15, 10.34it/s]


Validation Loss : 0.0013062203663347918
Epoch : 1


938it [04:31,  3.46it/s]
157it [00:15, 10.34it/s]


Validation Loss : 0.001049689980814008
Epoch : 2


938it [04:30,  3.46it/s]
157it [00:15, 10.35it/s]


Validation Loss : 0.0006065391256936653
Epoch : 3


938it [04:30,  3.46it/s]
157it [00:15, 10.32it/s]


Validation Loss : 0.0004637551343418468
Epoch : 4


938it [04:30,  3.46it/s]
157it [00:15, 10.35it/s]


Validation Loss : 0.0004645395600592872
Epoch : 5


938it [04:30,  3.46it/s]
157it [00:15, 10.30it/s]


Validation Loss : 0.000625908891611086


In [None]:
checkpoint = torch.load("model.pth", weights_only=False)
model.load_state_dict(checkpoint.state_dict())

<All keys matched successfully>