# CycleGAN

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt


In [None]:
# Able to train on up-paired training data
#Translate an image from target to source domain

#Generator G horses -> zebras
#adversarial loss

#Inverse mapping F, zebras -> horses
#cycle consitency loss

In [None]:
# X is a set of horses
# Y is a set of zebras

# G: X -> Y
# F: Y -> X

# Adversarial Loss:
#Discriminator D_Y classifies if an generated zebra is real or fake
#Discriminator D_X classifies if an generated horse is real or fake

# Forward Cycle Consistency Loss:
# X_hat = F(G(X))
# want to minimize ||X - X_hat|| (L1 loss)

# Backward Cycle Consistency Loss:
# Y_hat = G(F(Y))
# want to minimize ||Y - Y_hat|| (L1 loss)

In [None]:
# Adversarial Loss:
#Loss_D(G, D_Y, X, Y) = -E[log(D_Y(Y))] - E[log(1 - D_Y(G(X)))]

#Cycle Consistency Loss:
#Loss_cyc(G, F, X, Y) = E||X - F(G(X))|| + E||Y - G(F(Y))|| (L1 Loss)

# Full Objective:
#Loss(G, F, D_X, D_Y, X, Y) = Loss_D(G, D_Y, X, Y) + Loss_D(F, D_X, Y, X) + lambda * Loss_cyc(G, F, X, Y)

In [2]:
# Network Architecture:

# Discriminator: netowkr classifies if 70x70 patch is real or fake
# generating a grid

# Train G to minimize
# (D(G(x)) - 1)^2
# Train D to minimize
# (D(y)-1)^2 + D(G(x))^2

#batch size = 1
# lambda = 10
# lr = 0.0002 with Adam optimizer
#keep learning rate constant for 100 epochs,
#then linearly decay to 0 over next 100 epochs

In [9]:
###############################################################################
# Models
###############################################################################
num_classes = 14

class DoubleConv(nn.Module):
  def __init__(self, input_channels: int, output_channels: int):
    """Initialize the DownSampleBlock class.
      
    Parameters:
      input_size (int) -- input size to block
      input_channels (int) -- #channels into first layer in block
      output_channels (int) -- #channels each layer produces
    """
    super().__init__()
    conv1 = nn.Conv2d(in_channels=input_channels, out_channels=output_channels, kernel_size=3, stride=1, padding=0)
    conv2 = nn.Conv2d(in_channels=output_channels, out_channels=output_channels, kernel_size=3, stride=1, padding=0)

    nn.init.normal_(conv1.weight, mean=0.0, std= (2/(9*input_channels))**(1/2)) #sqrt(2/N)
    nn.init.normal_(conv2.weight, mean=0.0, std= (2/(9*output_channels))**(1/2)) 
    
    self.conv = nn.Sequential(conv1, nn.BatchNorm2d(output_channels), nn.ReLU(),
                              conv2, nn.BatchNorm2d(output_channels), nn.ReLU())
    
  def forward(self, x: torch.Tensor) -> torch.Tensor:
    return self.conv(x)

class DownSampleBlock(nn.Module):
  """
  DownSampling block in UNET

  MaxPool, DoubleCov
  MaxPool, Conv, BatchNorm, ReLU, Conv, BatchNorm, ReLU
  """
  def __init__(self, input_channels: int, output_channels: int):
    """Initialize the DownSampleBlock class.

    Parameters:
      input_size (int) -- input size to block
      input_channels (int) -- #channels into first layer in block
      output_channels (int) -- #channels each layer produces
    """
    super().__init__()
    self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
    self.conv = DoubleConv(input_channels, output_channels)

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    return self.conv(self.max_pool(x))

