<a href="https://colab.research.google.com/github/YassineNJ/UNet-implementation/blob/main/UNet_VAE_MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.optim as optim
import matplotlib
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from tqdm import tqdm
from torchvision import datasets
from torch.utils.data import DataLoader
from google.colab import drive
drive.mount('/content/drive')

In [None]:
batch_size = 64
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_data = datasets.MNIST(root='../input/data',train=True,download=True,transform=transforms.ToTensor())
train_loader = DataLoader(train_data,batch_size=batch_size,shuffle=True)
val_data = datasets.MNIST(root='../input/data',train=False,download=True,transform=transforms.ToTensor())
val_loader = DataLoader(val_data,batch_size=batch_size,shuffle=False)

In [None]:
class VAE(nn.Module):
  def __init__(self,in_dim = 784,z_dim = 20):
    super(VAE, self).__init__()

    self.z_dim = z_dim

    self.encoder = nn.Sequential(
        nn.Linear(in_dim,in_dim//4),
        nn.ReLU(inplace=True),
        nn.Linear(in_dim//4,in_dim//16),
        nn.ReLU(inplace=True),
        nn.Linear(in_dim//16,z_dim*2),
    )
    self.decoder = nn.Sequential(
        nn.Linear(z_dim,z_dim*4),
        nn.Linear(z_dim*4,z_dim*16),
        nn.Linear(z_dim*16,in_dim),
        nn.Sigmoid(),
    )

  def forward(self,x):
    x=x.view(x.size(0),-1)
    x = self.encoder(x)
    x = x.view(-1, 2, self.z_dim)
    mu = x[:,0,:]
    log_sigma = x[:,1,:]
    z = mu + torch.randn_like(mu) * torch.exp(log_sigma) 
    x_hat = self.decoder(z)
    return x_hat , mu , log_sigma

In [None]:
def criterion(x_hat,mu,log_sigma,x):
  l_rec = L_recon(x_hat,x)
  l_reg = torch.sum(0.5*(torch.exp(log_sigma)**2 + mu**2 - 1 - 2*log_sigma))
  return l_rec + l_reg

def display_images(x,x_hat,batch_size):
  n = int(np.sqrt(batch_size))
  fig ,axs = plt.subplots(2*n,n, figsize = (15,15))
  x_copy = np.squeeze(x.clone().detach().cpu().numpy())
  x_hat_copy = np.squeeze(x_hat.clone().detach().cpu().numpy())
  
  for i in range(n):
    for j in range(n):
      if len(x.shape)==1:
        axs[i,j].imshow(x_copy[i*n+j].reshape(28,28),cmap='gray')
        axs[i+n,j].imshow(x_hat_copy[i*n+j].reshape(28,28),cmap='gray')
      else:
        axs[i,j].imshow(x_copy[i*n+j],cmap='gray')
        axs[i+n,j].imshow(x_hat_copy[i*n+j],cmap='gray')

      axs[i,j].axis('off')
      axs[i+n,j].axis('off')
    
  plt.axis('off')
  plt.show()

def train_model(model):
  avg_loss = 0
  model.train()
  for i, (x,_) in enumerate(train_loader):
    optimizer.zero_grad()
    x = x.to(device)
    x_hat , mu , log_sigma = model(x)

    ###########################
    if i==0 and epoch%10==0:
      display_images(x,x_hat,batch_size)
    ###########################
    loss = criterion(x_hat,mu,log_sigma,x)
    loss.backward()
    optimizer.step()
    avg_loss += loss.item()
  return avg_loss/len(train_loader.dataset)


def validate(model):
  avg_loss = 0
  model.eval()
  with torch.no_grad():
    for x,_ in val_loader:
      x = x.to(device)
      x_hat , mu , sigma = model(x)
      loss = criterion(x_hat,mu,sigma,x)
      avg_loss += loss.data.item()
  
  avg_loss/= len(val_loader.dataset)
  global min_loss 
  global j 
  
  if min_loss > avg_loss:
    min_loss = avg_loss
    print('\nMin Loss:{:.4f}'.format(min_loss))
    torch.save(model.state_dict(), f'{path}/model_{type(model).__name__}.pth')
    print('model saved')
    j=1
  else:
    j+=1
    if j==5:
      j=1
      scheduler.step()
      for param_group in optimizer.param_groups:
            print("Current learning rate is: {}".format(param_group['lr']))  

  return avg_loss

In [None]:
model_1 = VAE(in_dim = 784,z_dim = 20).to(device)
epochs = 100
optimizer = torch.optim.Adam(model_1.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.8)
min_loss = 1e6
j = 1
L_recon =  nn.BCELoss(reduction='sum')
for epoch in range(epochs):
  train_loss = train_model(model_1)
  validation_loss = validate(model_1)
  print(f"Epoch : {epoch} \nTrain loss : {train_loss:.6f} \nValidation loss : {validation_loss:.6f}")

# UNet


In [None]:

def crop(e,d):
  diff = e.size(2) - d.size(2)
  return e[:,:,diff//2:d.size(2)+diff//2 ,diff//2:d.size(2)+diff//2]

class UNet(nn.Module):
  def __init__(self):
    super(UNet, self).__init__()

    self.maxpool = nn.MaxPool2d(kernel_size=2 , stride = 2)
    self.encoder_conv_1 = nn.Sequential(nn.Conv2d(in_channels= 1, out_channels= 64 , kernel_size= 3),
                                 nn.ReLU(inplace= True),
                                 nn.Conv2d(in_channels= 64, out_channels= 64 , kernel_size= 3),
                                 nn.ReLU(inplace= True),
                                 )
    self.encoder_conv_2 = nn.Sequential(nn.Conv2d(in_channels= 64, out_channels= 128 , kernel_size= 3),
                                 nn.ReLU(inplace= True),
                                 nn.Conv2d(in_channels= 128, out_channels= 128 , kernel_size= 3),
                                 nn.ReLU(inplace= True),
                                 )
    self.encoder_conv_3 = nn.Sequential(nn.Conv2d(in_channels= 128, out_channels= 256 , kernel_size= 3),
                                 nn.ReLU(inplace= True),
                                 nn.Conv2d(in_channels= 256, out_channels= 256 , kernel_size= 3),
                                 nn.ReLU(inplace= True),
                                 )
    # self.encoder_conv_4 = nn.Sequential(nn.Conv2d(in_channels= 256, out_channels= 512 , kernel_size= 3),
    #                             nn.ReLU(inplace= True),
    #                              nn.Conv2d(in_channels= 512, out_channels= 512 , kernel_size= 3),
    #                              nn.ReLU(inplace= True),
    #                              )
    # self.encoder_conv_5 = nn.Sequential(nn.Conv2d(in_channels= 512, out_channels= 1024 , kernel_size= 3),
    #                              nn.ReLU(inplace= True),
    #                              nn.Conv2d(in_channels= 1024, out_channels= 1024 , kernel_size= 3),
    #                              nn.ReLU(inplace= True))    
    
    # self.convt_1= nn.ConvTranspose2d(in_channels = 512, out_channels = 512 , kernel_size=2 ,stride=2)
    # self.convt_2= nn.ConvTranspose2d(in_channels = 512, out_channels = 256 , kernel_size=2,stride=2)
    
    self.convt_3= nn.ConvTranspose2d(in_channels = 128, out_channels = 128 , kernel_size=2,stride=2)
    self.convt_4= nn.ConvTranspose2d(in_channels = 128, out_channels = 64 , kernel_size=2,stride=2)

    # self.decoder_conv_1 = nn.Sequential(nn.Conv2d(in_channels= 1024, out_channels= 512 , kernel_size= 3),
    #                                     nn.ReLU(inplace= True),
    #                                     nn.Conv2d(in_channels= 512, out_channels= 512 , kernel_size= 3),
    #                                     nn.ReLU(inplace= True),
    #                                     )
    # self.decoder_conv_2 = nn.Sequential(nn.Conv2d(in_channels= 512, out_channels= 256 , kernel_size= 3),
    #                                     nn.ReLU(inplace= True),
    #                                     nn.Conv2d(in_channels= 256, out_channels= 256 , kernel_size= 3),
    #                                     nn.ReLU(inplace= True),
    #                                     )
    self.decoder_conv_3 = nn.Sequential(nn.Conv2d(in_channels= 256, out_channels= 128 , kernel_size= 3),
                                        nn.ReLU(inplace= True),
                                        nn.Conv2d(in_channels= 128, out_channels= 128 , kernel_size= 3),
                                        nn.ReLU(inplace= True),
                                        )
    self.decoder_conv_4 = nn.Sequential(nn.Conv2d(in_channels= 128, out_channels= 64 , kernel_size= 3),
                                        nn.ReLU(inplace= True),
                                        nn.Conv2d(in_channels= 64, out_channels= 64 , kernel_size= 3),
                                        nn.ReLU(inplace= True),
                                        )
    self.decoder_conv_5 = nn.Conv2d(in_channels= 64, out_channels= 1 , kernel_size= 1)
                                      
  def forward(self,x):
      e1 = self.encoder_conv_1(x)
      e2 = self.encoder_conv_2(self.maxpool(e1))
      e3 = self.encoder_conv_3(self.maxpool(e2))
      # e4 = self.encoder_conv_4(self.maxpool(e3))
      # e5 = self.encoder_conv_5(self.maxpool(e4))
      # e5 = e5.view(e5.size(0),2,512,e5.size(2),e5.size(3))
      # mu = e5[:,0,:]
      # log_sigma = e5[:,1,:]


      e3 = e3.view(e3.size(0),2,128,e3.size(2),e3.size(3))
      mu = e3[:,0,:]
      log_sigma = e3[:,1,:]

      z = mu + torch.randn_like(mu) * torch.exp(log_sigma) 
      # z = self.decoder_conv_1(torch.cat((self.convt_1(z),crop(e4,self.convt_1(z))),1))
      # z = self.decoder_conv_2(torch.cat((self.convt_2(z),crop(e3,self.convt_2(z))),1))
      z = self.decoder_conv_3(torch.cat((self.convt_3(z),crop(e2,self.convt_3(z))),1))
      z = self.decoder_conv_4(torch.cat((self.convt_4(z),crop(e1,self.convt_4(z))),1))
      z = self.decoder_conv_5(z)
      z = torch.sigmoid(F.interpolate(z, size=x.shape[-2:], mode='bilinear', align_corners=False))
      return z , mu , log_sigma

In [None]:
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Resize((54,54))
                                ])
batch_size = 64
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_data = datasets.MNIST(root='../input/data',train=True,download=True,transform=transform)
train_loader = DataLoader(train_data,batch_size=batch_size,shuffle=True)
val_data = datasets.MNIST(root='../input/data',train=False,download=True,transform=transform)
val_loader = DataLoader(val_data,batch_size=batch_size,shuffle=False)

In [None]:
model_2 = UNet().to(device)
epochs = 100
optimizer = torch.optim.Adam(model_2.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.8)
L_recon =  nn.BCELoss(reduction='sum')
min_loss = 1e6
j = 1
for epoch in range(epochs):
  train_loss = train_model(model_2)
  validation_loss = validate(model_2)
  print(f"Epoch : {epoch} \nTrain loss : {train_loss:.6f} \nValidation loss : {validation_loss:.6f}")