In [1]:
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
class MaskedLinear(nn.Linear):
  """ same as Linear except has a configurable mask on the weights """
  def __init__(self,in_features,out_features,bias = True):
    #super(MaskedLinear,self).__init__(in_features,out_features,bias)
    super().__init__(in_features,out_features,bias)
    self.register_buffer('mask', torch.ones(out_features, in_features))

  def set_mask(self,mask):
    self.mask.data.copy_(torch.from_numpy(mask.astype(np.unit8).T)) 
  
  def forward(self,input):
    return F.Linear(input, self.mask*self.weight, self.bias) # W*x + bias ; (Adamar's product)
  """
  - about register_buffer:
  we want to store torch.ones, and we don't want to make new parameter, which will store it, hence we do register_buffer
  - about copy:
  
  """

In [6]:
class MADE(nn.Module):
  def __init__(self, nin, nout, hidden_sizes, num_masks, natural_ordering):
    """
        nin: integer; number of inputs
        hidden sizes: a list of integers; number of units in hidden layers
        nout: integer; number of outputs, which usually collectively parameterize some kind of 1D distribution
              note: if nout is e.g. 2x larger than nin (perhaps the mean and std), then the first nin
              will be all the means and the second nin will be stds. i.e. output dimensions depend on the
              same input dimensions in "chunks" and should be carefully decoded downstream appropriately.
              the output of running the tests for this file makes this a bit more clear with examples.
        num_masks: can be used to train ensemble over orderings/connections
        natural_ordering: force natural ordering of dimensions, don't use random permutations
    """
    super().__init__()# self(MADE,self).__init__(#paramerets)
    self.nin = nin # self.fc1 = nn.Linear(nin, the_follow_dimension)
    self.hidden_sizes = hidden_sizes # a list of integers (on which layer how many nodes)
    self.bins = bins
    self.nout = nout ## ?? self.nout = nin*bins
    self.ordering = np.arange(self.nin)

    assert self.nout % self.nin == 0 , "nout must be integer multiple of nin"

    # define a simple MLP neural net
    self.net = []
    hs = [nin] + hidden_sizes + [nout] # list of sizes of each layer
    for h0,h1 in zip(hs,hs[1:]):
      self.net.extend([
                       MaskedLinear(h0,h1),
                       nn.Relu()# F.relu()          
      ])
    self.net.pop() #pop the Last Relu for the output layer
    self.net = nn.Sequential(*self.net)



    # seeds for orders/connectivities of the model ensemble
    self.natural_ordering = natural_ordering
    self.num_masks = num_masks
    self.seed = 0 # for cycling through num_masks orderings

    self.m = {}
    self.update_masks() # builds the initial self.m connectivity
    # note, we could also precompute the masks and cache them, but this
    # could get memory expensive for large number of masks.

    def update_mask(self):
      
      L = len(self.hidden_sizes)

      # fetch the next seed and construct a random stream
      rng = np.random.RandomState(self.seed)
      self.seed = (self.seed + 1) % self.num_masks

      # sample the order of the inputs and the connectivity of all neurons
      self.m[-1] = np.arange(self.nin) if self.natural_ordering else rng.permutation(self.nin) 
      for l in range(L):
        self.m[l] = rng.randint(self.m[l-1].min(),self.nin -1, size = self.hidden_sizes[l])

      # construct the mask matrices
      masks = [self.m[l-1][:,None] <= self.m[l][None,:] for l in range(L)]
      masks.append(self.m[L-1][:,None] < self.m[-1][None,:])

      # handle the case where nout = nin * k, for integer k > 1
      if self.nout > self.nin:
        k = int(self.nout / self.nin)
        # replicate the mask across the other outputs
        masks[-1] = np.concatenate([masks[-1]]*k, axis=1)

      # set the masks in all MaskedLinear layers
      layers = [l for l in self.net.modules() if isinstance(l, MaskedLinear)]
      for l,m in zip(layers, masks):
        l.set_mask(m)

    def visualize_masks(self):
      for m in self.masks:
        plt.figure(figsize=(5, 5))
        plt.imshow(m, cmap='gray')
        plt.show()

    def forward(self,x):
      self.net(x)
