<a href="https://colab.research.google.com/github/VedantDere0104/Cycle_GAN/blob/main/Cycle_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
####

In [None]:
import torch
from torch import nn

from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import VOCSegmentation
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from torchsummary import summary

In [None]:
in_channels = 3
out_channels = 3
z_dim = 512
hidden_dim = 32
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def show_tensor_images(image_tensor, num_images=2, size=(3 , 512 , 512) ):
  image_shifted = image_tensor
  #print(image_shifted)
  image_unflat = image_shifted.detach().cpu().view(-1, *size)
  #print(image_unflat)
  image_grid = make_grid(image_unflat[:num_images], nrow=2 , normalize=False)
  #print(image_grid)
  plt.imshow(image_grid.permute(1 , 2, 0).squeeze())
  plt.show()

In [None]:
def crop(image, new_shape):
  middle_height = image.shape[2] // 2
  middle_width = image.shape[3] // 2
  starting_height = middle_height - new_shape[2] // 2
  final_height = starting_height + new_shape[2]
  starting_width = middle_width - new_shape[3] // 2
  final_width = starting_width + new_shape[3]
  cropped_image = image[:, :, starting_height:final_height, starting_width:final_width]
  return cropped_image

In [None]:
class Conv(nn.Module):
  def __init__(self , 
               in_channels , 
               out_channels , 
               kernel_size = 3 , 
               stride = 1 , 
               padding = 1 , 
               use_norm = True , 
               use_activation = True , 
               use_dropout = False , 
               use_pool = True):
    super(Conv , self).__init__()

    self.use_norm = use_norm
    self.use_activation = use_activation
    self.use_dropout = use_dropout
    self.use_pool = use_pool
    
    self.conv1 = nn.Conv2d(in_channels , out_channels , kernel_size , stride , padding)

    if self.use_norm:
      self.norm = nn.InstanceNorm2d(out_channels)
    if self.use_activation:
      self.activation = nn.LeakyReLU(0.2)
    if self.use_dropout:
      self.dropout = nn.Dropout()
    if self.use_pool:
      self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

  def forward(self , x):
    x = self.conv1(x)
    if self.use_norm:
      x = self.norm(x)
    if self.use_activation:
      x = self.activation(x)
    if self.use_dropout:
      x = self.dropout(x)
    if self.use_pool:
      x = self.maxpool(x)
    return x

In [None]:
conv = Conv(in_channels , out_channels ).to(device)
summary(conv , (3 , 512 , 512))

In [None]:
class ConvT(nn.Module):
  def __init__(self , 
               in_channels , 
               out_channels , 
               kernel_size = (2 , 2) , 
               stride = (2 , 2) ,
               padding = 0 , 
               use_norm = True , 
               use_activation = True , 
               use_dropout = False):
    super(ConvT , self).__init__()

    self.use_norm = use_norm
    self.use_activation = use_activation
    self.use_dropout = use_dropout

    self.convT1 = nn.ConvTranspose2d(in_channels , out_channels , kernel_size , stride , padding)

    if self.use_norm:
      self.norm = nn.InstanceNorm2d(out_channels)
    if self.use_activation:
      self.activation = nn.ReLU(inplace=True)
    if self.use_dropout:
      self.dropout = nn.Dropout()

  def forward(self , x):
    x = self.convT1(x)
    if self.use_norm:
      x = self.norm(x)
    if self.use_activation:
      x = self.activation(x)
    if self.use_dropout:
      x = self.dropout(x)
    return x

In [None]:
convT = ConvT(in_channels , 32).to(device)
summary(convT , (3 , 256, 256))

In [None]:
class Decoder_block(nn.Module):
  def __init__(self , 
               in_channels , 
               out_channels ,
               use_norm = True , 
               use_activation = True , 
               use_dropout = False):
    super(Decoder_block , self).__init__()

    self.convT1 = ConvT(in_channels * 2 , in_channels , use_norm = use_norm , use_activation=use_activation , use_dropout=use_dropout)
    self.conv1 = Conv(in_channels , in_channels , use_pool=False , use_norm = use_norm , use_dropout=use_dropout , use_activation= use_activation)
    self.conv2 = Conv(in_channels , out_channels , use_pool=False , use_norm = use_norm , use_dropout=use_dropout , use_activation= use_activation)
  
  def forward(self , x , y):
    y = crop(y , x.shape)
    x = torch.cat([x , y] , dim=1)
    x = self.convT1(x)
    x = self.conv1(x)
    x = self.conv2(x)
    return x



In [None]:
x = torch.randn(2 , 3 , 64 , 64).to(device)
y = torch.randn(2 , 3 , 512 ,512).to(device)
decoder_block = Decoder_block(3 , 32).to(device)
z = decoder_block(x , y)
z.shape

