In [0]:
import numpy as np
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.utils as vutils
import torch.optim as optim

from torch.utils.data import DataLoader
from torch.autograd import Variable

In [0]:
# Image preprocessing
transform = transforms.Compose([transforms.Resize(32),
                                transforms.ToTensor(),
                                transforms.Normalize(mean=(0.5, ),
                                                     std=(0.5, ))])


fasion_mnist = datasets.FashionMNIST(root="./data",
                                     train=True,
                                     transform=transform,
                                     download=True)

# Parameters
params_loader = {'batch_size': 64,
                 'shuffle': False}

train_loader = DataLoader(fasion_mnist, **params_loader)

0it [00:00, ?it/s]

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


26427392it [00:07, 3768389.22it/s]                               


Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


0it [00:00, ?it/s]

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


32768it [00:00, 102345.58it/s]           
0it [00:00, ?it/s]

Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


4423680it [00:01, 3859783.02it/s]                             
0it [00:00, ?it/s]

Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


8192it [00:00, 33584.18it/s]            


Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
Processing...
Done!


In [0]:
# Generator model

class Generator(nn.Module):
    def __init__(self, z_dim, init_depth=1024, ncc=3):
        super(Generator, self).__init__()
        
        self.z_dim = z_dim
        self.init_depth = init_depth
        self.ncc = ncc
        
        self.model = self._create_model()
        
        # resolution at the output of the original network
        self.resl = 8
        self.dim = self.init_depth // 2
        
    
    def _block(self, ch_in, ch_out, ks=(4, 4), stride=2, padding=1, last=False):
        block = [nn.ConvTranspose2d(in_channels=ch_in,
                                    out_channels=ch_out,
                                    kernel_size=ks,
                                    stride=stride,
                                    padding=padding,
                                    bias=False)]
        if not last:
            block += [nn.BatchNorm2d(ch_out),
                      nn.ReLU()]
        else:
            block += [nn.Tanh()]
            
        return block
    
    def _init_layers(self):
        layers = [*self._block(self.z_dim, self.init_depth, stride=1, padding=0),
                  *self._block(self.init_depth, self.init_depth)]
        
        return nn.Sequential(*layers)
    
    def _to_rgb(self, dim):
        to_rgb = self._block(dim,
                             self.ncc,
                             ks=(5, 5),
                             stride=1,
                             padding=2,
                             last=True)
        return nn.Sequential(*to_rgb)
        
    
    def _create_model(self):
        init_layers = self._init_layers()
        to_rgb = self._to_rgb(self.init_depth)
        
        model = nn.Sequential()
        model.add_module("init_layers", init_layers)
        model.add_module("to_rgb", to_rgb)
        
        return model
    
    def grow_network(self):
        old_resl = self.resl
        self.resl *= 2
        growing = f"{old_resl}x{old_resl}->{self.resl}x{self.resl}"
        
        print(f"Growing the generator network : {growing}, can take few seconds...")
        
        self.layer_name = "layer_" + growing
        
        # Create the new model
        new_model = nn.Sequential()
        deep_copy_model(self.model, new_model, ["to_rgb"], not_is_in)
        
        # Add this layer to avoid a too brusk change when a new layer is added.
        hist_to_rgb = nn.Sequential()
        deep_copy_model(self.model, hist_to_rgb, ["to_rgb"], is_in)
        
        # Add this layer to avoid a too brusk change when a new layer is added.
        transition_to_rgb = nn.Sequential()
        transition_to_rgb.add_module("upsample_to_rgb",
                                     nn.Upsample(scale_factor=2, mode='nearest'))
        transition_to_rgb.add_module("hist_from_rgb", hist_to_rgb)
        
        # TODO: find a better way to handle the channels' dim
        dim_in = self.dim * 2
        dim_out = self.dim
        
        # Add the new layer on the top of the descriminator
        new_layer = nn.Sequential()
        new_layer.add_module("new_layer", nn.Sequential(*self._block(dim_in,
                                                                     dim_out)))
        new_layer.add_module("new_to_rgb", self._to_rgb(dim_out))
        
        
        new_model.add_module("concatenate_block",
                             ConcatenateLayers(transition_to_rgb,
                                               new_layer))
        new_model.add_module("fadein", FadeIn())
        
        
        self.model = new_model
        
    def clean_network(self):
        new_model = nn.Sequential()
        
        deep_copy_model(self.model,
                        new_model, 
                        ["concatenate_block", "fadein"],
                        not_is_in)
        
        rgb_layer, fadedIn_layer = nn.Sequential(), nn.Sequential()
        deep_copy_model(self.model.concatenate_block.new_layer,
                        rgb_layer,
                        ["new_to_rgb"],
                        is_in)
        deep_copy_model(self.model.concatenate_block.new_layer,
                        fadedIn_layer,
                        ["new_layer"],
                        is_in)
        
        new_model.add_module(self.layer_name, fadedIn_layer.new_layer)
        new_model.add_module("to_rgb", rgb_layer.new_to_rgb)
        
        self.model = new_model
        
        
    def forward(self, x):

        return self.model(x)

