<a href="https://colab.research.google.com/github/VedantDere0104/GANs/blob/main/Using_U_Net_for_dual_direction_Generative_Adversarial_Network.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
####

In [2]:
import torch
from torch import nn
from torchsummary import summary

In [3]:
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import torchvision

In [4]:

def show_tensor_images(image_tensor, num_images=2, size=(1, 28, 28) , switch = True):
  image_shifted = image_tensor
  #print(image_shifted)
  image_unflat = image_shifted.detach().cpu().view(-1, *size)
  #print(image_unflat)
  image_grid = make_grid(image_unflat[:num_images], nrow=2 , normalize=False)
  #print(image_grid)
  if switch:
    image_grid = image_grid * 255.0
  plt.imshow(image_grid.permute(1, 2, 0).squeeze())
  plt.show()

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
def crop(image, new_shape):
  middle_height = image.shape[2] // 2
  middle_width = image.shape[3] // 2
  starting_height = middle_height - new_shape[2] // 2
  final_height = starting_height + new_shape[2]
  starting_width = middle_width - new_shape[3] // 2
  final_width = starting_width + new_shape[3]
  cropped_image = image[:, :, starting_height:final_height, starting_width:final_width]
  return cropped_image

In [7]:
class Downsample(nn.Module):
  def __init__(self , 
               in_channels , 
               use_norm = True , 
               use_dropout = False):
    super(Downsample , self).__init__()

    self.use_norm = use_norm
    self.use_dropout = use_dropout

    self.conv1 = nn.Conv2d(in_channels , in_channels * 2 , kernel_size=3 , padding=1)
    self.conv2 = nn.Conv2d(in_channels * 2 , in_channels * 2 , kernel_size=3 , padding=1)
    
    self.activation = nn.ReLU(0.2)

    self.maxpool = nn.MaxPool2d(kernel_size=2)


    if self.use_norm:
      self.norm = nn.InstanceNorm2d(in_channels * 2)

    if self.use_dropout:
      self.dropout = nn.Dropout()

  def forward(self , x):
    x = self.conv1(x)
    if self.use_norm:
      x = self.norm(x)
    if self.use_dropout:
      x = self.dropout(x)
    x = self.activation(x)

    x = self.conv2(x)
    if self.use_norm:
      x = self.norm(x)
    if self.use_dropout:
      x = self.dropout(x)
    x = self.activation(x)

    x = self.maxpool(x)
    return x

