<a href="https://colab.research.google.com/github/VedantDere0104/Pix2Pix_GAN/blob/main/Pix2Pix_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
####

In [3]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import VOCSegmentation
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
torch.manual_seed(0)
import numpy as np
from torchsummary import summary

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
def show_tensor_images(image_tensor, num_images=2, size=(1, 28, 28)):
  image_shifted = image_tensor
  image_unflat = image_shifted.detach().cpu().view(-1, *size)
  image_grid = make_grid(image_unflat[:num_images], nrow=5)
  plt.imshow(image_grid.permute(1, 2, 0).squeeze())
  plt.show()

In [6]:
def crop(image, new_shape):
    middle_height = image.shape[2] // 2
    middle_width = image.shape[3] // 2
    starting_height = middle_height - new_shape[2] // 2
    final_height = starting_height + new_shape[2]
    starting_width = middle_width - new_shape[3] // 2
    final_width = starting_width + new_shape[3]
    cropped_image = image[:, :, starting_height:final_height, starting_width:final_width]
    return cropped_image

In [7]:
def get_summary(model , shape):
  return summary(model , shape)

In [8]:
class Conv(nn.Module):
  def __init__(self , 
               in_channels , 
               out_channels , 
               kernel_size = 3 , 
               stride = 1 , 
               padding = 1 , 
               use_norm = True , 
               use_activation = True , 
               use_dropout = False , 
               use_pool = True):
    super(Conv , self).__init__()

    self.use_norm = use_norm
    self.use_activation = use_activation
    self.use_dropout = use_dropout
    self.use_pool = use_pool

    self.conv1 = nn.Conv2d(in_channels , out_channels , kernel_size , stride , padding)

    if self.use_norm:
      self.norm = nn.InstanceNorm2d(out_channels)
    if self.use_activation:
      self.activation = nn.ReLU(inplace = True)
    if self.use_dropout:
      self.dropout = nn.Dropout()

    if self.use_pool:
      self.maxpool = nn.MaxPool2d(kernel_size=2 , stride=2)

  def forward(self , x):
    x = self.conv1(x)
    if self.use_norm:
      x = self.norm(x)
    if self.use_activation:
      x = self.activation(x)
    if self.use_dropout:
      x = self.dropout(x)
    if self.use_pool:
      x = self.maxpool(x)
    return x


In [None]:
conv = Conv(3 , 32).to(device)
get_summary(conv , (3 , 512 , 512))

In [10]:
class ConvT(nn.Module):
  def __init__(self , 
               in_channels , 
               out_channels , 
               kernel_size = (2 , 2) , 
               stride = (2 ,2) , 
               use_norm = True , 
               use_activation = True , 
               use_dropout = False):
    super(ConvT , self).__init__()

    self.use_norm = use_norm
    self.use_activation = use_activation
    self.use_dropout = use_dropout

    self.convT1 = nn.ConvTranspose2d(in_channels , out_channels , kernel_size , stride)

    if self.use_norm:
      self.norm = nn.InstanceNorm2d(out_channels)
    if self.use_activation:
      self.activation = nn.ReLU(inplace=True)
    if self.use_dropout:
      self.dropout = nn.Dropout()

  def forward(self , x):
    x = self.convT1(x)
    if self.use_norm:
      x = self.norm(x)
    if self.use_activation:
      x = self.activation(x)
    if self.use_dropout:
      x = self.dropout(x)
    return x

In [None]:
convT = ConvT(3 , 32).to(device)
get_summary(convT , (3 , 256 , 256))

In [12]:
class Decoder_block(nn.Module):
  def __init__(self , 
               in_channels , 
               out_channels , 
               use_norm = True , 
               use_dropout = False):
    super(Decoder_block , self).__init__()

    self.convT1 = ConvT(in_channels * 2 , in_channels * 2 , use_norm=use_norm , use_dropout=use_dropout)
    self.conv1 = Conv(in_channels * 2 , out_channels , use_pool=False)

  def forward(self , x , y):
    y = crop(y , x.shape)
    #print(x.shape,  y.shape)
    x = torch.cat([x , y] , dim=1)
    x = self.convT1(x)
    
    #print(x.shape,  y.shape)


    x = self.conv1(x)
    return x