In [None]:
class Generator(nn.Module):
  def __init__(self , 
               in_channels , 
               out_channels , 
               hidden_dim = 32):
    super(Generator , self).__init__()

    self.conv1 = Conv(in_channels , hidden_dim , use_norm=False)
    self.conv2 = Conv(hidden_dim , hidden_dim * 2)
    self.conv3 = Conv(hidden_dim * 2 , hidden_dim * 4)
    self.conv4 = Conv(hidden_dim * 4 , hidden_dim * 8)
    self.conv5 = Conv(hidden_dim * 8 , hidden_dim * 16)
    self.conv6 = Conv(hidden_dim * 16 , hidden_dim * 32 , use_norm=False)
    self.conv7 = Conv(hidden_dim * 32 , hidden_dim * 64 , use_norm=False) 
    
    self.middle = Conv(hidden_dim * 64 , hidden_dim * 64 , use_pool=False)

    self.decoder1 = Decoder_block(hidden_dim * 64 , hidden_dim * 32 , use_norm=False)
    self.decoder2 = Decoder_block(hidden_dim * 32 , hidden_dim * 16)
    self.decoder3 = Decoder_block(hidden_dim * 16 , hidden_dim * 8)
    self.decoder4 = Decoder_block(hidden_dim * 8 , hidden_dim * 4)
    self.decoder5 = Decoder_block(hidden_dim * 4 , hidden_dim * 2)
    self.decoder6 = Decoder_block(hidden_dim * 2 , hidden_dim , use_norm=False)
    self.decoder7 = Decoder_block(hidden_dim , out_channels , use_norm=False)

    self.last = Conv(out_channels , out_channels , use_norm=False , use_pool=False , use_activation=False)
    self.sigmoid = nn.Sigmoid()

  def forward(self , x):
    x1 = self.conv1(x)
    x2 = self.conv2(x1)
    x3 = self.conv3(x2)
    x4 = self.conv4(x3)
    x5 = self.conv5(x4)
    x6 = self.conv6(x5)
    x7 = self.conv7(x6)

    middle = self.middle(x7)

    x8 = self.decoder1(middle , x7)
    x9 = self.decoder2(x8 , x6)
    x10 = self.decoder3(x9 , x5)
    x11 = self.decoder4(x10 , x4)
    x12 = self.decoder5(x11 , x3)
    x13 = self.decoder6(x12 , x2)
    x14 = self.decoder7(x13 , x1)

    x = self.sigmoid(self.last(x14))

    return x

In [None]:
generator = Generator(in_channels, out_channels).to(device)
summary(generator , (3 , 512 , 512))

In [None]:
class Encoder(nn.Module):
  def __init__(self , 
               in_channels , 
               out_channels , 
               hidden_dim = 32 , 
               kernel_size = 3 , 
               stride = 1 , 
               padding = 1):
    super(Encoder , self).__init__()
    
    self.conv1 = Conv(in_channels , hidden_dim , use_norm=False)
    self.conv2 = Conv(hidden_dim , hidden_dim * 2)
    self.conv3 = Conv(hidden_dim * 2 , hidden_dim * 4)
    self.conv4 = Conv(hidden_dim * 4 , hidden_dim * 8)
    self.conv5 = Conv(hidden_dim * 8 , hidden_dim * 16)
    self.conv6 = Conv(hidden_dim * 16 , hidden_dim * 32)
    self.conv7 = Conv(hidden_dim * 32 , hidden_dim * 32)
    self.conv8 = Conv(hidden_dim * 32 , hidden_dim * 16)
    self.conv9 = Conv(hidden_dim * 16 , out_channels , use_norm=False)

  def forward(self , x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = self.conv3(x)
    x = self.conv4(x)
    x = self.conv5(x)
    x = self.conv6(x)
    x = self.conv7(x)
    x = self.conv8(x)
    x = self.conv9(x)
    x = x.view(x.shape[0] , x.shape[1])
    return x

In [None]:
encoder = Encoder(in_channels , 1).to(device)
summary(encoder , (3, 512 ,512))

In [None]:
class Discriminator(nn.Module):
  def __init__(self , 
               in_channels , 
               out_channels , 
               hidden_dim = 32):
    super(Discriminator , self).__init__()
    
    self.encoder = Encoder(in_channels , out_channels , hidden_dim)

  def forward(self , x):
    x = self.encoder(x)
    return x

In [None]:
adv_criterion = nn.BCEWithLogitsLoss()
recon_criterion = nn.L1Loss()
lambda_recon = 200
betas = (0.5 , 0.999)

n_epochs = 200
in_channels = 3
out_channels = 3
display_step = 100
batch_size = 2
lr = 0.0002
target_shape = 512

In [None]:
transform = transforms.Compose([
    transforms.ToTensor()
])

import torchvision
dataset = torchvision.datasets.ImageFolder("/content/drive/MyDrive/Maps/maps/", transform=transform)

In [None]:
generator_x = Generator(in_channels , out_channels).to(device)
generator_y = Generator(in_channels , out_channels).to(device)
discriminator_x = Discriminator(in_channels , 1).to(device)
discriminator_y = Discriminator(in_channels , 1).to(device)

In [None]:
opt_generator_x = torch.optim.Adam(generator_x.parameters() , lr = lr , betas=betas)
opt_generator_y = torch.optim.Adam(generator_y.parameters(), lr = lr , betas=betas)
opt_discriminator_x = torch.optim.Adam(discriminator_x.parameters(), lr = lr , betas=betas)
opt_discriminator_y = torch.optim.Adam(discriminator_y.parameters(), lr = lr , betas=betas)

In [None]:
def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)


