# Neural Turing Machines: Tutorial

#### Step 1: Import libraries 

In [11]:
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from time import time
from tensorboardX import SummaryWriter
import torchvision.utils as vutils
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

#### Let's define controller part 

In [12]:
class Controller(nn.Module):
    def __init__(self, num_inputs, num_outputs, num_hiddens):
        super(Controller, self).__init__()

        print("--- Initialize Controller")
        self.fc1 = nn.Linear(num_inputs, num_hiddens)
        self.fc2 = nn.Linear(num_hiddens, num_outputs)
        self.reset_parameters()

    def reset_parameters(self):
        # Initialize the linear layers
        nn.init.xavier_uniform_(self.fc1.weight, gain=1.4)
        nn.init.normal(self.fc1.bias, std=0.01)

        nn.init.xavier_uniform_(self.fc2.weight, gain=1.4)
        nn.init.normal(self.fc2.bias, std=0.01)

    def forward(self, x, last_read):

        x = torch.cat((x, last_read), dim=1)
        x = F.sigmoid(self.fc1(x))
        x = F.sigmoid(self.fc2(x))
        return x

#### Memory Part 

In [13]:
class Memory(nn.Module):
    def __init__(self, M, N, controller_out):
        super(Memory, self).__init__()

        self.N = N
        self.M = M
        self.read_lengths = self.N + 1 + 1 + 3 + 1
        self.write_lengths = self.N + 1 + 1 + 3 + 1 + self.N + self.N
        self.w_last = []
        self.reset_memory()

    def get_weights(self):
        return self.w_last

    def reset_memory(self):
        self.w_last = []
        self.w_last.append(torch.zeros([1, self.M], dtype=torch.float32))

    def address(self, k, beta, g, s, gamma, memory, w_last):
        # Content focus
        wc = self._similarity(k, beta, memory)
        # Location focus
        wg = self._interpolate(wc, g, w_last)
        w_hat = self._shift(wg, s)
        w = self._sharpen(w_hat, gamma)

        return w

    def _similarity(self, k, beta, memory):
        # Similarità coseno
        w = F.cosine_similarity(memory, k, -1, 1e-16)
        w = F.softmax(beta * w, dim=-1)
        return w

    def _interpolate(self, wc, g, w_last):
        return g * wc + (1 - g) * w_last

    def _shift(self, wg, s):
        result = torch.zeros(wg.size())
        result = _convolve(wg, s)
        return result

    def _sharpen(self, w_hat, gamma):
        w = w_hat ** gamma
        w = torch.div(w, torch.sum(w, dim=-1) + 1e-16)
        return w

#### Reading Part

In [14]:
class ReadHead(Memory):

    def __init__(self, M, N, controller_out):
        super(ReadHead, self).__init__(M, N, controller_out)

        print("--- Initialize Memory: ReadHead")
        self.fc_read = nn.Linear(controller_out, self.read_lengths)
        self.reset_parameters();

    def reset_parameters(self):
        # Initialize the linear layers
        nn.init.xavier_uniform_(self.fc_read.weight, gain=1.4)
        nn.init.normal(self.fc_read.bias, std=0.01)

    def read(self, memory, w):
        """Read from memory (according to section 3.1)."""
        return torch.matmul(w, memory)

    def forward(self, x, memory):
        param = self.fc_read(x)
        k, beta, g, s, gamma = torch.split(param, [self.N, 1, 1, 3, 1], dim=1)

        k = F.tanh(k)
        beta = F.softplus(beta)
        g = F.sigmoid(g)
        s = F.softmax(s, dim=1)
        gamma = 1 + F.softplus(gamma)

        w = self.address(k, beta, g, s, gamma, memory, self.w_last[-1])
        self.w_last.append(w)
        mem = self.read(memory, w)
        return mem, w


#### write memory 

