<a href="https://colab.research.google.com/github/VedantDere0104/CoMo_GAN/blob/main/CoMo_GAN_Version_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
####

In [3]:
import torch
from torch import nn
from torchsummary import summary
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import torchvision

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
in_channels = 3
out_channels = 3
out_channels_disc = 1
z_dim = 512

In [6]:

def show_tensor_images(image_tensor, num_images=2, size=(3 , 512 , 512) , switch = False):
  image_shifted = image_tensor
  #print(image_shifted)
  image_unflat = image_shifted.detach().cpu().view(-1, *size)
  #print(image_unflat)
  image_grid = make_grid(image_unflat[:num_images], nrow=2 , normalize=False)
  #print(image_grid)
  if switch:
    image_grid = image_grid * 255.0
  plt.imshow(image_grid.permute(1, 2, 0).squeeze())
  plt.show()

In [7]:
def crop(image, new_shape):
    middle_height = image.shape[2] // 2
    middle_width = image.shape[3] // 2
    starting_height = middle_height - new_shape[2] // 2
    final_height = starting_height + new_shape[2]
    starting_width = middle_width - new_shape[3] // 2
    final_width = starting_width + new_shape[3]
    cropped_image = image[:, :, starting_height:final_height, starting_width:final_width]
    return cropped_image

In [8]:
class Conv(nn.Module):
  def __init__(self , 
               in_channels , 
               out_channels , 
               kernel_size = (2 , 2) , 
               stride = (2 , 2) , 
               padding = 0 , 
               use_norm = True , 
               use_activation = True , 
               use_dropout = False , 
               activation = 'lreu'):
    super(Conv , self).__init__()

    self.use_norm = use_norm
    self.use_activation = use_activation
    self.use_dropout = use_dropout
    self.activation = activation

    self.conv1 = nn.Conv2d(in_channels , out_channels , kernel_size , stride , padding , padding_mode='reflect')
    
    if self.use_norm:
      self.norm = nn.InstanceNorm2d(out_channels)
    if self.use_activation:
      if self.activation == 'lrelu':
        self.activation_ = nn.LeakyReLU(0.2)
      else :
        self.activation_ = nn.ReLU(inplace=True)
    if self.use_dropout:
      self.dropout = nn.Dropout()
  
  def forward(self , x):
    x = self.conv1(x)
    if self.use_norm:
      x = self.norm(x)
    if self.use_activation:
      x = self.activation_(x)
    if self.use_dropout:
      x = self.dropout(x)

    return x

In [None]:
conv = Conv(3 , 32).to(device)
summary(conv , (3 , 512 , 512))

In [10]:
class ConvT(nn.Module):
  def __init__(self , 
               in_channels , 
               out_channels , 
               kernel_size = (2 , 2) , 
               stride = (2 , 2) , 
               padding = 0 , 
               use_norm = True , 
               use_activation = True , 
               use_dropout = False , 
               activation = 'reu'):
    super(ConvT , self).__init__()

    self.use_norm = use_norm
    self.use_activation = use_activation
    self.use_dropout = use_dropout
    self.activation = activation

    self.convT1 = nn.ConvTranspose2d(in_channels , out_channels , kernel_size , stride , padding )
    
    if self.use_norm:
      self.norm = nn.InstanceNorm2d(out_channels)
    if self.use_activation:
      if self.activation == 'lrelu':
        self.activation_ = nn.LeakyReLU(0.2)
      else :
        self.activation_ = nn.ReLU(inplace=True)
    if self.use_dropout:
      self.dropout = nn.Dropout()
  
  def forward(self , x):
    x = self.convT1(x)
    if self.use_norm:
      x = self.norm(x)
    if self.use_activation:
      x = self.activation_(x)
    if self.use_dropout:
      x = self.dropout(x)

    return x

In [None]:
convT = ConvT(3 , 32).to(device)
summary(convT , (3 , 256 , 256))

