In [3]:
!pip3 install --upgrade tqdm

Collecting tqdm
[?25l  Downloading https://files.pythonhosted.org/packages/a5/13/cd55c23e3e158ed5b87cae415ee3844fc54cb43803fa3a0a064d23ecb883/tqdm-4.40.0-py2.py3-none-any.whl (54kB)
[K     |██████                          | 10kB 16.8MB/s eta 0:00:01[K     |████████████                    | 20kB 3.5MB/s eta 0:00:01[K     |█████████████████▉              | 30kB 5.0MB/s eta 0:00:01[K     |███████████████████████▉        | 40kB 6.4MB/s eta 0:00:01[K     |█████████████████████████████▉  | 51kB 4.2MB/s eta 0:00:01[K     |████████████████████████████████| 61kB 3.3MB/s 
[?25hInstalling collected packages: tqdm
  Found existing installation: tqdm 4.28.1
    Uninstalling tqdm-4.28.1:
      Successfully uninstalled tqdm-4.28.1
Successfully installed tqdm-4.40.0


In [0]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.optim import Adam
from torch.nn import NLLLoss, CrossEntropyLoss
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
from typing import Tuple

import plotly.offline as py
import plotly.graph_objs as go
import plotly.express as px
import plotly.figure_factory as ff

In [0]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [0]:
class CheckerboardSplit(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, X):
        b, c, h, w = X.shape
        
        mask1 = torch.tensor([[True, False], [False, True]]).repeat(h // 2, w // 2)
        mask2 = torch.tensor([[False, True], [True, False]]).repeat(h // 2, w // 2)

        x1 = X[:,:,mask1].view(b, c, h, w // 2)
        x2 = X[:,:,mask2].view(b, c, h, w // 2)
        
        return (x1, x2)


class ChannelSplit(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, X):
        b, c, h, w = X.shape
        x1, x2 = torch.split(X, c // 2, dim=1)
        return (x1, x2)


class AffineCouplingBlock(nn.Module):
    def __init__(self, n_filters):
        super().__init__()
        self.transform = nn.Sequential(
            nn.Conv2d(n_filters, n_filters, kernel_size=1, stride=1, padding=0),
            nn.ReLU(),
            nn.Conv2d(n_filters, n_filters, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(n_filters, n_filters, kernel_size=1, stride=1, padding=0)
        )
    
    def forward(self, h):
        h = h + self.transform(h)
        return h

class AffineCoupling(nn.Module):
    def __init__(self, c_in, *, n_filters=256, n_blocks=8):
        super().__init__()

        self.c_in = c_in
        self.simple_resnet = nn.Sequential(
            nn.Conv2d(c_in, n_filters, kernel_size=3, stride=1, padding=1),
            *[
                AffineCouplingBlock(n_filters) for _ in range(n_blocks)
             ],
            nn.ReLU(),
            nn.Conv2d(n_filters, 2 * c_in, kernel_size=3, stride=1, padding=1)
        )
    
    def forward(self, X):
        x1, x2 = X
        log_s, t = torch.split(self.simple_resnet(x1), self.c_in, dim=1)
        y1, y2 = x1, torch.exp(log_s) * (x1 + t)
        return (y1, y2)
        

class TupleFlip(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, X):
        x1, x2 = X
        return (x2, x1)


class InverseCheckerboardSplit(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, X):
        x1, x2 = X
        b, c, h, w = x1.shape

        mask1 = torch.tensor([[True, False], [False, True]]).repeat(h // 2, w)
        mask2 = torch.tensor([[False, True], [True, False]]).repeat(h // 2, w)

        Y = torch.zeros(b, c, h, w * 2)
        Y[:,:,mask1] = x1.view(b, c, -1)
        Y[:,:,mask2] = x2.view(b, c, -1)
        return Y


class InverseChannelSplit(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, X):
        Y = torch.cat(X, dim=1)
        return Y


class Squeeze(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, X):
        b, c, h, w = X.shape
        
        mask1 = torch.tensor([[True, False], [False, False]]).repeat(h // 2, w // 2)
        mask2 = torch.tensor([[False, True], [False, False]]).repeat(h // 2, w // 2)
        mask3 = torch.tensor([[False, False], [True, False]]).repeat(h // 2, w // 2)
        mask4 = torch.tensor([[False, False], [False, True]]).repeat(h // 2, w // 2)

        y1 = X[:,:,mask1].view(b, c, h // 2, w // 2)
        y2 = X[:,:,mask2].view(b, c, h // 2, w // 2)
        y3 = X[:,:,mask3].view(b, c, h // 2, w // 2)
        y4 = X[:,:,mask4].view(b, c, h // 2, w // 2)

        Y = torch.cat([y1, y2, y3, y4], dim=1)
        return Y

In [0]:
class CelebModel(nn.Module):
    def __init__(self, c_in):
        super().__init__()

        self.transform = nn.Sequential(
            CheckerboardSplit(),
            *[
                nn.Sequential(
                    AffineCoupling(c_in),
                    TupleFlip()
                ) for _ in range(4)
            ],
            InverseCheckerboardSplit(),

            Squeeze(),

            ChannelSplit(),
            *[
                nn.Sequential(
                    AffineCoupling(2 * c_in),
                    TupleFlip()
                ) for _ in range(3)
            ],
            InverseChannelSplit(),

            CheckerboardSplit(),
            *[
                nn.Sequential(
                    AffineCoupling(4 * c_in),
                    TupleFlip()
                ) for _ in range(3)
            ],
            InverseCheckerboardSplit(),
            
            Squeeze(),

            ChannelSplit(),
            *[
                nn.Sequential(
                    AffineCoupling(8 * c_in),
                    TupleFlip()
                ) for _ in range(3)
            ],
            InverseChannelSplit(),

            CheckerboardSplit(),
            *[
                nn.Sequential(
                    AffineCoupling(16 * c_in),
                    TupleFlip()
                ) for _ in range(3)
            ],
            InverseCheckerboardSplit()
        )
    
    def forward(self, X):
        X = self.transform(X)
        return X

In [0]:
w, h = 32, 32
model = CelebModel(3)

In [0]:
Z = model(torch.randn(2, 3, h, w))

In [86]:
Z.shape

torch.Size([2, 48, 8, 8])

In [58]:
Z

tensor([[[[-8.2260e-01, -8.4404e-01, -4.2416e-01,  ...,  1.1539e+00,
            9.2918e-01,  8.2985e-01],
          [ 1.6781e+00,  1.5061e+00,  1.3250e+00,  ...,  1.2642e+00,
            7.2473e-01,  6.8122e-01],
          [ 1.0881e+00,  1.0760e+00,  5.5404e-01,  ..., -4.1125e-01,
           -3.9030e-01, -3.2707e-01],
          ...,
          [ 3.5591e+00,  3.8263e+00,  1.2163e+00,  ..., -6.6586e-01,
           -1.9238e+00, -1.7205e+00],
          [ 1.5570e+00,  2.2883e+00, -7.4584e-01,  ...,  1.7321e+00,
           -4.8728e-01, -5.1052e-01],
          [-5.0331e-01, -6.6329e-01,  3.4068e-01,  ..., -5.8974e-01,
            8.0413e-01,  6.7210e-01]],

         [[-7.8825e-02, -1.7988e-01, -4.7321e-01,  ..., -6.4529e-01,
            1.6685e+00,  1.4285e+00],
          [-8.9735e-01, -8.1413e-01, -4.8527e-02,  ...,  7.7919e-01,
           -4.9819e-01, -5.8922e-01],
          [ 1.3988e+00,  1.5623e+00,  8.0480e-01,  ..., -1.0357e+00,
           -1.5504e+00, -1.3215e+00],
          ...,
     