In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F 

from torch_geometric.nn.inits import uniform, glorot

In [2]:
class NTM(nn.Module):
    def __init__(self, vector_length, hidden_size, memory_size):
        super(NTM, self).__init__()
        self.controller = FeedForwardController(vector_length = vector_length + 1 + memory_size[1], 
                                                hidden_size   = hidden_size)
        self.memory = Memory(memory_size)
        self.read_head = ReadHead(self.memory, hidden_size)
        self.write_head = WriteHead(self.memory, hidden_size)
        self.fc = nn.Linear(hidden_size + memory_size[1], vector_length)
        self.reset_parameters()
        
    def reset_parameters(self,):
        nn.init.xavier_uniform_(self.fc.weight, gain=1)
        nn.init.normal_(self.fc.bias, std=0.01)

    def get_initial_state(self, batch_size=1):
        self.memory.reset(batch_size)
        read = self.memory.get_initial_read(batch_size)
        read_head_state = self.read_head.get_initial_state(batch_size)
        write_head_state = self.write_head.get_initial_state(batch_size)
        return (read, read_head_state, write_head_state)

    def forward(self, x, previous_state):
        previous_read, previous_read_head_state, previous_write_head_state = previous_states
        
        controller_x = torch.cat([x, previous_read], dim=1)
        controller_x = self.controller(controller_x)
        # Read
        read_head_output, read_head_state = self.read_head(controller_x, previous_read_head_state)
        # Write
        write_head_state = self.write_head(controller_x, previous_write_head_state)
        
        fc_input = torch.cat((controller_x, read_head_output), dim=1)
        state = (read_head_output, read_head_state, write_head_state)
        return F.sigmoid(self.fc(fc_input)), state

In [3]:
class Memory(nn.Module):
    def __init__(self, memory_size):
        super(Memory, self).__init__()
        self._memory_size = memory_size

        # Initialize memory bias
        initial_state = torch.ones(memory_size) * 1e-6
        self.register_buffer('initial_state', initial_state.data)

        # Initial read vector is a learnt parameter
        self.initial_read = nn.Parameter(torch.randn(1, self._memory_size[1]) * 0.01)

    def get_size(self):
        return self._memory_size

    def reset(self, batch_size): # >>> Do we need batch size?
        self.memory = self.initial_state.clone().repeat(batch_size, 1, 1)

    def get_initial_read(self, batch_size):
        return self.initial_read.clone().repeat(batch_size, 1)

    def read(self):
        return self.memory

    def write(self, w, e, a):
        self.memory = self.memory * (1 - torch.matmul(w.unsqueeze(-1), e.unsqueeze(1)))
        self.memory = self.memory + torch.matmul(w.unsqueeze(-1), a.unsqueeze(1))
        return self.memory

    def size(self):
        return self._memory_size

In [4]:
class Head(nn.Module):
    def __init__(self, memory, hidden_size):
        super(Head, self).__init__()
        self.memory = memory
        memory_length, memory_vector_length = memory.get_size()
        # (k : vector, beta: scalar, g: scalar, s: vector, gamma: scalar)
        self.k_layer = nn.Linear(hidden_size, memory_vector_length)
        self.beta_layer = nn.Linear(hidden_size, 1)
        self.g_layer = nn.Linear(hidden_size, 1)
        self.s_layer = nn.Linear(hidden_size, 3)
        self.gamma_layer = nn.Linear(hidden_size, 1)
        
        for layer in [self.k_layer, self.beta_layer, self.g_layer, self.s_layer, self.gamma_layer]:
            nn.init.xavier_uniform_(layer.weight, gain=1.4)
            nn.init.normal_(layer.bias, std=0.01)

        self._initial_state = nn.Parameter(torch.randn(1, self.memory.get_size()[0]) * 1e-5)

    def get_initial_state(self, batch_size):
        # Softmax to ensure weights are normalized
        return F.softmax(self._initial_state, dim=1).repeat(batch_size, 1)

    def get_head_weight(self, x, previous_state, memory_read):
        k = self.k_layer(x)
        beta = F.softplus(self.beta_layer(x))
        g = F.sigmoid(self.g_layer(x))
        s = F.softmax(self.s_layer(x), dim=1)
        gamma = 1 + F.softplus(self.gamma_layer(x))
        # Focusing by content
        w_c = F.softmax(beta * F.cosine_similarity(memory_read + 1e-16, k.unsqueeze(1) + 1e-16, dim=-1), dim=1)
        # Focusing by location
        w_g = g * w_c + (1 - g) * previous_state
        w_t = self.shift(w_g, s)
        w = w_t ** gamma
        w = torch.div(w, torch.sum(w, dim=1).unsqueeze(1) + 1e-16)
        return w

    def shift(self, w_g, s):
        result = w_g.clone()
        for b in range(len(w_g)):
            result[b] = _convolve(w_g[b], s[b])
        return result
    
