<a href="https://colab.research.google.com/github/VedantDere0104/Image_Reconstruction/blob/main/Super_Resolution_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
####

In [2]:
#! unzip '/content/drive/MyDrive/Wider_Face_Dataset/WIDER_train.zip' -d '/content/drive/MyDrive/Wider_Face_Dataset/'

In [3]:
import torchvision
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader , Dataset
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import torch
from torch import nn
from torchsummary import summary
import cv2 as cv
import numpy as np
from torchvision.utils import make_grid

In [4]:
def show_tensor_images(image_tensor, num_images=2, size=(1, 28, 28) , switch = True):
  image_shifted = image_tensor
  #print(image_shifted)
  image_unflat = image_shifted.detach().cpu().view(-1, *size)
  #print(image_unflat)
  image_grid = make_grid(image_unflat[:num_images], nrow=2 , normalize=False)
  #print(image_grid)
  if switch:
    image_grid = image_grid * 255.0
  plt.imshow(image_grid.permute(1 , 2, 0).squeeze())
  plt.show()

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
transform = transforms.Compose([transforms.ToTensor() , 
                                transforms.ToPILImage()  , 
                                transforms.Resize((512 , 512) ) , 
                                transforms.ToTensor()])

low_res_dataset = torchvision.datasets.ImageFolder("/content/drive/MyDrive/Wider_Face_Dataset/Cleaned_Dataset/", transform=transform)

In [7]:
high_res_dataset = torchvision.datasets.ImageFolder("/content/drive/MyDrive/Wider_Face_Dataset/High_res_dataset" , transform=transform)

In [8]:
class Low_High_Dataset(Dataset):
  def __init__(self , 
               low_res_dataset , 
               high_res_dataset):
    super(Low_High_Dataset , self).__init__()

    self.low_res_dataset = low_res_dataset
    self.high_res_dataset = high_res_dataset

  def __len__(self):
    return len(self.low_res_dataset)

  def __getitem__(self , idx):
    x = self.low_res_dataset
    y = self.high_res_dataset
    x = x[idx][0]
    y = y[idx][0]
    #print('x ->' , x)
    #print('y ->' , y)
    return (x , y)

In [9]:
dataset = Low_High_Dataset(low_res_dataset , high_res_dataset)

In [10]:
batch_size = 2
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [11]:
dataloader = DataLoader(dataset , batch_size , True)

In [None]:
dataloader

In [13]:

def crop(image, new_shape):
  middle_height = image.shape[2] // 2
  middle_width = image.shape[3] // 2
  starting_height = middle_height - new_shape[2] // 2
  final_height = starting_height + new_shape[2]
  starting_width = middle_width - new_shape[3] // 2
  final_width = starting_width + new_shape[3]
  cropped_image = image[:, :, starting_height:final_height, starting_width:final_width]
  return cropped_image

In [14]:

class Helper_1(nn.Module):
  def __init__(self , in_channels , out_channels , kernel_size = (2 , 2) , stride = (2 , 2) , use_batch_norm = True , activation = 'lreu'):
    super(Helper_1 , self).__init__()

    self.use_batch_norm = use_batch_norm
    self.conv1 = nn.Conv2d(in_channels , out_channels , kernel_size , stride)
    self.batch_norm = nn.InstanceNorm2d(out_channels)
    self.activation = activation
    if self.activation == 'relu':
      self.relu = nn.ReLU()
    else :
      self.relu = nn.LeakyReLU(0.2)

  def forward(self , x):
    #print(x.shape)
    x = self.conv1(x)
    x = self.batch_norm(x)
    x = self.relu(x)
    #print(x.shape)
    return x

In [15]:

