In [None]:
from google.colab import drive       
drive.mount('/content/drive')         
# drive.mount("/content/drive", force_remount=True)   

In [None]:
!pip install fastai==2.1.8                                  

In [None]:
import fastai                    
#export                             
from fastai.basics import *        
from fastai.callback.core import *          
from fastai.vision.all import *         
fastai.__version__                                

'2.1.8'

## Wrapping the modules

### Second GANModule

In [None]:
#export
class GANModule(Module):
    "Wrapper around a `generator` and a `critic` to create a GAN."
    def __init__(self, generator=None, critic=None, gen_mode=False):
        if generator is not None: self.generator=generator
        if critic    is not None: self.critic   =critic        
        self.features = []
        self.input = []
        self.features_crit = []
        store_attr('gen_mode')

    def SaveFeatures(self,unet):         
        for i,layer in enumerate(unet.children()):       
            for k,layer2 in enumerate(layer):
                for j,layer1 in enumerate(layer2.children()):   
                    if j==0:                                  
                        layer1.register_forward_hook(self.hook_fn1)
                    if j==5:                                 
                        layer1.register_forward_hook(self.hook_fn)
                        

    def hook_fn(self,m,i,o): self.features.append(o)  # = o
    def hook_fn1(self,m,i,o): self.input.append(i)

    def SaveCriticFeatures(self,unet):
      for j,layer1 in enumerate(unet.children()):    
        if j in (0,2,4,6):                   
          layer1.register_forward_hook(self.hook_fn_crit)
        
    def hook_fn_crit(self,m,i,o): self.features_crit.append(o)

    def forward(self, *args): 
        return self.generator(*args) if self.gen_mode else self.critic(*args)

    def switch(self, gen_mode=None):
        "Put the module in generator mode if `gen_mode`, in critic mode otherwise."
        self.gen_mode = (not self.gen_mode) if gen_mode is None else gen_mode  

## Creating the generator  

### Code - 6

In [None]:
class deconv2d(nn.Module):
    """ Upsample block suggested by [2] to remove checkerboard pattern from images """
    def __init__(self, input_, output_dim,ks=3,scale_factor_=2,use_cuda=False):
        super(deconv2d, self).__init__()
        if use_cuda:
            self.up1 = nn.UpsamplingNearest2d(scale_factor=scale_factor_).cuda(device_id=0)
            self.c2 = nn.Conv2d(input_, output_dim, kernel_size=3, stride=1, padding=0).cuda(device_id=0)
        else:
            self.up1 = nn.UpsamplingNearest2d(scale_factor=scale_factor_)
            self.c2 = nn.Conv2d(input_, output_dim, kernel_size=3, stride=1, padding=0)
            
    def forward(self, x):
        h = self.up1(x)
        h = F.pad(h, (1, 1, 1, 1), mode='reflect')
        h = (self.c2(h))
        return h


class res_block(nn.Module):
    def __init__(self,in_dim, out_dim, ks=3, s=1):
        super().__init__()
        self.in_dim,self.out_dim,self.ks,self.s = in_dim,out_dim,ks,s
        self.p = int((self.ks - 1) / 2)
        self.layers=nn.Sequential(*[nn.Conv2d(self.in_dim,self.out_dim,kernel_size=self.ks,stride=self.s,
                                              padding=0,bias=False),
                                    nn.InstanceNorm2d(num_features=out_dim,eps = 0.00001,affine=True,
                                                      momentum=0.05)]).cuda()
    def forward(self,x):
        y = F.pad(x,(self.p,self.p,self.p,self.p),mode='reflect').cuda()
        y = self.layers(y)
        y = F.relu(y)
        y = F.pad(y,(self.p,self.p,self.p,self.p),mode='reflect').cuda()
        y = self.layers(y)
        return x+y

In [None]:
import torch.nn.functional as F

