Imports


In [11]:
import matplotlib.pyplot as plt
import sklearn.datasets as datasets
import numpy as np

import random

import torch
from torch import nn
from torch import optim
from torch.utils.data import TensorDataset, DataLoader

! pip install nflows
from nflows.flows.base import Flow
from nflows.distributions.normal import StandardNormal
from nflows.transforms.base import CompositeTransform
from nflows.transforms.autoregressive import MaskedAffineAutoregressiveTransform
from nflows.transforms.permutations import ReversePermutation



# MNIST

In [2]:
import gzip
import pickle

def load_data():
  f = gzip.open('mnist.pkl.gz', 'rb')

  # fix for encoding of pickle
  u = pickle._Unpickler(f)
  u.encoding = 'latin1'

  train_data, validation_data, test_data = u.load()
  f.close()

  return (train_data, validation_data, test_data)

In [3]:
def preprocess_data(data, rng, alpha=1.0e-6, logit=False, should_dequantize=True):
  """
  Processes the dataset
  """
  x = dequantize(data[0], rng) if should_dequantize else data[0]  # dequantize pixels
  x = logit_transform(x, alpha) if logit else x                   # logit
  labels = data[1]                                                # numeric labels
  encoded_labels = one_hot_encode(labels, 10)                     # 1-hot encoded labels
  return (x, labels, encoded_labels)

def dequantize(x, rng):
  """
  Adds noise to pixels to dequantize them
  """
  return x + rng.rand(*x.shape) / 256.0

def logit_transform(x, alpha=1.0e-6):
  """
  Transforms pixel values with logit to reduce the impact of boundary effects
  """
  a = alpha + (1 - 2*alpha) * x
  return np.log(a / (1.0 - a))

def one_hot_encode(labels, nr_labels):
  """
  Transforms numeric labels to 1-hot encoded labels
  """
  y = np.zeros([labels.size, nr_labels])
  y[range(labels.size), labels] = 1
  return y

In [4]:
def load_vectorized_data():
  train_data, validation_data, test_data = load_data()
  rng = np.random.RandomState(42)
  processed_train_data = preprocess_data(train_data, rng, logit=True)
  processed_validation_data = preprocess_data(validation_data, rng, logit=True)
  processed_test_data = preprocess_data(test_data, rng, logit=True)
  return (processed_train_data, processed_validation_data, processed_test_data)

In [5]:
train_data, validation_data, test_data = load_vectorized_data()

train_x = train_data[0]
train_labels = train_data[1]
train_y = train_data[2]

In [8]:
def build_flow(num_dim, hidden_features=1024, layers=5, batch_norm=False):
  base_dist = StandardNormal(shape=[num_dim])
  transforms = []
  for _ in range(layers):
    transforms.append(ReversePermutation(features=num_dim))
    transforms.append(MaskedAffineAutoregressiveTransform(features=num_dim,
                                                          hidden_features=hidden_features,
                                                          use_batch_norm=batch_norm))
  transform = CompositeTransform(transforms)
  return Flow(transform, base_dist)

In [20]:
def train(x, num_iter, train_loader, weight_decay=None):
  num_dim = x.shape[1]
  flow = build_flow(num_dim)
  optimizer = optim.Adam(flow.parameters(),
                         weight_decay=0 if weight_decay is None else weight_decay)

  x = x.clone().detach()

  i = 1
  for x, y, _ in train_loader:
    optimizer.zero_grad()
    loss = -flow.log_prob(inputs=x).mean()
    loss.backward()
    optimizer.step()

  # for iter in range(1, num_iter + 1):
  #   optimizer.zero_grad()
  #   loss = -flow.log_prob(inputs=x).mean()
  #   loss.backward()
  #   optimizer.step()

    if i % 100 == 0:
      print('iteration {}, loss {:.5f}'.format(i, loss.detach().numpy()))
    i+=1

  return flow


In [21]:
# train_x = torch.tensor(train_x, dtype=torch.float32)
# train_y = torch.tensor(train_y, dtype=torch.float32)
# train_labels = torch.tensor(train_labels, dtype=torch.float32)
trainset = TensorDataset(train_x.clone().detach(), train_y.clone().detach(), train_labels.clone().detach())
trainloader = DataLoader(trainset, batch_size=100, shuffle=True, num_workers=0)

In [22]:
epochs = 30
for epoch in range(1, epochs + 1):
  print('Epoch {}'.format(epoch))
  model = train(train_x, 5000, trainloader)



Epoch 1
iteration 100, loss 1950.92798
iteration 200, loss 1608.98401
iteration 300, loss 1511.99036
iteration 400, loss 1469.01550
iteration 500, loss 1425.98669
Epoch 2
iteration 100, loss 6136.89648
iteration 200, loss 3037.57788
iteration 300, loss 2683.40063
iteration 400, loss 2422.02759
iteration 500, loss 4343.75732
Epoch 3
iteration 100, loss 2060.49902
iteration 200, loss 1744.64954
iteration 300, loss 1590.92078
iteration 400, loss 1550.77344
iteration 500, loss 1480.35083
Epoch 4
iteration 100, loss 6548.99756
iteration 200, loss 3573.33228
iteration 300, loss 4737.01514
iteration 400, loss 4105.22754
iteration 500, loss 3695.04614
Epoch 5
iteration 100, loss 2219.29370
iteration 200, loss 1782.40552
iteration 300, loss 1665.16370
iteration 400, loss 1561.95850
iteration 500, loss 1523.92419
Epoch 6
iteration 100, loss 2602.58765
iteration 200, loss 2059.84717
iteration 300, loss 1863.29675
iteration 400, loss 1734.94177
iteration 500, loss 1653.68713
Epoch 7
iteration 100,