In [None]:
import numpy as np


In [None]:
def sample_data():
    count = 100000
    rand = np.random.RandomState(0)
    a = [[-1.5, 2.5]] + rand.randn(count // 3, 2) * 0.2
    b = [[1.5, 2.5]] + rand.randn(count // 3, 2) * 0.2
    c = np.c_[2 * np.cos(np.linspace(0, np.pi, count // 3)),
    -np.sin(np.linspace(0, np.pi, count // 3))]

    c += rand.randn(*c.shape) * 0.2
    data_x = np.concatenate([a, b, c], axis=0)
    data_y = np.array([0] * len(a) + [1] * len(b) + [2] * len(c))
    perm = rand.permutation(len(data_x))
    return data_x[perm], data_y[perm]

In [None]:
class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(MLP, self).__init__()
        self.theta_model = Model(input_size)
        
        self.input_size = input_size
        self.num_classes = num_classes
        
        self.fc1 = nn.Linear(input_size, hidden_size) 
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)
    
    def forward(self, batch):
        xs_batch = batch[:, 0]
        ys_batch = batch[:, 1]
        
        xs_matrix = xs_batch[:, None] == np.arange(self.input_size)
        
        matrix = torch.tensor(xs_matrix).float()
        
        out = self.fc1(matrix)
        out = self.relu(out)
        out = self.fc2(out)
        out = F.softmax(out, dim=1)
        
        xs_error = self.theta_model.forward(xs_batch)
        ys_error = -torch.gather(out, 1, torch.tensor(ys_batch[:, None])).log().sum() / ys_batch.size
        return xs_error + ys_error

In [None]:
class MADE(torch.nn.Module):
    def __init__(self, size=400, hidden_layers=1):
        super().__init__()
        
        self.size = size
        
        ms = [np.random.permutation(size)]
        for _ in range(hidden_layers + 1):
            ms.append(np.random.randint(ms[-1].min(), size, size=size))
            
        masks = []
        for m, next_m in zip(ms, ms[1:]):
            masks.append(torch.tensor((m[:,None] <= next_m[None,:])).float())
        masks.append(torch.tensor((ms[-1][:,None] < ms[0][None,:])).float())
            
        layers = []
            
        for mask in masks:
            layers.append(MaskedLinear(size, size, mask))
            layers.append(nn.ReLU())
        layers.pop()
            
        self.model = nn.Sequential(*layers)

    def forward(self, batch):
        xs_batch = batch[:, 0]
        ys_batch = batch[:, 1]
        
        xs_matrix = xs_batch[:, None] == np.arange(self.size / 2)
        ys_matrix = ys_batch[:, None] == np.arange(self.size / 2)
        
        matrix = torch.tensor(np.concatenate((xs_matrix, ys_matrix), 1)).float()
       
        out = self.model(matrix)
        out = out.view(-1, 2, 200)
        out = F.softmax(out, dim=2).view(-1, 400)
        
        return -torch.gather(out, 1, torch.tensor(batch)).log().sum() / batch.size