z_dim = 100
net_g = Generator(z_dim, ncc=1)
net_g.grow_network()
print(net_g)
net_g.clean_network()
print(net_g)
net_g.grow_network()
print(net_g)
net_g.clean_network()
print(net_g)
fixed_noise = torch.randn(params_loader["batch_size"], z_dim, 1, 1)
val = net_g(fixed_noise)
print(val.shape)

Growing the generator network : 8x8->16x16, can take few seconds...
Generator(
  (model): Sequential(
    (init_layers): Sequential(
      (0): ConvTranspose2d(100, 1024, kernel_size=(4, 4), stride=(1, 1), bias=False)
      (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): ConvTranspose2d(1024, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (4): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
    )
    (concatenate_block): ConcatenateLayers(
      (trans_layer): Sequential(
        (upsample_to_rgb): Upsample(scale_factor=2, mode=nearest)
        (hist_from_rgb): Sequential(
          (to_rgb): Sequential(
            (0): ConvTranspose2d(1024, 1, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
            (1): Tanh()
          )
        )
      )
      (new_layer): Sequential(
        (new_layer): Sequential(
          (0): Co

RuntimeError: ignored

In [0]:
# Discriminator model

def is_in(name, module_name):
    return name in module_name

def not_is_in(name, module_name):
    return name not in module_name

# ref: https://github.com/nashory/pggan-pytorch/blob/master/network.py
def deep_copy_model(model, new_model, module_name, is_inside):
    """
    Return the module of the model given in parameters,
    if not_ is True return all the module not in the module_name list.
    Args:
        model (torch.nn.Module): model to copy
        new_model (torch.nn.Module): The copied model.
        module_name (list(string)): module name to conserve during the copy
        is_inside (Callable): Include or exclude modules to save
    """
    for name, m in model.named_children():
        if is_inside(name, module_name):
            new_model.add_module(name, m)
            new_model[-1].load_state_dict(m.state_dict())
            

class ConcatenateLayers(nn.Module):
    def __init__(self, trans_layer, new_layer):
        super(ConcatenateLayers, self).__init__()
        
        self.trans_layer = trans_layer
        self.new_layer = new_layer
        
    def forward(self, x):
        return (self.trans_layer(x), self.new_layer(x))
    
class FadeIn(nn.Module):
    def __init__(self):
        super(FadeIn, self).__init__()
        
        self.alpha = 0.
        
    def update_alpha(self, delta):
        self.alpha += delta
        self.alpha = max(0., min(self.alpha, 1.))
        
    def forward(self, x):
        """
        Args:
            x (tuple):
                val_from_trans (torch.Tensor): Tensor from the transition layer
                val_from_new (torch.Tensor): Tensor from the last added layer
        
        Return:
            Fade in the tensor comming from the last added layer.
            Formula: (1-alpha)*trans + alpha* new
        """
        val_from_trans, val_from_new = x
        out = torch.add((1.-self.alpha) * val_from_trans,
                        self.alpha * val_from_new)
        return out


class Discriminator(nn.Module):
    def __init__(self, img_size, ncc=3, init_depth=64, max_depth=1024):
        super(Discriminator, self).__init__()
        
        self.img_size = img_size
        self.ncc = ncc
        self.init_depth = init_depth
        self.max_depth = max_depth
        self.resl = 8
        
        self.model = self._create_model()
        self.dim = self.max_depth // 2
        
    def _create_model(self):
        model = nn.Sequential()
        
        rgb_layer = self._from_rgb(self.max_depth)
        init_layers = self._init_layers()
        
        model.add_module("rgb_layer", rgb_layer)
        model.add_module("init_layers", init_layers)
        
        return model
        
    def _block(self, depth_in, depth_out, last=False, stride=2, padding=1):
        block = [nn.Conv2d(in_channels=depth_in,
                           out_channels=depth_out,
                           kernel_size=(3, 3),
                           stride=stride,
                           padding=padding,
                           bias=False)]
        if not last:
            block += [nn.BatchNorm2d(depth_out),
                      nn.LeakyReLU(0.2)]
        
        return block
    
    def _init_layers(self):
        internal_layers = [*self._block(self.max_depth, self.max_depth),
                           # Last layer different from the others : depth x 4 x 4 -> 1 x 1 x 1
                           *self._block(self.max_depth, 1, True, 2, 0)]
        
        return nn.Sequential(*internal_layers)
    
    def _from_rgb(self, dim):
        rgb_layer = self._block(self.ncc, dim)
        
        return nn.Sequential(*rgb_layer)
    
    def grow_network(self):
        """Add a new layer on the top of the descriminator."""
        old_resl = self.resl
        self.resl *= 2
        growing = f"{self.resl}x{self.resl}->{old_resl}x{old_resl}"
        
        print(f"Growing the discriminator network : {growing}, can take few seconds...")
        
        self.layer_name = "layer_" + growing
        
        
        # TODO: find a better way to handle the channels' dim
        dim_in = self.dim
        dim_out = self.dim * 2
        
        hist_from_rgb = nn.Sequential()
        deep_copy_model(self.model, hist_from_rgb, ["rgb_layer"], is_in)
        
        # Add this layer to avoid a too brusk change when a new layer is added.
        transition_from_rgb = nn.Sequential()
        transition_from_rgb.add_module("downsample_from_rgb",
                                       nn.AvgPool2d(kernel_size=2))
        transition_from_rgb.add_module("hist_from_rgb", hist_from_rgb)
        
        # Add the new layer on the top of the descriminator
        new_layer = nn.Sequential()
        new_layer.add_module("new_from_rgb", self._from_rgb(dim_in))
        new_layer.add_module("new_layer", nn.Sequential(*self._block(dim_in,
                                                                     dim_out)))
        
        # Create the new model
        new_model = nn.Sequential()
        new_model.add_module("concatenate_block",
                             ConcatenateLayers(transition_from_rgb,
                                               new_layer))
        new_model.add_module("fadein", FadeIn())
        
        deep_copy_model(self.model, new_model, ["rgb_layer"], not_is_in)
        
        self.model = new_model
        self.dim //= 2
        
    def clean_network(self):
        """Once the new layer is completly fade in, remove the useless layers."""
        rgb_layer, fadedIn_layer = nn.Sequential(), nn.Sequential()
        deep_copy_model(self.model.concatenate_block.new_layer,
                        rgb_layer,
                        ["new_from_rgb"],
                        is_in)
        deep_copy_model(self.model.concatenate_block.new_layer,
                        fadedIn_layer,
                        ["new_layer"],
                        is_in)
        
        new_model = nn.Sequential()
        new_model.add_module("rgb_layer", rgb_layer.new_from_rgb)
        new_model.add_module(self.layer_name, fadedIn_layer.new_layer)
        
        deep_copy_model(self.model,
                        new_model,
                        ["concatenate_block", "fadein"],
                        not_is_in)

        self.model = new_model
    
    def forward(self, input_):
        output = self.model(input_)

        return output.view(-1, 1).squeeze(1)
    
net_d = Discriminator(4, 1)
# print(net_d)
net_d.grow_network()
print(net_d)
net_d.clean_network()
print(net_d)
net_d.grow_network()
print(net_d)
net_d.clean_network()
print(net_d)
for img, _ in train_loader:
    val = net_d(img)
    
    print(val.shape)
    break

Growing the discriminator network : 16x16->8x8, can take few seconds...
Discriminator(
  (model): Sequential(
    (concatenate_block): ConcatenateLayers(
      (trans_layer): Sequential(
        (downsample_from_rgb): AvgPool2d(kernel_size=2, stride=2, padding=0)
        (hist_from_rgb): Sequential(
          (rgb_layer): Sequential(
            (0): Conv2d(1, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
            (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): LeakyReLU(negative_slope=0.2)
          )
        )
      )
      (new_layer): Sequential(
        (new_from_rgb): Sequential(
          (0): Conv2d(1, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): LeakyReLU(negative_slope=0.2)
        )
        (new_layer): Sequential(
          (0): Conv2d(512, 1024, kernel_size=(3,

In [0]:
# https://github.com/pytorch/examples/blob/master/dcgan/main.py
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [0]:
# Dimension of the latent space
Z_dim = 100

# Number of color channel in the final image
ncc = 3

# Number of internal node
ngf, ndf = 64, 64

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")