In [8]:
class Upsample(nn.Module):
  def __init__(self , 
               in_channels , 
               use_norm = True , 
               use_dropout = False):
    super(Upsample , self).__init__()

    self.use_norm = use_norm
    self.use_dropout = use_dropout

    self.convT1 = nn.ConvTranspose2d(in_channels , in_channels , kernel_size=2 , stride=2 , padding=0)

    self.conv1 = nn.Conv2d(in_channels   , in_channels // 2 , kernel_size=3 , padding=1)
    self.conv2 = nn.Conv2d(in_channels  + in_channels//2, in_channels // 2, kernel_size=3 , padding=1)

    self.lrelu = nn.LeakyReLU(0.2)

    if self.use_norm:
      self.norm = nn.InstanceNorm2d(in_channels * 2)
    if self.use_dropout:
      self.dropout = nn.Dropout()

  def forward(self , x , x_skip_con):
    #print(x.shape)
    x = self.convT1(x)
    #print(x.shape)
    x = self.conv1(x)
    if self.use_norm:
      x = self.norm(x)
    if self.use_dropout:
      x = self.dropout(x)
    x = self.lrelu(x)

    x_skip_con = crop(x_skip_con , x.shape)

    #print(x_skip_con.shape)
    #print(x.shape)
    x = torch.cat((x , x_skip_con) , dim=1)
    #print(x.shape)
    x = self.conv2(x)

    if self.use_norm:
      x = self.norm(x)
    if self.use_dropout:
      x = self.dropout(x)
    x = self.lrelu(x)

    return x


In [9]:
upsample = Upsample(3 , use_dropout=True)

In [10]:
x = torch.randn(2 , 2048 , 8 , 8)
y = torch.randn(2 , 3 , 512 , 512)

In [11]:
class Feature_map_block(nn.Module):
  def __init__(self , 
               in_channels , 
               out_channels):
    super(Feature_map_block , self).__init__()

    
    self.conv = nn.Conv2d(in_channels , in_channels * 2 , kernel_size = 1 , stride = 1)
    self.instance = nn.InstanceNorm2d(in_channels * 2)
    self.lrelu = nn.ReLU()

    self.instance1 = nn.InstanceNorm2d(out_channels)
    self.conv1 = nn.Conv2d(in_channels * 2 , out_channels , kernel_size=1 ,stride=1)

  def forward(self , x):
    x = self.lrelu(self.instance(self.conv(x)))
    x = self.lrelu(self.instance1(self.conv1(x)))
    return x

In [None]:
fmb = Feature_map_block(3 , 32).to(device)
summary(fmb , (3 , 512 , 512))

In [13]:
###################################################

In [14]:
class Encoder(nn.Module):
  def __init__(self ,in_channels , out_channels , hidden_dim = 32 ):
    super(Encoder , self).__init__()

    self.conv1 = nn.Conv2d(in_channels , hidden_dim , kernel_size=2,  stride=2)
    self.instance1 = nn.InstanceNorm2d(hidden_dim)
    self.conv2 = nn.Conv2d(hidden_dim , hidden_dim  * 2 , kernel_size=2 , stride=2)
    self.instance2 = nn.InstanceNorm2d(hidden_dim * 2)
    self.conv3 = nn.Conv2d(hidden_dim * 2 , hidden_dim * 4 , kernel_size=2 , stride=2)
    self.instance3 = nn.InstanceNorm2d(hidden_dim  * 4)
    self.conv4 = nn.Conv2d(hidden_dim * 4 , hidden_dim * 8 , kernel_size=2 , stride=2)
    self.instance4 = nn.InstanceNorm2d(hidden_dim * 8)
    self.conv5 = nn.Conv2d(hidden_dim * 8 , hidden_dim * 16 , kernel_size=2 , stride=2)
    self.instance5 = nn.InstanceNorm2d(hidden_dim * 16)
    self.conv6 = nn.Conv2d(hidden_dim * 16 , hidden_dim * 32 , kernel_size=2 , stride=2)
    self.instance6 = nn.InstanceNorm2d(hidden_dim * 32)

    self.relu = nn.LeakyReLU(0.2)

  def forward(self , x):
    x = self.relu(self.instance1(self.conv1(x)))
    x = self.relu(self.instance2(self.conv2(x)))
    x = self.relu(self.instance3(self.conv3(x)))
    x = self.relu(self.instance4(self.conv4(x)))
    x = self.relu(self.instance5(self.conv5(x)))
    x = self.relu(self.instance6(self.conv6(x)))
    return x



In [15]:
class Dual_direction_GAN(nn.Module):
  def __init__(self , in_channels ,  out_channels , hidden_dim = 32):
    super(Dual_direction_GAN , self).__init__()

    self.upfeature = Feature_map_block(in_channels , hidden_dim)
    self.downsample1 = Downsample(hidden_dim )
    self.downsample2 = Downsample(hidden_dim * 2 )
    self.downsample3 = Downsample(hidden_dim * 4 )
    self.downsample4 = Downsample(hidden_dim * 8 )
    self.downsample5 = Downsample(hidden_dim  * 16)
    self.downsample6 = Downsample(hidden_dim * 32)

    self.fmb = Feature_map_block(hidden_dim * 64 , hidden_dim * 32)
    self.encoder = Encoder(in_channels , hidden_dim * 32 )

    self.fmb1 = Feature_map_block(hidden_dim * 32 * 2 , hidden_dim * 32)

    self.upsample1 = Upsample(hidden_dim * 32)
    self.upsample2 = Upsample(hidden_dim * 16)
    self.upsample3 = Upsample(hidden_dim * 8)
    self.upsample4 = Upsample(hidden_dim * 4)
    self.upsample5 = Upsample(hidden_dim * 2)
    self.upsample6 = Upsample(hidden_dim)

    self.last = Feature_map_block(hidden_dim //2  , out_channels)
    self.sigmoid = nn.ReLU()

  def forward(self , x , y):
    x0 = self.upfeature(x)
    x1 = self.downsample1(x0)
    x2 = self.downsample2(x1)
    x3 = self.downsample3(x2)
    x4 = self.downsample4(x3)
    x5 = self.downsample5(x4)
    x6 = self.downsample6(x5)

    x6 = self.fmb(x6)
    y = self.encoder(y)

    #print(x6.shape , y.shape)
    x6 = torch.cat((x6 , y) ,dim = 1)
    x6 = self.fmb1(x6)

    #print(x6.shape)
    
    x7 = self.upsample1(x6 , x5)
    #print(x7.shape)
    x8 = self.upsample2(x7 , x4)
    x9 = self.upsample3(x8 , x3)
    x10 = self.upsample4(x9 , x2)
    x11 = self.upsample5(x10 , x1)
    x12 = self.upsample6(x11 , x0)

    #print(x12.shape)

    x = self.sigmoid(self.last(x12))
    return x


In [16]:
x = torch.randn(2 , 3 , 512 , 512 ).to(device)
y = torch.randn(2 , 3 , 512 , 512 ).to(device)


In [17]:
GAN = Dual_direction_GAN(3 , 3).to(device)

In [18]:
z = GAN(x , y)

In [19]:
class Helper_1(nn.Module):
  def __init__(self , in_channels , out_channels , kernel_size = (2 , 2) , stride = (2 , 2) , use_batch_norm = True , activation = 'lreu'):
    super(Helper_1 , self).__init__()

    self.use_batch_norm = use_batch_norm
    self.conv1 = nn.Conv2d(in_channels , out_channels , kernel_size , stride)
    self.batch_norm = nn.InstanceNorm2d(out_channels)
    self.activation = activation
    if self.activation == 'relu':
      self.relu = nn.ReLU()
    else :
      self.relu = nn.LeakyReLU(0.2)

  def forward(self , x):
    #print(x.shape)
    x = self.conv1(x)
    x = self.batch_norm(x)
    x = self.relu(x)
    #print(x.shape)
    return x

In [20]:
class Discriminator(nn.Module):
  def __init__(self , in_channels , hidden_dim , out_channels):
    super(Discriminator , self).__init__()

    self.conv1 = Helper_1(in_channels , hidden_dim , use_batch_norm=False)
    self.conv2 = Helper_1(hidden_dim , hidden_dim * 2 )
    self.conv3 = Helper_1(hidden_dim * 2 , hidden_dim * 4)
    self.conv4 = Helper_1(hidden_dim * 4 , hidden_dim * 8)
    self.conv5 = Helper_1(hidden_dim * 8 , hidden_dim * 16)
    self.conv6 = Helper_1(hidden_dim * 16 , hidden_dim * 32)
    self.conv7 = Helper_1(hidden_dim * 32 , hidden_dim * 32)
    self.conv8 = Helper_1(hidden_dim * 32 , hidden_dim * 32)
    self.flatten = nn.Flatten()

    self.relu = nn.LeakyReLU(0.2)
    self.linear1 = nn.Linear(4096 , hidden_dim * 32)
    self.batchnorm1 = nn.BatchNorm1d(hidden_dim * 32)
    self.linear2 = nn.Linear(hidden_dim * 32 , hidden_dim * 8)
    self.batchnorm2 = nn.BatchNorm1d(hidden_dim * 8)
    self.linear3 = nn.Linear(hidden_dim * 8 , hidden_dim)
    self.batchnorm3 = nn.BatchNorm1d(hidden_dim)
    self.linear4 = nn.Linear(hidden_dim , out_channels)
    self.sigmoid = nn.Sigmoid()


  def forward(self , x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = self.conv3(x)
    x = self.conv4(x)
    x = self.conv5(x)
    x = self.conv6(x)
    x = self.conv7(x)
    x = self.conv8(x)
    x = self.flatten(x)
    x = self.relu(self.linear1(x))
    x = self.relu(self.linear2(x))
    x = self.relu(self.linear3(x))
    x = self.sigmoid(self.linear4(x))
    return x

In [21]:
generator_x = Dual_direction_GAN(3 , 3).to(device)
generator_y = Dual_direction_GAN(3 , 3).to(device)

In [22]:

discriminator_x = Discriminator(3 , 32 , 1).to(device)
discriminator_y = Discriminator(3 , 32 , 1).to(device)

In [23]:
def weights_init(m):
  if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
    torch.nn.init.normal_(m.weight, 0.0, 0.02)
  if isinstance(m, nn.BatchNorm2d):
    torch.nn.init.normal_(m.weight, 0.0, 0.02)
    torch.nn.init.constant_(m.bias, 0)

In [24]:
generator_x = generator_x.apply(weights_init)
generator_y = generator_y.apply(weights_init)
discriminator_x = discriminator_x.apply(weights_init)
discriminator_y = discriminator_y.apply(weights_init)

In [25]:
criterion = nn.BCEWithLogitsLoss()
l1_loss = nn.L1Loss()
lambda_recon = 200
mse_loss = nn.MSELoss()

In [26]:
n_epochs = 100
input_dim = 3
real_dim = 3
display_step = 10
batch_size = 2
lr = 0.0002
target_shape = 512
betas = (0.5 , 0.999)

In [27]:
transform = transforms.Compose([ transforms.ToTensor(), ])

dataset = torchvision.datasets.ImageFolder("/content/drive/MyDrive/Maps/maps/", transform=transform)

In [28]:
mean_generator_loss = 0
mean_discriminator_loss = 0
dataloader = DataLoader(dataset , batch_size = batch_size , shuffle=True)
cur_step = 0

In [29]:
opt_generator_y = torch.optim.Adam(generator_y.parameters() , lr=lr , betas=betas)
opt_generator_x = torch.optim.Adam(generator_x.parameters() , lr = lr , betas = betas)
opt_discriminator_x = torch.optim.Adam(discriminator_x.parameters() , lr = lr , betas=betas)
opt_discriminator_y = torch.optim.Adam(discriminator_y.parameters() , lr = lr , betas=betas)

In [30]:

def get_loss(fake , real  , criterion = criterion , l1_loss = l1_loss , lambda_recon = lambda_recon , switch = True):
  gen_loss = criterion(fake , real)
  l1_loss_ = l1_loss(fake , real)
  if switch:
    loss = gen_loss + lambda_recon *  l1_loss_
  else:
    loss = gen_loss 
  return loss

In [None]:
for epoch in range(n_epochs):
  for img , _ in tqdm(dataloader):
    image_width = img.shape[3]
    condition = img[: , : , : , :image_width//2]
    condition = nn.functional.interpolate(condition , size = target_shape)
    real = img[: , : , : , image_width//2:]
    real = nn.functional.interpolate(real , size = target_shape)
    cur_batch_size = len(condition)
    real = real.to(device)
    condition = condition.to(device)  

    opt_generator_y.zero_grad()
    Y_XY = generator_y(real , real)
    Y_YY = generator_y(condition , real)

    #loss_Y_XY = criterion(Y_XY , condition)
    loss_Y_XY = get_loss(Y_XY , condition , switch=True)
    #loss_Y_YY = criterion(Y_YY , condition)
    loss_Y_YY = get_loss(Y_YY , condition , switch=True)
    loss_Y = (loss_Y_XY + loss_Y_YY) /2


    loss_Y.backward()
    opt_generator_y.step()


    opt_generator_x.zero_grad()
    X_XX = generator_x(real , condition)
    X_YX = generator_x(condition , condition)

    #loss_X_XX = criterion(X_XX , real)
    loss_X_XX = get_loss(X_XX , real , switch=True)
    #loss_X_YX = criterion(X_YX , real)
    loss_X_YX = get_loss(X_YX , real ,  switch=True)
    loss_X = (loss_X_XX + loss_X_YX) /2
    loss_X.backward()
    opt_generator_x.step()


    opt_discriminator_y.zero_grad()
    with torch.no_grad():
      disc_Y_XY = generator_y(real , real)
      disc_Y_YY = generator_y(condition ,real)
    disc_fake_y_pred_YX = discriminator_y(disc_Y_XY)
    disc_loss_fake_pred_YX = criterion(disc_fake_y_pred_YX , torch.zeros_like(disc_fake_y_pred_YX))

    disc_fake_y_pred_YY = discriminator_y(disc_Y_YY)
    disc_loss_fake_pred_YY = criterion(disc_fake_y_pred_YY , torch.zeros_like(disc_fake_y_pred_YY))

    disc_real_pred = discriminator_y(condition)
    disc_real_pred_loss = criterion(disc_real_pred , torch.ones_like(disc_real_pred))

    disc_y_loss = (disc_loss_fake_pred_YX + disc_loss_fake_pred_YY + disc_real_pred_loss) /3
    #print(disc_y_loss)

    disc_y_loss.backward()
    opt_discriminator_y.step()

    opt_discriminator_x.zero_grad()
    with torch.no_grad():
      disc_X_XX = generator_x(real , condition)
      disc_X_YX_ = generator_x(condition , condition)
    disc_fake_pred_XX = discriminator_x(disc_X_XX)
    disc_loss_fake_pred_XX = criterion(disc_fake_pred_XX , torch.zeros_like(disc_fake_pred_XX))

    disc_fake_pred_YX_ = discriminator_x(disc_X_YX_)
    disc_loss_fake_pred_YX_ = criterion(disc_fake_pred_YX_ , torch.zeros_like(disc_fake_pred_YX_))

    disc_x_real_pred = discriminator_x(real)
    disc_x_real_loss = criterion(disc_x_real_pred , torch.ones_like(disc_x_real_pred))
    disc_x_loss = (disc_loss_fake_pred_XX + disc_loss_fake_pred_YX_ + disc_x_real_loss) / 3
    #print(disc_x_loss)
    disc_x_loss.backward()
    opt_discriminator_x.step()

    disc_loss = (disc_x_loss + disc_y_loss)/2
    gen_loss = (loss_X + loss_Y)/2

    mean_discriminator_loss += disc_loss.item() / display_step
    mean_generator_loss += gen_loss.item() / display_step

    if cur_step % display_step == 0:
      if cur_step > 0:
        print(f"Epoch {epoch}: Step {cur_step}: Generator loss: {mean_generator_loss}, Discriminator loss: {mean_discriminator_loss}")
      else:
        print("Pretrained initial state")
      print('Y')
      show_tensor_images(condition, size=(input_dim, target_shape, target_shape) , switch=False)
      print('X')
      show_tensor_images(real, size=(real_dim, target_shape, target_shape) , switch=False)
      print('X --> Y')
      show_tensor_images(Y_XY, size=(real_dim, target_shape, target_shape) , switch=False)
      #plt.imshow(  Y_XY.permute(1, 2, 0))
      print('Y --> Y')
      show_tensor_images(Y_YY, size=(real_dim, target_shape, target_shape) , switch=False)
      #plt.imshow(  Y_YY.permute(1, 2, 0))
      print('X --> X')
      show_tensor_images(X_XX, size=(real_dim, target_shape, target_shape) , switch=False)
      print('Y --> X')
      show_tensor_images(X_YX, size=(real_dim, target_shape, target_shape) , switch= False)
      mean_generator_loss = 0
      mean_discriminator_loss = 0
    cur_step += 1