discriminator_cfg_p16_D = {
	'l_num': 4,'linear':256,
	'l0_c': 3,                           # 256
	'l1_ci': 256,'l1_cf': 256,   'l1_k': 3, 'l1_s': 2, # 64
	'l2_ci': 256,'l2_cf': 128,   'l2_k': 3, 'l2_s': 2, # 64
	'l3_ci': 128,'l3_cf': 64,   'l3_k': 3, 'l3_s': 2, # 64
	'l4_ci': 64,'l4_cf': 32,   'l4_k': 3, 'l4_s': 2, # 16
}

supported_patch_size_D = {
	16: discriminator_cfg_p16_D
}

discriminator_cfg_p16 = {
	'l_num': 5,'linear':256,
	'l0_c': 3,                           # 256
	'l1_ci': 3,'l1_cf': 32,   'l1_k': 3, 'l1_s': 1, # 64
	'l2_ci': 32,'l2_cf': 32,   'l2_k': 3, 'l2_s': 2, # 64
	'l3_ci': 32,'l3_cf': 64,   'l3_k': 3, 'l3_s': 2, # 64
	'l4_ci': 64,'l4_cf': 128,   'l4_k': 3, 'l4_s': 2, # 16
	'l5_ci': 128,'l5_cf': 256,   'l5_k': 3, 'l5_s': 2, # 16
}

supported_patch_size = {
	16: discriminator_cfg_p16
}


