<a href="https://colab.research.google.com/github/VedantDere0104/GANs/blob/main/Image_translation_with_dual%E2%80%90directional_generative_adversarial_networks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Image translation with dual‐directional generative adversarial
networks :- https://ietresearch.onlinelibrary.wiley.com/doi/full/10.1049/cvi2.12011

In [1]:
####

In [2]:
import torch
from torch import nn
from torchsummary import summary
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import torchvision

In [3]:
def show_tensor_images(image_tensor, num_images=2, size=(1, 28, 28) , switch = True):
  image_shifted = image_tensor
  #print(image_shifted)
  image_unflat = image_shifted.detach().cpu().view(-1, *size)
  #print(image_unflat)
  image_grid = make_grid(image_unflat[:num_images], nrow=2 , normalize=False)
  #print(image_grid)
  if switch:
    image_grid = image_grid * 255.0
  plt.imshow(image_grid.permute(1, 2, 0).squeeze())
  plt.show()

In [4]:
def crop(image, new_shape):
    middle_height = image.shape[2] // 2
    middle_width = image.shape[3] // 2
    starting_height = middle_height - new_shape[2] // 2
    final_height = starting_height + new_shape[2]
    starting_width = middle_width - new_shape[3] // 2
    final_width = starting_width + new_shape[3]
    cropped_image = image[:, :, starting_height:final_height, starting_width:final_width]
    return cropped_image

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
class Helper_1(nn.Module):
  def __init__(self , in_channels , out_channels , kernel_size = (2 , 2) , stride = (2 , 2) , use_batch_norm = True):
    super(Helper_1 , self).__init__()

    self.use_batch_norm = use_batch_norm
    self.conv1 = nn.Conv2d(in_channels , out_channels , kernel_size , stride)
    self.batch_norm = nn.InstanceNorm2d(out_channels)
    self.relu = nn.ReLU()

  def forward(self , x):
    #print(x.shape)
    x = self.conv1(x)
    x = self.batch_norm(x)
    x = self.relu(x)
    #print(x.shape)
    return x


In [None]:
helper = Helper_1(3 , 64 , (2 , 2) , (2 , 2) , True)
x = torch.randn(1 , 3 , 64 , 64)
z = helper(x)
z.shape

In [8]:
class Encoder(nn.Module):
  def __init__(self , in_channels , hidden_dim , out_channels):
    super(Encoder , self).__init__()

    self.conv1 = Helper_1(in_channels , hidden_dim , (2 , 2) , (2 , 2) , False)
    self.conv2 = Helper_1(hidden_dim , hidden_dim * 2 , (2 , 2) , (2 , 2) , True)
    self.conv3 = Helper_1(hidden_dim * 2 , hidden_dim * 4 , (2 , 2) , (2 , 2) , True)
    self.conv4 = Helper_1(hidden_dim * 4 , hidden_dim * 8 , (2 , 2) , (2 , 2) ,True)
    self.conv5 = Helper_1(hidden_dim * 8 , hidden_dim * 16 , (2 , 2) , (2 ,2) , True)
    self.conv6 = Helper_1(hidden_dim * 16 , hidden_dim * 32 , (2 , 2) ,(2 , 2) , True)
    self.conv7 = Helper_1(hidden_dim * 32 , hidden_dim * 64 , (2 , 2) , (2 , 2) , True)
    self.conv8 = Helper_1(hidden_dim * 64 , hidden_dim * 32 , (2 , 2) , (2 , 2) , True)
    self.flatten = nn.Flatten()
    self.linear1 = nn.Linear(4096 , hidden_dim * 16)
    self.batchnorm = nn.BatchNorm1d(hidden_dim * 16)
    self.relu = nn.ReLU()
    
    self.linear2 = nn.Linear(hidden_dim * 16 , hidden_dim * 16)


  def forward(self , x):
    x = self.conv1(x)
    x = self.conv2(x) 
    x = self.conv3(x)
    x = self.conv4(x) 
    x = self.conv5(x)
    x = self.conv6(x)
    x = self.conv7(x)
    x = self.conv8(x)
    x = self.flatten(x)
    x = self.relu(self.batchnorm(self.linear1(x)))
    x = self.linear2(x)
    x = x.view(x.shape[0] , x.shape[1] , 1 , 1)
    return x

In [None]:
encoder = Encoder(3 , 32 , 512).to(device)
summary(encoder , (3 , 512 , 512))

