# Cycle GAN
A simpler CycleGAN constructed following [this repo](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix).

Modify it to handle the single-channeled TPC data

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader 
from torch.nn import init

import sys
from pathlib import Path
import numpy as np
import pandas as pd
import itertools
# from functools import partial

In [2]:
%load_ext autoreload
%autoreload 2

sys.path.append('/sdcc/u/yhuang2/PROJs/GAN/collisionGAN/utils')
import network

In [3]:
sys.path.append('/sdcc/u/yhuang2/PROJs/GAN/collisionGAN//models')
from resnet import ResnetGenerator
from alexnet import AlexNetDiscriminator
from base_model import BaseModel

## A simple option class
(Let us get fancier later)

In [4]:
class options:
    """
    I just list all the options here 
    """
    def __init__(self):
        # data
        self.dataroot = '/sdcc/u/yhuang2/PROJs/GAN/datasets/ls4gan/toyzero_cropped/toyzero_2021-06-29_safi_U/'
        self.phase = 'train'
        assert self.phase in ['train', 'test'], "Invalid phase, choose from ['train', 'test']"
        self.max_dataset_size = 1000
        
        # base_model
        self.cuda=True
        self.isTrain = True if self.phase == 'train' else False
        self.checkpoint_dir = '/sdcc/u/yhuang2/PROJs/GAN/collisionGAN/checkpoints'
        self.experiment_name = 'experiment'
        
        # Training
        self.netG = 'resnet_6blocks'
        self.batch_size = 32
        self.num_workers = 1
        self.input_nc = 1
        self.output_nc = 1
        self.ngf = 64
        self.norm_type = 'instance'
        self.use_dropout = False
        self.gan_mode = 'vanilla'
        self.lr = 0.0002
        self.lr_policy = 'linear'
        self.lr_linear = 100
        self.lr_step = 50
        self.epochs = 200
        self.beta1 = .5
        
        self.lambda_A = 10
        self.lambda_B = 10
        self.lambda_identity = .5
        self.pool_size = 50
        
opt = options()

## Construct generators and discriminators

In [5]:
def define_G(
    input_nc,
    output_nc,
    ngf,
    netG,
    norm_type='instance', 
    use_dropout=False, 
    cuda=True,
    net_name='network',
):
    
    """
    Create a generator:
    
    Parameters:
        1. input_nc (int): number of input channels
        2. output_nc (int): number of output channels
        3. ngf (int): base number of filters in the conv layers
        4. netG (str), the architecture's name: resnet_9blocks | resnet_6blocks 
            (I will implement remaining later),
        5. norm_type (str): the name of the normalization: instance | batch | none
        6. use_dropout (bool): whether to use dropout.
    
    Returns: A generator
    """
    
    # net = None,
    norm_layer = network.get_norm_layer(norm_type=norm_type)
    
    if netG == 'resent_9blocks':
        net = ResnetGenerator(
            input_nc, 
            output_nc, 
            ngf, 
            norm_layer=norm_layer, 
            use_dropout=use_dropout,
            n_blocks=9
        )
    elif netG == 'resnet_6blocks':
        net = ResnetGenerator(
            input_nc, 
            output_nc, 
            ngf, 
            norm_layer=norm_layer, 
            use_dropout=use_dropout,
            n_blocks=6
        )
    else:
        raise NotImplementedError(f'Generator model name {netG} is not implemented')
    
    if cuda:
        assert torch.cuda.is_available(), "Cuda is not available"
        net.to('cuda')
    network.init_weights(net, net_name=net_name)
    
    return net


def define_D(input_nc, cuda=True, net_name='network'):
    """
    Create an very simple AlexNet discriminator

    Parameters:
        1. input_nc (int): number of channels in input images
        
    Returns: A discriminator
    """
    net = AlexNetDiscriminator(input_channels=input_nc)
    if cuda:
        assert torch.cuda.is_available(), "Cuda is not available"
        net.to('cuda')
    network.init_weights(net, net_name=net_name)
    return net


# def define_D(
#     input_nc, 
#     ndf,
#     netD,
#     n_layer=3,
#     norm_type='instance',
# ):
#     """
#     Create a discriminator