class Decoder(nn.Module):
  def __init__(self):   # list_
    super(Decoder, self).__init__()

    layers = []
    layers32 = []
    layers64 = []
    layers128 = []
    layers256 = []
    patch_size = 16
    cfg = supported_patch_size[patch_size]
    
    self.norm_ = nn.InstanceNorm2d(num_features=3,eps = 0.00001,affine=True,momentum=0.05)
    for l in range(1,cfg['l_num']+1): 
      convo = {'l2_layer': layers32,'l3_layer': layers64,'l4_layer': layers128,'l5_layer': layers256}
      if l==2 or l==3 or l==4 or l==5:
        convo[f'l{l}_layer']+=(nn.Sequential(*[nn.Conv2d(cfg[f'l{l}_ci'],cfg[f'l{l}_cf'],
                                                         kernel_size=cfg[f'l{l}_k'],stride=cfg[f'l{l}_s'],
                                                         padding=0,bias=False),
                                               nn.InstanceNorm2d(num_features=cfg[f'l{l}_cf'],eps = 0.00001,
                                                                 affine=True,momentum=0.05),
                                               nn.ReLU()]))
      else:
        layers+=(nn.Sequential(*[nn.Conv2d(cfg[f'l{l}_ci'],cfg[f'l{l}_cf'],kernel_size=cfg[f'l{l}_k'],
                                           stride=cfg[f'l{l}_s'],padding=0,bias=False),
                                nn.InstanceNorm2d(num_features=cfg[f'l{l}_cf'],eps = 0.00001,affine=True,
                                                  momentum=0.05),
                                nn.ReLU()]))     
    self.layers = nn.Sequential(*layers)#.cuda()
    self.layers32 = nn.Sequential(*layers32)
    self.layers64 = nn.Sequential(*layers64)
    self.layers128 = nn.Sequential(*layers128)
    self.layers256 = nn.Sequential(*layers256)

    layers1 = []
    layers3 = []
    layers4 = []
    patch_size = 16
    
    cfg = supported_patch_size_D[patch_size]
    
    for l in range(1,cfg['l_num']+1): 
        convo = {'l3_layer': layers3,'l4_layer': layers4,}
        if l>=3:                        # concat starts from 128 and goes to 32
            convo[f'l{l}_layer'].append(nn.Conv2d(2*cfg[f'l{l}_ci'],cfg[f'l{l}_ci'],kernel_size=1,
                                                  stride=1,padding=0,bias=False)) 
            convo[f'l{l}_layer'].append(deconv2d(cfg[f'l{l}_ci'],cfg[f'l{l}_cf'],
                                                 ks=cfg[f'l{l}_k']))  
            convo[f'l{l}_layer'].append(nn.InstanceNorm2d(num_features=cfg[f'l{l}_cf'],eps = 0.00001,
                                                          affine=True,momentum=0.05))
            convo[f'l{l}_layer'].append(nn.ReLU())
        else:
            layers1.append(deconv2d(cfg[f'l{l}_ci'],cfg[f'l{l}_cf'],ks=cfg[f'l{l}_k'],
                                    scale_factor_=cfg[f'l{l}_s']))  
            layers1.append(nn.InstanceNorm2d(num_features=cfg[f'l{l}_cf'],eps = 0.00001,
                                             affine=True,momentum=0.05))
            layers1.append(nn.ReLU())
    self.res_ = nn.Sequential(*[res_block(in_dim = 256,out_dim = 256) for _ in range(9)])                                
    self.layers_ = nn.Sequential(*layers1).cuda()
    self.last_conv_ = nn.Conv2d(64,3,kernel_size=7,stride=1,padding=0,bias=False).cuda() 
    self.first_layer_cat = nn.Conv2d(6,3,kernel_size=1,stride=1,padding=0,bias=False)
    self.layers_3 = nn.Sequential(*layers3).cuda()
    self.layers_4 = nn.Sequential(*layers4).cuda()

  def Interpolate(self, x,y):
    if y.shape[-2:] != x.shape[-2:]:
        x = F.interpolate(x, y.shape[-2:], mode='nearest')
    return x
  
  def forward(self, tensor_orig):
    x_tensor_orig = self.norm_(tensor_orig)
    x = F.pad(x_tensor_orig,(15,15,15,15),mode='reflect')
    x = self.layers(x)
    x32 = self.layers32(x)
    x64 = self.layers64(x32)
    x128 = self.layers128(x64)
    x = self.layers256(x128)

    x = self.res_(x)
    x = self.layers_(x)
    list_2 = x128.detach() #
    x_cat = self.Interpolate(list_2,x) #
    x = torch.cat([x_cat,x],dim=1) #
    # x = x_cat+x
    x_ = self.layers_3(x)
    list_1 = x64.detach()
    x_cat = self.Interpolate(list_1,x_)
    x_ = torch.cat([x_cat,x_],dim=1)
    # x_ = x_cat+x_
    x__ = self.layers_4(x_)
    list_0 = x32.detach()
    x_cat = self.Interpolate(list_0,x__)
    x__ = torch.cat([x_cat,x__],dim=1)
    # x__ = x_cat+x__
    x = F.pad(x__,(3,3,3,3),mode='reflect')
    x = self.last_conv_(x)
    tensor_orig_ = tensor_orig.detach() #
    x = torch.cat([tensor_orig_,x],dim=1) #
    x = self.first_layer_cat(x) #
    x = nn.Sigmoid().cuda()(x)
    x = (x*2)-1
    return x 

In [None]:
class Generator(nn.Module):
  def __init__(self):
    super(Generator, self).__init__()
    self.gen = nn.Sequential(Decoder()).cuda()
  def forward(self, x):
    fake_picture = self.gen(x)
    not_useful_output = self.gen(fake_picture)                
    return fake_picture


Gen = Generator().cuda()
x = (torch.randn(10,3,256,256)).cuda()
Gen(x).shape

torch.Size([10, 3, 256, 256])

## Creating Critic 

### Paper Critic