In [None]:
generator_x = generator_x.apply(weights_init)
generator_y = generator_y.apply(weights_init)
discriminator_x = discriminator_x.apply(weights_init)
discriminator_y = discriminator_y.apply(weights_init)

In [None]:
dataloader = DataLoader(dataset , batch_size = batch_size , shuffle=True)
input_dim = 3
real_dim = 3

In [None]:
def get_loss(real , 
             condition , 
             gen , 
             disc , 
             adv_criterion = adv_criterion , 
             recon_criterion = recon_criterion , 
             lambda_recon=lambda_recon):
  fake_gen = gen(condition)
  disc_fake_pred = disc(fake_gen)
  gen_loss_ = adv_criterion(fake_gen , real)
  disc_loss = adv_criterion(disc_fake_pred , torch.zeros_like(disc_fake_pred))
  gen_loss = recon_criterion(fake_gen , real)
  loss = disc_loss + lambda_recon * gen_loss + lambda_recon * gen_loss_ 
  return loss

In [None]:
def train():
  mean_generator_loss = 0
  mean_discriminator_loss = 0
  cur_step = 0
  for epoch in range(n_epochs):
    for img , _ in tqdm(dataloader):
      image_width = img.shape[3]
      condition = img[: , : , : , :image_width//2]
      condition = nn.functional.interpolate(condition , size = target_shape)
      real = img[: , : , : , image_width//2:]
      real = nn.functional.interpolate(real , size = target_shape)
      cur_batch_size = len(condition)
      real = real.to(device)
      condition = condition.to(device)

      opt_generator_x.zero_grad()

      gen_loss_condition_real = get_loss(real , condition , generator_x , discriminator_x)

      gen_loss_condition_real.backward()
      opt_generator_x.step()

      opt_generator_y.zero_grad()

      gen_loss_real_conditon = get_loss(condition , real , generator_y , discriminator_y)

      gen_loss_real_conditon.backward()
      opt_generator_y.step()

      opt_discriminator_x.zero_grad()
      
      with torch.no_grad():
        fake_condition_real = generator_x(condition)
      disc_fake_pred_condition_real = discriminator_x(fake_condition_real)
      disc_real_pred = discriminator_x(real)
      
      disc_fake_loss = adv_criterion(disc_fake_pred_condition_real , torch.zeros_like(disc_fake_pred_condition_real))
      disc_real_loss = adv_criterion(disc_real_pred , torch.ones_like(disc_real_pred))

      disc_loss = (disc_fake_loss + disc_real_loss)/2

      disc_loss.backward()
      opt_discriminator_x.step()


      opt_discriminator_y.zero_grad()
      with torch.no_grad():
        fake_real_condition = generator_y(real)
      disc_fake_pred_real_condition = discriminator_y(fake_real_condition)
      disc_real_pred_ = discriminator_y(condition)

      disc_loss_1 = adv_criterion(disc_fake_pred_real_condition , torch.zeros_like(disc_fake_pred_real_condition))
      disc_loss_2 = adv_criterion(disc_real_pred_ , torch.ones_like(disc_real_pred_))

      disc_loss_ = (disc_loss_1 + disc_loss_2) /2
      disc_loss_.backward()
      opt_discriminator_y.step()

      disc_loss = (disc_loss + disc_loss_)/2
      gen_loss = (gen_loss_real_conditon + gen_loss_condition_real)/2

      mean_discriminator_loss += disc_loss.item() / display_step
      mean_generator_loss += gen_loss.item() / display_step

      if cur_step % display_step == 0:
        if cur_step > 0:
          print(f"Epoch {epoch}: Step {cur_step}: Generator loss: {mean_generator_loss}, Discriminator loss: {mean_discriminator_loss}")
        else:
          print("Pretrained initial state")
        print('condition')
        show_tensor_images(condition, size=(input_dim, target_shape, target_shape))
        print('real')
        show_tensor_images(real, size=(real_dim, target_shape, target_shape) )
        print('fake_condition_real')
        show_tensor_images(fake_condition_real , size=(input_dim , target_shape , target_shape))
        print('fake_real_condition')
        show_tensor_images(fake_real_condition , size=(input_dim , target_shape , target_shape))
        mean_generator_loss = 0
        mean_discriminator_loss = 0
      cur_step += 1

In [None]:
train()

In [None]:
#torch.save(generator_x.state_dict() , '/content/drive/MyDrive/Map_Dataset_Cycle_Gen_x.pth')
#torch.save(generator_y.state_dict() , '/content/drive/MyDrive/Map_Dataset_Cycle_Gen_y.pth')
#torch.save(discriminator_x.state_dict() , '/content/drive/MyDrive/Map_Dataset_Cycle_Disc_x.pth')
#torch.save(discriminator_y.state_dict() , '/content/drive/MyDrive/Map_Dataset_Cycle_Disc_y.pth')