#     Parameters:
#         1. input_nc (int): number of channels in input images
#         2. ndf (int): number of filters in the first conv layer
#         3. netD (str): the architecture's name: basic | n_layers | pixel
#         4.n_layers_D (int): the number of conv layers in the discriminator; 
#             effective when netD=='n_layers'
#         5. norm_type (str): the type of normalization layers used in the network.

#     Returns: A discriminator
#     """
    
#     norm_layer = network.get_norm_layer(norm_type=norm_type)
#     net = NLayerDiscriminator(input_nc, ndf, n_layers=n_layers, norm_layer=norm_layer)
#     return network.init_weight(net)

## GAN Loss

In [6]:
class GANLoss(nn.Module):
    """
    Define different GAN objectives.
    """

    def __init__(self, gan_mode):
        """ 
        Initialize the GANLoss class.

        Parameters:
            gan_mode (str) -- the type of GAN objective. 
                It currently supports 
                    - vanilla; 
                    - lsgan; 
                    - wgangp.

        NOTE: DO NOT use sigmoid as the last layer of Discriminator.
            - vanilla handles it with BCEWithLogitsLoss.
            - lsgan needs no sigmoid. 
        """
        super(GANLoss, self).__init__()
        self.register_buffer('real_label', torch.tensor(1.0))
        self.register_buffer('fake_label', torch.tensor(0.0))
        self.gan_mode = gan_mode
        if gan_mode == 'lsgan':
            self.loss = nn.MSELoss()
        elif gan_mode == 'vanilla':
            self.loss = nn.BCEWithLogitsLoss()
        elif gan_mode == 'wgangp':
            self.loss = None
        else:
            raise NotImplementedError('gan mode %s not implemented' % gan_mode)

    def __call__(self, prediction, target_is_real):
        """
        Calculate loss given Discriminator's output and grount truth labels.

        Parameters:
            prediction (tensor) -- prediction from a discriminator
            target_is_real (bool) -- if the ground truth label is for real images or fake images

        Returns:
            the calculated loss.
        """
        if self.gan_mode in ['lsgan', 'vanilla']:
            target_tensor = self.real_label if target_is_real else self.fake_label
            loss = self.loss(prediction, target_tensor.expand_as(prediction))
        elif self.gan_mode == 'wgangp':
            L = prediction.mean()
            loss = -L if target_is_real else L
        return loss

In [7]:
class CycleGAN(BaseModel):
    def __init__(self, opt):
        super(CycleGAN, self).__init__(opt)
        self.device = 'cuda' if opt.cuda else 'cpu'
        
        self.loss_names = ['D_A', 'D_B', 'G_A', 'G_B', 'cycle_A', 'cycle_B', 'idt_A', 'idt_B']
        if self.isTrain:
            self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
        else:
            self.model_names = ['G_A', 'G_B']
            
        # generator networks
        self.netG_A = define_G(
            opt.input_nc,
            opt.output_nc,
            opt.ngf,
            opt.netG,
            opt.norm_type,
            opt.use_dropout,
            cuda=opt.cuda,
            net_name='generator A->B',
        )
        self.netG_B = define_G(
            opt.output_nc,
            opt.input_nc, 
            opt.ngf,
            opt.netG,
            opt.norm_type,
            opt.use_dropout,
            cuda=opt.cuda,
            net_name='generator B->A',
        )
        
        # discriminator networks, criterions, and optimizers for training
        if opt.isTrain:
#             self.netD_A = define_D(
#                 opt.output_nc,
#                 opt.ndf, # number of discriminator filters in the first conv layer
#                 opt.netD,
#                 opt.n_layers_D,
#                 opt.norm_type,
#             )
#             self.netD_B = define_D(
#                 opt.input_nc,
#                 opt.ndf, # number of discriminator filters in the first conv layer
#                 opt.netD,
#                 opt.n_layers_D,
#                 opt.norm_type,
#             )
            self.netD_A = define_D(opt.output_nc, cuda=opt.cuda, net_name='discriminator A')
            self.netD_B = define_D(opt.input_nc, cuda=opt.cuda, net_name='discriminator B')
            
            if opt.lambda_identity > 0:
                assert(opt.input_nc == opt.output_nc)
            
            # fake images buffers
            self.fake_A_pool = network.ImagePool(opt.pool_size)
            self.fake_B_pool = network.ImagePool(opt.pool_size)
            
            # Loss functions
            self.criterionGAN = GANLoss(opt.gan_mode).to(self.device)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            
            # Optimizers
            self.optimizer_G = torch.optim.Adam(
                itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), 
                lr=opt.lr, 
                betas=(opt.beta1, .999)
            )
            self.optimizer_D = torch.optim.Adam(
                itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), 
                lr=opt.lr, 
                betas=(opt.beta1, .999)
            )
            self.optimizers = [self.optimizer_G, self.optimizer_D]
            if self.isTrain:
                self.schedulers = [network.get_scheduler(optimizer, opt) for optimizer in self.optimizers]

    def set_input(self, data):
        self.real_A = data['A'].to(self.device)
        self.real_B = data['B'].to(self.device)
   
    def forward(self):
        self.fake_B = self.netG_A(self.real_A) # G_A(A)
        self.rec_A = self.netG_B(self.fake_B)  # G_B(G_A(A))
        self.fake_A = self.netG_B(self.real_B) # G_B(B)
        self.rec_B = self.netG_A(self.fake_A)  # G_A(G_B(B))
        
    def backward_D_basic(self, netD, real, fake):
        """
        Calculate GAN loss for the discriminator
        
        Parameters:
            1. netD (network): the discriminator D
            2. real (tensor): real images
            3. fake (tensor): images generated by a generator
        
        Retrun:
            the discriminator loss.
        """
        # Real 
        loss_D_real = self.criterionGAN(netD(real), True)
        
        # Fake
        loss_D_fake = self.criterionGAN(netD(fake.detach()), False)
        
        # Combine loss and calculate gradients
        loss_D = (loss_D_real + loss_D_fake) * .5
        loss_D.backward()
        
        return loss_D

    def backward_D_A(self):
        """
        Calculate GAN loss for the discriminator D_A
        """
        self.loss_D_A = self.backward_D_basic(
            self.netD_A,
            self.real_B,
            # the collection of fake images is a 
            # mixture of old and new fake images
            self.fake_B_pool.query(self.fake_B)
        )

    def backward_D_B(self):
        """
        Calculate GAN loss for the discriminator D_B
        """
        self.loss_D_B = self.backward_D_basic(
            self.netD_B,
            self.real_A,
            # the collection of fake images is a 
            # mixture of old and new fake images
            self.fake_A_pool.query(self.fake_A)
        )

    def backward_G(self):
        """
        Calculate the loss for generators G_A and G_B
        """
        lambda_idt = self.opt.lambda_identity # it is a multplier to lambda_A and lambda_B
        lambda_A, lambda_B = self.opt.lambda_A, self.opt.lambda_B
        
        # Identity loss
        self.loss_idt_A, self.loss_idt_B = 0, 0
        if lambda_idt > 0:
            self.idt_A = self.netG_A(self.real_B)
            self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
            self.idt_B = self.netG_B(self.real_A)
            self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
        
        # GAN losses
        self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
        self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
        
        # Cycle losses:
        self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
        self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
        
        # Combine and backpropagate
        self.loss_G = self.loss_G_A + self.loss_G_B \
                    + self.loss_cycle_A + self.loss_cycle_B \
                    + self.loss_idt_A + self.loss_idt_B
        self.loss_G.backward()
 
    def optimize_parameters(self):
        """
        Calculate losses and gradients, and update network weights.
        """
        self.forward()
        
        # Optimize generators
        for net in [self.netD_A, self.netD_B]:
            for param in net.parameters():
                param.requires_grad = False
        
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        
        # Optimize discriminators
        for net in [self.netD_A, self.netD_B]:
            for param in net.parameters():
                param.requires_grad = True
                
        self.optimizer_D.zero_grad()
        self.backward_D_A()
        self.backward_D_B()
        self.optimizer_D.step()

## Load dataset

In [8]:
def load_image_fnames(dirname, max_dataset_size=float('inf')):
    """
    load image fnames.
    If max_dataset_size is not infinity and is less than all available images,
    return a random subset of max_dataset_size image fnames.
    """
    assert Path(dirname).exists(), f"{dirname} doesn't exist"
    image_fnames = np.array(list(Path(dirname).glob('*npz')))
    
    if max_dataset_size != float('inf') and max_dataset_size < len(image_fnames):
        indices = np.arange(len(image_fnames))
        np.random.shuffle(indices)
        indices = indices[:max_dataset_size]
        image_fnames = image_fnames[indices]
    return image_fnames