In [None]:
class discriminator_(nn.Module):
    def __init__(self,Spectral=True):
        super(discriminator_, self).__init__()
        
        self.Spectral = Spectral

        discriminator_cfg_p16 = {
            'l_num': 5,'linear':256,
            'l0_c': 3,                           # 256
            'l1_ci': 3,'l1_cf': 128,   'l1_k': 5, 'l1_s': 2, # 64
            'l2_ci': 128,'l2_cf': 256,   'l2_k': 5, 'l2_s': 2, # 64
            'l3_ci': 256,'l3_cf': 512,   'l3_k': 5, 'l3_s': 2, # 64
            'l4_ci': 512,'l4_cf': 1024,   'l4_k': 5, 'l4_s': 2, # 16
            'l5_ci': 512,'l5_cf': 512,   'l5_k': 5, 'l5_s': 2, # 16
            'l6_ci': 512,'l6_cf': 1024,   'l6_k': 5, 'l6_s': 2, # 16
            'l7_ci': 1024,'l7_cf': 1024,   'l7_k': 5, 'l7_s': 2, # 16
        }

        supported_patch_size = {
            16: discriminator_cfg_p16
        }

        patch_size = 16
        cfg = supported_patch_size[patch_size]
        
        l = 1
        if not self.Spectral:   
          self.h0 = nn.Sequential(*[nn.Conv2d(cfg[f'l1_ci'],cfg[f'l1_cf'],kernel_size=cfg[f'l1_k'],stride=cfg[f'l1_s'],padding=0,bias=False),
                              nn.InstanceNorm2d(num_features=cfg[f'l{l}_cf'],eps = 0.00001,affine=True,momentum=0.05),
                              nn.LeakyReLU(0.2)])
        else: 
          self.h0 = nn.Sequential(*[nn.utils.spectral_norm(nn.Conv2d(cfg[f'l1_ci'],cfg[f'l1_cf'],kernel_size=cfg[f'l1_k'],stride=cfg[f'l1_s'],padding=0,bias=False)),
                              nn.LeakyReLU(0.2)])
        self.h0_pred = nn.Sequential(*[nn.Conv2d(cfg[f'l{l}_cf'],1,kernel_size=3,stride=1,padding=0,bias=False)])

        l = 2
        if not self.Spectral:
          self.h1 = nn.Sequential(*[nn.Conv2d(cfg[f'l{l}_ci'],cfg[f'l{l}_cf'],kernel_size=cfg[f'l{l}_k'],stride=cfg[f'l{l}_s'],padding=0,bias=False),
                              nn.InstanceNorm2d(num_features=cfg[f'l{l}_cf'],eps = 0.00001,affine=True,momentum=0.05),
                              nn.LeakyReLU(0.2)])
        else:
          self.h1 = nn.Sequential(*[nn.utils.spectral_norm(nn.Conv2d(cfg[f'l2_ci'],cfg[f'l2_cf'],kernel_size=cfg[f'l2_k'],stride=cfg[f'l2_s'],padding=0,bias=False)),
                              nn.LeakyReLU(0.2)])
        self.h1_pred = nn.Sequential(*[nn.Conv2d(cfg[f'l{l}_cf'],1,kernel_size=5,stride=1,padding=0,bias=False)])

        l = 3
        if not self.Spectral:
          self.h2 = nn.Sequential(*[nn.Conv2d(cfg[f'l{l}_ci'],cfg[f'l{l}_cf'],kernel_size=cfg[f'l{l}_k'],stride=cfg[f'l{l}_s'],padding=0,bias=False),
                              nn.InstanceNorm2d(num_features=cfg[f'l{l}_cf'],eps = 0.00001,affine=True,momentum=0.05),
                              nn.LeakyReLU(0.2)])
        else:
          self.h2 = nn.Sequential(*[nn.utils.spectral_norm(nn.Conv2d(cfg[f'l3_ci'],cfg[f'l3_cf'],kernel_size=cfg[f'l3_k'],stride=cfg[f'l3_s'],padding=0,bias=False)),
                              nn.LeakyReLU(0.2)])
        self.h2_pred = nn.Sequential(*[nn.Conv2d(cfg[f'l{l}_cf'],1,kernel_size=5,stride=1,padding=0,bias=False)])

        l = 4
        if not self.Spectral:
          self.h3 = nn.Sequential(*[nn.Conv2d(cfg[f'l{l}_ci'],cfg[f'l{l}_cf'],kernel_size=cfg[f'l{l}_k'],stride=cfg[f'l{l}_s'],padding=0,bias=False),
                              nn.InstanceNorm2d(num_features=cfg[f'l{l}_cf'],eps = 0.00001,affine=True,momentum=0.05),
                              nn.LeakyReLU(0.2)])
        else:
          self.h3 = nn.Sequential(*[nn.utils.spectral_norm(nn.Conv2d(cfg[f'l4_ci'],cfg[f'l4_cf'],kernel_size=cfg[f'l4_k'],stride=cfg[f'l4_s'],padding=0,bias=False)),
                              nn.LeakyReLU(0.2)])
        self.h3_pred = nn.Sequential(*[nn.Conv2d(cfg[f'l{l}_cf'],1,kernel_size=2,stride=1,padding=0,bias=False)])

        self.avg = nn.AdaptiveAvgPool2d((1,1))    
        self.lrelu = nn.LeakyReLU(0.2)
    
    def shape_(self,x,s,k):
        shape = x.shape[2]
        shape = int((((s*(shape-1))+k)-shape)/2)
        # shape = int((shape+2)/2)
        return shape 

    def forward(self,x):
        x_h0 = self.h0(x)
        shape = self.shape_(x_h0,1,3)
        x_h0_ = F.pad(x_h0,(shape,shape,shape,shape),mode='reflect')
        h0_pred = self.h0_pred(x_h0_)
        # h0_pred = self.lrelu(h0_pred)
        h0_pred = self.avg(h0_pred)

        x_h1 = self.h1(x_h0)
        shape = self.shape_(x_h1,1,5)
        x_h1_ = F.pad(x_h1,(shape,shape,shape,shape),mode='reflect')
        h1_pred = self.h1_pred(x_h1_)
        # h1_pred = self.lrelu(h1_pred)
        h1_pred = self.avg(h1_pred)

        x_h2 = self.h2(x_h1)
        shape = self.shape_(x_h2,1,5) 
        x_h2_ = F.pad(x_h2,(shape,shape,shape,shape),mode='reflect')
        h2_pred = self.h2_pred(x_h2_)
        # h2_pred = self.lrelu(h2_pred)
        h2_pred = self.avg(h2_pred)
        
        x_h3 = self.h3(x_h2)
        shape = self.shape_(x_h3,1,2) 
        x_h3_ = F.pad(x_h3,(shape,shape,shape,shape),mode='reflect')
        h3_pred = self.h3_pred(x_h3_)
        # h3_pred = self.lrelu(h3_pred)
        final = self.avg(h3_pred)
        
        return h0_pred,h1_pred,h2_pred,final