class Encoder(nn.Module):
  def __init__(self , in_channels , hidden_dim , out_channels):
    super(Encoder , self).__init__()

    self.conv1 = Helper_1(in_channels , hidden_dim , (2 , 2) , (2 , 2) , False , 'relu')
    self.conv2 = Helper_1(hidden_dim , hidden_dim * 2 , (2 , 2) , (2 , 2) , True , 'relu')
    self.conv3 = Helper_1(hidden_dim * 2 , hidden_dim * 4 , (2 , 2) , (2 , 2) , True , 'relu')
    self.conv4 = Helper_1(hidden_dim * 4 , hidden_dim * 8 , (2 , 2) , (2 , 2) ,True, 'relu')
    self.conv5 = Helper_1(hidden_dim * 8 , hidden_dim * 16 , (2 , 2) , (2 ,2) , True, 'relu')
    self.conv6 = Helper_1(hidden_dim * 16 , hidden_dim * 32 , (2 , 2) ,(2 , 2) , True, 'relu')
    self.conv7 = Helper_1(hidden_dim * 32 , hidden_dim * 64 , (2 , 2) , (2 , 2) , True, 'relu')
    self.conv8 = Helper_1(hidden_dim * 64 , hidden_dim * 32 , (2 , 2) , (2 , 2) , True, 'relu')
    self.flatten = nn.Flatten()
    self.linear1 = nn.Linear(4096 , hidden_dim * 16)
    self.batchnorm = nn.BatchNorm1d(hidden_dim * 16)
    self.relu = nn.ReLU()
    
    self.linear2 = nn.Linear(hidden_dim * 16 , hidden_dim * 16)


  def forward(self , x):
    x = self.conv1(x)
    x = self.conv2(x) 
    x = self.conv3(x)
    x = self.conv4(x) 
    x = self.conv5(x)
    x = self.conv6(x)
    x = self.conv7(x)
    x = self.conv8(x)
    x = self.flatten(x)
    x = self.relu(self.linear1(x))
    x = self.linear2(x)
    x = x.view(x.shape[0] , x.shape[1] , 1 , 1)
    return x

In [16]:

class Helper_2(nn.Module):
  def __init__(self , in_channels , out_channels , kernel_size = (2 , 2) , stride = (2 , 2) , use_batchnorm = True , activation = 'relu'):

    super(Helper_2 , self).__init__()

    self.use_batchnorm = use_batchnorm
    self.convT1 = nn.ConvTranspose2d(in_channels , out_channels , 
                                     kernel_size , stride)
    
    if self.use_batchnorm:
      self.batchnorm = nn.InstanceNorm2d(out_channels)

    self.activation = activation
    if self.activation == 'lrelu':
      self.lrelu = nn.LeakyReLU()
    elif self.activation == 'relu':
      self.relu = nn.ReLU()
    

  def forward(self , x):
    x = self.convT1(x)
    if self.use_batchnorm:
      x = self.batchnorm(x)
    if self.activation == 'lrelu':
      x = self.lrelu(x)
    elif self.activation == 'relu':
      x = self.relu(x)
    return x

In [17]:

class Generator(nn.Module):
  def __init__(self , z_in_channels , img_in_channels , hidden_dim , out_channels):
    super(Generator , self).__init__()

    self.encoder = Encoder(3 , 32 , 512)

    self.convT1 = Helper_2(z_in_channels , hidden_dim , use_batchnorm=False)
    self.convT2 = Helper_2(hidden_dim , hidden_dim  *2 , use_batchnorm=True)
    self.convT3 = Helper_2(hidden_dim * 2 , hidden_dim * 4 , use_batchnorm=True)
    self.convT4 = Helper_2(hidden_dim * 4 , hidden_dim * 8 , use_batchnorm=True)
    self.convT5 = Helper_2(hidden_dim * 8 , hidden_dim * 16 , use_batchnorm=True)
    self.convT6 = Helper_2(hidden_dim * 16 , hidden_dim  * 32 , use_batchnorm=True)
    self.convT7 = Helper_2(hidden_dim * 32 , hidden_dim * 32 , use_batchnorm=False)

    self.conv1 = Helper_1(img_in_channels , hidden_dim , use_batch_norm=False)
    self.conv2 = Helper_1(hidden_dim , hidden_dim * 32 , use_batch_norm=False)

    self.convT_1 = Helper_2(hidden_dim * 32 * 2 , hidden_dim * 32 , use_batchnorm=False)
    self.convT_2 = Helper_2(hidden_dim * 32 , 3 , use_batchnorm=False , activation='relu')


  def forward(self , x , y):
    x = self.encoder(x)
    x = self.convT1(x)
    x = self.convT2(x)
    x = self.convT3(x)
    x = self.convT4(x)
    x = self.convT5(x)
    x = self.convT6(x)
    x = self.convT7(x)

    y = self.conv1(y)
    y = self.conv2(y)

    z = torch.cat([x , y] , dim=1)

    z = self.convT_1(z)
    z = self.convT_2(z)

    return z
    

