In [None]:
import os

import numpy as np
import matplotlib.pyplot as plt
import torch
from sklearn.datasets import load_digits
from sklearn import datasets
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F



In [2]:
class Digits(Dataset):
    """Scikit-Learn Digits dataset."""

    def __init__(self, mode="train", transforms=None):
        digits = load_digits()
        if mode == "train":
            self.data = digits.data[:1000].astype(np.float32)
        elif mode == "val":
            self.data = digits.data[1000:1350].astype(np.float32)
        else:
            self.data = digits.data[1350:].astype(np.float32)

        self.transforms = transforms

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        if self.transforms:
            sample = self.transforms(sample)
        return sample

In [3]:
class RoundStraightThrough(torch.autograd.Function):
    def __init__(self):
        super().__init__()
        @staticmethod
        def forward(ctx, input):
            rounded = torch.round(input, out=None)
            return rounded
        
        @staticmethod
        def backward(ctx, grad_output):
            grad_input = grad_output.clone()
            return grad_input

In [4]:
def log_min_exp(a, b):
    return a + torch.log1p(-torch.exp(b - a))

In [5]:
def log_integer_probability(x, mean, logscale):
    scale = torch.exp(logscale)
    logp = log_min_exp(
        F.logsigmoid((x + 0.5) - mean / scale),
        F.logsigmoid((x - 0.5) - mean / scale)
    )
    return logp

In [6]:
D = 64   # input dimension
M = 256  # the number of neurons in scale (s) and translation (t) nets

lr = 1e-3 # learning rate
num_epochs = 1000 # max. number of epochs
max_patience = 20 # an early stopping is used, if training doesn't improve for longer than 20 epochs, it is stopped
nett = lambda: nn.Sequential(nn.Linear(D // 2, M), nn.LeakyReLU(),
                                     nn.Linear(M, M), nn.LeakyReLU(),
                                     nn.Linear(M, D // 2))
netts = [nett]

In [None]:
class IDF(nn.Module):
    def __init__(self, netts, num_flows, D=2):
        super().__init__()
        if len(netts) == 1:
            self.t = nn.ModuleList([netts[0]() for _ in range(num_flows)])
            self.idf_git = 1

        elif len(netts) == 4:
            self.t_a = nn.ModuleList([netts[0]() for _ in range(num_flows)])
            self.t_b = nn.ModuleList([netts[1]() for _ in range(num_flows)])
            self.t_c = nn.ModuleList([netts[2]() for _ in range(num_flows)])
            self.t_d = nn.ModuleList([netts[3]() for _ in range(num_flows)])
            self.idf_git = 1

        else:
            raise ValueError(f"The transformation net need to be either 1 or 4. The provided net contains {len(netts)} layers.")

        self.num_flows = num_flows
        self.round = RoundStraightThrough.apply

        self.mean = nn.Parameter(torch.zeros(1, D))
        self.logscale = nn.Parameter(torch.ones(1, D))

        self.D = D

    def coupling(self, x, index, forward=True):
        if self.idf_git == 1:
            (xa, xb) = torch.chunk(x, 2, 1)

            if forward:
                yb = xb + self.round(self.t[index](xa))
            else:
                yb = xb - self.round(self.t[index](xa))
            return torch.cat((xa, yb), 1)
        elif self.idf_git == 4:
            (xa, xb, xc, xd) = torch.chunk(x, 4, 1)
            if forward:
                ya = xa + self.round(self.t_a[index](torch.cat((xb, xc, xd), 1)))
                yb = xb + self.round(self.t_b[index](torch.cat((ya, xc, xd), 1)))
                yc = xc + self.round(self.t_c[index](torch.cat((ya, yb, xd), 1)))
                yd = xd + self.round(self.t_d[index](torch.cat((ya, yb, yc), 1)))
            else:
                yd = xd - self.round(self.t_d[index](torch.cat((xa, xb, xc), 1)))
                yc = xc - self.round(self.t_c[index](torch.cat((xa, xb, yd), 1)))
                yb = xb - self.round(self.t_b[index](torch.cat((xa, yc, yd), 1)))
                ya = xa - self.round(self.t_a[index](torch.cat((yb, yc, yd), 1)))

            return torch.cat((ya, yb, yc, yd), 1)

    def permute(self, x):
        return x.flip(1)

    def f(self, x):
        z = x
        for i in range(self.num_flows):
            z = self.coupling(z, i, forward=True)
            z = self.permute(z)
        return z

    def f_inv(self, z):
        x = z
        for i in reversed(range(self.num_flows)):
            x = self.permute(x)
            x = self.coupling(x, i, forward=False)

        return x

    def log_prior(self, x):
        log_p = log_integer_probability(x, self.mean, self.logscale)
        return log_p.sum(1)

    def forward(self, x, reduction="avg"):
        z = self.f(x)
        if reduction == 'sum':
            return -self.log_prior(z).sum()
        else:
            return -self.log_prior(z).mean()

    def sample(self, batch_size, int_max=100):
        z = self.prior_sample(batch_size=batch_size, D=self.D)
        x = self.f_inv(z)
        return x.view(batch_size, 2, self.D)

    def prior_sample(self, batch_size, D=2):
        y = torch.rand(batch_size, self.D)
        x = torch.exp(self.logscale) * torch.log(y / (1. - y)) + self.mean
        return torch.round(x)

In [8]:
IDF(netts, 1)

IDF(
  (t): ModuleList(
    (0): Sequential(
      (0): Linear(in_features=32, out_features=256, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=256, out_features=256, bias=True)
      (3): LeakyReLU(negative_slope=0.01)
      (4): Linear(in_features=256, out_features=32, bias=True)
    )
  )
)