## Initiating generator and critic

In [None]:
from fastai.vision.models import resnet18
import torch.nn as nn
m = resnet18() 
m = nn.Sequential(*list(m.children())[:-2]) 
generator = Gen 
critic = discriminator_(Spectral=False) 
critic 

In [None]:
res_input = nn.AvgPool2d(10, stride=1, padding=4, count_include_pad=False).cuda()
res_output = nn.AvgPool2d(10, stride=1, padding=4, count_include_pad=False).cuda()

## Loss function

In [None]:
import torch.nn as nn
import torch.nn.functional as F

#export
class GANLoss(GANModule):
    "Wrapper around `crit_loss_func` and `gen_loss_func`"
    def __init__(self, gen_loss_func, crit_loss_func, gan_model):
        super().__init__()
        self.SaveFeatures(generator)
        self.SaveCriticFeatures(critic)
        store_attr('gen_loss_func,crit_loss_func,gan_model')

    def gram(self,input_1):
        b,c,h,w = input_1.size()
        x = input_1.view(b,c, -1)
        return torch.bmm(x, x.transpose(1,2))/(b*c*h*w)*1e6

    def gram_loss(self,input_, target):
        if self.gram(target).shape[0]!=16 or self.gram(input_).shape[0]!=16:
          m = min(self.gram(target).shape[0],self.gram(input_).shape[0])
          o,i = self.gram(target)[:m+1],self.gram(input_)[:m+1]
          return F.mse_loss(i,o)
        else:
          return F.mse_loss(self.gram(input_), self.gram(target))

    def generator(self, output, target):
        "Evaluate the `output` with the critic then uses `self.gen_loss_func`"
        input_list = self.input  
        if (target).shape[0]!=16:
          pass 
          
        fake_pred = self.gan_model.critic(output)
        fake = self.features_crit
        self.features_crit = []
        real_pred_ = self.gan_model.critic(target)
        real = self.features_crit       
        main_loss = 0
        for j,(inp,targ) in enumerate(zip(real, fake)):
          if j>0:
            main_loss += self.gram_loss(inp,targ)

        style = 0.1; content = 100.0; aware = 100.0; gatey = 0.0 
        gen_loss = (self.gen_loss_func(fake_pred, output, target)) 
        style_aware = self.features     
        x = style_aware[0]
        y = style_aware[1] #.detach()   
        input_ = res_input(input_list[0][0])
        target_ = res_output(output) 
        
        style_aware_loss1 = F.mse_loss(input_,target_)
        style_aware_loss = F.mse_loss(x,y)
        self.gen_loss = (style*gen_loss)+(content*style_aware_loss1)+(aware*style_aware_loss)+gatey*main_loss  # gen_loss+  # (0.5*style_aware_loss1) reduced to half 
        
        weights = [style ,content ,aware ,gatey]
        losses__ = [gen_loss,style_aware_loss1,style_aware_loss,main_loss]
        name = ['style_loss','content_loss','style_aware_loss','style_loss_perceptual']

        for c,d,e in zip(weights,losses__,name):
          if c==0:
            print(1*d,f'{e}')
          else:
            print(c*d,f'{e}')


        self.check = 0
        self.features = []
        self.hook_layers = []  
        self.input = []
        self.features_crit = []
        return self.gen_loss 

    def critic(self, real_pred, input):
        "Create some `fake_pred` with the generator from `input` and compare them to `real_pred` in `self.crit_loss_func`."
        fake = self.gan_model.generator(input) #.requires_grad_(False)
        fake_pred = self.gan_model.critic(fake)
        self.crit_loss = 1*(self.crit_loss_func(real_pred, fake_pred))  
        print(self.crit_loss,'critic_loss')
        self.check = 0
        self.features = []
        self.hook_layers = []
        self.input = []
        self.features_crit = []
        return self.crit_loss

