In [1]:
import torch
from torch import nn
from torch.nn import Parameter
from torch.optim import Adam
from torch.nn.utils import clip_grad_norm_
import torch.nn.functional as F

import numpy as np

In [2]:
def batch_to_one_hot(batch_cat_id, num_cats):
    """
    Arguments
    ---------
    batch_cat_id : torch.tensor [bs, seq_len, 1]
    
    Returns
    -------
    batch_cat_OH : torch.tensor [bs, seq_len, num_cats]
    
    """
    cat_samples = batch_cat_id.chunk(len(batch_cat_id), dim = 0)
    batch_cat_OH = list()
    for cat_sample in cat_samples:
        cat_id = cat_sample.squeeze()
        cat_OH = torch.zeros(len(cat_id), num_cats)
        cat_OH[torch.arange(len(cat_id)), cat_id] = 1
        batch_cat_OH.append(cat_OH)

    return torch.stack(batch_cat_OH, dim = 0)

In [162]:
def CESequenceLoss(p_y_x, y):
    """Cross Entropy Loss for Sequential Data
    
    Arguments
    ---------
    p_y_x : torch.tensor [bs, t, d]
        Probability vector for all categories with values [0, 1]
        
    y : torch.tensor [bs, t, d]
        One Hot Encoded ground truth labels
    
    Returns
    -------
    E : torch.tensor []
    """
    log_p_y_x = nn.LogSoftmax(-1)(log_p_y_x)
    log_p_y_x = torch.log(p_y_x).clamp(min = -100)
    E_i_t = - (y * log_p_y_x).sum(dim = 2)
    E_i = E_i_t.sum(dim = 1)
    E = E_i.mean(dim = 0)
    return E

In [142]:
class Controller(nn.Module):
    
    def __init__(self):
        """Controller : Recurrent Net Cell
        
        Attributes
        ----------
        cell : nn.GRUCell
        output_embedder : nn.Linear
        
        Methods
        -------
        next_state(x_t, h, r) -> (h)
        output(h, r) -> (y)
        
        """
        super(Controller, self).__init__()
        
        self.cell = nn.GRUCell(2 * d_hidden, d_hidden)
        self.output_embedder = nn.Sequential(
            nn.Linear(2 * d_hidden, d_hidden),
            nn.Softmax(-1))
            
        
    def next_state(self, x_t, h, r):
        """Concat input und read, return hidden state
        
        Arguments
        ---------
        x_t : torch.tensor [bs, d_hidden]
            Input timesteps t of sequence
        
        h : torch.tensor [bs, d_hidden]
            Hidden State
            
        r : torch.tensor [bs, d_hidden]
            Read 
            
        Returns
        -------
        h : torch.tensor [bs, d_hidden]
            Hidden State  
        """
        
        xr = torch.cat([x_t, r], dim = -1)
        h = self.cell(xr, h)
        return h
        
    def output(self, h, r):
        """Concat input und read, return hidden state
        
        Arguments
        ---------
        h : torch.tensor [bs, d_hidden]
            Hidden State
            
        r : torch.tensor [bs, d_hidden]
            Read 
            
        Returns
        -------
        y : torch.tensor [bs, d_output]
            Output, passed trough Activation
        """
        
        hr = torch.cat([h,r], dim = -1)
        y = self.output_embedder(hr)
        return y

In [152]:
class NeuralTuringMachine(nn.Module):
    """NTM : Total Model Infrastucture
    
    Attributes
    ----------
    controller
    writing_head
    reading_head
    memory
    
    Methods
    -------
    initialize_state_and_read(bs) -> (h, r)
    forward(x) -> (output)
    """
    
    def __init__(self):
        super(NeuralTuringMachine, self).__init__()
        
        self.controller = Controller()
        self.writing_head = WritingHead()
        self.reading_head = ReadingHead()
        self.memory = Memory()
        
    def initialize_state_and_read(self, bs):
        """Initialize h and r for first timestep
        
        Arguments
        ---------
        bs : int
        
        Returns
        -------
        h : torch.tensor [bs, d_hidden]
            initial hidden state
            
        r : torch.tensor [bs, d_hidden]
            initial read
        """
        h_0 = torch.zeros(x.shape[0], d_hidden)
        r_0 = torch.zeros(x.shape[0], d_hidden)
        return h_0, r_0
      
    def forward(self, x):
        """Read in Sequence of length ht, output Sequence of length ft
        
        Arguments
        ---------
        ht : int
            History sequence timesteps
            
        ft : int
            Future sequence timesteps
            
        x : torch.tensor [bs, ht, d_input]
            Input sequence
            
        Returns
        -------
        output : torch.tensor [bs, ft, d_output]
            Output  sequence
        """
        bs = x.shape[0]
        r,h = self.initialize_state_and_read(bs)
        self.memory.init_w_previous(bs)
        self.memory.init_memory(bs)
        
        output = list()
        
        for t in range(ft):
            h = self.controller.next_state(x[:,t,:], h, r)
            
            k_r, ß_r, g_r, s_r, y_r = self.reading_head(h)
            k_w, ß_w, g_w, s_w, y_w = self.writing_head(h)
            e,a = self.writing_head.variables_for_memory(h)
            
            attention_r = self.memory.attention(k_r, ß_r, g_r, s_r, y_r)
            attention_w = self.memory.attention(k_w, ß_w, g_w, s_w, y_w)

            r = self.memory.read(attention_r)
            self.memory.write(attention_w, e, a)
            
            y = self.controller.output(h,r)
            output.append(y)
            
        output = torch.stack(output, dim = 1)    
        return output

