<a href="https://colab.research.google.com/github/CheshireCat12/Deep_learning_challenges/blob/master/PGGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import numpy as np
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.utils as vutils
import torch.optim as optim

from torch.utils.data import DataLoader
from torch.autograd import Variable

In [0]:
# Image preprocessing
transform = transforms.Compose([transforms.Resize(32),
                                transforms.ToTensor(),
                                transforms.Normalize(mean=(0.5, ),
                                                     std=(0.5, ))])


fasion_mnist = datasets.FashionMNIST(root="./data",
                                     train=True,
                                     transform=transform,
                                     download=True)

# Parameters
params_loader = {'batch_size': 64,
                 'shuffle': False}

train_loader = DataLoader(fasion_mnist, **params_loader)

In [0]:
# Generator model

class Generator(nn.Module):
    def __init__(self, Z_dim, ngf, ncc):
        super(Generator, self).__init__()
        
        self.layers = nn.Sequential(
            # input is Z, going into a convolution
            # in_channels, out_channels, kernel_size, stride=1, padding=0
            # formula: (in-1)* stride - 2*padding + kernel_size
            # 128 -> 
            
            # 1 -> (1-1)*1 - 2*0 + 4 = 4
            # in: (100 x 4 x 4)
            nn.ConvTranspose2d(Z_dim, ngf*8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf*8),
            nn.ReLU(),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf*8, ngf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*4),
            nn.ReLU(),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*2),
            nn.ReLU(),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf*2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(ngf, ngf, 4, 2, 1, bias=False),
            nn.Tanh(),
            # state size. (ncc) x 64 x 64
            nn.ConvTranspose2d(ngf, ncc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (ncc) x 128 x 128
        )
        
    def forward(self, input_):
        output = self.layers(input_)
        # print(output.shape)
        return output

In [51]:
# Discriminator model

# ref: https://github.com/nashory/pggan-pytorch/blob/master/network.py
def get_named_module(model, module_name):
    """Return the module of the model given in parameters."""
    new_model = nn.Sequential()
    
    for name, m in model.named_children():
        if name == module_name:
            new_model.add_module(name, m)
            new_model[-1].load_state_dict(m.state_dict())
                
    return new_model

class ConcatenateLayers(nn.Module):
    def __init__(self, trans_layer, new_layer):
        super(ConcatenateLayers, self).__init__()
        
        self.trans_layer = trans_layer
        self.new_layer = new_layer
        
    def forward(self, x):
        return (self.trans_layer(x), self.new_layer(x))
    
class FadeIn(nn.Module):
    def __init__(self):
        super(FadeIn, self).__init__()
        
        self.alpha = 0.
        
    def update_alpha(self, delta):
        self.alpha += delta
        self.alpha = max(0., min(self.alpha, 1.))
        
    def forward(self, x):
        """
        Args:
            x (tuple):
                val_from_trans (torch.Tensor): Tensor from the transition layer
                val_from_new (torch.Tensor): Tensor from the last added layer
        
        Return:
            Fade in the tensor comming from the last added layer.
            Formula: (1-alpha)*trans + alpha* new
        """
        val_from_trans, val_from_new = x
        out = torch.add((1.-self.alpha) * val_from_trans,
                        self.alpha * val_from_new)
        return out