## Callbacks for GAN training

### Main callback

In [None]:
#export 
def set_freeze_model(m, rg): 
    for p in m.parameters(): p.requires_grad_(rg) 

In [None]:
class SiameseImage(fastuple):
    def show(self, title,ax=None,ctx=None, **kwargs): 
        img1,_ = self
        return show_image(img1, ax=ax,title=title, ctx=ctx, **kwargs)

In [None]:
#export
prediction = []
from IPython.display import clear_output

class GANTrainer(Callback):
    "Handles GAN Training."
    run_after = TrainEvalCallback
    def __init__(self, switch_eval=False, clip=None, beta=0.98, gen_first=False, show_img=True,gen_lr =None ,critic_lr =None):
        store_attr('switch_eval,clip,gen_first,show_img')
        self.gen_loss,self.crit_loss = AvgSmoothLoss(beta=beta),AvgSmoothLoss(beta=beta)
        self.gen_lr_,self.critic_lr_ = gen_lr,critic_lr
        self.gen_lr_1,self.critic_lr_1 = gen_lr,critic_lr 

    def _set_trainable(self):
        train_model = self.generator if     self.gen_mode else self.critic
        loss_model  = self.generator if not self.gen_mode else self.critic
        set_freeze_model(train_model, True)
        set_freeze_model(loss_model, False)
        if self.switch_eval:
            train_model.train()
            loss_model.eval()

    def before_fit(self):
        "Initialize smootheners."
        self.generator,self.critic = self.model.generator,self.model.critic
        self.gen_mode = self.gen_first
        self.switch(self.gen_mode)
        self.crit_losses,self.gen_losses = [],[]
        self.gen_loss.reset() ; self.crit_loss.reset()
        self.mbar = master_bar(list(range(self.n_epoch))) ###

    def before_validate(self):
        "Switch in generator mode for showing results."
        self.switch(gen_mode=True)

    def before_batch(self):
        "Clamp the weights with `self.clip` if it's not None, set the correct input/target."
        if self.training and self.clip is not None:
            for p in self.critic.parameters(): p.data.clamp_(-self.clip, self.clip)
        if not self.gen_mode:
            (self.learn.xb,self.learn.yb) = (self.yb,self.xb)        
        if self.epoch>=0:
            if not self.learn.gan_trainer.gen_mode:
                self.learn.opt.set_hyper('lr', self.critic_lr_1)
                print(self.learn.opt.hypers[0]['lr'],'learning_rate_crit_not')
            if self.learn.gan_trainer.gen_mode:
                self.learn.opt.set_hyper('lr', self.gen_lr_1)
                print(self.learn.opt.hypers[0]['lr'],'learning_rate_gen_not')
        
    def after_batch(self):
        "Record `last_loss` in the proper list."
        if not self.training: return
        if self.gen_mode:
            self.gen_loss.accumulate(self.learn)
            self.gen_losses.append(self.gen_loss.value)

            self.last_gen = to_detach(self.pred)
            self.last_gen_ = to_detach(self.xb)
            img_ = self.last_gen[0]
            img_xb = self.last_gen_[0]
            self.learn.model.cuda()
            img2 = SiameseImage(img_,'kjkjj')    
            img = [img2] 
            similarity = ['gjgj']
            self.mbar.show_imgs(img, similarity)
        else:
            self.crit_loss.accumulate(self.learn)
            self.crit_losses.append(self.crit_loss.value)
            self.learn.model.cuda()
            
    def before_epoch(self):
        "Put the critic or the generator back to eval if necessary."
        self.switch(self.gen_mode)
        self.mbar.update(self.epoch)


    def after_epoch(self):
        clear_output(wait=True)
    
    def switch(self, gen_mode=None):
        "Switch the model and loss function, if `gen_mode` is provided, in the desired mode."
        self.gen_mode = (not self.gen_mode) if gen_mode is None else gen_mode
        self._set_trainable()
        self.model.switch(gen_mode)
        self.loss_func.switch(gen_mode)    