In [None]:
decoder_block = Decoder_block(3 , 64).to(device)
x = torch.randn(2 , 3 , 256 , 256).to(device)
y = torch.randn(2 , 3 , 512 , 512).to(device)
z = decoder_block(x , y)
z.shape

In [14]:
class Generator(nn.Module):
  def __init__(self , 
               in_channels ,
               out_channels , 
               hidden_dim = 32):
    super(Generator , self).__init__()

    self.conv1 = Conv(in_channels , hidden_dim * 2 , use_norm=False)
    self.conv2 = Conv(hidden_dim * 2 , hidden_dim * 4)
    self.conv3 = Conv(hidden_dim * 4 , hidden_dim * 8)
    self.conv4 = Conv(hidden_dim * 8 , hidden_dim * 16)
    self.conv5 = Conv(hidden_dim * 16 , hidden_dim * 32 , use_norm=False)
    self.conv6 = Conv(hidden_dim * 32 , hidden_dim * 64 , use_norm=False)

    self.middle = Conv(hidden_dim * 64 , hidden_dim * 64 , use_pool=False , use_norm=True)

    in_channels = 3
    self.decoder_1 = Decoder_block(hidden_dim * 64 , hidden_dim * 32 , use_norm=False)
    self.decoder_2 = Decoder_block(hidden_dim * 32 , hidden_dim * 16)
    self.decoder_3 = Decoder_block(hidden_dim * 16 , hidden_dim * 8)
    self.decoder_4 = Decoder_block(hidden_dim * 8 , hidden_dim * 4)
    self.decoder_5 = Decoder_block(hidden_dim * 4 , hidden_dim * 2 , use_norm=False)
    self.decoder_6 = Decoder_block(hidden_dim * 2 , hidden_dim , use_norm=False)

    self.last = Conv(hidden_dim , out_channels , use_pool=False , use_norm=False , use_activation=False)
    self.sigmoid = nn.Sigmoid()
  def forward(self , x):
    x1 = self.conv1(x)
    x2 = self.conv2(x1)
    x3 = self.conv3(x2)
    x4 = self.conv4(x3)
    x5 = self.conv5(x4)
    x6 = self.conv6(x5)
    #print(x6.shape)
    x_middle = self.middle(x6)
    #print(x_middle.shape , x6.shape)

    x7 = self.decoder_1(x_middle , x6)
    x8 = self.decoder_2(x7 , x5)
    x9 = self.decoder_3(x8, x4)
    x10 = self.decoder_4(x9 , x3)
    x11 = self.decoder_5(x10 , x2)
    x12 = self.decoder_6(x11 , x1)

    x_last = self.sigmoid(self.last(x12))
    return x_last


In [None]:
generator = Generator(3 , 3).to(device)
get_summary(generator , (3 , 512 , 512))

In [16]:
class Discriminator(nn.Module):
  def __init__(self , 
               in_channels , 
               out_channels , 
               hidden_dim = 32):
    super(Discriminator , self).__init__()

    self.conv1 = Conv(in_channels , hidden_dim)
    self.conv2 = Conv(hidden_dim , hidden_dim * 2)
    self.conv3 = Conv(hidden_dim * 2 , hidden_dim * 4)
    self.conv4 = Conv(hidden_dim * 4 , hidden_dim * 8)
    self.conv5 = Conv(hidden_dim * 8 , hidden_dim * 16)
    self.conv6 = Conv(hidden_dim * 16 , hidden_dim * 32)
    self.conv7 = Conv(hidden_dim * 32 , 1)

    self.flatten = nn.Flatten()

    self.linear1 = nn.Linear(16 , 8)
    self.linear2 = nn.Linear(8 , 1)
    self.relu = nn.ReLU(inplace=True)
    self.batchnorm = nn.BatchNorm1d(8)

    self.sigmoid = nn.Sigmoid()

  def forward(self , x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = self.conv3(x)
    x = self.conv4(x)
    x = self.conv5(x)
    x = self.conv6(x)
    x = self.conv7(x)
    x = self.flatten(x)
    #print(x.shape)
    x = self.relu(self.batchnorm(self.linear1(x)))
    x = self.sigmoid(self.linear2(x))
    return x

In [None]:
discriminator = Discriminator(3 , 1).to(device)
summary(discriminator , (3 , 512 , 512))

In [18]:
adv_criterion = nn.BCEWithLogitsLoss()
recon_criterion = nn.L1Loss()
lambda_recon = 200


n_epochs = 200
in_channels = 3
out_channels = 3
display_step = 100
batch_size = 2
lr = 0.0002
target_shape = 512



In [19]:

transform = transforms.Compose([
    transforms.ToTensor()
])

import torchvision
dataset = torchvision.datasets.ImageFolder("/content/drive/MyDrive/Maps/maps/", transform=transform)

In [20]:
gen = Generator(in_channels , out_channels).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
disc = Discriminator(in_channels  , 1).to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)

