In [None]:
#default_exp train_loop

In [1]:
#hide
#export
from fastai import *
from fastai.vision import *
from fastai.callbacks import *
from fastai.utils.mem import *
from fastai.vision.gan import *
from PIL import Image

import numpy as np

import torch
import torch.nn.functional as F
import torch.nn as nn

from torch.utils.data import DataLoader
from torch.utils.data.dataset import TensorDataset
import pdb

from pyfiles.losses import *

# Training Wrapper

In [None]:
#export
class UNET_UNIT(Learner):
    "A `Learner` suitable for GANs."
    def __init__(self, data:DataBunch, generator:nn.Module, critic:nn.Module, gen_loss_func:LossFunction,
                 crit_loss_func:LossFunction, n_crit=None, n_gen=None, switcher:Callback=None, gen_first:bool=False, switch_eval:bool=True,
                 show_img:bool=True, clip:float=None, **learn_kwargs):
        gan = GANModule(generator, critic)
        loss_func = GANLoss(gen_loss_func, crit_loss_func, gan)
        switcher = ifnone(switcher, partial(FixedGANSwitcher, n_crit=n_crit, n_gen=n_gen))
        super().__init__(data, gan, loss_func=loss_func, callback_fns=[switcher], **learn_kwargs)
        trainer = GANTrainer(self, clip=clip, switch_eval=switch_eval, show_img=show_img)
        self.gan_trainer = trainer
        self.callbacks.append(trainer)

In [None]:
#export
class GANModule(nn.Module):
    "Wrapper around a `generator` and a `critic` to create a GAN."
    def __init__(self, generator:nn.Module=None, critic:nn.Module=None, gen_mode:bool=True):
        super().__init__()
        self.gen_mode = gen_mode
        if generator: self.generator,self.critic = generator,critic

    def forward(self, *args):
        return self.generator(*args) if self.gen_mode else self.critic(*args)

    def switch(self, gen_mode:bool=None):
        "Put the model in generator mode if `gen_mode`, in critic mode otherwise."
        self.gen_mode = (not self.gen_mode) if gen_mode is None else gen_mode

In [None]:
#export
class GANLoss(GANModule):
    "Wrapper around `loss_funcC` (for the critic) and `loss_funcG` (for the generator)."
    def __init__(self, loss_funcG:Callable, loss_funcC:Callable, gan_model:GANModule):
        super().__init__()
        self.loss_funcG,self.loss_funcC,self.gan_model = loss_funcG,loss_funcC,gan_model

    def generator(self, output, x_a, x_b):
        "Evaluate the `output` with the critic then uses `self.loss_funcG` to combine it with `target`."
        output = torch.split(output, 2, dim=0)
        x_a_recon, x_b_recon = torch.split(output[0], 1, dim=0)
        x_ab, x_ba = torch.split(output[1], 1, dim=0)
        fake_pred_x_aa, fake_pred_x_bb = self.gan_model.critic(x_a_recon, x_b_recon)
        fake_pred_x_ab, fake_pred_x_ba = self.gan_model.critic(x_ab, x_ba)
        
        cycled_output = self.gan_model.generator(x_ba, x_ab)
        cycle_a = cycled_output[3]
        cycle_b = cycled_output[2]
        return self.loss_funcG(x_a, x_b, x_a_recon, x_b_recon, cycle_a, cycle_b, fake_pred_x_ab, fake_pred_x_ba)

    def critic(self, real_pred, b, c):
        fake = self.gan_model.generator(b.requires_grad_(False), c.requires_grad_(False)).requires_grad_(True)
        fake = torch.split(fake, 2, dim=0)
        fake_ns = torch.split(fake[0], 1, dim=0)
        fake_s = torch.split(fake[1], 1, dim=0)
        fake_pred_aToA, fake_pred_bToB = self.gan_model.critic(fake_ns[0], fake_ns[1])
        fake_pred_aToB, fake_pred_bToA = self.gan_model.critic(fake_s[0], fake_s[1])
        return self.loss_funcC(real_pred[0], real_pred[1], fake_pred_aToA, fake_pred_bToB, fake_pred_aToB, fake_pred_bToA)