### Switchers

In [None]:
# #export
class FixedGANSwitcher(Callback):
    "Switcher to do `n_crit` iterations of the critic then `n_gen` iterations of the generator."
    run_after = GANTrainer
    def __init__(self, n_crit=1, n_gen=1): store_attr('n_crit,n_gen')
    def before_train(self): self.n_c,self.n_g = 0,0

    def after_batch(self):
        "Switch the model if necessary."
        if not self.training: return
        if self.learn.gan_trainer.gen_mode:
            self.n_g += 1
            n_iter,n_in,n_out = self.n_gen,self.n_c,self.n_g
        else:
            self.n_c += 1
            n_iter,n_in,n_out = self.n_crit,self.n_g,self.n_c
        target = n_iter if isinstance(n_iter, int) else n_iter(n_in)
        if target == n_out:
            self.learn.gan_trainer.switch()
            self.n_c,self.n_g = 0,0 

## GAN data for art

### Dataloader 

In [None]:
import numpy as np
import random

list_ = []
def num_list():
  for t in range(0,29):     
    list_.append(t)
  return 

num_list()

def output_pictures(t):   
    num = random.choice(list_)
    list_.remove(num)
    if len(list_)==0:
      num_list()
    path = '/content/drive/My Drive/Colab_Notebooks/Style_transfer/Style_aware_content_loss/painting_1/' # painting_1
    path_ = path+f'{num}.jpg'
    return path_  

bs = 16 
size = 256  

### Datablock

In [None]:
dblock = DataBlock(blocks=(ImageBlock,ImageBlock),   
          get_y=output_pictures,
          get_items= get_image_files,
          splitter = RandomSplitter(),
          batch_tfms = [ Brightness(p=0.2, draw=[0.5, 0.55, 0.6, 0.7,0.5, 0.55, 0.6, 0.7,0.5, 0.55, 0.6, 0.7,0.5, 0.55, 0.6, 0.7]),
                        Saturation(max_lighting=0.2, p=0.2)])