In [12]:
class Linear(nn.Module):
  def __init__(self ,
               in_channels ,
               out_channels , 
               kernel_size = (2, 2) , 
               stride = (2 , 2) ,
               padding = 0 ,
               use_norm = True , 
               use_activation = True , 
               activation = 'lrelu'):
    super(Linear , self).__init__()

    self.use_norm = use_norm
    self.use_activation = use_activation
    self.activation = activation

    self.linear1 = nn.Linear(in_channels , out_channels)
    
    if self.use_norm:
      self.norm = nn.BatchNorm1d(out_channels)
    if self.use_activation:
      if self.activation == 'lrelu':
        self.activation_ = nn.LeakyReLU(0.2)
      else:
        self.activation_ = nn.ReLU(inplace=True)
  
  def forward(self ,x):
    x = self.linear1(x)
    if self.use_norm:
      x = self.norm(x)
    if self.use_activation:
      x = self.activation_(x)

    return x

In [13]:
linear = Linear(3 , 32).to(device)
x = torch.randn(3 , 3).to(device)
z = linear(x)

In [14]:
class Generator_Encoder(nn.Module):
  def __init__(self, 
                in_channels , 
               out_channels, 
               hidden_dim = 32 , 
               ):
    super(Generator_Encoder , self).__init__()

    self.conv1 = Conv(in_channels , hidden_dim , use_norm=False)
    self.conv2 = Conv(hidden_dim , hidden_dim * 2)
    self.conv3 = Conv(hidden_dim * 2 , hidden_dim * 4)
    self.conv4 = Conv(hidden_dim * 4 , hidden_dim * 8)
    self.conv5 = Conv(hidden_dim * 8 , hidden_dim * 16)
    self.conv6 = Conv(hidden_dim * 16 , hidden_dim * 32)
    self.conv7 = Conv(hidden_dim * 32 , hidden_dim * 32)

    self.flatten = nn.Flatten()

    self.linear1 = Linear(16384 , hidden_dim * 32)
    self.linear2 = Linear(hidden_dim * 32 , hidden_dim * 32)
    self.linear3 = Linear(hidden_dim * 32 , out_channels)
    

  def forward(self , x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = self.conv3(x)
    x = self.conv4(x)
    x = self.conv5(x)
    x = self.conv6(x)
    x = self.conv7(x)
    x = self.flatten(x)
    x = self.linear1(x)
    x = self.linear2(x)
    x = self.linear3(x)
    return x.view(x.shape[0] , x.shape[1] , 1 , 1)

In [None]:
encoder = Generator_Encoder(3 , 512).to(device)
summary(encoder , (3 , 512 , 512))

In [16]:
class HE(nn.Module):
  def __init__(self ,
               in_channels , 
               out_channels,  
               hidden_dim = 32 ):
    super(HE , self).__init__()

    self.convT1 = ConvT(in_channels , hidden_dim)
    self.convT2 = ConvT(hidden_dim , hidden_dim * 2)
    self.convT3 = ConvT(hidden_dim * 2 , hidden_dim * 4)
    self.convT4 = ConvT(hidden_dim *4 , hidden_dim * 8)

    self.conv1 = Conv(hidden_dim * 8 , hidden_dim * 4)
    self.conv2 = Conv(hidden_dim * 4 , hidden_dim * 2)
    self.conv3 = Conv(hidden_dim * 2 , hidden_dim)
    self.conv4 = Conv(hidden_dim , out_channels)

  def forward(self,  x):
    x = self.convT1(x)
    x = self.convT2(x)
    x = self.convT3(x)
    x = self.convT4(x)
    x = self.conv1(x)
    x = self.conv2(x)
    x = self.conv3(x)
    x = self.conv4(x)
    return x

In [None]:
he = HE(512 , 256).to(device)
summary(he , (512 , 1 , 1))

In [18]:
class FIN(nn.Module):
  def __init__(self):
    super(FIN , self).__init__()

  def forward(self , x , o):
    norm = ((x - torch.mean(x))/ torch.var(x)) * self.FIN_parameters(o , switch = True) + self.FIN_parameters(o , switch= True , angle = 'sin')
    return norm

  def FIN_parameters(self , o , switch =False , angle = 'cos'):
    a = torch.randn(*o.shape , requires_grad=True).to(device)
    b = torch.randn(*o.shape , requires_grad=True).to(device)
    if switch:
      if angle == 'cos':
        f = a * torch.cos(o) + b
      else:
        f = a * torch.sin(o) + b
    else:
      f = a * o + b
    return f

In [None]:
x = torch.randn(2 , 3 , 512 , 512).to(device)
o = torch.randn(1).to(device)
fin = FIN().to(device)
norm = fin(x , o)
norm.shape

In [20]:
class DRB(nn.Module):
  def __init__(self , 
               in_channels , 
               out_channels):
    super(DRB , self).__init__()

    self.he = HE(in_channels , out_channels)
    self.FIN = FIN()
    self.hem = HE(in_channels , out_channels)

  def forward(self , x , o):
    he = self.he(x)
    fin = self.FIN(x , o)
    hem = self.hem(x)
    #print(he.shape , hem.shape)
    hy = torch.cat([he , fin] , dim=1)
    hy = torch.cat([hy , x] , dim=1)
    hym = torch.cat([hem , fin] , dim=1)
    hym = torch.cat([hym , x] , dim=1)
    x = torch.cat([hy , hym] , dim=1)
    return x

In [None]:
x = torch.randn(3 , 512 , 1 , 1).to(device)
o = torch.randn(1).to(device)
drb = DRB(512 , 512).to(device)
x= drb(x , o)
print(x.shape)

In [22]:
class Generator_Decoder(nn.Module):
  def __init__(self , 
               in_channels , 
               out_channels , 
               hidden_dim = 32):
    super(Generator_Decoder , self).__init__()

    self.convT1 = ConvT(in_channels , hidden_dim)
    self.convT2 = ConvT(hidden_dim , hidden_dim * 2)
    self.convT3 = ConvT(hidden_dim * 2 , hidden_dim * 4)
    self.convT4 = ConvT(hidden_dim * 4 , hidden_dim * 8)
    self.convT5 = ConvT(hidden_dim * 8 , hidden_dim * 16)
    self.convT6 = ConvT(hidden_dim * 16 , hidden_dim * 32)
    self.convT7 = ConvT(hidden_dim * 32 , hidden_dim * 32)
    self.convT8 = ConvT(hidden_dim * 32 , hidden_dim * 16)
    self.convT9 = ConvT(hidden_dim * 16 , out_channels)

  def forward(self , x):
    x = self.convT1(x)
    x = self.convT2(x)
    x = self.convT3(x)
    x = self.convT4(x)
    x = self.convT5(x)
    x = self.convT6(x)
    x = self.convT7(x)
    x = self.convT8(x)
    x = self.convT9(x)

    return x

In [None]:
generator_decoder = Generator_Decoder(512 , 3).to(device)
summary(generator_decoder , (512 , 1 , 1))

In [24]:
class Generator(nn.Module):
  def __init__(self , 
               in_channels , 
               out_channels , 
               z_dim = 512 , 
               hidden_dim = 32):
    super(Generator ,self).__init__()

    self.encoder = Generator_Encoder(in_channels , z_dim)
    self.drb = DRB(z_dim , z_dim)
    self.decoder = Generator_Decoder((hidden_dim * 32 + z_dim) * 2 , out_channels)

  def forward(self , x , o):
    x = self.encoder(x)
    h_ = self.drb(x , o)
    y = self.decoder(h_)
    return y

In [None]:
x = torch.randn(2 , 3 , 512 , 512).to(device)
o = torch.randn(1).to(device)
generator = Generator(3 , 3).to(device)
y = generator(x , o)
print(y.shape)

In [26]:
class Discriminator(nn.Module):
  def __init__(self , 
               in_channels , 
               out_channels):
    super(Discriminator , self).__init__()

    self.encoder = Generator_Encoder(in_channels , out_channels)

  def forward(self , x):
    x = self.encoder(x)
    return x

In [None]:
discriminator = Discriminator(3 , 1).to(device)
summary(discriminator , (3 , 512 , 512))

In [28]:
class M_(nn.Module):
  def __init__(self , 
               in_channels , 
               out_channels):
    super(M_ , self).__init__()

    self.generator_x = Generator(in_channels , out_channels)
  def forward(self , x , x_ , o , o_):
    x , _ = self.generator_x(x , o)
    _ , x_ = self.generator_x(x_ , o_ )
    return x , x_ 

In [29]:
x = torch.randn(2 , 3 , 512 , 512).to(device)
x_ = torch.randn_like(x)
o = torch.randn(1).to(device)
o_ = torch.randn_like(o)
m = M_(3 , 3).to(device)
x , x_ = m(x , x_ , o , o_)

In [30]:
class O_Net(nn.Module):
  def __init__(self ,
               in_channels , 
               out_channels):
    super(O_Net , self).__init__()

    self.discriminator_x = Discriminator(in_channels ,out_channels)
    self.discriminator_y = Discriminator(in_channels ,out_channels)

  def forward(self , x , x_):
    o = self.discriminator_x(x)
    o_ = self.discriminator_y(x)
    delta = o - o_
    return delta

In [None]:
onet = O_Net(3 , 1).to(device)
x = torch.randn(2 , 3 , 512 , 512).to(device)
y = torch.randn_like(x)
a = onet(x , y )
a.shape

In [32]:
generator_x = Generator(in_channels , out_channels , z_dim).to(device)
generator_y = Generator(in_channels , out_channels , z_dim).to(device)
discriminator_x = Discriminator(in_channels , out_channels_disc).to(device)
discriminator_y = Discriminator(in_channels , out_channels_disc).to(device)
o_net = O_Net(in_channels , out_channels_disc).to(device)

In [33]:
def weights_init(m):
  if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
    torch.nn.init.normal_(m.weight, 0.0, 0.02)
  if isinstance(m, nn.BatchNorm2d):
    torch.nn.init.normal(m.weight, 0.0, 0.02)
    torch.nn.init.constant(m.bias, 0)

In [34]:
generator_x = generator_x.apply(weights_init)
generator_y = generator_y.apply(weights_init)
discriminator_x = discriminator_x.apply(weights_init)
discriminator_y = discriminator_y.apply(weights_init)
o_net = o_net.apply(weights_init)

In [35]:
criterion = nn.BCEWithLogitsLoss()
l1_loss = nn.L1Loss()
lambda_recon = 200

In [36]:
n_epochs = 100
input_dim = 3
real_dim = 3
display_step = 10
batch_size = 2
lr = 0.0002
target_shape = 512

betas = (0.5 , 0.999)

In [37]:

transform = transforms.Compose([ transforms.ToTensor(), ])

dataset = torchvision.datasets.ImageFolder("/content/drive/MyDrive/Maps/maps/", transform=transform)

In [38]:

mean_generator_loss = 0
mean_discriminator_loss = 0
dataloader = DataLoader(dataset , batch_size = batch_size , shuffle=True)
cur_step = 0

In [39]:
opt_generator_y = torch.optim.Adam(generator_y.parameters() , lr=lr , betas=betas)
opt_generator_x = torch.optim.Adam(generator_x.parameters() , lr = lr , betas = betas)
opt_discriminator_x = torch.optim.Adam(discriminator_x.parameters() , lr = lr , betas=betas)
opt_discriminator_y = torch.optim.Adam(discriminator_y.parameters() , lr = lr , betas=betas)
opt_o_net = torch.optim.Adam(o_net.parameters() , lr=lr , betas = betas)

In [40]:
def get_gen_loss(del_1 , del_2 , delta , criterion_loss , loss_m):
  loss_g_o = torch.norm(del_1) + torch.norm(del_1 - delta)
  loss_gt = torch.norm(del_2 - delta)
  loss_o = loss_g_o + loss_gt

  loss_generator = loss_o + criterion_loss + loss_m
  return loss_generator


In [41]:
def get_loss(fake , real  , criterion = criterion , l1_loss = l1_loss , lambda_recon = lambda_recon):
  gen_loss = criterion(fake , real)
  l1_loss = l1_loss(fake , real)
  loss = gen_loss + lambda_recon *  l1_loss
  return loss

In [None]:
for epoch in range(n_epochs):
  for img , _ in tqdm(dataloader):
    image_width = img.shape[3]
    y = img[: , : , : , :image_width//2]
    y = nn.functional.interpolate(y , size = target_shape)
    x = img[: , : , : , image_width//2:]
    x = nn.functional.interpolate(x , size = target_shape)
    cur_batch_size = len(x)
    x = x.to(device)
    y = y.to(device)  
    o = torch.randn(x.shape[0] , 1 , 1 , 1).to(device)
    o_ = torch.randn_like(o)

    #show_tensor_images(x)
    #show_tensor_images(y)

    opt_generator_x.zero_grad()
    y_o  = generator_x(x , o)
    loss_y_o = get_loss(y_o , y)
    loss_y_o.backward()
    opt_generator_x.step()

    opt_discriminator_x.zero_grad()
    with torch.no_grad():
      y_o = generator_x(x , o)
    disc_fake_pred = discriminator_x(y_o)
    disc_real_pred = discriminator_x(y)
    disc_loss = criterion(disc_fake_pred , torch.zeros_like(disc_fake_pred))
    disc_real_pred_loss = criterion(y , torch.ones_like(y))
    disc_loss = (disc_loss + disc_real_pred_loss)/2
    disc_loss.backward()
    opt_discriminator_x.step()

    opt_generator_y.zero_grad()
    x_o  = generator_y(y , o)
    loss_x_o = get_loss(x_o , x)
    loss_x_o.backward()
    opt_generator_y.step()

    opt_discriminator_y.zero_grad()
    with torch.no_grad():
      x_o  = generator_y(y , o)
    disc_fake_pred_ = discriminator_y(x_o)
    disc_real_pred_ = discriminator_y(x)
    disc_loss_ = criterion(disc_fake_pred_ , torch.zeros_like(disc_fake_pred_))
    disc_real_pred_loss_ = criterion(y , torch.ones_like(y))
    disc_loss_ = (disc_loss_ + disc_real_pred_loss_)/2
    disc_loss_.backward()
    opt_discriminator_x.step()


    generator_loss_ = (loss_y_o + loss_x_o)/2
    discriminator_loss_ = (disc_loss + disc_loss_)/2

    gen_loss_ = torch.tensor(generator_loss_ , requires_grad=True)

    opt_o_net.zero_grad()
    with torch.no_grad():
      y_o  = generator_x(x , o)
      x_o = generator_y(y , o)
      y_o_ = generator_x(x , o_)
      x_o_ = generator_y(y , o_)
    delta_1 = o_net(y_o , x_o)
    delta_2 = o_net(y_o_ , x_o_)
    loss_1 = get_gen_loss(delta_1 , delta_2 , 0 , gen_loss_ , 0)
    loss_1.backward()
    opt_o_net.step()

    mean_generator_loss = ((generator_loss_ + loss_1)/2) /display_step
    mean_discriminator_loss = ((discriminator_loss_ + loss_1)/2) / display_step


    if cur_step % display_step == 0:
      if cur_step > 0:
        print(f"Epoch {epoch}: Step {cur_step}: Generator loss: {mean_generator_loss}, Discriminator loss: {mean_discriminator_loss}")
      else:
        
        #print(img_tensor.shape)
        print("Pretrained initial state")

      
      print('Y')
      show_tensor_images(y, size=(real_dim, target_shape, target_shape) , switch=False)


      print('X')
      show_tensor_images(x, size=(real_dim, target_shape, target_shape) , switch=False)

      print('y_o')
      show_tensor_images(y_o, size=(real_dim, target_shape, target_shape) , switch= False)

      print('x_o')
      show_tensor_images(x_o, size=(real_dim, target_shape, target_shape) , switch= False)

      print('y_o_')
      show_tensor_images(y_o_, size=(input_dim, target_shape, target_shape) , switch=False)

      print('x_o_')
      show_tensor_images(x_o_, size=(real_dim, target_shape, target_shape) , switch=False)



      mean_generator_loss = 0
      mean_discriminator_loss = 0
    cur_step += 1