In [None]:
#export
class GANTrainer(LearnerCallback):
    "Handles GAN Training."
    _order=-20
    def __init__(self, learn:Learner, switch_eval:bool=False, clip:float=None, beta:float=0.98, gen_first:bool=False,
                 show_img:bool=True):
        super().__init__(learn)
        self.switch_eval,self.clip,self.beta,self.gen_first,self.show_img = switch_eval,clip,beta,gen_first,show_img
        self.generator,self.critic = self.model.generator,self.model.critic

    def _set_trainable(self):
        train_model = self.generator if     self.gen_mode else self.critic
        loss_model  = self.generator if not self.gen_mode else self.critic
        requires_grad(train_model, True)
        requires_grad(loss_model, False)
        if self.switch_eval:
            train_model.train()
            loss_model.eval()

    def on_train_begin(self, **kwargs):
        "Create the optimizers for the generator and critic if necessary, initialize smootheners."
        if not getattr(self,'opt_gen',None):
            self.opt_gen = self.opt.new([nn.Sequential(*flatten_model(self.generator))])
        else: self.opt_gen.lr,self.opt_gen.wd = self.opt.lr,self.opt.wd
        if not getattr(self,'opt_critic',None):
            self.opt_critic = self.opt.new([nn.Sequential(*flatten_model(self.critic))])
        else: self.opt_critic.lr,self.opt_critic.wd = self.opt.lr,self.opt.wd
        self.gen_mode = self.gen_first
        self.switch(self.gen_mode)
        self.closses,self.glosses = [],[]
        self.smoothenerG,self.smoothenerC = SmoothenValue(self.beta),SmoothenValue(self.beta)
        self.recorder.add_metric_names(['gen_loss', 'disc_loss'])
        self.imgs,self.titles = [],[]

    def on_train_end(self, **kwargs):
        "Switch in generator mode for showing results."
        self.switch(gen_mode=True)

    def on_batch_begin(self, last_input, last_target, **kwargs):
        "Clamp the weights with `self.clip` if it's not None, return the correct input."
        if self.gen_mode:
            self.last_input = last_input
            
        if self.clip is not None:
            for p in self.critic.parameters(): p.data.clamp_(-self.clip, self.clip)
        test = {'last_input':last_input,'last_target':last_input}
        #print(test)
        return test
    
    def on_backward_begin(self, last_loss, last_output, **kwargs):
        "Record `last_loss` in the proper list."
        last_loss = last_loss.detach().cpu()
        if self.gen_mode:
            self.smoothenerG.add_value(last_loss)
            self.glosses.append(self.smoothenerG.smooth)
            self.last_gen = last_output.detach().cpu()
            last_gen_split = torch.split(self.last_gen, 1, 0)
            self.last_critic_preds_ns = self.gan_trainer.critic(last_gen_split[0].cuda(), last_gen_split[1].cuda())
            self.last_critic_preds_s = self.gan_trainer.critic(last_gen_split[2].cuda(), last_gen_split[3].cuda())
        else:
            self.smoothenerC.add_value(last_loss)
            self.closses.append(self.smoothenerC.smooth)

    def on_epoch_begin(self, epoch, **kwargs):
        "Put the critic or the generator back to eval if necessary."
        self.switch(self.gen_mode)

    def on_epoch_end(self, pbar, epoch, last_metrics, **kwargs):
        "Put the various losses in the recorder and show a sample image."
        if not hasattr(self, 'last_gen') or not self.show_img: return
        data = self.learn.data
        inputBPre = torch.unbind(self.last_input[1], dim=0)
        aToA = im.Image(self.last_gen[0]/2+0.5)
        bToB = im.Image(self.last_gen[1]/2+0.5)
        aToB = im.Image(self.last_gen[2]/2+0.5)
        bToA = im.Image(self.last_gen[3]/2+0.5)
        self.imgs.append(aToA)
        self.imgs.append(aToB)
        self.imgs.append(bToB)
        self.imgs.append(bToA)
        self.titles.append(f'Epoch {epoch}-A to A')
        self.titles.append(f'Epoch {epoch}-A to B')
        self.titles.append(f'Epoch {epoch}-B to B')
        self.titles.append(f'Epoch {epoch}-B to A')
        pbar.show_imgs(self.imgs, self.titles)
        return add_metrics(last_metrics, [getattr(self.smoothenerG,'smooth',None),getattr(self.smoothenerC,'smooth',None)])

    def switch(self, gen_mode:bool=None):
        "Switch the model, if `gen_mode` is provided, in the desired mode."
        self.gen_mode = (not self.gen_mode) if gen_mode is None else gen_mode
        self.opt.opt = self.opt_gen.opt if self.gen_mode else self.opt_critic.opt
        self._set_trainable()
        self.model.switch(gen_mode)
        self.loss_func.switch(gen_mode)

In [None]:
#export
class FixedGANSwitcher(LearnerCallback):
    "Switcher to do `n_crit` iterations of the critic then `n_gen` iterations of the generator."
    def __init__(self, learn:Learner, n_crit=5, n_gen=1):
        super().__init__(learn)
        self.n_crit,self.n_gen = 1,1

    def on_train_begin(self, **kwargs):
        "Initiate the iteration counts."
        self.n_c,self.n_g = 0,0

    def on_batch_end(self, iteration, **kwargs):
        "Switch the model if necessary."
        if self.learn.gan_trainer.gen_mode:
            self.n_g += 1
            n_iter,n_in,n_out = self.n_gen,self.n_c,self.n_g
        else:
            self.n_c += 1
            n_iter,n_in,n_out = self.n_crit,self.n_g,self.n_c
        target = n_iter if isinstance(n_iter, int) else n_iter(n_in)
        if target == n_out:
            self.learn.gan_trainer.switch()
            self.n_c,self.n_g = 0,0