In [None]:
from pathlib import Path   
path = Path('/content/drive/My Drive/Colab_Notebooks/Style_transfer/Style_aware_content_loss/input_/')

In [None]:
dls = dblock.dataloaders(path,path=path,bs = bs,num_workers=0)     

## GAN Learner

## Loss Function

In [None]:
def _tk_mean(fake_pred, output, target): # gen_loss
  fake_pred_ = 0
  for x in fake_pred:
    target = torch.ones_like(x)
    h_ = torch.mean((torch.abs(x-target)))
    fake_pred_ += h_
  return fake_pred_
def _tk_diff(real_pred, fake_pred):     # crit_loss
  fake_pred_ = 0
  for x in fake_pred:
    target = torch.zeros_like(x)
    h_ = torch.mean((torch.abs(x-target)))
    fake_pred_ += h_
  
  real_pred_ = 0
  for y in real_pred:
    target = torch.ones_like(y)
    h_ = torch.mean((torch.abs(y-target)))
    real_pred_ += h_
  return fake_pred_+real_pred_

## Training class

In [None]:
#export
@delegates()
class GANLearner(Learner):
    "A `Learner` suitable for GANs."
    def __init__(self, dls, generator, critic, gen_loss_func, crit_loss_func, switcher=None, gen_first=False,
                 switch_eval=True, show_img=True, clip=None, cbs=None, metrics=None, **kwargs):
        gan = GANModule(generator, critic)
        loss_func = GANLoss(gen_loss_func, crit_loss_func, gan)
        if switcher is None: switcher = FixedGANSwitcher(n_crit=2, n_gen=3) 
        trainer = GANTrainer(clip=clip, switch_eval=switch_eval, gen_first=gen_first,
                             show_img=show_img,gen_lr = 2e-5,critic_lr = 2e-5) 
        self.mode = gan.gen_mode 
        cbs = L(cbs) + L(trainer, switcher)
        metrics = L(metrics) + L(*LossMetrics('gen_loss,crit_loss'))
        super().__init__(dls, gan, loss_func=loss_func, cbs=cbs, metrics=metrics, **kwargs)

    @classmethod
    def from_learners(cls, gen_learn, crit_learn, switcher=None, weights_gen=None, **kwargs):
        "Create a GAN from `learn_gen` and `learn_crit`."
        losses = gan_loss_from_func(gen_learn.loss_func, crit_learn.loss_func, weights_gen=weights_gen)
        return cls(gen_learn.dls, gen_learn.model, crit_learn.model, *losses, switcher=switcher, **kwargs)

    @classmethod
    def wgan(cls, dls, generator, critic, switcher=None, clip=None, switch_eval=False, **kwargs):   # Clip=0.01
        "Create a WGAN from `data`, `generator` and `critic`."
        return cls(dls, generator, critic, _tk_mean, _tk_diff, switcher=switcher, clip=clip, switch_eval=switch_eval, **kwargs)
    
GANLearner.from_learners = delegates(to=GANLearner.__init__)(GANLearner.from_learners)   
GANLearner.wgan = delegates(to=GANLearner.__init__)(GANLearner.wgan)   

## Training



In [None]:
learn = GANLearner.wgan(dls, generator, critic, opt_func = Adam)
# gan_ = 'path/where/your/model/is/saved'   
# learn.load(gan_,with_opt=True)    

In [None]:
learn.model.cuda()  
clear_output(wait=True)

In [None]:
learn.recorder.train_metrics=True    
learn.recorder.valid_metrics=True 

In [None]:
learn.fit(50, 0.01, wd=0)  

In [None]:
gan_ = 'path/to/the/place/where/you/wnat/to/save/your/model'
learn.save(gan_,with_opt=True, pickle_protocol=2)                         