# Cycle GAN
A simpler CycleGAN follow [this repo](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix).

Modify it to handle the single-channeled TPC data

In [2]:
import torch
import torch.nn as nn
from torch.nn import init

import sys
from pathlib import Path
# from functools import partial

In [3]:
util_path = '/sdcc/u/yhuang2/PROJs/GAN/yi_CycleGAN/utils'
assert Path(util_path).exists()
if util_path not in sys.path:
    sys.path.append(util_path)
from network import ImagePool

In [4]:
# Load model
model_path = '/sdcc/u/yhuang2/PROJs/GAN/yi_CycleGAN/models'
assert Path(model_path).exists()
if model_path not in sys.path:
    sys.path.append(model_path)
from resnet import ResnetGenerator

In [6]:
def define_G(
    input_nc,
    output_nc,
    ngf,
    netG,
    norm_type='instance', 
    use_dropout=False
):
    
    """
    Create a generator:
    
    Parameters:
        1. input_nc (int): number of input channels
        2. output_nc (int): number of output channels
        3. ngf (int): base number of filters in the conv layers
        4. netG (str), the architecture's name: resnet_9blocks | resnet_6blocks 
            (I will implement remaining later),
        5. norm_type (str): the name of the normalization: instance | batch | none
        6. use_dropout (bool): whether to use dropout.
    
    Returns: A generator
    """
    
    net = None,
    norm_layer = get_norm_layer(norm_type=norm_type)
    
    if netG == 'resent_9blocks':
        net = ResnetGenerator(
            input_nc, 
            output_nc, 
            ngf, 
            norm_layer=norm_layer, 
            use_dropout=use_dropout,
            n_blocks=9
        )
    elif netG == 'resnet_6blocks':
        net = ResnetGenerator(
            input_nc, 
            output_nc, 
            ngf, 
            norm_layer=norm_layer, 
            use_dropout=use_dropout,
            n_blocks=6
        )
    else:
        raise NotImplementedError(f'Generator model name {netG} is not implemented')
    
    return init_weights(net)


def define_D(
    input_nc, 
    ndf,
    netD,
    n_layer=3,
    norm_type='instance',
):
    """
    Create a discriminator

    Parameters:
        1. input_nc (int): number of channels in input images
        2. ndf (int): number of filters in the first conv layer
        3. netD (str): the architecture's name: basic | n_layers | pixel
        4.n_layers_D (int): the number of conv layers in the discriminator; 
            effective when netD=='n_layers'
        5. norm_type (str): the type of normalization layers used in the network.

    Returns: A discriminator
    """
    
    norm_layer = get_norm_layer(norm_type=norm_type)
    net = NLayerDiscriminator(input_nc, ndf, n_layers=n_layers, norm_layer=norm_layer)
    return init_weight(net)