class toyzero_dataset(Dataset):
    def __init__(self, opt):
        super(toyzero_dataset, self).__init__()
        dir_A = Path(opt.dataroot)/f'{opt.phase}A'
        dir_B = Path(opt.dataroot)/f'{opt.phase}B'
        self.image_fnames_A = load_image_fnames(dir_A, opt.max_dataset_size)
        self.image_fnames_B = load_image_fnames(dir_B, opt.max_dataset_size)
        self.size_A = len(self.image_fnames_A)
        self.size_B = len(self.image_fnames_B)
    
    def __len__(self):
        return max(self.size_A, self.size_B)
    
    def _load(self, image_fname):
        image = np.load(image_fname)
        image = image[image.files[0]]
        image = np.expand_dims(np.float32(image), 0)
        return image
    
    def __getitem__(self, index):
        index_A = index % self.size_A
        index_B = np.random.randint(0, self.size_B - 1) # inclusive end
        image_A = self._load(self.image_fnames_A[index_A])
        image_B = self._load(self.image_fnames_B[index_B])
        
        return {'A': image_A, 'B': image_B}

In [9]:
dataset = toyzero_dataset(opt)
dataloader = DataLoader(
    dataset,
    batch_size=opt.batch_size,
    num_workers=opt.num_workers,
    shuffle=True
)
print(f'number of batches = {len(dataloader)}')

number of batches = 32


## Train

In [10]:
model = CycleGAN(opt)

initialize generator A->B with normal distribution
initialize generator B->A with normal distribution
initialize discriminator A with normal distribution
initialize discriminator B with normal distribution


In [12]:
epochs = 200
for epoch in range(epochs):
    print(f'\n{epoch + 1} / {epochs}')
    
    for i, data in enumerate(dataloader):
        model.set_input(data)
        model.optimize_parameters() # run forward inside
        losses = model.get_current_losses()
        if i == len(dataloader) - 1:
            for key, val in losses.items():
                print(f'\t{key}, {val:.5f}')
                
    model.update_learning_rate()
    if epoch % 20 == 0:
        save_suffix = f'epoch_{epoch}'
        model.save_networks(save_suffix)


1 / 200
	D_A, 0.12154
	D_B, 0.09191
	G_A, 1.50898
	G_B, 1.68529
	cycle_A, 2.86329
	cycle_B, 2.99282
	idt_A, 1.49036
	idt_B, 1.41479
learning rate 2.00e-04 -> 2.00e-04

2 / 200
	D_A, 0.08734
	D_B, 0.06404
	G_A, 1.66890
	G_B, 3.45798
	cycle_A, 1.78605
	cycle_B, 2.98108
	idt_A, 1.48043
	idt_B, 0.88862
learning rate 2.00e-04 -> 2.00e-04

3 / 200
	D_A, 0.35954
	D_B, 0.24107
	G_A, 1.09639
	G_B, 2.27057
	cycle_A, 2.67571
	cycle_B, 3.00795
	idt_A, 1.51717
	idt_B, 1.32266
learning rate 2.00e-04 -> 2.00e-04

4 / 200
	D_A, 0.65901
	D_B, 0.63771
	G_A, 0.96532
	G_B, 3.08482
	cycle_A, 1.74189
	cycle_B, 2.92602
	idt_A, 1.46199
	idt_B, 0.86122
learning rate 2.00e-04 -> 2.00e-04

5 / 200
	D_A, 0.14052
	D_B, 0.22439
	G_A, 1.71800
	G_B, 4.33523
	cycle_A, 2.54174
	cycle_B, 3.06497
	idt_A, 1.55416
	idt_B, 1.23928
learning rate 2.00e-04 -> 2.00e-04

6 / 200
	D_A, 0.33007
	D_B, 0.20137
	G_A, 1.57826
	G_B, 4.25931
	cycle_A, 2.32158
	cycle_B, 3.20161
	idt_A, 1.57874
	idt_B, 1.16688
learning rate 2.00e-04 -> 2

In [165]:
memory_all = torch.cuda.max_memory_allocated(device=None)
memory_cuda = torch.cuda.max_memory_allocated(device='cuda')
print(memory_all/1024 ** 3)
print(memory_cuda/1024 ** 3)

3.2298989295959473
3.2298989295959473