In [157]:
class Head(nn.Module):
    """ Head : Superclass of ReadingHead and WritingHead
    
    Attributes
    ----------
    project_to_key : nn.Linear
    project_to_temperature : nn.Linear
    project_to_gate : nn.Linear
    project_to_shift : nn.Linear
    project_to_gamma : nn.Linear
    
    Methods
    -------
    project_to_variables(h) -> (k, ß, g, s, y)
    adjust_variables_for_attention(k, ß, g, s, y) -> (k, ß, g, s, y)
    
    """

    def __init__(self):
        super(Head, self).__init__()
        
        self.project_to_key = nn.Linear(d_hidden, d_hidden)
        self.project_to_temperature = nn.Linear(d_hidden, 1)
        self.project_to_gate = nn.Linear(d_hidden, 1)
        self.project_to_shift = nn.Linear(d_hidden, 3)
        self.project_to_gamma = nn.Linear(d_hidden, 1)
        
         
    def project_to_variables(self, h):
        """Create Parameters for Attention
        
        Arguments
        ---------
        h : torch.tensor [bs, d_hidden]
        
        Returns
        -------
        k : torch.tensor [bs, d_hidden]
        ß : torch.tensor [bs, 1]
        g : torch.tensor [bs, 1]
        s : torch.tensor [bs, 3]
        y : torch.tensor [bs, 1]
        
        """
        
        k = self.project_to_key(h)
        ß = self.project_to_temperature(h)
        g = self.project_to_gate(h)
        s = self.project_to_shift(h)
        y = self.project_to_gamma(h)
        
        return k, ß, g, s, y
    
    def adjust_variables_for_attention(self, k, ß, g, s, y):
        """Adjust variables to correct value ranges

        Arguments
        -------
        k : torch.tensor [bs, d_hidden]
        ß : torch.tensor [bs, 1]
        g : torch.tensor [bs, 1]
        s : torch.tensor [bs, 3]
        y : torch.tensor [bs, 1]
        
        Returns
        -------
        k : torch.tensor [bs, d_hidden]
        ß : torch.tensor [bs, 1]
        g : torch.tensor [bs, 1]
        s : torch.tensor [bs, 3]
        y : torch.tensor [bs, 1]
        
        """
        
        k = k.clone()
        ß = nn.ReLU()(ß)
        g = nn.Sigmoid()(g)
        y = torch.ones(1) + nn.ReLU()(y) 
        s = nn.Softmax(dim = 1)(s)
        return k, ß, g, s, y
        
    
    def forward(self, h):
        k, ß, g, s, y = self.project_to_variables(h)
        k, ß, g, s, y = self.adjust_variables_for_attention(k, ß, g, s, y)
        return k, ß, g, s, y

In [156]:
class ReadingHead(Head):
    """ReadingHead : Subclass of Head
    
    Attributes
    ----------
    project_to_key : nn.Linear
    project_to_temperature : nn.Linear
    project_to_gate : nn.Linear
    project_to_shift : nn.Linear
    project_to_gamma : nn.Linear
    
    Methods
    -------
    project_to_variables(h) -> (k, ß, g, s, y)
    adjust_variables_for_attention(k, ß, g, s, y) -> (k, ß, g, s, y)
    
    """
    
    def __init__(self):
        super(ReadingHead, self).__init__()