In [15]:
class WriteHead(Memory):

    def __init__(self, M, N, controller_out):
        super(WriteHead, self).__init__(M, N, controller_out)

        print("--- Initialize Memory: WriteHead")
        self.fc_write = nn.Linear(controller_out, self.write_lengths)
        self.reset_parameters()

    def reset_parameters(self):
        # Initialize the linear layers
        nn.init.xavier_uniform_(self.fc_write.weight, gain=1.4)
        nn.init.normal_(self.fc_write.bias, std=0.01)

    def write(self, memory, w, e, a):
        """write to memory (according to section 3.2)."""
        w = torch.squeeze(w)
        e = torch.squeeze(e)
        a = torch.squeeze(a)

        erase = torch.ger(w, e)
        add = torch.ger(w, a)

        m_tilde = memory * (1 - erase)
        memory_update = m_tilde + add

        return memory_update

    def forward(self, x, memory):
        param = self.fc_write(x)

        k, beta, g, s, gamma, a, e = torch.split(param, [self.N, 1, 1, 3, 1, self.N, self.N], dim=1)

        k = F.tanh(k)
        beta = F.softplus(beta)
        g = F.sigmoid(g)
        s = F.softmax(s, dim=-1)
        gamma = 1 + F.softplus(gamma)
        a = F.tanh(a)
        e = F.sigmoid(e)

        w = self.address(k, beta, g, s, gamma, memory, self.w_last[-1])
        self.w_last.append(w)
        mem = self.write(memory, w, e, a)
        return mem, w


In [16]:
def _convolve(w, s):
    """Circular convolution implementation."""
    b, d = s.shape
    assert b == 1, 'does _convolve work for b != 1?'
    assert d == 3
    w = torch.squeeze(w)
    t = torch.cat([w[-1:], w, w[:1]])
    c = F.conv1d(t.view(1, 1, -1), s.view(1, 1, -1)).view(b, -1)
    return c

In [17]:
class NTM(nn.Module):
    def __init__(self, M, N, num_inputs, num_outputs, controller_out_dim, controller_hid_dim):
        super(NTM, self).__init__()

        print("----------- Build Neural Turing machine -----------")
        self.num_inputs = num_inputs
        self.num_outputs = num_outputs
        self.M = M
        self.N = N

        self.memory = torch.zeros(self.M, self.N)
        self.last_read = torch.zeros(1, self.N)

        self.controller = Controller(self.num_inputs + self.N, controller_out_dim, controller_hid_dim)
        self.read_head = ReadHead(self.M, self.N, controller_out_dim)
        self.write_head = WriteHead(self.M, self.N, controller_out_dim)

        self.fc_out = nn.Linear(self.num_inputs + N, self.num_outputs)
        self.reset_parameters()

    def forward(self, X=None):

        if X is None:
            X = torch.zeros(1, self.num_inputs)

        controller_out = self.controller(X, self.last_read)
        self._read_write(controller_out)

        out = torch.cat((X, self.last_read), -1)
        out = F.sigmoid(self.fc_out(out))

        return out

    def _read_write(self, controller_out):
        # READ
        read, w = self.read_head(controller_out, self.memory)
        self.last_read = read

        # WRITE
        mem, w = self.write_head(controller_out, self.memory)
        self.memory = mem

    def initalize_state(self):
        stdev = 1 / (np.sqrt(self.N + self.M))
        self.memory = nn.init.uniform_((torch.Tensor(self.M, self.N)), -stdev, stdev)
        self.last_read = F.tanh(torch.randn(1, self.N))

        self.read_head.reset_memory()
        self.write_head.reset_memory()

    def reset_parameters(self):
        # Initialize the linear layers
        nn.init.xavier_uniform_(self.fc_out.weight, gain=1.4)
        nn.init.normal_(self.fc_out.bias, std=0.5)

    def get_memory_info(self):
        return self.memory, self.read_head.get_weights(), self.write_head.get_weights()

    def calculate_num_params(self):
        """Returns the total number of parameters."""
        num_params = 0
        for p in self.parameters():
            num_params += p.data.view(-1).size(0)
        return num_params


In [18]:
class BinaySeqDataset(Dataset):

    def __init__(self, args):
        self.seq_len = args['sequence_length']
        self.seq_width = args['token_size']
        self.dataset_dim = args['training_samples']

    def _generate_seq(self):
        seq = np.random.binomial(1, 0.5, (self.seq_len, self.seq_width))
        seq = torch.from_numpy(seq)
        # Add start and end token
        inp = torch.zeros(self.seq_len + 2, self.seq_width)
        inp[1:self.seq_len + 1, :self.seq_width] = seq.clone()
        inp[0, 0] = 1.0
        inp[self.seq_len + 1, self.seq_width - 1] = 1.0
        outp = seq.data.clone()

        return inp.float(), outp.float()

    def __len__(self):
        return self.dataset_dim

    def __getitem__(self, idx):
        inp, out = self._generate_seq()
        return inp, out