class UpSampleBlock(nn.Module):
  """
  UpSampling block in UNET

  Upsample, Conv, Concat, Conv, ReLU, Conv, ReLU
  """
  def __init__(self, input_channels: int, output_channels: int):
    """Initialize the DownSampleBlock class.
    
    Parameters:
      input_size (int) -- input size to block
      input_channels (int) -- #channels into first layer in block
      output_channels (int) -- #channels each layer produces
    """
    super().__init__()
    self.up_sample = nn.ConvTranspose2d(input_channels, output_channels, kernel_size=2, stride=2)
    self.conv = DoubleConv(input_channels, output_channels)

  def forward(self, x: torch.Tensor, res: torch.Tensor) -> torch.Tensor:
    """Forward pass for DownSampleBlock.

     Parameters:
      x (torch.Tensor) -- input tensor to block
      res (torch.Tensor) -- residual connection feeding into block
    
    Return x
    x will be used as the input to the next upsizing block (or final layer)
    """
    x = self.up_sample(x)
    size_diff1 = (res.shape[2]-x.shape[2])//2
    size_diff2 = (res.shape[2]-x.shape[2]) - size_diff1
    x = torch.concat((x, res[:, :, size_diff1:-size_diff2, size_diff1:-size_diff2]), dim = 1)
    x = self.conv(x)
    return x

class UNET(nn.Module):
  def __init__(self, channels_in = 4):
    """Initialize the DownSampleBlock class.
    
    Parameters:
      channels_in (int) -- input images channel size
    """
    super().__init__()
    self.first_conv = DoubleConv(channels_in, 64)
    self.downsample_blocks = nn.ModuleList([DownSampleBlock(c, 2*c) for c in [64, 128, 256, 512]])
    self.upsample_blocks = nn.ModuleList([UpSampleBlock(2*c, c) for c in [512, 256, 128, 64]])
    self.final_layer = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=1, stride=1, padding=0)
    self.loss_fn = nn.CrossEntropyLoss()
  
  def forward(self, x: torch.Tensor, target=None) -> torch.Tensor:
    """Forward pass for DownSampleBlock.

     Parameters:
      x (torch.Tensor) --input tensor to block
      targets (torch.Tensor) --target output of model

    Returns (logits, loss)

    logits (torch.Tensor) --model's raw output
    loss (torch.float32) --2d CrossEntropyLoss result
    """
    x = self.first_conv(x)
    residuals = []
    for downsample in self.downsample_blocks:
      residuals.append(x)
      x = downsample(x)
    for i, upsample in enumerate(self.upsample_blocks):
      print(i, x.shape, residuals[-(i+1)].shape)
      x = upsample(x, residuals[-(i+1)])
    x = self.final_layer(x)

    if target is None:
      loss = None
    else:
      diff = 94 # (512 - 324)//2
      loss = self.loss_fn(x, target[:, 0, diff:-diff, diff:-diff])
    return x, loss

In [14]:
G = UNET()
x = torch.randn((1, 4, 256, 256))
G(x)[0].shape

0 torch.Size([1, 1024, 8, 8]) torch.Size([1, 512, 24, 24])
1 torch.Size([1, 512, 12, 12]) torch.Size([1, 256, 57, 57])
2 torch.Size([1, 256, 20, 20]) torch.Size([1, 128, 122, 122])
3 torch.Size([1, 128, 36, 36]) torch.Size([1, 64, 252, 252])


torch.Size([1, 14, 68, 68])

In [8]:
try:
    from torchinfo import summary
except:
    print("[INFO] Couldn't find torchinfo... installing it.")
    !pip3 install -q torchinfo
    from torchinfo import summary
summary(G, input_size=[1, 4, 512, 512], col_names =['input_size', 'output_size', 'num_params', 'trainable'])

[INFO] Couldn't find torchinfo... installing it.


ModuleNotFoundError: No module named 'torchinfo'

In [None]:
model = UNET().to(device)
image, target = sample["image"].to(device), sample["mask"].to(device)
logits, loss = model(image, target)
print(logits.shape, loss)

In [5]:
## To Do:
# - Modify UNET to preseve image size