In [None]:
x = torch.randn(10 , 3 , 512 , 512 , device = device)
z = encoder(x)
z.shape

In [11]:
class Helper_2(nn.Module):
  def __init__(self , in_channels , out_channels , kernel_size = (2 , 2) , stride = (2 , 2) , use_batchnorm = True):
    super(Helper_2 , self).__init__()

    self.use_batchnorm = use_batchnorm
    self.convT1 = nn.ConvTranspose2d(in_channels , out_channels , 
                                     kernel_size , stride)
    
    if self.use_batchnorm:
      self.batchnorm = nn.InstanceNorm2d(out_channels)
    self.lrelu = nn.LeakyReLU()

  def forward(self , x):
    x = self.convT1(x)
    if self.use_batchnorm:
      x = self.batchnorm(x)
    x = self.lrelu(x)
    return x
    

In [None]:
helper_2 = Helper_2(3 , 32 , (2 , 2) , (2 , 2) , True).to(device)
summary(helper_2 , (3 , 512 , 512))

In [13]:
class Generator(nn.Module):
  def __init__(self , z_in_channels , img_in_channels , hidden_dim , out_channels):
    super(Generator , self).__init__()

    self.encoder = Encoder(3 , 32 , 512)

    self.convT1 = Helper_2(z_in_channels , hidden_dim , use_batchnorm=False)
    self.convT2 = Helper_2(hidden_dim , hidden_dim  *2)
    self.convT3 = Helper_2(hidden_dim * 2 , hidden_dim * 4)
    self.convT4 = Helper_2(hidden_dim * 4 , hidden_dim * 8)
    self.convT5 = Helper_2(hidden_dim * 8 , hidden_dim * 16)
    self.convT6 = Helper_2(hidden_dim * 16 , hidden_dim  * 32)
    self.convT7 = Helper_2(hidden_dim * 32 , hidden_dim * 32)

    self.conv1 = Helper_1(img_in_channels , hidden_dim , use_batch_norm=False)
    self.conv2 = Helper_1(hidden_dim , hidden_dim * 32)

    self.convT_1 = Helper_2(hidden_dim * 32 * 2 , hidden_dim * 32 , use_batchnorm=False)
    self.convT_2 = Helper_2(hidden_dim * 32 , 3)


  def forward(self , x , y):
    x = self.encoder(x)
    x = self.convT1(x)
    x = self.convT2(x)
    x = self.convT3(x)
    x = self.convT4(x)
    x = self.convT5(x)
    x = self.convT6(x)
    x = self.convT7(x)

    y = self.conv1(y)
    y = self.conv2(y)

    z = torch.cat([x , y] , dim=1)

    z = self.convT_1(z)
    z = self.convT_2(z)

    return z


In [None]:
generator = Generator(512 , 3 , 32 , 3).to(device)
x = torch.randn(5 , 3 , 512 , 512 , device = device)
ans = generator(x , x)
ans.shape

