In [1]:
!mkdir exp
!mkdir visualization
!mkdir weights

In [2]:
!mkdir model
!mkdir model/layers

#### ContractingBlock.py

In [3]:
%%writefile model/layers/ContractingBlock.py
from torch import nn

class ContractingBlock(nn.Module):
    def __init__(self, input_channels, use_bn=True, kernel_size=3, activation='relu'):
        super(ContractingBlock, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, input_channels * 2, kernel_size=kernel_size, padding=1, stride=2, padding_mode='reflect')
        self.activation = nn.ReLU() if activation == 'relu' else nn.LeakyReLU(0.2)
        if use_bn:
            self.instancenorm = nn.InstanceNorm2d(input_channels * 2)
        self.use_bn = use_bn

    def forward(self, x):
        '''
        Function for completing a forward pass of ContractingBlock: 
        Given an image tensor, completes a contracting block and returns the transformed tensor.
        Parameters:
            x: image tensor of shape (batch size, channels, height, width)
        '''
        x = self.conv1(x)
        if self.use_bn:
            x = self.instancenorm(x)
        x = self.activation(x)
        return x

Writing model/layers/ContractingBlock.py


#### ExpandingBlock.py

In [4]:
%%writefile model/layers/ExpandingBlock.py
from torch import nn

class ExpandingBlock(nn.Module):
    def __init__(self, input_channels, use_bn=True):
        super(ExpandingBlock, self).__init__()
        self.conv1 = nn.ConvTranspose2d(input_channels, input_channels // 2, kernel_size=3, stride=2, padding=1, output_padding=1)
        if use_bn:
            self.instancenorm = nn.InstanceNorm2d(input_channels // 2)
        self.use_bn = use_bn
        self.activation = nn.ReLU()

    def forward(self, x):
        '''
        Function for completing a forward pass of ExpandingBlock: 
        Given an image tensor, completes an expanding block and returns the transformed tensor.
        Parameters:
            x: image tensor of shape (batch size, channels, height, width)
            skip_con_x: the image tensor from the contracting path (from the opposing block of x)
                    for the skip connection
        '''
        x = self.conv1(x)
        if self.use_bn:
            x = self.instancenorm(x)
        x = self.activation(x)
        return x

Writing model/layers/ExpandingBlock.py


#### ResidualBlock.py

In [5]:
%%writefile model/layers/ResidualBlock.py
from torch import nn

class ResidualBlock(nn.Module):

    def __init__(self, input_channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, input_channels, kernel_size=3, padding=1, padding_mode='reflect')
        self.conv2 = nn.Conv2d(input_channels, input_channels, kernel_size=3, padding=1, padding_mode='reflect')
        self.instancenorm = nn.InstanceNorm2d(input_channels)
        self.activation = nn.ReLU()

    def forward(self, x):
        '''
        Function for completing a forward pass of ResidualBlock: 
        Given an image tensor, completes a residual block and returns the transformed tensor.
        Parameters:
            x: image tensor of shape (batch size, channels, height, width)
        '''
        original_x = x.clone()
        x = self.conv1(x)
        x = self.instancenorm(x)
        x = self.activation(x)
        x = self.conv2(x)
        x = self.instancenorm(x)
        return original_x + x

Writing model/layers/ResidualBlock.py


#### Discriminator.py

In [6]:
%%writefile model/Discriminator.py
from torch import nn
from model.layers.ContractingBlock import ContractingBlock

class Discriminator(nn.Module):
    '''
    Discriminator Class
    Structured like the contracting path of the U-Net, the discriminator will
    output a matrix of values classifying corresponding portions of the image as real or fake. 
    Parameters:
        input_channels: the number of image input channels
        hidden_channels: the initial number of discriminator convolutional filters
    '''
    def __init__(self, input_channels, hidden_channels=64):
        super(Discriminator, self).__init__()
        self.upfeature = nn.Conv2d(input_channels, hidden_channels, kernel_size=7, padding=3, padding_mode='reflect')
        self.contract1 = ContractingBlock(hidden_channels, use_bn=False, kernel_size=4, activation='lrelu')
        self.contract2 = ContractingBlock(hidden_channels * 2, kernel_size=4, activation='lrelu')
        self.contract3 = ContractingBlock(hidden_channels * 4, kernel_size=4, activation='lrelu')
        self.final = nn.Conv2d(hidden_channels * 8, 1, kernel_size=1)

    def forward(self, x):
        x0 = self.upfeature(x)
        x1 = self.contract1(x0)
        x2 = self.contract2(x1)
        x3 = self.contract3(x2)
        xn = self.final(x3)
        return xn

Writing model/Discriminator.py


#### Generator.py

In [7]:
%%writefile model/Generator.py
from torch import nn
from model.layers.ContractingBlock import ContractingBlock
from model.layers.ResidualBlock import ResidualBlock
from model.layers.ExpandingBlock import ExpandingBlock

class Generator(nn.Module):
    '''
    Generator Class
    A series of 2 contracting blocks, 9 residual blocks, and 2 expanding blocks to 
    transform an input image into an image from the other class, with an upfeature
    layer at the start and a downfeature layer at the end.
    Values:
        input_channels: the number of channels to expect from a given input
        output_channels: the number of channels to expect for a given output
    '''
    def __init__(self, input_channels, output_channels, hidden_channels=64):
        super(Generator, self).__init__()
        self.upfeature = nn.Conv2d(input_channels, hidden_channels, kernel_size=7, padding=3, padding_mode='reflect')
        self.contract1 = ContractingBlock(hidden_channels)
        self.contract2 = ContractingBlock(hidden_channels * 2)
        res_mult = 4
        self.res0 = ResidualBlock(hidden_channels * res_mult)
        self.res1 = ResidualBlock(hidden_channels * res_mult)
        self.res2 = ResidualBlock(hidden_channels * res_mult)
        self.res3 = ResidualBlock(hidden_channels * res_mult)
        self.res4 = ResidualBlock(hidden_channels * res_mult)
        self.res5 = ResidualBlock(hidden_channels * res_mult)
        self.res6 = ResidualBlock(hidden_channels * res_mult)
        self.res7 = ResidualBlock(hidden_channels * res_mult)
        self.res8 = ResidualBlock(hidden_channels * res_mult)
        self.expand2 = ExpandingBlock(hidden_channels * 4)
        self.expand3 = ExpandingBlock(hidden_channels * 2)
        self.downfeature =  nn.Conv2d(hidden_channels, output_channels, kernel_size=7, padding=3, padding_mode='reflect')
        self.tanh = nn.Tanh()

    def forward(self, x):
        '''
        Function for completing a forward pass of Generator: 
        Given an image tensor, passes it through the U-Net with residual blocks
        and returns the output.
        Parameters:
            x: image tensor of shape (batch size, channels, height, width)
        '''
        x0 = self.upfeature(x)
        x1 = self.contract1(x0)
        x2 = self.contract2(x1)
        x3 = self.res0(x2)
        x4 = self.res1(x3)
        x5 = self.res2(x4)
        x6 = self.res3(x5)
        x7 = self.res4(x6)
        x8 = self.res5(x7)
        x9 = self.res6(x8)
        x10 = self.res7(x9)
        x11 = self.res8(x10)
        x12 = self.expand2(x11)
        x13 = self.expand3(x12)
        xn = self.downfeature(x13)
        return self.tanh(xn)

Writing model/Generator.py


In [8]:
%%writefile config.py
TRAIN_DIR = '/kaggle/input/day-night-gan/data'

Writing config.py


In [9]:
%%writefile dataset.py
import os
import glob
from PIL import Image
import torch
import random
from torch.utils.data import Dataset

class ImageDataset(Dataset):
    def __init__(self, root, transforms=None, mode='train'):
        self.transforms = transforms
        self.day = sorted(glob.glob(os.path.join(root, 'day', '*.jpg')))
        self.night = sorted(glob.glob(os.path.join(root, 'night', '*.jpg')))
        assert len(self.day) > 0, "Make sure you downloaded the images!"


    def __getitem__(self, index):
        if index >= len(self.day):
            day_path = random.choice(self.day)
            night_path = self.night[index]
        elif index >= len(self.night):
            day_path = self.day[index]
            night_path = random.choice(self.night)
        else:
            day_path = self.day[index]
            night_path = self.night[index]
            
        day, night = Image.open(day_path).convert('RGB'), Image.open(night_path).convert('RGB')
        if self.transforms is not None:
            day, night = self.transforms(day), self.transforms(night)
         
        return (day - 0.5) * 2, (night - 0.5) * 2

    def __len__(self):
        return max(len(self.day), len(self.night))

Writing dataset.py


#### loss.py

In [10]:
%%writefile loss.py
import torch

def get_disc_loss(real_X, fake_X, disc_X, adv_criterion):
    '''
    Return the loss of the discriminator given inputs.
    Parameters:
        real_X: the real images from pile X
        fake_X: the generated images of class X
        disc_X: the discriminator for class X; takes images and returns real/fake class X
            prediction matrices
        adv_criterion: the adversarial loss function; takes the discriminator 
            predictions and the target labels and returns a adversarial 
            loss (which you aim to minimize)
    '''
    disc_fake_X_hat = disc_X(fake_X.detach()) # Detach generator
    disc_fake_X_loss = adv_criterion(disc_fake_X_hat, torch.zeros_like(disc_fake_X_hat))
    disc_real_X_hat = disc_X(real_X)
    disc_real_X_loss = adv_criterion(disc_real_X_hat, torch.ones_like(disc_real_X_hat))
    disc_loss = (disc_fake_X_loss + disc_real_X_loss) / 2
    return disc_loss


def get_gen_adversarial_loss(real_X, disc_Y, gen_XY, adv_criterion):
    '''
    Return the adversarial loss of the generator given inputs
    (and the generated images for testing purposes).
    Parameters:
        real_X: the real images from pile X
        disc_Y: the discriminator for class Y; takes images and returns real/fake class Y
            prediction matrices
        gen_XY: the generator for class X to Y; takes images and returns the images 
            transformed to class Y
        adv_criterion: the adversarial loss function; takes the discriminator 
                  predictions and the target labels and returns a adversarial 
                  loss (which you aim to minimize)
    '''
    fake_Y = gen_XY(real_X)
    disc_fake_Y_hat = disc_Y(fake_Y)
    adversarial_loss = adv_criterion(disc_fake_Y_hat, torch.ones_like(disc_fake_Y_hat))
    return adversarial_loss, fake_Y


def get_identity_loss(real_X, gen_YX, identity_criterion):
    '''
    Return the identity loss of the generator given inputs
    (and the generated images for testing purposes).
    Parameters:
        real_X: the real images from pile X
        gen_YX: the generator for class Y to X; takes images and returns the images 
            transformed to class X
        identity_criterion: the identity loss function; takes the real images from X and
                        those images put through a Y->X generator and returns the identity 
                        loss (which you aim to minimize)
    '''
    identity_X = gen_YX(real_X)
    identity_loss = identity_criterion(identity_X, real_X)
    return identity_loss, identity_X


def get_cycle_consistency_loss(real_X, fake_Y, gen_YX, cycle_criterion):
    '''
    Return the cycle consistency loss of the generator given inputs
    (and the generated images for testing purposes).
    Parameters:
        real_X: the real images from pile X
        fake_Y: the generated images of class Y
        gen_YX: the generator for class Y to X; takes images and returns the images 
            transformed to class X
        cycle_criterion: the cycle consistency loss function; takes the real images from X and
                        those images put through a X->Y generator and then Y->X generator
                        and returns the cycle consistency loss (which you aim to minimize)
    '''
    cycle_X = gen_YX(fake_Y)
    cycle_loss = cycle_criterion(cycle_X, real_X)
    return cycle_loss, cycle_X

def get_gen_loss(real_A, real_B, gen_AB, gen_BA, disc_A, disc_B, adv_criterion, identity_criterion, cycle_criterion, lambda_identity=0.1, lambda_cycle=10):
    '''
    Return the loss of the generator given inputs.
    Parameters:
        real_A: the real images from pile A
        real_B: the real images from pile B
        gen_AB: the generator for class A to B; takes images and returns the images 
            transformed to class B
        gen_BA: the generator for class B to A; takes images and returns the images 
            transformed to class A
        disc_A: the discriminator for class A; takes images and returns real/fake class A
            prediction matrices
        disc_B: the discriminator for class B; takes images and returns real/fake class B
            prediction matrices
        adv_criterion: the adversarial loss function; takes the discriminator 
            predictions and the true labels and returns a adversarial 
            loss (which you aim to minimize)
        identity_criterion: the reconstruction loss function used for identity loss
            and cycle consistency loss; takes two sets of images and returns
            their pixel differences (which you aim to minimize)
        cycle_criterion: the cycle consistency loss function; takes the real images from X and
            those images put through a X->Y generator and then Y->X generator
            and returns the cycle consistency loss (which you aim to minimize).
            Note that in practice, cycle_criterion == identity_criterion == L1 loss
        lambda_identity: the weight of the identity loss
        lambda_cycle: the weight of the cycle-consistency loss
    '''
    # Hint 1: Make sure you include both directions - you can think of the generators as collaborating
    # Hint 2: Don't forget to use the lambdas for the identity loss and cycle loss!
    # Adversarial Loss -- get_gen_adversarial_loss(real_X, disc_Y, gen_XY, adv_criterion)
    adv_loss_BA, fake_A = get_gen_adversarial_loss(real_B, disc_A, gen_BA, adv_criterion)
    adv_loss_AB, fake_B = get_gen_adversarial_loss(real_A, disc_B, gen_AB, adv_criterion)
    gen_adversarial_loss = adv_loss_BA + adv_loss_AB

    # Identity Loss -- get_identity_loss(real_X, gen_YX, identity_criterion)
    identity_loss_A, identity_A = get_identity_loss(real_A, gen_BA, identity_criterion)
    identity_loss_B, identity_B = get_identity_loss(real_B, gen_AB, identity_criterion)
    gen_identity_loss = identity_loss_A + identity_loss_B

    # Cycle-consistency Loss -- get_cycle_consistency_loss(real_X, fake_Y, gen_YX, cycle_criterion)
    cycle_loss_BA, cycle_A = get_cycle_consistency_loss(real_A, fake_B, gen_BA, cycle_criterion)
    cycle_loss_AB, cycle_B = get_cycle_consistency_loss(real_B, fake_A, gen_AB, cycle_criterion)
    gen_cycle_loss = cycle_loss_BA + cycle_loss_AB

    # Total loss
    gen_loss = lambda_identity * gen_identity_loss + lambda_cycle * gen_cycle_loss + gen_adversarial_loss
    return gen_loss, fake_A, fake_B

Writing loss.py


#### utils.py

In [11]:
%%writefile utils.py
from torch import nn
from torchvision.utils import make_grid

def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        nn.init.normal_(m.weight, 0.0, 0.02)
        nn.init.constant_(m.bias, 0)
        


def visualize_images(image_tensor, num_images=25, size=(1, 28, 28)):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in an uniform grid.
    '''
    image_tensor = (image_tensor + 1) / 2
    image_shifted = image_tensor
    image_unflat = image_shifted.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    return image_grid.permute(1, 2, 0).squeeze()

Writing utils.py


#### train.py

In [12]:
%%writefile train.py
import os
import argparse
import torch
from torch import nn
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
import cv2
import numpy as np
import matplotlib.pyplot as plt

from tqdm.auto import tqdm

from dataset import ImageDataset
from config import TRAIN_DIR
from model.Generator import Generator
from model.Discriminator import Discriminator
from utils import weights_init, visualize_images
from loss import get_disc_loss, get_gen_loss
from torch.utils.tensorboard import SummaryWriter

if __name__ == "__main__":
    #TODO: Create ArgumentParser object 
    parser = argparse.ArgumentParser(description="Training script for Face Identification")
    # Add arguments
    parser.add_argument('--epochs', type=int, default=20, help='epochs')
    parser.add_argument('--batch_size', type=int, default=1, help='batch_size')
    parser.add_argument('--lr', type=float, default=0.001 , help='learning rate')
    parser.add_argument('--input_size', type=int, default=600, help='input size image')
    parser.add_argument('--logdir', type=str, default='./exp', help='tensorboard')
    # Parse the command-line arguments
    args = parser.parse_args()

    #TODO: Hyper-parameters initialization
    print('-------------------- Hyper-parameters initialization -----------------------')
    LR, EPOCHS, BATCH_SIZE, INPUT_SIZE  = args.lr, args.epochs, args.batch_size, args.input_size
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    writer = SummaryWriter(log_dir=args.logdir)
    print('-------------------- ------------------------------- -----------------------')

    #TODO: dataset initialization
    print('-------------------- Dataset initialization -----------------------')
    train_tfms = transforms.Compose([
        transforms.Resize(INPUT_SIZE),
        transforms.RandomCrop(INPUT_SIZE),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])

    trainset = ImageDataset(TRAIN_DIR, train_tfms)
    trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
    print(f"Number of training samples: {trainset.__len__()}")
    print('-------------------- ------------------------------- -----------------------')

 

    #TODO: Model initialization
    print('-------------------- Model initialization -----------------------')
    gen_AB = Generator(3,3).to(DEVICE)
    gen_BA = Generator(3,3).to(DEVICE)
    disc_A = Discriminator(3).to(DEVICE)
    disc_B = Discriminator(3).to(DEVICE)
    gen_AB = gen_AB.apply(weights_init)
    gen_BA = gen_BA.apply(weights_init)
    disc_A = disc_A.apply(weights_init)
    disc_B = disc_B.apply(weights_init)
    print('-------------------- ------------------------------- -----------------------')

    

    #TODO: Optimizer and loss function initialization
    print('--------------------  Optimizer and loss function initialization -----------------------')
    gen_opt = torch.optim.Adam(list(gen_AB.parameters()) + list(gen_BA.parameters()), lr=LR, betas=(0.5, 0.999))
    disc_A_opt = torch.optim.Adam(disc_A.parameters(), lr=LR, betas=(0.5, 0.999))
    disc_B_opt = torch.optim.Adam(disc_B.parameters(), lr=LR, betas=(0.5, 0.999))
    
    adv_criterion = nn.MSELoss() 
    recon_criterion = nn.L1Loss()
    
    print('-------------------- ------------------------------- -----------------------')

    
    #TODO: Train epochs
    print('-------------------- Train -----------------------')
    best_loss, best_epoch = 10000, 0
    for epoch in range(EPOCHS):
        discrimination_loss, generator_loss = 0,0
        pbar = tqdm(trainloader)
        for idx, (real_A, real_B) in enumerate(pbar):
            # real_A = nn.functional.interpolate(real_A, size=target_shape)
            # real_B = nn.functional.interpolate(real_B, size=target_shape)
            batch_size = real_A.shape[0]
            real_A = real_A.to(DEVICE)
            real_B = real_B.to(DEVICE)

            ### Update discriminator A ###
            disc_A_opt.zero_grad() # Zero out the gradient before backpropagation
            with torch.no_grad():
                fake_A = gen_BA(real_B)
            disc_A_loss = get_disc_loss(real_A, fake_A, disc_A, adv_criterion)
            disc_A_loss.backward(retain_graph=True) # Update gradients
            disc_A_opt.step() # Update optimizer

            ### Update discriminator B ###
            disc_B_opt.zero_grad() # Zero out the gradient before backpropagation
            with torch.no_grad():
                fake_B = gen_AB(real_A)
            disc_B_loss = get_disc_loss(real_B, fake_B, disc_B, adv_criterion)
            disc_B_loss.backward(retain_graph=True) # Update gradients
            disc_B_opt.step() # Update optimizer

            ### Update generator ###
            gen_opt.zero_grad()
            gen_loss, fake_A, fake_B = get_gen_loss(
                real_A, real_B, gen_AB, gen_BA, disc_A, disc_B, adv_criterion, recon_criterion, recon_criterion
            )
            gen_loss.backward() # Update gradients
            gen_opt.step() # Update optimizer

            # Keep track of the average discriminator loss
            discrimination_loss += disc_A_loss.item() 
            # Keep track of the average generator loss
            generator_loss += gen_loss.item() 
            
            pbar.set_description(f"Epoch {epoch+1}: Iteration {idx+1}/{len(trainloader)}: Generator (U-Net) loss: {generator_loss/(idx+1)}, Discriminator loss: {discrimination_loss/(idx+1)}")
        
        writer.add_scalar('Loss/discrimination', discrimination_loss/len(trainloader),global_step=epoch+1)
        writer.add_scalar('Loss/generator', generator_loss/len(trainloader), global_step=epoch+1)
        
        
            
        real_img = visualize_images(torch.cat([real_A, real_B]), size=(3, INPUT_SIZE, INPUT_SIZE)).numpy()
        fake_img = visualize_images(torch.cat([fake_B, fake_A]), size=(3, INPUT_SIZE, INPUT_SIZE)).numpy()
        #print(real_img.shape, real_img.max(), real_img.min())
        real_img = (real_img*255).astype(np.uint8)
        fake_img = (fake_img*255).astype(np.uint8)
        img = cv2.vconcat([cv2.cvtColor(real_img, cv2.COLOR_RGB2BGR), cv2.cvtColor(fake_img, cv2.COLOR_RGB2BGR)])
        cv2.imwrite(os.path.join('./visualization', f"{epoch+1}.jpg"), img)
        # writer.add_image('images', img, global_step=epoch)
                # You can change save_model to True if you'd like to save the model
        
        torch.save({
            'gen_AB': gen_AB.state_dict(),
            'gen_BA': gen_BA.state_dict(),
            'gen_opt': gen_opt.state_dict(),
            'disc_A': disc_A.state_dict(),
            'disc_A_opt': disc_A_opt.state_dict(),
            'disc_B': disc_B.state_dict(),
            'disc_B_opt': disc_B_opt.state_dict()
        }, f"./weights/cycleGAN_last.pth")
        
        if best_loss > (discrimination_loss + generator_loss) / len(trainloader):
            best_loss = (discrimination_loss + generator_loss) / len(trainloader)
            best_epoch = epoch
            torch.save({
                    'gen_AB': gen_AB.state_dict(),
                    'gen_BA': gen_BA.state_dict(),
                    'gen_opt': gen_opt.state_dict(),
                    'disc_A': disc_A.state_dict(),
                    'disc_A_opt': disc_A_opt.state_dict(),
                    'disc_B': disc_B.state_dict(),
                    'disc_B_opt': disc_B_opt.state_dict()
                }, f"./weights/cycleGAN_best.pth")
                
        if best_epoch - epoch > 5:
            break

Writing train.py


In [13]:
!python train.py

2024-02-15 11:20:02.288399: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-15 11:20:02.288511: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-15 11:20:02.406087: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
-------------------- Hyper-parameters initialization -----------------------
-------------------- ------------------------------- -----------------------
-------------------- Dataset initialization -----------------------
Number of training samples: 522
-------------------- ------------------------------- -----------------------
-------------------- Model