# MADE: Masked Autencoder for Distribution Estimation

We will apply this algorithm to MNIST dataset for generating new handrwritten digits.

## Table of Contents

1. [Hyperparameters...](#1st)

<div id='1st'/>

## 1. Let us import some libraries and define some classes and functions

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import argparse
import torch.optim as optim
from torch.autograd import Variable
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import pylab
import torch.distributions.binomial
import os
from torchvision.utils import save_image
os.environ['KMP_DUPLICATE_LIB_OK']='True'

In [2]:
class MaskedLinear(nn.Linear):
    """ same as Linear except has a configurable mask on the weights """

    # The init gets runned once an object (self) has been assigned this class:)
    def __init__(self, in_features, out_features, bias=True):
        # This initializes the nn.Linear class
        super().__init__(in_features, out_features, bias)
        
        # We initialize the mask as ones and it means we are not treating them as parameters!
        self.register_buffer('mask', torch.ones(out_features, in_features))

    # This is used to set the masks
    def set_mask(self, mask):
        self.mask.data.copy_(torch.from_numpy(mask.astype(np.uint8).T))

    def forward(self, input):
        return F.linear(input.float(), self.mask.float() * self.weight.float(), self.bias.float())

In [3]:
class MADE(nn.Module):
    def __init__(self, nin, hidden_sizes, nout, num_masks=1, natural_ordering=True):
        
        """
        nin: integer; number of inputs
        hidden sizes: a list of integers; number of units in hidden layers
        nout: integer; number of outputs, which usually collectively parameterize some kind of 1D distribution
              note: if nout is e.g. 2x larger than nin (perhaps the mean and std), then the first nin
              will be all the means and the second nin will be stds. i.e. output dimensions depend on the
              same input dimensions in "chunks" and should be carefully decoded downstream appropriately.
              the output of running the tests for this file makes this a bit more clear with examples.
        num_masks: can be used to train ensemble over orderings/connections
        natural_ordering: force natural ordering of dimensions, don't use random permutations
        """

        super().__init__() # Initializes nn.Module
        self.nin = nin
        self.nout = nout
        self.hidden_sizes = hidden_sizes
        
        assert self.nout % self.nin == 0, "nout must be integer multiple of nin"

        # define a simple MLP neural net
        self.net = []
        hs = [nin] + hidden_sizes + [nout]
        for h0,h1 in zip(hs, hs[1:]): # zip function goes through 2 iterable at the same time.
            self.net.extend([
                    MaskedLinear(h0, h1),
                    nn.ReLU(),
                ])
        self.net.pop() # pop the last ReLU for the output layer
        self.net.extend([nn.Sigmoid()])
        self.net = nn.Sequential(*self.net)

        # seeds for orders/connectivities of the model ensemble
        self.natural_ordering = natural_ordering
        self.num_masks = num_masks
        self.seed = 0 # for cycling through num_masks orderings

        self.m = {}
        self.update_masks() # builds the initial self.m connectivity
        # note, we could also precompute the masks and cache them, but this
        # could get memory expensive for large number of masks.

    def update_masks(self):
        if self.m and self.num_masks == 1: return # only a single seed, skip for efficiency! YES! perfect.
        L = len(self.hidden_sizes) # number of layers

        # fetch the next seed and construct a random stream
        rng = np.random.RandomState(self.seed)
        self.seed = (self.seed + 1) % self.num_masks # we repeat the process every num_masks.

        # sample the order of the inputs and the connectivity of all neurons
        self.m[-1] = np.arange(self.nin) if self.natural_ordering else rng.permutation(self.nin)
        for l in range(L):
            self.m[l] = rng.randint(self.m[l-1].min(), self.nin-1, size=self.hidden_sizes[l])

        # construct the mask matrices
        masks = [self.m[l-1][:,None] <= self.m[l][None,:] for l in range(L)]
        masks.append(self.m[L-1][:,None] < self.m[-1][None,:])

        # handle the case where nout = nin * k, for integer k > 1
        if self.nout > self.nin:
            k = int(self.nout / self.nin)
            # replicate the mask across the other outputs
            masks[-1] = np.concatenate([masks[-1]]*k, axis=1)

        # set the masks in all MaskedLinear layers
        layers = [l for l in self.net.modules() if isinstance(l, MaskedLinear)]
        for l,m in zip(layers, masks):
            l.set_mask(m)

    def forward(self, x):
        return self.net(x)

In [4]:
def run_epoch(split, upto=None):
    torch.set_grad_enabled(split=='train') # enable/disable grad for efficiency of forwarding test batches
    model.train() if split == 'train' else model.eval()
    nsamples = 1 if split == 'train' else args["samples"]
    x = xtr if split == 'train' else xte
    N,D = x.size() # N is the number of samples and D is the size of each sample
    			   # In our case 60.000x784 or 10.000x784 are the sizes.
    B = 100 # batch size, less than in the loaded code!
    nsteps = N//B if upto is None else min(N//B, upto) # enough steps so that we use the whole set
    lossfs = []
    for step in range(nsteps):

        # fetch the next batch of data
        xb = Variable(x[step*B:step*B+B])
        # xb = x[step*B:step*B+B]
        xb = xb.float()

        # print(xb.dtype)

        # get the logits, potentially run the same batch a number of times, resampling each time
        xbhat = torch.zeros_like(xb)
        for s in range(nsamples):
            # perform order/connectivity-agnostic training by resampling the masks
            if step % args["resample_every"] == 0 or split == 'test': # if in test, cycle masks every time
                model.update_masks()
            # forward the model
            xbhat += model(xb)
        xbhat /= nsamples

        # evaluate the binary cross entropy loss
        loss = F.binary_cross_entropy(xbhat, xb, size_average=False) / B # With logits before...
        lossf = loss.data.item()
        lossfs.append(lossf)

        # backward/update
        if split == 'train':
            opt.zero_grad()
            loss.backward()            
            opt.step()

    print("%s epoch average loss: %f" % (split, np.mean(lossfs)))

    if split == 'train':
        trainL[epoch] = np.mean(lossfs)

    if split == 'test':
        testL[epoch] = np.mean(lossfs)

In [9]:
def sample_64(n):
    sample = torch.zeros((64,784))
    for r in range(64):
        image = np.zeros(28*28)
        image = torch.tensor(image)
        for i in range(28*28):
            # for idx, m in enumerate(model.named_modules()):
            #     print(idx, '->', m)
            prob = model(image) # Why not outputting between 0 and 1?¿?
            # print(prob[0:20])
            #pixel = np.random.binomial(1, prob[i].detach().numpy())
            pixel = torch.round(prob[i]) # clearer image
            image[i] = pixel
        sample[r] = image
    # Now image stores the sampled image using the regular order...
    path = 'samples/sample_' + str(n) + '.png'
    sample = sample.view(64,1,28,28)
    save_image(sample, os.path.join(args["sample_dir"], 'fake_images-{}.png'.format(epoch+1)))

In [14]:
args = {
  'dtr_path': './data/x_train.npy',

  'dte_path': './data/x_test.npy',

  'hiddens': '500',

  'num_masks': 1,

  'resample_every': 20,

  'samples': 1, 
    
  'epochs': 100,
    
  'sample_dir': './samples'
}

In [11]:
# Hyper-parameters:
np.random.seed(42)
torch.manual_seed(42)
num_epochs = args["epochs"] # 150

## 2. Image Preprocessing

In [12]:
# IMAGE PREPROCESSING:

xtr = np.load(args["dtr_path"])
xte = np.load(args["dte_path"])
xtr = torch.from_numpy(xtr)
xte = torch.from_numpy(xte)

# Sizes: 60000x784, 10000x784, recall that 28*28 = 784

## 3. Model

In [13]:
# MODEL and optimizer:

# Recall, map(fun,iter) applies the function to every element of the iter.
hidden_list = list(map(int, args["hiddens"].split(',')))
model = MADE(xtr.size(1), hidden_list, xtr.size(1), num_masks=args["num_masks"])
print("number of model parameters:",sum([np.prod(p.size()) for p in model.parameters()]))
# model.cuda()
model = model.float()

# set up the optimizer
opt = torch.optim.Adam(model.parameters(), 1e-3, weight_decay=1e-4) # Initially -3
# Here we apply weight decay to the learning rate, every 45 epochs
scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=45, gamma=0.1)