In [19]:
def clip_grads(net):
    parameters = list(filter(lambda p: p.grad is not None, net.parameters()))
    for p in parameters:
        p.grad.data.clamp_(args['min_grad'], args['max_grad'])

In [20]:
if __name__ == "__main__":
    args = {'sequence_length':300,'token_size':10,'memory_capacity':64,'memory_vector_size':128,'training_samples':99,
            'controller_output_dim':256,'controller_hidden_dim':512,'learning_rate':1e-4,'min_grad':-10,'max_grad':10,
           'logdir':'./','loadmodel':'','savemodel':'checkpoint.model'}
    writer = SummaryWriter()
    dataset = BinaySeqDataset(args)
    dataloader = DataLoader(dataset, batch_size=1,
                            shuffle=True, num_workers=4)

    model = NTM(M=args['memory_capacity'],
                N=args['memory_vector_size'],
                num_inputs=args['token_size'],
                num_outputs=args['token_size'],
                controller_out_dim=args['controller_output_dim'],
                controller_hid_dim=args['controller_hidden_dim'],
                )

    print(model)

    criterion = torch.nn.BCELoss()
    optimizer = torch.optim.RMSprop(model.parameters(), lr=args['learning_rate'])

    print("--------- Number of parameters -----------")
    print(model.calculate_num_params())
    print("--------- Start training -----------")

    losses = []

    if args['loadmodel'] != '':
        model.load_state_dict(torch.load(args['loadmodel']))

    for e, (X, Y) in enumerate(dataloader):
        tmp = time()
        model.initalize_state()
        optimizer.zero_grad()

        inp_seq_len = args['sequence_length'] + 2
        out_seq_len = args['sequence_length']

        X.requires_grad = True

        # Input rete: sequenza
        for t in range(0, inp_seq_len):
            model(X[:, t])

        # Input rete: null
        y_pred = torch.zeros(Y.size())
        for i in range(0, out_seq_len):
            y_pred[:, i] = model()

        loss = criterion(y_pred, Y)
        loss.backward()
        clip_grads(model)
        optimizer.step()
        losses += [loss.item()]

        if e % 50 == 0:
            mean_loss = np.array(losses[-50:]).mean()
            print("Loss: ", loss.item())
#             writer.add_scalar('Mean loss', loss.item(), e)
            if e % 1000 == 0:
#                 for name, param in model.named_parameters():
#                     writer.add_histogram(name, param.clone().cpu().data.numpy(), e)
                mem_pic, read_pic, write_pic = model.get_memory_info()
#                 pic1 = vutils.make_grid(y_pred, normalize=True, scale_each=True)
#                 pic2 = vutils.make_grid(Y, normalize=True, scale_each=True)
#                 pic3 = vutils.make_grid(mem_pic, normalize=True, scale_each=True)
#                 pic4 = vutils.make_grid(read_pic, normalize=True, scale_each=True)
#                 pic5 = vutils.make_grid(write_pic, normalize=True, scale_each=True)
#                 writer.add_image('NTM output', pic1, e)
#                 writer.add_image('True output', pic2, e)
#                 writer.add_image('Memory', pic3, e)
#                 writer.add_image('Read weights', pic4, e)
#                 writer.add_image('Write weights', pic5, e)
#                 torch.save(model.state_dict(), args.savemodel)
            losses = []


----------- Build Neural Turing machine -----------
--- Initialize Controller
--- Initialize Memory: ReadHead
--- Initialize Memory: WriteHead
NTM(
  (controller): Controller(
    (fc1): Linear(in_features=138, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=256, bias=True)
  )
  (read_head): ReadHead(
    (fc_read): Linear(in_features=256, out_features=134, bias=True)
  )
  (write_head): WriteHead(
    (fc_write): Linear(in_features=256, out_features=390, bias=True)
  )
  (fc_out): Linear(in_features=138, out_features=10, bias=True)
)
--------- Number of parameters -----------
338554
--------- Start training -----------


  if sys.path[0] == '':
  del sys.path[0]
  from ipykernel import kernelapp as app
  app.launch_new_instance()


Loss:  1.266266107559204


KeyboardInterrupt: 