In [159]:
class WritingHead(Head):
    """WritingHead : Subclass of Head
    
    Attributes
    ----------
    project_to_key : nn.Linear
    project_to_temperature : nn.Linear
    project_to_gate : nn.Linear
    project_to_shift : nn.Linear
    project_to_gamma : nn.Linear
    
    project_to_add : nn.Linear
    project_to_erase : nn.Linear
    
    Methods
    -------
    project_to_variables(h) -> (k, ß, g, s, y)
    adjust_variables_for_attention(k, ß, g, s, y) -> (k, ß, g, s, y)
    variables_for_memory(h) -> (e,a)
    """
    
    
    def __init__(self):
        super(WritingHead, self).__init__()
        
        self.project_to_add = nn.Linear(d_hidden, d_hidden)
        self.project_to_erase = nn.Linear(d_hidden, d_hidden)
        
    def variables_for_memory(self, h):
        """Erase and Add for Memory writing
        Arguments
        ---------
        h : torch.tensor [bs, d_hidden]
        
        Returns
        -------
        a : torch.tensor [bs, d_hidden]
        e : torch.tensor [bs, d_hidden]
        """
        a = self.project_to_add(h)
        e = self.project_to_erase(h)
        return e,a

In [131]:
class Memory(nn.Module):
    """
    Attributes
    ----------
    N : int []
        Number of Memory states
        
    M : torch.tensor [bs, N, d_hidden]
        Memory Tensor
        
    w_previous : torch.tensor [bs, N]
        Attention weights of the previous timestep
        
    Methods
    -------
    attention(self, k, ß, g, s, y) -> (w)        
    read(w) -> (r)
    write(w,e,a) -> void
    
    """
    
    def __init__(self):
        super(Memory, self).__init__()
        
        self.N = N
        
        self.w_previous = nn.Softmax(-1)(torch.randn((bs, N)))
        
        self.init_w_previous_buffer()
        self.init_memory_buffer()
        
    def init_w_previous_buffer(self):
        """Create Attention buffer
        
        Arguments
        ---------
        N : int
            Number of Memory states
            
        Returns
        -------
        w_previous_buffer : torch.tensor [N]
        
        """
        self.register_buffer("w_previous_buffer", nn.Softmax(-1)(torch.Tensor(self.N)))
        
    def init_w_previous(self, bs):
        """Create Attention buffer
        
        Arguments
        ---------
        bs : int
        w_previous_buffer : torch.tensor [N]
            
        Returns
        -------
        w_previous : torch.tensor [bs, N]
        
        """
        self.w_previous = self.w_previous_buffer.clone().repeat(bs, 1)
        
    def init_memory_buffer(self):
        """Create Memory buffer
        
        Arguments
        ---------
        N : int
            Number of Memory states
            
        d_hidden : int
            Dimension of Memory states
            
        Returns
        -------
        M_init : torch.tensor [N, d_hidden]
        
        """
        self.register_buffer('M_init', torch.Tensor(N, d_hidden))
        nn.init.uniform_(self.M_init, 0, 1)
        
    def init_memory(self, bs):
        """Expand Memory Buffer to batchsize for first timestep in Training
        
        Arguments
        ---------
        bs : int
        M_init : torch.tensor [N, d_hidden]
        
        Returns
        -------
        M : torch.tensor [bs, N, d_hidden]
        """
        self.M = self.M_init.clone().repeat(bs, 1, 1)
    
    def attention_content_focus(self, k, ß):
        """
        Arguments
        ---------
        k : torch.tensor [bs, d_hidden]
            Key, (technically it acts more as the query in this case)

        M : torch.tensor [bs, N, d_hidden]
            Memory, (technically it is the keys and values)

        Returns
        -------
        w : torch.tensor [bs, N]


                       < k * M[i] >
        w_i = Softmax(-------------- * ß ) 
                       |k| * |M[i]|

        """

        dot_product = torch.bmm(self.M, k.unsqueeze(2)).squeeze(-1)
        M_norm = torch.linalg.norm(self.M, dim = -1)
        k_norm = torch.linalg.norm(k, dim = - 1).unsqueeze(-1)
        mul_norms = M_norm * k_norm

        alignment = dot_product / mul_norms
        w = nn.Softmax(dim = 1)(alignment * ß)

        return w

    def attention_location_focus(self, w_current, g):
        """
        Arguments
        ---------
        w : torch.tensor [bs, N]
        w_previous : torch.tensor [bs, N]
        g : torch.tensor [bs, 1]
        
        Returns
        -------
        w_g : torch.tensor [bs, N]
            
        """
        
        w_g = g * w_current + (torch.tensor(1.) - g) * self.w_previous
        return w_g

    def attention_convolution(self, w, s):
        """
        Arguments
        ---------
        w : torch.tensor [bs, N]
        s : torch.tensor [bs, 3]
            the indices of s are [-1, 0, 1]

        Returns
        -------
        w_shifted : torch.tensor [bs, N]

        """
        w_d = w[:,-1].unsqueeze(1)
        w_0 = w[:,0].unsqueeze(1)
        w_cycle = torch.cat([w_d, w, w_0], dim = -1).unsqueeze(1) # [bs, 1, N+2]
        s = s.flip(dims = (1,)).unsqueeze(2) # [bs, 3, 1]

        max_first_idx = w_cycle.shape[2] - 3
        w_shifted = torch.cat([torch.bmm(w_cycle[:,:,i:i+3],s) 
                               for i in range(max_first_idx + 1)], 
                              dim = -1).squeeze(1)

        return w_shifted


    def attention_sharpen(self, w, y):
        """Apply Temperature to shifted weights
        
        Arguments
        ---------
        w : torch.tensor [bs, N]
        
        Returns
        -------
        w : torch.tensor [bs, N]
    
        """
        nominator = (w ** y)
        denominator = nominator.sum(dim = 1).unsqueeze(1)
        
        w = nominator / denominator
        return w

    def attention(self, k, ß, g, s, y):
        """
        Arguments
        ---------
        k : torch.tensor [bs, d_hidden]
            Key, (technically it acts more as the query in this case)

        M : torch.tensor [N, d_hidden]
            Memory, (technically it is the keys and values)

        Returns
        -------
        w : torch.tensor [bs, N]

        """
        w = self.attention_content_focus(k,ß)
        w = self.attention_location_focus(w, g)
        w = self.attention_convolution(w, s)
        w = self.attention_sharpen(w, y)
        
        self.w_previous = w
        return w
    
    def read(self, w):
        """
        Arguments
        ---------
        w : torch.tensor [bs, N]

        M : torch.tensor [bs, N, d_hidden]
            Memory, (technically it is the keys and values)


        Returns
        -------
        r : torch.tensor [bs, d_hidden]

        """

        r = torch.bmm(w.unsqueeze(1), self.M).squeeze(1)
        return r
    
    
    def write(self, w, e, a):
        """Create Memory Matrix for next timestep

        Arguments
        ---------
        M : torch.tensor [bs, N, d_hidden]
        w : torch.tensor [bs, N]
        e : torch.tensor [bs, d_hidden]
        a : torch.tensor [bs, d_hidden]

        Returns
        -------
        M : torch.tensor [bs, N, d_hidden]

            Mt[i] = Mt-1[i] * (I - w[i] * diag(e)) + w[i] * a

        """
        bs = w.shape[0]
        M_next = list()
        
        for i in range(N):
            w_i = w[:,i].reshape(bs, 1, 1)         # [bs, 1, 1]
            I = torch.eye(d_hidden).repeat(bs,1,1) # [bs, d, d]
            e_diag = torch.diag_embed(e)           # [bs, d, d]
            M_i = self.M[:,i,:].unsqueeze(1)       # [bs, 1, d]
            
            M_i = torch.bmm(M_i, I - w_i * e_diag) + w_i * a.unsqueeze(1) 
            M_next.append(M_i)
            
        M_next = torch.cat(M_next, dim = 1)
        self.M = M_next

In [134]:
# Hyperparams
bs = 256
ht = 6
ft = 6
num_cats = 5
d_hidden = num_cats
N = 4

# Data
x = torch.randint(0,num_cats-1, (bs, ht, 1))
x_OH = batch_to_one_hot(x, num_cats)

model = NeuralTuringMachine()
optimizer = Adam(model.parameters(), lr = 0.001)

In [135]:
for epoch in range(500):
    y_pred = model(x_OH)
    loss = CESequenceLoss(y_pred, x_OH)
    loss.backward()
    clip_grad_norm_(model.parameters(), 10)
    optimizer.step()
    optimizer.zero_grad()
    if epoch % 50 == 0:
        print(f'loss = {loss.detach()}')

loss = 10.45910358428955
loss = 9.353789329528809
loss = 8.495030403137207
loss = 7.749838829040527
loss = 6.872945308685303
loss = 5.709312915802002
loss = 4.470155239105225
loss = 2.952481508255005
loss = 1.949126124382019
loss = 1.43279230594635