In [18]:
class Discriminator(nn.Module):
  def __init__(self , in_channels , hidden_dim , out_channels):
    super(Discriminator , self).__init__()

    self.conv1 = Helper_1(in_channels , hidden_dim , use_batch_norm=False)
    self.conv2 = Helper_1(hidden_dim , hidden_dim * 2 )
    self.conv3 = Helper_1(hidden_dim * 2 , hidden_dim * 4)
    self.conv4 = Helper_1(hidden_dim * 4 , hidden_dim * 8)
    self.conv5 = Helper_1(hidden_dim * 8 , hidden_dim * 16)
    self.conv6 = Helper_1(hidden_dim * 16 , hidden_dim * 32)
    self.conv7 = Helper_1(hidden_dim * 32 , hidden_dim * 32)
    self.conv8 = Helper_1(hidden_dim * 32 , hidden_dim * 32)
    self.flatten = nn.Flatten()

    self.relu = nn.ReLU()
    self.linear1 = nn.Linear(4096 , hidden_dim * 32)
    self.batchnorm1 = nn.BatchNorm1d(hidden_dim * 32)
    self.linear2 = nn.Linear(hidden_dim * 32 , hidden_dim * 8)
    self.batchnorm2 = nn.BatchNorm1d(hidden_dim * 8)
    self.linear3 = nn.Linear(hidden_dim * 8 , hidden_dim)
    self.batchnorm3 = nn.BatchNorm1d(hidden_dim)
    self.linear4 = nn.Linear(hidden_dim , out_channels)
    self.sigmoid = nn.Sigmoid()


  def forward(self , x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = self.conv3(x)
    x = self.conv4(x)
    x = self.conv5(x)
    x = self.conv6(x)
    x = self.conv7(x)
    x = self.conv8(x)
    x = self.flatten(x)
    x = self.relu(self.linear1(x))
    x = self.relu(self.linear2(x))
    x = self.relu(self.linear3(x))
    x = self.sigmoid(self.linear4(x))
    return x

In [19]:
generator_x = Generator(512 , 3 , 32 , 3).to(device)
generator_y = Generator(512 , 3 , 32 , 3).to(device)
generator_ = Generator(512 , 3 , 32 , 3).to(device)

In [20]:
discriminator_x = Discriminator(3 , 32 , 1).to(device)
discriminator_y = Discriminator(3 , 32 , 1).to(device)

In [21]:

def weights_init(m):
  if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
    torch.nn.init.normal_(m.weight, 0.0, 0.02)
  if isinstance(m, nn.BatchNorm2d):
    torch.nn.init.normal(m.weight, 0.0, 0.02)
    torch.nn.init.constant(m.bias, 0)

In [22]:

generator_x = generator_x.apply(weights_init)
generator_y = generator_y.apply(weights_init)
discriminator_x = discriminator_x.apply(weights_init)
discriminator_y = discriminator_y.apply(weights_init)
generator_ = generator_.apply(weights_init)

In [23]:
criterion = nn.BCEWithLogitsLoss()
l1_loss = nn.L1Loss()
lambda_recon = 200

In [24]:


betas = (0.5 , 0.999)
n_epochs = 100
input_dim = 3
real_dim = 3
display_step = 10
lr = 0.0002
target_shape = 512

In [25]:
cur_step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0

In [26]:
opt_generator_y = torch.optim.Adam(generator_y.parameters() , lr=lr , betas=betas)
opt_generator_x = torch.optim.Adam(generator_x.parameters() , lr = lr , betas = betas)
opt_discriminator_x = torch.optim.Adam(discriminator_x.parameters() , lr = lr , betas=betas)
opt_discriminator_y = torch.optim.Adam(discriminator_y.parameters() , lr = lr , betas=betas)
opt_generator_ = torch.optim.Adam(generator_.parameters() , lr=lr , betas=betas)

In [27]:
def get_loss(fake , real  , criterion = criterion , l1_loss = l1_loss , lambda_recon = lambda_recon):
  gen_loss = criterion(fake , real)
  l1_loss = l1_loss(fake , real)
  loss = gen_loss + lambda_recon *  l1_loss
  return loss

In [None]:
torch.autograd.set_detect_anomaly(True)

In [29]:
threshold = 80

In [None]:
for epoch in range(n_epochs):
  for low_res , high_res in tqdm(dataloader):
    real = low_res.to(device)
    condition = high_res.to(device)
    cur_batch_size = len(condition)  

    if cur_step < threshold:
      #print('Generator x is in grad mode')
      opt_generator_y.zero_grad()
      # Generator y -> high to low
      hl1 = generator_y(condition , real)
      hl2 = generator_y(condition , condition)

      #loss_Y_XY = criterion(Y_XY , condition)
      loss_Y_XY = get_loss(hl1 , condition)
      #loss_Y_YY = criterion(Y_YY , condition)
      loss_Y_YY = get_loss(hl2 , real)
      loss_Y = (loss_Y_XY + loss_Y_YY) /2

      loss_Y.backward()
      opt_generator_y.step()
    elif cur_step > threshold:
      #print('Generator x is in no grad mode')
      with torch.no_grad():
        hl1 = generator_x(condition , real)
        hl2 = generator_x(condition , condition)
        loss_Y = 0

    opt_generator_.zero_grad()
    with torch.no_grad():
      hl2 = generator_y(condition , condition)
    lh = generator_(hl2 , hl2)
    loss_ = get_loss(lh , condition)
    loss_.backward()
    opt_generator_.step()

    opt_generator_x.zero_grad()
    # Generator x -> low to high
    with torch.no_grad():
      hl1 = generator_y(condition , real)
      hl2 = generator_y(condition , condition)
      lh = generator_(hl2 , hl2)
      
    lh1 = generator_x(hl1 , lh)
    lh2 = generator_x(lh , lh)

    #loss_X_XX = criterion(X_XX , real)
    loss_X_XX = get_loss(lh1 , condition)
    #loss_X_YX = criterion(X_YX , real)
    loss_X_YX = get_loss(lh2 , condition)
    loss_X = (loss_X_XX + loss_X_YX) /2
    loss_X.backward()
    opt_generator_x.step()




    opt_discriminator_y.zero_grad()
    with torch.no_grad():
      disc_hl1 = generator_y(condition , real)
      disc_hl2 = generator_y(condition ,condition)
      disc_lh1 = generator_x(disc_hl1 , condition)
      disc_lh2 = generator_x(disc_hl1 , condition)
      disc_lh = generator_(disc_hl2 , disc_hl2)
    
    disc_fake_lh = discriminator_y(disc_lh)
    disc_ = criterion(disc_fake_lh , torch.zeros_like(disc_fake_lh))

    disc_fake_hl1 = discriminator_y(disc_hl1)
    disc_loss_fake_pred_hl1 = criterion(disc_fake_hl1 , torch.zeros_like(disc_fake_hl1))

    disc_fake_hl2 = discriminator_y(disc_hl2)
    disc_loss_fake_pred_hl2 = criterion(disc_fake_hl2 , torch.zeros_like(disc_fake_hl2))

    disc_fake_lh1 = discriminator_y(disc_lh1)
    disc_loss_fake_pred_lh1 = criterion(disc_fake_lh1 , torch.zeros_like(disc_fake_lh1))

    disc_fake_lh2 = discriminator_y(disc_lh2)
    disc_loss_fake_pred_lh2 = criterion(disc_fake_lh2 , torch.zeros_like(disc_fake_lh2))


    disc_real_pred = discriminator_y(condition)
    disc_real_pred_loss = criterion(disc_real_pred , torch.ones_like(disc_real_pred))

    disc_y_loss = (disc_loss_fake_pred_hl1 + disc_loss_fake_pred_hl2 + disc_real_pred_loss + disc_loss_fake_pred_lh1 + disc_loss_fake_pred_lh2 + disc_) /6

    #print(disc_y_loss)

    disc_y_loss.backward()
    opt_discriminator_y.step()

    disc_loss =  disc_y_loss
    gen_loss = (loss_X + loss_Y)/2

    mean_discriminator_loss += disc_loss.item() / display_step
    mean_generator_loss += gen_loss.item() / display_step

    if cur_step % display_step == 0:
      if cur_step > 0:
        print(f"Epoch {epoch}: Step {cur_step}: Generator loss: {mean_generator_loss}, Discriminator loss: {mean_discriminator_loss}")
      else:
        
        #print(img_tensor.shape)
        print("Pretrained initial state")

      
      print('Low_res_img')
      show_tensor_images(real, size=(real_dim, target_shape, target_shape) , switch=False)


      print('high --> low')
      show_tensor_images(hl1, size=(real_dim, target_shape, target_shape) , switch=False)
      print('high --> low')
      show_tensor_images(hl2, size=(real_dim, target_shape, target_shape) , switch= False)

      print('High_res_img')
      show_tensor_images(condition, size=(input_dim, target_shape, target_shape) , switch=False)

      print('low --> high')
      show_tensor_images(lh1, size=(real_dim, target_shape, target_shape) , switch=False)
      print('low --> high')
      show_tensor_images(lh2, size=(real_dim, target_shape, target_shape) , switch=False)

      print('low , low --> High')
      show_tensor_images(lh ,size=(real_dim, target_shape, target_shape) , switch=False )

      mean_generator_loss = 0
      mean_discriminator_loss = 0
    cur_step += 1