class Discriminator(nn.Module):
    def __init__(self, img_size, ncc=3, init_depth=64, max_depth=1024):
        super(Discriminator, self).__init__()
        
        self.img_size = img_size
        self.ncc = ncc
        self.init_depth = init_depth
        self.max_depth = max_depth
        self.resl = 8
        
        self.model = self._create_model()
        self.dim = self.max_depth // 2
        
    def _create_model(self):
        model = nn.Sequential()
        
        rgb_layer = self._from_rgb(self.max_depth)
        init_layers = self._init_layers()
        
        model.add_module("rgb_layer", rgb_layer)
        model.add_module("init_layers", init_layers)
        
        return model
        
    def _block(self, depth_in, depth_out, last=False, stride=2, padding=1):
        block = [nn.Conv2d(in_channels=depth_in,
                           out_channels=depth_out,
                           kernel_size=(3, 3),
                           stride=stride,
                           padding=padding,
                           bias=False)]
        if not last:
            block += [nn.BatchNorm2d(depth_out),
                      nn.LeakyReLU(0.2)]
        
        return block
    
    def _init_layers(self):
        internal_layers = [*self._block(self.max_depth, self.max_depth),
                           # Last layer different from the others : depth x 4 x 4 -> 1 x 1 x 1
                           *self._block(self.max_depth, 1, True, 2, 0)]
        
        return nn.Sequential(*internal_layers)
    
    def _from_rgb(self, dim):
        rgb_layer = self._block(self.ncc, dim)
        
        return nn.Sequential(*rgb_layer)
    
    def grow_network(self):
        """Add a new layer on the top of the descriminator."""
        last_resl = self.resl
        self.resl = self.resl * 2
        growing = f"{self.resl}x{self.resl}->{last_resl}x{last_resl}"
        
        print(f"Growing network : {growing}, can take few seconds...")
        
        self.layer_name = "layer_" + growing
        
        
        # TODO: find a better way to handle the channels' dim
        dim_in = self.dim
        dim_out = self.dim * 2
        
        hist_from_rgb = get_named_module(self.model, "rgb_layer")
        
        # Add this layer to avoid a too brusk change when a new layer is added.
        transition_from_rgb = nn.Sequential()
        transition_from_rgb.add_module("downsample_from_rgb",
                                       nn.AvgPool2d(kernel_size=2))
        transition_from_rgb.add_module("hist_from_rgb", hist_from_rgb)
        
        # Add the new layer on the top of the descriminator
        new_layer = nn.Sequential()
        new_layer.add_module("new_from_rgb", self._from_rgb(dim_in))
        new_layer.add_module("new_layer", nn.Sequential(*self._block(dim_in,
                                                                  dim_out)))
        
        # Create the new model
        new_model = nn.Sequential()
        new_model.add_module("concatenate_block",
                             ConcatenateLayers(transition_from_rgb,
                                               new_layer))
        new_model.add_module("fadein", FadeIn())
        
        for name, m in self.model.named_children():
            if name != "rgb_layer":
                new_model.add_module(name, m)
                new_model[-1].load_state_dict(m.state_dict())
                
        self.model = new_model
        self.dim //= 2
        
    def clean_network(self):
        """Once the new layer is completly fade in, remove the useless layers."""
        rgb_layer = get_named_module(self.model.concatenate_block.new_layer,
                                    "new_from_rgb")
        fadedIn_layer = get_named_module(self.model.concatenate_block.new_layer,
                                        "new_layer")
        
        new_model = nn.Sequential()
        new_model.add_module("rgb_layer", rgb_layer)
        new_model.add_module(self.layer_name, fadedIn_layer)
        
        for name, m in self.model.named_children():
            if name not in ["concatenate_block", "fadein"]:
                new_model.add_module(name, m)
                new_model[-1].load_state_dict(m.state_dict())
                
        self.model = new_model
    
    def forward(self, input_):
        output = self.model(input_)

        return output.view(-1, 1).squeeze(1)
    
net_d = Discriminator(4, 1)
# print(net_d)
net_d.grow_network()
print(net_d)
net_d.clean_network()
print(net_d)
net_d.grow_network()
print(net_d)
net_d.clean_network()
print(net_d)
for img, _ in train_loader:
    val = net_d(img)
    
    print(val.shape)
    break

Growing network : 16x16->8x8, can take few seconds...
Discriminator(
  (model): Sequential(
    (concatenate_block): ConcatenateLayers(
      (trans_layer): Sequential(
        (downsample_from_rgb): AvgPool2d(kernel_size=2, stride=2, padding=0)
        (hist_from_rgb): Sequential(
          (rgb_layer): Sequential(
            (0): Conv2d(1, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
            (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): LeakyReLU(negative_slope=0.2)
          )
        )
      )
      (new_layer): Sequential(
        (new_from_rgb): Sequential(
          (0): Conv2d(1, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): LeakyReLU(negative_slope=0.2)
        )
        (new_layer): Sequential(
          (0): Conv2d(512, 1024, kernel_size=(3, 3), stride=(2, 2)

In [0]:
# https://github.com/pytorch/examples/blob/master/dcgan/main.py
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [0]:
# Dimension of the latent space
Z_dim = 100

# Number of color channel in the final image
ncc = 3

# Number of internal node
ngf, ndf = 64, 64

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")