<a href="https://colab.research.google.com/github/VedantDere0104/Dual_Directional_GAN/blob/main/Dual_Directional_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
####

In [2]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import VOCSegmentation
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from torchsummary import summary

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
in_channels = 3
out_channels = 3
out_channels_disc = 1
z_dim = 512

In [4]:
def show_tensor_images(image_tensor, num_images=2, size=(1, 28, 28)):
  image_shifted = image_tensor
  image_unflat = image_shifted.detach().cpu().view(-1, *size)
  image_grid = make_grid(image_unflat[:num_images], nrow=5)
  plt.imshow(image_grid.permute(1, 2, 0).squeeze())
  plt.show()

In [5]:
def crop(image, new_shape):
    middle_height = image.shape[2] // 2
    middle_width = image.shape[3] // 2
    starting_height = middle_height - new_shape[2] // 2
    final_height = starting_height + new_shape[2]
    starting_width = middle_width - new_shape[3] // 2
    final_width = starting_width + new_shape[3]
    cropped_image = image[:, :, starting_height:final_height, starting_width:final_width]
    return cropped_image

In [6]:
class Conv(nn.Module):
  def __init__(self , 
               in_channels , 
               out_channels , 
               kernel_size = 3 , 
               stride = 1 , 
               padding = 1 ,
               use_norm = True ,
               use_activation = True , 
               use_dropout = False , 
               use_pool = True):
    super(Conv , self).__init__()

    self.use_norm = use_norm
    self.use_activation = use_activation
    self.use_dropout = use_dropout
    self.use_pool = use_pool

    self.conv1 = nn.Conv2d(in_channels , out_channels , kernel_size , stride , padding)

    if self.use_norm:
      self.norm = nn.BatchNorm2d(out_channels)
    
    if self.use_activation:
      self.activation = nn.ReLU(inplace=True)

    if self.use_dropout:
      self.dropout = nn.Dropout()

    if self.use_pool:
      self.maxpool = nn.MaxPool2d(kernel_size=(2 , 2) , stride=(2 , 2))

  def forward(self , x):
    x = self.conv1(x)
    if self.use_norm:
      x = self.norm(x)
    if self.use_activation:
      x = self.activation(x)
    if self.use_dropout:
      x = self.dropout(x)
    if self.use_pool:
      x = self.maxpool(x)
    
    return x

In [None]:
conv = Conv(3 , 32).to(device)
summary(conv , (3 , 512 , 512))

In [8]:
class ConvT(nn.Module):
  def __init__(self , 
               in_channels , 
               out_channels , 
               kernel_size = (2 , 2) , 
               stride = (2 , 2) , 
               padding = 0 , 
               use_norm = True , 
               use_activation = True , 
               use_dropout = False):
    super(ConvT , self).__init__()

    self.use_norm = use_norm
    self.use_activation = use_activation
    self.use_dropout = use_dropout

    self.convT1 = nn.ConvTranspose2d(in_channels * 2 , out_channels , kernel_size, stride , padding)

    if self.use_norm:
      self.norm = nn.InstanceNorm2d(out_channels)
    if self.use_activation:
      self.activation = nn.LeakyReLU(0.2)
    if self.use_dropout:
      self.dropout = nn.Dropout()

  def forward(self , x , y):
    y = crop(y , x.shape)
    x = torch.cat([x , y] , dim=1)
    x = self.convT1(x)

    if self.use_norm:
      x = self.norm(x)
    if self.use_activation:
      x = self.activation(x)
    if self.use_dropout:
      x = self.dropout(x)
    return x

In [None]:
convT = ConvT(128 , 32).to(device)
x = torch.randn(2 , 128 , 256 , 256).to(device)
y = torch.randn(2 , 128 , 512 , 512).to(device)
z = convT(x , y)
z.shape