In [21]:
def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)

In [None]:
gen = gen.apply(weights_init)
disc = disc.apply(weights_init)

In [23]:
def get_loss(real , 
             condition , 
             gen = gen , 
             disc = disc , 
             adv_criterion = adv_criterion , 
             recon_criterion = recon_criterion , 
             lambda_recon=lambda_recon):
  fake_gen = gen(condition)
  disc_fake_pred = disc(fake_gen)
  gen_loss_ = adv_criterion(fake_gen , real)
  disc_loss = adv_criterion(disc_fake_pred , torch.zeros_like(disc_fake_pred))
  gen_loss = recon_criterion(fake_gen , real)
  loss = disc_loss + lambda_recon * gen_loss + lambda_recon * gen_loss_ 
  return loss

In [24]:

dataloader = DataLoader(dataset , batch_size = batch_size , shuffle=True)
input_dim = 3
real_dim = 3

In [26]:
def train():
  mean_generator_loss = 0
  mean_discriminator_loss = 0
  cur_step = 0
  for epoch in range(n_epochs):
    for img , _ in tqdm(dataloader):
      image_width = img.shape[3]
      condition = img[: , : , : , :image_width//2]
      condition = nn.functional.interpolate(condition , size = target_shape)
      real = img[: , : , : , image_width//2:]
      real = nn.functional.interpolate(real , size = target_shape)
      cur_batch_size = len(condition)
      real = real.to(device)
      condition = condition.to(device)
      #print(torch.max(condition) , torch.max(real))

      disc_opt.zero_grad()
      with torch.no_grad():
        fake_img = gen(condition)
      disc_fake_pred = disc(fake_img)
      disc_real_pred = disc(real)
      disc_fake_loss = adv_criterion(disc_fake_pred , torch.zeros_like(disc_fake_pred))
      disc_real_loss = adv_criterion(disc_real_pred , torch.ones_like(disc_real_pred))
      disc_loss = (disc_fake_loss + disc_real_loss)/2
      disc_loss.backward()
      disc_opt.step()

      gen_opt.zero_grad()
      gen_loss = get_loss(real , condition)
      gen_loss.backward()
      gen_opt.step()

      mean_discriminator_loss += disc_loss.item() / display_step
      mean_generator_loss += gen_loss.item() / display_step


      if cur_step % display_step == 0:
        if cur_step > 0:
          print(f"Epoch {epoch}: Step {cur_step}: Generator loss: {mean_generator_loss}, Discriminator loss: {mean_discriminator_loss}")
        else:
          print("Pretrained initial state")
        print('condition')
        show_tensor_images(condition, size=(input_dim, target_shape, target_shape))
        print('real')
        show_tensor_images(real, size=(real_dim, target_shape, target_shape) )
        print('Fake_generated_img')
        show_tensor_images(fake_img, size=(real_dim, target_shape, target_shape) )
        mean_generator_loss = 0
        mean_discriminator_loss = 0
      cur_step += 1

In [None]:
train()

In [28]:
#torch.save(gen.state_dict() , '/content/drive/MyDrive/Map_Datase_Gen_1.pth')

In [29]:
#torch.save(disc.state_dict() , '/content/drive/MyDrive/Map_Dataset_Disc_1.pth')

In [None]:
#gen.load_state_dict(torch.load('/content/drive/MyDrive/Map_Datase_Gen_1.pth'))

In [None]:
#gen.load_state_dict(torch.load('/content/drive/MyDrive/Map_Dataset_Gen_1.pth'))
#disc.load_state_dict(torch.load('/content/drive/MyDrive/Map_Dataset_Disc_1.pth'))