number of model parameters: 785284


In [None]:
#### TRAINING:
#state_dict = torch.load('save/model_final.ckpt')
#model.load_state_dict(state_dict)

testL = np.zeros(num_epochs)
trainL = np.zeros(num_epochs)

for epoch in range(num_epochs):
    
    print("epoch %d" % (epoch, ))
    scheduler.step(epoch)
    run_epoch('test', upto=5) # run only a few batches for approximate test accuracy
    run_epoch('train')
    
    if (epoch) % 10 == 0:
        sample_64(epoch)
        torch.save(model.state_dict(), os.path.join('save', 'model--{}.ckpt'.format(epoch+1)))

    plt.figure()
    pylab.xlim(0, num_epochs + 1)
    plt.plot(range(1, num_epochs + 1), trainL, label='train log-likelihood')
    plt.plot(range(1, num_epochs + 1), testL, label='test log-likelihood')
    plt.legend()
    plt.title("MADE log-likelihood")
    plt.xlabel("Epoch")
    plt.ylabel("Log-likelihood")
    plt.savefig(os.path.join('save', 'loss.pdf'))
    plt.show()

torch.save(model.state_dict(), os.path.join('save', 'model_final.ckpt'))
print("optimization done. full test set eval:") # 79.72 my last experiment with 100 epochs!
if num_epochs > 0:
    run_epoch('test')

## 4. Nearest neighbors

In [None]:
# NEAREST NEIGHBOUR!!!!!
# Images 28x28, search the closest one.
# function(generated_image) --> closest training_image

if NN == True:
    aux_data_loader = np.load(args["dtr_path"])

    def nearest_gt(generated_image):
        min_d = 0
        closest = False
        for i, image in enumerate(aux_data_loader):
            image = np.array(image).reshape(28,28) # all distances in binary...
            image = torch.tensor(image).float()
            d = torch.dist(generated_image,image) # must be torch tensors (1,28,28)
            if i == 0 or d < min_d:
                min_d = d
                closest = image

        return closest

    # calculate closest to...
    samples = torch.zeros(24, 1, 28, 28)
    NN = torch.zeros(24, 1, 28, 28)
    for i in range(0,24):
            image = torch.tensor(sample(i))
            samples[i] = image
            NN[i] = nearest_gt(samples[i])
            print(i)
    save_image(samples, 'f24.png')
    save_image(NN.data, 'NN24.png')

## 5. References

[1] https://github.com/karpathy/pytorch-made