In [10]:
class Encoder_Source(nn.Module):
  def __init__(self , 
               in_channels , 
               z_dim = 512 ,
               hidden_dim = 32):
    super(Encoder_Source , self).__init__()

    self.conv1 = Conv(in_channels , hidden_dim , use_norm=False)
    self.conv2 = Conv(hidden_dim , hidden_dim * 2 , use_norm=False)
    self.conv3 = Conv(hidden_dim * 2 , hidden_dim * 4)
    self.conv4 = Conv(hidden_dim * 4 , hidden_dim * 8)
    self.conv5 = Conv(hidden_dim * 8 , hidden_dim * 16)
    self.conv6 = Conv(hidden_dim * 16 , hidden_dim * 32)
    self.conv7 = Conv(hidden_dim * 32 , hidden_dim * 64 , use_norm=False)
    self.conv8 = Conv(hidden_dim * 64,  z_dim , use_norm=False)

  def forward(self ,x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = self.conv3(x)
    x = self.conv4(x)
    x = self.conv5(x)
    x = self.conv6(x)
    x = self.conv7(x)
    x = self.conv8(x)

    return x

  

In [None]:
encoder_source = Encoder_Source(3).to(device)
summary(encoder_source , (3 , 512 , 512))

In [12]:
class Encoder_Target(nn.Module):
  def __init__(self , 
               in_channels , 
               out_channels , 
               hidden_dim = 32):
    super(Encoder_Target , self).__init__()

    self.conv1 = Conv(in_channels , hidden_dim , kernel_size=7 , stride=4 , padding=0 , use_pool=False , use_norm=False)
    self.conv2 = Conv(hidden_dim , hidden_dim * 2 , kernel_size=7 , stride=4 , padding=0 , use_pool=False)
    self.conv3 = Conv(hidden_dim * 2 , hidden_dim * 4 , kernel_size=7 ,stride=4 ,padding=0 , use_pool=False)
    self.conv4 = Conv(hidden_dim * 4 , out_channels , kernel_size=6 , stride=1 , padding=0 , use_pool=False , use_norm=False)


  def forward(self , x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = self.conv3(x)
    x = self.conv4(x)
    return x

In [None]:
encoder_target = Encoder_Target(3 , 512).to(device)
summary(encoder_target , (3 , 512 , 512))

In [14]:
class Generator(nn.Module):
  def __init__(self , 
               in_channels , 
               out_channels , 
               z_dim = 512 , 
               hidden_dim = 32):
    super(Generator , self).__init__()

    self.encoder_target = Encoder_Target(in_channels , z_dim)

    self.conv1 = Conv(in_channels , hidden_dim , use_norm=False)
    self.conv2 = Conv(hidden_dim , hidden_dim * 2 , use_norm=False)
    self.conv3 = Conv(hidden_dim * 2 , hidden_dim * 4)
    self.conv4 = Conv(hidden_dim * 4 , hidden_dim * 8)
    self.conv5 = Conv(hidden_dim * 8 , hidden_dim * 16)
    self.conv6 = Conv(hidden_dim * 16 , hidden_dim * 32 , use_norm=False)
    self.conv7 = Conv(hidden_dim * 32 , hidden_dim * 64 , use_norm=False)

    self.last_conv = Conv(hidden_dim * 64 , z_dim , use_norm=False)


    self.conv_mapping_1 = Conv(z_dim * 2 , hidden_dim * 64 , use_pool=False , use_norm=False)

    self.convT1 = ConvT(hidden_dim * 64 , hidden_dim * 32)
    self.convT2 = ConvT(hidden_dim * 32 , hidden_dim * 16)
    self.convT3 = ConvT(hidden_dim * 16 , hidden_dim * 8)
    self.convT4 = ConvT(hidden_dim * 8 , hidden_dim * 4)
    self.convT5 = ConvT(hidden_dim * 4 , hidden_dim * 2)
    self.convT6 = ConvT(hidden_dim * 2,  hidden_dim , use_norm=False)
    self.convT7 = ConvT(hidden_dim , hidden_dim , use_norm=False)

    
    self.convT_last = nn.ConvTranspose2d(hidden_dim , out_channels , kernel_size=(2 , 2) , stride=(2 , 2))
    self.relu = nn.ReLU(inplace=True)
    self.sigmoid = nn.Sigmoid()

    self.last_conv1 = Conv(in_channels , out_channels , use_pool=False)
    self.last_conv2 = Conv(out_channels , out_channels , use_pool=False , use_norm=False , use_activation=False)


  
  def forward(self , x , y):
    y = self.encoder_target(y)

    x1 = self.conv1(x)
    x2 = self.conv2(x1)
    x3 = self.conv3(x2)
    x4 = self.conv4(x3)
    x5 = self.conv5(x4)
    x6 = self.conv6(x5)
    x7 = self.conv7(x6)

    x8 = self.last_conv(x7)

    z = torch.cat([x8 , y] , dim=1)
    z = self.conv_mapping_1(z)

    x9 = self.convT1(z , x7)
    x10 = self.convT2(x9 , x6)
    x11 = self.convT3(x10 , x5)
    x12 = self.convT4(x11 , x4)
    x13 = self.convT5(x12 , x3)
    x14 = self.convT6(x13 , x2)
    x15 = self.convT7(x14 , x1)

    x = self.relu(self.convT_last(x15))
    x = self.last_conv1(x)
    x = self.sigmoid(self.last_conv2(x))

    return x

In [None]:
x = torch.randn(2 , 3 , 512 , 512).to(device)
y = torch.randn(2 , 3 , 512 , 512).to(device)
generator = Generator(3 , 3).to(device)
z = generator(x , y)
z.shape

In [16]:
class Discriminator(nn.Module):
  def __init__(self , 
               in_channels , 
               z_dim = 512 , 
               hidden_dim = 32):
    super(Discriminator , self).__init__()

    self.encoder = Encoder_Source(in_channels , z_dim)

    self.flatten = nn.Flatten()

    self.linear1 = nn.Linear(z_dim * 4 , z_dim)
    self.batchnorm1 = nn.BatchNorm1d(z_dim)
    
    self.linear2 = nn.Linear(z_dim , 32)
    self.batchnorm2 = nn.BatchNorm1d(32)

    self.linear3 = nn.Linear(32 , 1)
    self.relu = nn.ReLU(inplace=True)
    self.sigmoid = nn.Sigmoid()

  def forward(self , x):
    x = self.encoder(x)
    x = self.flatten(x)
    x = self.linear1(x)
    x = self.batchnorm1(x)
    x = self.relu(x)
    x = self.linear2(x)
    x = self.batchnorm2(x)
    x = self.relu(x)
    x = self.linear3(x)
    x = self.sigmoid(x)
    return x

In [None]:
discriminator = Discriminator(in_channels).to(device)
summary(discriminator , (3 , 512 , 512))

In [18]:

adv_criterion = nn.BCEWithLogitsLoss()
recon_criterion = nn.L1Loss()
lambda_recon = 200
betas = (0.5 , 0.999)

n_epochs = 200
in_channels = 3
out_channels = 3
display_step = 5
batch_size = 2
lr = 0.0002
target_shape = 512

In [19]:

transform = transforms.Compose([
    transforms.ToTensor()
])

import torchvision
dataset = torchvision.datasets.ImageFolder("/content/drive/MyDrive/Maps/maps/", transform=transform)

In [20]:
generator_x = Generator(in_channels , out_channels).to(device)
generator_y = Generator(in_channels , out_channels).to(device)
discriminator_x = Discriminator(in_channels , out_channels_disc).to(device)
discriminator_y = Discriminator(in_channels , out_channels_disc).to(device)

In [21]:
opt_generator_x = torch.optim.Adam(generator_x.parameters() , lr=lr , betas = betas )
opt_generator_y = torch.optim.Adam(generator_y.parameters() , lr=lr , betas = betas )
opt_discriminator_x = torch.optim.Adam(discriminator_x.parameters() , lr=lr , betas = betas )
opt_discriminator_y = torch.optim.Adam(discriminator_y.parameters() , lr=lr , betas = betas )

In [22]:
def weights_init(m):
  if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
    torch.nn.init.normal_(m.weight, 0.0, 0.02)
  if isinstance(m, nn.BatchNorm2d):
    torch.nn.init.normal_(m.weight, 0.0, 0.02)
    torch.nn.init.constant_(m.bias, 0)

In [23]:
generator_x = generator_x.apply(weights_init)
generator_y = generator_y.apply(weights_init)
discriminator_x = discriminator_x.apply(weights_init)
discriminator_y = discriminator_y.apply(weights_init)

In [24]:
dataloader = DataLoader(dataset , batch_size = batch_size , shuffle=True)
input_dim = 3
real_dim = 3

In [25]:
def get_gen_loss(fake_output , 
                 real , 
                 disc_fake_pred, 
                 adv_criterion = adv_criterion , 
                 recon_criterion = recon_criterion , 
                 lambda_recon = lambda_recon):
  gen_loss = adv_criterion(fake_output , real)
  disc_loss_ = adv_criterion(disc_fake_pred , torch.zeros_like(disc_fake_pred))
  gen_loss = recon_criterion(fake_output , real)
  loss = disc_loss_ + lambda_recon * gen_loss + lambda_recon * gen_loss
  return loss

In [26]:
def train():
  mean_generator_loss = 0
  mean_discriminator_loss = 0
  cur_step = 0
  for epoch in range(n_epochs):
    for img , _ in tqdm(dataloader):
      image_width = img.shape[3]
      condition = img[: , : , : , :image_width//2]
      condition = nn.functional.interpolate(condition , size = target_shape)
      real = img[: , : , : , image_width//2:]
      real = nn.functional.interpolate(real , size = target_shape)
      cur_batch_size = len(condition)
      real = real.to(device)
      condition = condition.to(device)
      #print(torch.max(condition) , torch.max(real))

      #X = condition , Y = real
      # generator_x => real
      # generator_y => condition
      # discriminator_x => real
      # discriminator_y => condition

      opt_generator_x.zero_grad()

      fake_X_X_X = generator_x(condition , real)
      fake_X_Y_X = generator_x(condition , condition)
      
      disc_fake_pred_X_X_X = discriminator_x(fake_X_X_X)
      disc_fake_pred_X_Y_X = discriminator_x(fake_X_Y_X)

      gen_x_loss_1 = get_gen_loss(fake_X_X_X , real , disc_fake_pred_X_X_X)
      gen_x_loss_2 = get_gen_loss(fake_X_Y_X , real , disc_fake_pred_X_Y_X)

      generator_x_loss = (gen_x_loss_1 + gen_x_loss_2)/2

      generator_x_loss.backward()
      opt_generator_x.step()

      opt_generator_y.zero_grad()
      
      fake_Y_X_Y = generator_y(real , real)
      fake_Y_Y_Y = generator_y(real , condition)

      disc_fake_pred_Y_X_Y = discriminator_y(fake_Y_X_Y)
      disc_fake_pred_Y_Y_Y = discriminator_y(fake_Y_Y_Y)

      gen_y_loss_1 = get_gen_loss(fake_Y_X_Y , condition , disc_fake_pred_Y_X_Y)
      gen_y_loss_2 = get_gen_loss(fake_Y_Y_Y , condition , disc_fake_pred_Y_Y_Y)

      generator_y_loss = (gen_y_loss_1 + gen_y_loss_2)/2

      generator_y_loss.backward()
      opt_generator_y.step()

      opt_discriminator_x.zero_grad()
      
      with torch.no_grad():
        fake_X_X_X = generator_x(condition , real)
        fake_X_Y_X = generator_x(condition , condition)
      disc_fake_pred_X_X_X_ = discriminator_x(fake_X_X_X)
      disc_fake_pred_X_Y_X_ = discriminator_x(fake_X_Y_X)
      disc_real_pred_X = discriminator_x(real)

      disc_fake_pred_X_X_X_loss = adv_criterion(disc_fake_pred_X_X_X_ , torch.zeros_like(disc_fake_pred_X_X_X_))
      disc_fake_pred_X_Y_X_loss = adv_criterion(disc_fake_pred_X_Y_X_ , torch.zeros_like(disc_fake_pred_X_Y_X_))
      disc_real_pred_X_loss = adv_criterion(disc_real_pred_X , torch.ones_like(disc_real_pred_X))
      
      discriminator_x_loss = (disc_fake_pred_X_X_X_loss + disc_fake_pred_X_Y_X_loss + disc_real_pred_X_loss)/3

      discriminator_x_loss.backward()
      opt_discriminator_x.step()


      opt_discriminator_y.zero_grad()

      with torch.no_grad():
        fake_Y_X_Y = generator_y(real , real)
        fake_Y_Y_Y = generator_y(real , condition)        

      disc_fake_pred_Y_X_Y_ = discriminator_y(fake_Y_X_Y)
      disc_fake_pred_Y_Y_Y_ = discriminator_y(fake_Y_Y_Y)
      disc_real_pred_Y = discriminator_y(condition)

      disc_fake_pred_Y_X_Y_loss = adv_criterion(disc_fake_pred_Y_X_Y_ , torch.zeros_like(disc_fake_pred_Y_X_Y_))
      disc_fake_pred_Y_Y_Y_loss = adv_criterion(disc_fake_pred_Y_Y_Y_ , torch.zeros_like(disc_fake_pred_Y_Y_Y_))
      disc_real_pred_Y_loss = adv_criterion(disc_real_pred_Y , torch.ones_like(disc_real_pred_Y))

      discriminator_y_loss = (disc_fake_pred_Y_X_Y_loss + disc_fake_pred_Y_Y_Y_loss + disc_real_pred_Y_loss)/3
      discriminator_y_loss.backward()
      opt_discriminator_y.step()

      discriminator_loss = (discriminator_x_loss + discriminator_y_loss)/2
      generator_loss = (generator_x_loss + generator_y_loss)/2

      mean_discriminator_loss += discriminator_loss.item() / display_step
      mean_generator_loss += generator_loss.item() / display_step

      if cur_step % display_step == 0:
        if cur_step > 0:
          print(f"Epoch {epoch}: Step {cur_step}: Generator loss: {mean_generator_loss}, Discriminator loss: {mean_discriminator_loss}")
        else:
          print("Pretrained initial state")
        print('condition')
        show_tensor_images(condition, size=(input_dim, target_shape, target_shape))
        print('real')
        show_tensor_images(real, size=(real_dim, target_shape, target_shape) )
        print('fake_X_X_X')
        show_tensor_images(fake_X_X_X, size=(real_dim, target_shape, target_shape) )
        print('fake_X_Y_X')
        show_tensor_images(fake_X_Y_X , size=(real_dim , target_shape , target_shape))
        print('fake_Y_X_Y')
        show_tensor_images(fake_Y_X_Y , size=(real_dim , target_shape , target_shape))
        print('fake_Y_Y_Y')
        show_tensor_images(fake_Y_Y_Y , size=(real_dim , target_shape , target_shape))
        mean_generator_loss = 0
        mean_discriminator_loss = 0
      cur_step += 1

In [None]:
train()