In [15]:
class Discriminator(nn.Module):
  def __init__(self , in_channels , hidden_dim , out_channels):
    super(Discriminator , self).__init__()

    self.conv1 = Helper_1(in_channels , hidden_dim , use_batch_norm=False)
    self.conv2 = Helper_1(hidden_dim , hidden_dim * 2)
    self.conv3 = Helper_1(hidden_dim * 2 , hidden_dim * 4)
    self.conv4 = Helper_1(hidden_dim * 4 , hidden_dim * 8)
    self.conv5 = Helper_1(hidden_dim * 8 , hidden_dim * 16)
    self.conv6 = Helper_1(hidden_dim * 16 , hidden_dim * 32)
    self.conv7 = Helper_1(hidden_dim * 32 , hidden_dim * 32)
    self.conv8 = Helper_1(hidden_dim * 32 , hidden_dim * 32)
    self.flatten = nn.Flatten()

    self.relu = nn.ReLU()
    self.linear1 = nn.Linear(4096 , hidden_dim * 32)
    self.batchnorm1 = nn.BatchNorm1d(hidden_dim * 32)
    self.linear2 = nn.Linear(hidden_dim * 32 , hidden_dim * 8)
    self.batchnorm2 = nn.BatchNorm1d(hidden_dim * 8)
    self.linear3 = nn.Linear(hidden_dim * 8 , hidden_dim)
    self.batchnorm3 = nn.BatchNorm1d(hidden_dim)
    self.linear4 = nn.Linear(hidden_dim , out_channels)
    self.sigmoid = nn.Sigmoid()


  def forward(self , x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = self.conv3(x)
    x = self.conv4(x)
    x = self.conv5(x)
    x = self.conv6(x)
    x = self.conv7(x)
    x = self.conv8(x)
    x = self.flatten(x)
    x = self.relu(self.batchnorm1(self.linear1(x)))
    x = self.relu(self.batchnorm2(self.linear2(x)))
    x = self.relu(self.batchnorm3(self.linear3(x)))
    x = self.sigmoid(self.linear4(x))
    return x

In [None]:
disc = Discriminator(3 , 32 , 1).to(device)
summary(disc , (3 , 512 , 512))

In [None]:
a = torch.randn(10 , 3 , 512 , 512 ,device =  device)
a = disc(a)
print(a.shape)

In [18]:
generator_x = Generator(512 , 3 , 32 , 3).to(device)
generator_y = Generator(512 , 3 , 32 , 3).to(device)

In [19]:
discriminator_x = Discriminator(3 , 32 , 1).to(device)
discriminator_y = Discriminator(3 , 32 , 1).to(device)

In [20]:
def weights_init(m):
  if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
    torch.nn.init.normal_(m.weight, 0.0, 0.02)
  if isinstance(m, nn.BatchNorm2d):
    torch.nn.init.normal_(m.weight, 0.0, 0.02)
    torch.nn.init.constant_(m.bias, 0)

In [21]:
generator_x = generator_x.apply(weights_init)
generator_y = generator_y.apply(weights_init)
discriminator_x = discriminator_x.apply(weights_init)
discriminator_y = discriminator_y.apply(weights_init)

In [22]:
criterion = nn.BCEWithLogitsLoss()
l1_loss = nn.L1Loss()
lambda_recon = 200

In [23]:
n_epochs = 100
input_dim = 3
real_dim = 3
display_step = 10
batch_size = 2
lr = 0.000002
target_shape = 512

In [24]:
transform = transforms.Compose([ transforms.ToTensor(), ])

dataset = torchvision.datasets.ImageFolder("/content/drive/MyDrive/Maps/maps/", transform=transform)

In [None]:
dataset

In [26]:
  mean_generator_loss = 0
  mean_discriminator_loss = 0
  dataloader = DataLoader(dataset , batch_size = batch_size , shuffle=True)
  cur_step = 0

In [27]:
betas = (0.5 , 0.999)

In [28]:
opt_generator_y = torch.optim.Adam(generator_y.parameters() , lr=lr , betas=betas)
opt_generator_x = torch.optim.Adam(generator_x.parameters() , lr = lr , betas = betas)
opt_discriminator_x = torch.optim.Adam(discriminator_x.parameters() , lr = lr , betas=betas)
opt_discriminator_y = torch.optim.Adam(discriminator_y.parameters() , lr = lr , betas=betas)

In [29]:
def get_gen_loss(gen , disc , real , condition , adv_criterion , recon_criterion , lambda_recon):
  fake = gen(condition)
  disc_fake_hat = disc(fake , condition)
  gen_adv_loss = adv_criterion(disc_fake_hat , torch.ones_like(disc_fake_hat))
  gen_rec_loss = recon_criterion(real , condition)
  gen_loss = gen_adv_loss + lambda_recon * gen_rec_loss
  return gen_loss


In [30]:
def get_loss(fake , real  , criterion = criterion , l1_loss = l1_loss , lambda_recon = lambda_recon):
  gen_loss = criterion(fake , real)
  l1_loss = l1_loss(fake , real)
  loss = gen_loss + lambda_recon *  l1_loss
  return loss

In [None]:
for epoch in range(n_epochs):
  for img , _ in tqdm(dataloader):
    image_width = img.shape[3]
    condition = img[: , : , : , :image_width//2]
    condition = nn.functional.interpolate(condition , size = target_shape)
    real = img[: , : , : , image_width//2:]
    real = nn.functional.interpolate(real , size = target_shape)
    cur_batch_size = len(condition)
    real = real.to(device)
    condition = condition.to(device)  

    opt_generator_y.zero_grad()
    Y_XY = generator_y(real , real)
    Y_YY = generator_y(real , condition)

    #loss_Y_XY = criterion(Y_XY , condition)
    loss_Y_XY = get_loss(Y_XY , condition)
    #loss_Y_YY = criterion(Y_YY , condition)
    loss_Y_YY = get_loss(Y_YY , condition)
    loss_Y = (loss_Y_XY + loss_Y_YY) /2

    loss_Y.backward()
    opt_generator_y.step()


    opt_generator_x.zero_grad()
    X_XX = generator_x(condition , real)
    X_YX = generator_x(condition , condition)

    #loss_X_XX = criterion(X_XX , real)
    loss_X_XX = get_loss(X_XX , real)
    #loss_X_YX = criterion(X_YX , real)
    loss_X_YX = get_loss(X_YX , real)
    loss_X = (loss_X_XX + loss_X_YX) /2
    loss_X.backward()
    opt_generator_x.step()


    opt_discriminator_y.zero_grad()
    with torch.no_grad():
      disc_Y_XY = generator_y(real , real)
      disc_Y_YY = generator_y(condition ,real)
    disc_fake_y_pred_YX = discriminator_y(disc_Y_XY)
    disc_loss_fake_pred_YX = criterion(disc_fake_y_pred_YX , torch.zeros_like(disc_fake_y_pred_YX))

    disc_fake_y_pred_YY = discriminator_y(disc_Y_YY)
    disc_loss_fake_pred_YY = criterion(disc_fake_y_pred_YY , torch.zeros_like(disc_fake_y_pred_YY))

    disc_real_pred = discriminator_y(condition)
    disc_real_pred_loss = criterion(disc_real_pred , torch.ones_like(disc_real_pred))

    disc_y_loss = (disc_loss_fake_pred_YX + disc_loss_fake_pred_YY + disc_real_pred_loss) /3
    #print(disc_y_loss)

    disc_y_loss.backward()
    opt_discriminator_y.step()

    opt_discriminator_x.zero_grad()
    with torch.no_grad():
      disc_X_XX = generator_x(real , condition)
      disc_X_YX_ = generator_x(condition , condition)
    disc_fake_pred_XX = discriminator_x(disc_X_XX)
    disc_loss_fake_pred_XX = criterion(disc_fake_pred_XX , torch.zeros_like(disc_fake_pred_XX))

    disc_fake_pred_YX_ = discriminator_x(disc_X_YX_)
    disc_loss_fake_pred_YX_ = criterion(disc_fake_pred_YX_ , torch.zeros_like(disc_fake_pred_YX_))

    disc_x_real_pred = discriminator_x(real)
    disc_x_real_loss = criterion(disc_x_real_pred , torch.ones_like(disc_x_real_pred))
    disc_x_loss = (disc_loss_fake_pred_XX + disc_loss_fake_pred_YX_ + disc_x_real_loss) / 3
    #print(disc_x_loss)
    disc_x_loss.backward()
    opt_discriminator_x.step()

    disc_loss = (disc_x_loss + disc_y_loss)/2
    gen_loss = (loss_X + loss_Y)/2

    mean_discriminator_loss += disc_loss.item() / display_step
    mean_generator_loss += gen_loss.item() / display_step

    if cur_step % display_step == 0:
      if cur_step > 0:
        print(f"Epoch {epoch}: Step {cur_step}: Generator loss: {mean_generator_loss}, Discriminator loss: {mean_discriminator_loss}")
      else:
        print("Pretrained initial state")
      print('Y')
      show_tensor_images(condition, size=(input_dim, target_shape, target_shape) , switch=False)
      print('X')
      show_tensor_images(real, size=(real_dim, target_shape, target_shape) , switch=False)
      print('X --> Y')
      show_tensor_images(Y_XY, size=(real_dim, target_shape, target_shape) , switch=True)
      #plt.imshow(  Y_XY.permute(1, 2, 0))
      print('Y --> Y')
      show_tensor_images(Y_YY, size=(real_dim, target_shape, target_shape) , switch=True)
      #plt.imshow(  Y_YY.permute(1, 2, 0))
      print('X --> X')
      show_tensor_images(X_XX, size=(real_dim, target_shape, target_shape) , switch=False)
      print('Y --> X')
      show_tensor_images(X_YX, size=(real_dim, target_shape, target_shape) , switch=False)
      mean_generator_loss = 0
      mean_discriminator_loss = 0
    cur_step += 1