def _convolve(w, s):
    """Circular convolution implementation."""
    assert s.size(0) == 3
    t = torch.cat([w[-1:], w, w[:1]], dim=0)
    c = F.conv1d(t.view(1, 1, -1), s.view(1, 1, -1)).view(-1)
    return c

In [5]:
class ReadHead(Head):
    def forward(self, x, previous_state):
        memory_read = self.memory.read()
        w = self.get_head_weight(x, previous_state, memory_read)
        return torch.matmul(w.unsqueeze(1), memory_read).squeeze(1), w


class WriteHead(Head):
    def __init__(self, memory, hidden_size):
        super(WriteHead, self).__init__(memory, hidden_size)
        memory_length, memory_vector_length = memory.get_size()
        self.e_layer = nn.Linear(hidden_size, memory_vector_length)
        self.a_layer = nn.Linear(hidden_size, memory_vector_length)
        for layer in [self.e_layer, self.a_layer]:
            nn.init.xavier_uniform_(layer.weight, gain=1.4)
            nn.init.normal_(layer.bias, std=0.01)

    def forward(self, x, previous_state):
        memory_read = self.memory.read()
        w = self.get_head_weight(x, previous_state, memory_read)
        e = F.sigmoid(self.e_layer(x))
        a = self.a_layer(x)

        # write to memory (w, memory, e , a)
        self.memory.write(w, e, a)
        return w

In [6]:
# class Controller(nn.Module):
#     def __init__(self, vector_length, hidden_size):
#         super(Controller, self).__init__()
#         # We allow either a feed-forward network or a LSTM for the controller
#         self._controller = FeedForwardController(vector_length, hidden_size)

#     def forward(self, x):
#         return self._controller(x)

#     def get_initial_state(self, batch_size):
#         return self._controller.get_initial_state(batch_size)
    
class FeedForwardController(nn.Module):
    def __init__(self, vector_length, hidden_size):
        super(FeedForwardController, self).__init__()
        self.vector_length = vector_length
        self.hidden_size   = hidden_size
        
        self.layer_1 = nn.Linear(vector_length, hidden_size)
        self.layer_2 = nn.Linear(hidden_size, hidden_size)
        
        
    def reset_parameter():
        stdev = 5 / (np.sqrt(self.vector_length + self.hidden_size))
        nn.init.uniform_(self.layer_1.weight, -stdev, stdev)
        nn.init.uniform_(self.layer_2.weight, -stdev, stdev)
        
    def forward(self, x):
        x = F.relu(self.layer_1(x))
        x = F.relu(self.layer_2(x))
        return x

    def get_initial_state(self):
        return 0

In [7]:
a = nn.Embedding(5, 3, max_norm=2)
a(torch.arange(5))

tensor([[-0.1285,  0.7534,  0.2341],
        [ 0.5962,  0.3908, -0.4896],
        [ 1.9124,  0.1074,  0.5756],
        [ 0.1425, -1.9427,  0.4533],
        [-0.2537,  0.1529,  0.5604]], grad_fn=<EmbeddingBackward>)