In [205]:
import os

import numpy as np
import torch
import torchvision.datasets as dset
import torch.nn as nn
from torch.distributions import MultivariateNormal
import torchvision.transforms as transforms

import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO
from pyro.optim import Adam

from utils.custom_mlp import MLP, Exp
from utils.mnist_cached import MNISTCached, mkdir_p

In [206]:
pyro.enable_validation(True)
pyro.distributions.enable_validation(False)
pyro.set_rng_seed(0)
pyro.clear_param_store()
# Enable smoke test - run the notebook cells on CI.
smoke_test = 'CI' in os.environ

In [207]:
# for loading and batching MNIST dataset
def setup_data_loaders(batch_size=128, use_cuda=False):
    root = './data'
    download = True
    trans = transforms.ToTensor()
    train_set = dset.MNIST(root=root, train=True, transform=trans,
                           download=download)
    test_set = dset.MNIST(root=root, train=False, transform=trans)

    kwargs = {'num_workers': 1, 'pin_memory': use_cuda}
    train_loader = torch.utils.data.DataLoader(dataset=train_set,
        batch_size=batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(dataset=test_set,
        batch_size=batch_size, shuffle=False, **kwargs)
    return train_loader, test_loader

In [208]:
def CustomLinear(last_layer_size, layer_size, use_cuda):
    # get our nn layer module (in this case nn.Linear by default)
    cur_linear_layer = nn.Linear(last_layer_size, layer_size)
    # for numerical stability -- initialize the layer properly
    cur_linear_layer.weight.data.normal_(0, 0.001)
    cur_linear_layer.bias.data.normal_(0, 0.001)
    # use GPUs to share data during training (if available)
    if use_cuda:
        cur_linear_layer = nn.DataParallel(cur_linear_layer)
    return cur_linear_layer

In [209]:
class Memory():
    def __init__(self, n_address, n_memory_vec, n_address_vec, m_value = 1.0):
        '''
        
        '''
        self.y_dim = n_memory_vec
        self.z_dim = n_memory_vec
        self.n_address = n_address
        self.n_memory_vec = n_memory_vec
        self.n_address_vec = n_address_vec
        self.m_value = 1.0
        self.R = torch.rand(self.n_address, self.n_memory_vec)
        self.U = torch.eye(self.n_address)
        self.V = torch.eye(self.n_memory_vec)
        self.M = None
        
        self.A = nn.Parameter(torch.eye(self.n_address_vec, self.n_address))
        self.fc_y1 = nn.Linear(self.y_dim, self.n_address_vec//2)
        self.fc_y2 = nn.Linear(self.n_address_vec//2, self.n_address_vec)
        
    
    def forward(self, _input, batch_size):
        _y = self.fc_y1(_input)
        bt = self.fc_y2(_y+0.08*torch.randn(_y.size()))
        wt = self.A(bt)
        return wt.view(batch_size, 1, -1)
#         self.M = MultivariateNormal(self.R.view(-1, self.n_address*self.n_memory_vec), covariance_matrix=kronecker_product(self.V, self.U))
    
    
    def write(self, y_list, Z, v_sigma):
        W = None
        len_t = len(y_list)
        for _y in y_list:
            _y = self.fc_y1(_y)
            bt = self.fc_y2(_y)
            wt = self.A(bt)
            if W is None:
                W = wt.view(1, self.n_address)
            else:
                W = torch.cat((W, wt), 0)
        dd = Z - torch.matmul(W, self.R)
        sigma_c  = torch.matmul(W, self.U)
        sigma_gusai = torch.eye(len_t)
        sigma_z = torch.matmul(torch.matmul(W, self.U), torch.t(W)) + sigma_gusai * (v_sigma*v_sigma)
        R = self.R + torch.matmul(torch.matmul(torch.t(sigma_c),torch.inverse(sigma_z)), dd)
        U = self.U - torch.matmul(torch.matmul(torch.t(sigma_c),torch.inverse(sigma_z)), sigma_c)

In [210]:
class Encoder1(nn.Module):
    '''
    return q(y|x) and q(z|x)
    '''
    def __init__(self, D_in, H, C):
        super(Encoder1, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, 2*C)
        self.linear3y = torch.nn.Linear(2*C, C)
        self.linear3z = torch.nn.Linear(2*C, C)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        y = F.relu(self.linear3y(x))
        z = F.relu(self.linear3z(x))
        return y, z


class Encoder2(nn.Module):
    '''
    return q(z|x, y, M) from 2 C-size vector(q(z'|x) and p(z|y, M))
    '''
    def __init__(self, C, H):
        super(Encoder1, self).__init__()
        self.linear1 = torch.nn.Linear(2*C, H)
        self.linear2 = torch.nn.Linear(H, C)

    def forward(self, x1, x2):
        x = torch.cat((x1, x2), 1)
        x = F.relu(self.linear1(x))
        z = F.relu(self.linear2(x))
        return z


class Decoder(nn.Module):
    '''
    return q(x|z)
    '''
    def __init__(self, D_in, H, D_out):
        super(Decoder, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = F.relu(self.linear1(x))
        return self.sigmoid(self.linear2(x))

In [220]:
class KanervaMachine(nn.Module):
    '''
    '''
    def __init__(
        self,
        n_address=40,
        n_memory_vec=50,
        n_address_vec=25,
        n_input=784,
        n_hidden=100,
        batch_size=10,
        aux_loss_multiplier=None,
        use_cuda=False
    ):
        super(KanervaMachine, self).__init__()
        if use_cuda:
            self.cuda()
        self.use_cuda = use_cuda
        self.n_address = n_address
        self.n_memory_vec = n_memory_vec
        self.n_address_vec = n_address_vec
        self.n_latent_vec = n_memory_vec # the same size as one of the memory-cols
        self.batch_size = batch_size
        self.aux_loss_multiplier = aux_loss_multiplier
        self.n_input = n_input
        self.n_hidden = n_hidden
        self.A = CustomLinear(self.n_address_vec, self.n_address, use_cuda) ## Pointer Matrix
        self.R = None ##  (n_address × n_memory_vec) matrix as the mean of M
        self.U = None ## (n_address × n_address matrix) that provides the covariance between rows of M
        self.V = None ## (n_memory_vec × n_memory_vec) matrix that provides the covariance between cols of M
        self.encoder1 = None
        self.encoder2 = None
        self.mlp_memory = None
        self.decoder = None
        self.allow_broadcast = False
        self.setup_networks()
    
    def setup_networks(self):
        self.encoder1 = MLP(
            [self.n_input] + [self.n_hidden, ] + [[self.n_latent_vec, self.n_latent_vec, self.n_latent_vec, self.n_latent_vec]],
            activation=nn.Softplus,
            output_activation=[None, None, Exp, Exp],
            allow_broadcast=self.allow_broadcast,
            use_cuda=self.use_cuda
        )
        self.encoder2 = MLP(
            [self.n_latent_vec + self.n_latent_vec] + [self.n_hidden, ] + [[self.n_latent_vec, self.n_latent_vec]],
            activation=nn.Softplus,
            output_activation=[None, Exp],
            allow_broadcast=self.allow_broadcast,
            use_cuda=self.use_cuda
        )
        self.decoder = MLP(
            [self.n_latent_vec] + [self.n_hidden, ] + [self.n_input],
            activation=nn.Softplus,
            output_activation=nn.Sigmoid,
            allow_broadcast=self.allow_broadcast,
            use_cuda=self.use_cuda
        )
        self.mlp_memory = MLP(
            [self.n_latent_vec] + [self.n_address_vec//2, ] + [self.n_address_vec],
            activation=nn.Softplus,
            output_activation=nn.Softplus,
            allow_broadcast=self.allow_broadcast,
            use_cuda=self.use_cuda
        )
        self.R = torch.zeros(self.n_address, self.n_memory_vec, requires_grad=True)
        self.U = torch.eye(self.n_address, requires_grad=True)
        self.V = torch.eye(self.n_memory_vec, requires_grad=True)
        
    def model(self, xs):
        """
        The model corresponds to the following generative process:
        p(z) = normal(0,I)              # handwriting style (latent)
        p(y|x) = categorical(I/10.)     # which digit (semi-supervised)
        p(x|y,z) = bernoulli(loc(y,z))   # an image
        loc is given by a neural network  `decoder`
        :param xs: a batch of scaled vectors of pixels from an image
        :param ys: (optional) a batch of the class labels i.e.
                   the digit corresponding to the image(s)
        :return: None
        """
        # register this pytorch module and all of its sub-modules with pyro
#         print("model")
        pyro.module("kanerva_machine", self)
        
        batch_size = xs.size(0)
        with pyro.iarange("data", xs.size(0)):
            
            # sample y, the latent vector as the key of the memory, from the constant prior distribution
            y_prior_loc = xs.new_zeros([batch_size, self.n_latent_vec])
            y_prior_scale = xs.new_ones([batch_size, self.n_latent_vec])
            ys = pyro.sample("y", dist.Normal(y_prior_loc, y_prior_scale).independent(1))
            
            # sample z, the latent vector to generate an image, from the memory with "y".
            b = self.mlp_memory.forward(ys)
            w = self.A(b)
            z_loc = torch.matmul(w, self.R) # (batch_size, n_memory_vec)
            z_scale = torch.ones(batch_size, self.n_memory_vec)
            zs = pyro.sample("z", dist.Normal(z_loc, z_scale).independent(1))
            
            # sample x, the target images, with "z".
            x_loc = self.decoder.forward(zs)
            pyro.sample("x", dist.Bernoulli(x_loc).independent(1), obs=xs)
            
            return x_loc
    
    def guide(self, xs, len_t=30):
        """
        The guide corresponds to the following:
        q(y|x) = categorical(alpha(x))              # infer digit from an image
        q(z|x,y) = normal(loc(x,y),scale(x,y))       # infer handwriting style from an image and the digit
        loc, scale are given by a neural network `encoder_z`
        alpha is given by a neural network `encoder_y`
        :param xs: a batch of scaled vectors of pixels from an image
        :param ys: (optional) a batch of the class labels i.e.
                   the digit corresponding to the image(s)
        :return: None
        """
        # Writing Phase:
        #  sample y, z from q(y|x) and q(z|x)
#         print("guide")
        with pyro.iarange("data", xs.size(0)):
            y_loc_w, z_loc_w, y_scale_w, z_scale_w = self.encoder1.forward(xs)
            ys = pyro.sample("y", dist.Normal(y_loc_w, y_scale_w).independent(1))
            b = self.mlp_memory.forward(ys)
            w = self.A(b)
            self.write_inference(w, z_scale_w)
        
        #Reading Phase:
        # inform Pyro that the variables in the batch of xs, ys are conditionally independent
            batch_size = xs.size(0)
            b = self.mlp_memory.forward(ys)
            w = self.A(b)
            pre_z_loc = torch.matmul(w, self.R) # (batch_size, n_memory_vec)
            z_loc, z_scale = self.encoder2([z_loc_w, pre_z_loc])
            zs = pyro.sample("z", dist.Normal(z_loc, z_scale).independent(1))
    
    def write_inference(self, W, Z, v_sigma=1):
        '''
        Updating the Memory.
        '''
        dd = Z - torch.matmul(W, self.R)
        sigma_c  = torch.matmul(W, self.U)
        sigma_gusai = torch.eye(W.size(0))
        sigma_z = torch.matmul(torch.matmul(W, self.U), torch.t(W)) + sigma_gusai * (v_sigma*v_sigma)
        self.R = self.R + torch.matmul(torch.matmul(torch.t(sigma_c),torch.inverse(sigma_z)), dd)
        self.U = self.U - torch.matmul(torch.matmul(torch.t(sigma_c),torch.inverse(sigma_z)), sigma_c)
    
    def init_memory(self):
        self.R = torch.zeros(self.n_address, self.n_memory_vec, requires_grad=True)
        self.U = torch.eye(self.n_address, requires_grad=True)
        self.V = torch.eye(self.n_memory_vec, requires_grad=True)

In [226]:
def train(svi, train_loader, use_cuda=False):
    # initialize loss accumulator
    epoch_loss = 0.
    # do a training epoch over each mini-batch x returned
    # by the data loader
    for _, (x, _) in enumerate(train_loader):
        # if on GPU put mini-batch into CUDA memory
        if use_cuda:
            x = x.cuda()
        # do ELBO gradient and accumulate loss
        b_size = x.size(0)
        i_size = x.size(-1)*x.size(-1)
        epoch_loss += svi.step(x.view(b_size, i_size))
        print(epoch_loss)

    # return epoch loss
    normalizer_train = len(train_loader.dataset)
    total_epoch_loss_train = epoch_loss / normalizer_train
    return total_epoch_loss_train

In [231]:
def evaluate(svi, test_loader, use_cuda=False):
    # initialize loss accumulator
    test_loss = 0.
    # compute the loss over the entire test set
    for i, (x, _) in enumerate(test_loader):
#         km.init_memory()
        # if on GPU put mini-batch into CUDA memory
        if use_cuda:
            x = x.cuda()
        # compute ELBO estimate and accumulate loss
        b_size = x.size(0)
        i_size = x.size(-1)*x.size(-1)
        test_loss += svi.evaluate_loss(x.view(b_size, i_size))
    normalizer_test = len(test_loader.dataset)
    total_epoch_loss_test = test_loss / normalizer_test
    return total_epoch_loss_test

In [232]:
km = KanervaMachine()

adam_params = {"lr": 0.00042, "betas": (0.9, 0.999)}
optimizer = Adam(adam_params)

svi = SVI(km.model, km.guide, optimizer, loss=TraceEnum_ELBO(max_iarange_nesting=1), num_particles=7)

In [235]:
USE_CUDA = False
NUM_EPOCHS = 100
TEST_FREQUENCY = 5
train_loader, test_loader = setup_data_loaders(batch_size=256, use_cuda=USE_CUDA)

In [236]:
train_elbo = []
test_elbo = []
# training loop
for epoch in range(NUM_EPOCHS):
    total_epoch_loss_train = train(svi, train_loader, use_cuda=USE_CUDA)
    train_elbo.append(-total_epoch_loss_train)
    print("[epoch %03d]  average training loss: %.4f" % (epoch, total_epoch_loss_train))

    if epoch % TEST_FREQUENCY == 0:
        # report test diagnostics
        total_epoch_loss_test = evaluate(svi, test_loader, use_cuda=USE_CUDA)
        test_elbo.append(-total_epoch_loss_test)
        print("[epoch %03d] average test loss: %.4f" % (epoch, total_epoch_loss_test))

152288.515625
304689.921875
457300.5
609978.0
762715.0625
915330.28125
1067963.921875
1221261.375
1374337.65625
1527460.453125
1680886.84375
1834427.296875
1987942.171875
2141569.703125
2295400.359375
2449215.78125
2602829.984375
2756576.828125
2910191.671875
3064162.21875
3217843.546875
3371966.640625
3525916.25
3679610.265625
3833900.703125
3987880.390625
4142318.796875
4296252.171875
4450515.859375
4604649.03125
4758846.15625
4912894.03125
5067381.3125
5221869.03125
5376376.109375
5531120.15625
5685947.0625
5840700.78125
5995407.84375
6150243.53125
6305127.71875
6459654.25
6614178.71875
6768865.4375
6923582.59375
7078026.796875
7232596.5625
7387488.421875
7542311.1875
7697361.390625
7852392.421875
8007134.6875
8162143.75
8317054.265625
8471656.796875
8626795.84375
8781761.671875
8936864.765625
9091619.15625
9246555.390625
9401475.078125
9556625.03125
9711718.0625
9866592.734375
10021450.484375
10176413.5
10331313.46875
10486548.640625
10641249.25
10796491.234375
10951833.4375
111066



[epoch 000] average test loss: 608.8538
155875.125
311612.921875
467689.40625
623170.015625
779243.6875
935361.421875
1091272.328125
1247361.9375
1403190.109375
1559162.25
1714986.34375
1870671.15625
2026426.65625
2182581.765625
2338710.03125
2494271.578125
2650099.5
2806052.875
2962397.265625
3118352.53125
3273977.734375
3429894.375
3585761.21875
3742060.203125
3897708.578125
4053432.265625
4209310.625
4365051.484375
4520900.671875
4676783.40625
4832720.328125
4988806.640625
5144515.28125
5300195.34375
5455950.21875
5611767.6875
5768328.9375
5924262.578125
6079902.921875
6235890.6875
6391976.78125
6547777.46875
6703909.65625
6859526.125
7015469.1875
7171541.96875
7327699.53125
7483488.46875
7639425.859375
7795241.25
7951086.203125
8107283.609375
8262937.359375
8418594.765625
8574298.6875
8730230.78125
8886235.453125
9042156.75
9198353.78125
9354502.5
9510604.140625
9666648.609375
9822533.921875
9978681.09375
10134741.859375
10290820.78125
10446864.375
10603037.640625
10758915.1875
109

Process Process-38:
Traceback (most recent call last):
  File "/Users/YumaKajihara/.pyenv/versions/anaconda3-2.5.0/lib/python3.5/multiprocessing/process.py", line 252, in _bootstrap
    self.run()
  File "/Users/YumaKajihara/.pyenv/versions/anaconda3-2.5.0/lib/python3.5/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/YumaKajihara/.pyenv/versions/anaconda3-2.5.0/lib/python3.5/site-packages/torch/utils/data/dataloader.py", line 52, in _worker_loop
    r = index_queue.get()
  File "/Users/YumaKajihara/.pyenv/versions/anaconda3-2.5.0/lib/python3.5/multiprocessing/queues.py", line 335, in get
    res = self._reader.recv_bytes()
  File "/Users/YumaKajihara/.pyenv/versions/anaconda3-2.5.0/lib/python3.5/multiprocessing/connection.py", line 216, in recv_bytes
    buf = self._recv_bytes(maxlength)
  File "/Users/YumaKajihara/.pyenv/versions/anaconda3-2.5.0/lib/python3.5/multiprocessing/connection.py", line 407, in _recv_bytes
    buf = se

KeyboardInterrupt: 

In [185]:
??SVI.step()