In [5]:
class CycleGAN(nn.Module):
    def __init__(self, opt):
        super(CycleGAN, self).__init__()
        # model and loss names
        self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
        self.loss_names = [
            'D_A', # GAN adversarial loss at the discriminator D_A
            'G_A', # GAN adversarial loss at the generator G_A
            'cycle_A', # cycle loss at domain A
            'idt_A', # identity loss at domain A
            'D_B', # GAN adversarial loss at the discriminator D_B
            'G_B', # GAN adversarial loss at the generator G_B
            'cycle_B', # cycle loss at domain B
            'idt_B', # identity loss at domain B
        ]
        # generator networks
        self.netG_A = define_G(
            opt.input_nc,
            opt.output_nc,
            opt.ngf,
            opt.netG,
            opt.norm_type,
            opt.use_dropout
        )
        self.netG_B = define_G(
            opt.output_nc,
            opt.input_nc, 
            opt.ngf,
            opt.netG,
            opt.norm_type,
            opt.use_dropout
        )
        
        # discriminator networks, criterions, and optimizers for training
        if opt.isTrain:
            self.netD_A = define_D(
                opt.output_nc,
                opt.ndf, # number of discriminator filters in the first conv layer
                opt.netD,
                opt.n_layers_D,
                opt.norm_type,
            )
            self.netD_B = define_D(
                opt.input_nc,
                opt.ndf, # number of discriminator filters in the first conv layer
                opt.netD,
                opt.n_layers_D,
                opt.norm_type,
            )
            
            if opt.lambda_identity > 0:
                assert(opt.input_nc == opt.output_nc)
            
            # fake images buffers
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)
            
            # Loss functions
            self.criterionGAN = GANLoss(opt.gan_mode)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            
            # Optimizers
            self.optimizer_G = torch.optim.Adam(
                itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), 
                lr=opt.lr, 
                betas=(opt.beta1, .999)
            )
            self.optimizer_D = torch.optim.Adam(
                itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), 
                lr=opt.lr, 
                betas=(opt.beta1, .999)
            )
            self.optimizers = [self.optimizer_G, self.optimizer_D]
            self.schedulers = [
                network.get_scheduler(
                    optimizer, 
                    lr_policy=opt.lr_policy, 
                    n_epochs_no_decay=opt.n_epochs_no_decay, 
                    n_epochs_decay=opt.n_epochs_decay, 
                    lr_decay_steps=opt.lr_decay_steps
                )
                for optimizer in self.optimizer
            ]
       
    
    def set_input(self, input):
        self.real_A = input['A']
        self.real_B = input['B']
        self.image_paths = input['A_paths']
        
    
    def forward(self):
        self.fake_B = self.netG_A(self.real_A) # G_A(A)
        self.rec_A = self.netG_B(self.fake_B)  # G_B(G_A(A))
        self.fake_A = self.netG_B(self.real_B) # G_B(B)
        self.rec_B = self.netG_A(self.fake_A)  # G_A(G_B(B))
        
        
    def backward_D_basic(self, netD, real, fake):
        """
        Calculate GAN loss for the discriminator
        
        Parameters:
            1. netD (network): the discriminator D
            2. real (tensor): real images
            3. fake (tensor): images generated by a generator
        
        Retrun:
            the discriminator loss.
        """
        # Real
        loss_D_real = self.criterionGAN(netD(real), True)
        
        # Fake
        loss_D_fake = self.criterionGAN(netD(fake.detach()), False)
        
        # Combine loss and calculate gradients
        loss_D = (loss_D_real + loss_D_fake) * .5
        loss_D.backward()
        
        return loss_D
    
    def backward_D_A(self):
        """
        Calculate GAN loss for the discriminator D_A
        """
        self.loss_D_A = self.backward_D_basic(
            self.netD_A,
            self.real_B,
            # the collection of fake images is a 
            # mixture of old and new fake images
            self.fake_B_pool.query(self.fake_B)
        )
    
    def backward_D_B(self):
        """
        Calculate GAN loss for the discriminator D_B
        """
        self.loss_D_B = self.backward_D_basic(
            self.netD_B,
            self.real_A,
            # the collection of fake images is a 
            # mixture of old and new fake images
            self.fake_A_pool.query(self.fake_A)
        )
        
    def backward_G(self):
        """
        Calculate the loss for generators G_A and G_B
        """
        lambda_idt = self.opt.lambda_identity # it is a multplier to lambda_A and lambda_B
        lambda_A, lambda_B = self.opt.lambda_A, self.opt.lambda_B
        
        # Identity loss
        self.loss_idt_A, self.loss_idt_B = 0, 0
        if lambda_idt > 0:
            self.idt_A = self.netG_A(self.real_B)
            self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
            self.idt_B = self.netG_B(self.real_A)
            self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
        
        # GAN losses
        self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
        self.loss_G_A = self.criterionGAN(self.netD_B(self.fake_A), True)
        
        # Cycle losses:
        self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
        self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
        
        # Combine and backpropagate
        self.loss_G = self.loss_G_A + self.loss_G_B \
                    + self.loss_cycle_A + self.loss_cycle_B \
                    + self.loss_idt_A + self.loss_idt_B
        self.loss_G.backward()
        
    def optimize_parameters(self):
        """
        Calculate losses and gradients, and update network weights.
        """
        self.forward()
        
        # Optimize generators
        for net in [self.netD_A, self.netD_B]:
            for param in net.parameters():
                param.requires_grad = False
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        
        # Optimize discriminators
        for net in [self.netD_A, self.netD_B]:
            for param in net.parameters():
                param.requires_grad = False
        self.optimizer_D.zero_grad()
        self.backward_D_A()
        self.backward_D_B()
        self.optimizer_D.step()