In [3]:
import torch
t = torch.Tensor([1,2,2,3])
t

tensor([1., 2., 2., 3.])

In [13]:
# read_head -> w_t vec weights with N entries 
# read_vector r_t = convex combination of mem elements with weights = w_t

In [38]:
# class ReadHead(torch.nn.Module):
#     def __init__(self, *args, **kwargs):
#         super().__init__(*args, **kwargs)
    
#     def forward(self, mem_mat: torch.Tensor):
#         """Emit read vector r_t"""
        
#         # product weights w_t
#         # emit reaad vector r_t
#         weights = torch.rand(mem_mat.size()[0])
#         weights = weights[:, None]
#         return torch.sum(weights * mem_mat, dim=0)

In [None]:
# class WriteHead(torch.nn.Module):
#     def __init__(self, *args, **kwargs):
#         super().__init__(*args, **kwargs)
#     def forward(self, memmat: torch.Tensor):
#         """Write head adjusts memory matrix."""

#         weights = torch.rand(memmat.size()[0])
#         v_erase =  torch.rand(memmat.size()[1])
#         v_add =  torch.rand(memmat.size()[1])
#         weights = weights[:, None]
#         ones = torch.ones(memmat.size()[1])
#         for i in range(len(memmat.size()[0])):
#             memmat[i, :] = memmat[i, :] * (ones - weights[i] * v_erase)
#             memmat[i, :] = memmat[i, :] * (ones - weights[i] * v_add)

#         return None




In [None]:
# free params controller
SIZE_OF_MEM = 10
NUM_READ_HEADS = 5
NUM_WRITE_HEADS = 5
LOC_SHIFT_RANGE = list(range(1,6))


In [9]:
from typing import Optional
class MemoryBank(torch.nn.Module):
    def __init__(self, num_vectors: int, vec_dim: int):
        super(MemoryBank, self).__init__()
        self.num_vec = num_vectors
        self.vec_dim = vec_dim
        self.batch_size: Optional[int] = None
        self.data: torch.Tensor | None = None

    def init_state(self, batch_size, device):
        self.batch_size = batch_size
        self.data = torch.zeros(batch_size, self.num_vec, self.vec_dim).to(device)

    def update(self, weight: torch.Tensor, erase_vec: torch.Tensor, add_vec: torch.Tensor):
        # make sure that batch_dim of tensor is indeed self.batch_dim
        # TODO check if dims are ok and batch dim is considered correctly
        erase_row_stack = erase_vec.repeat(erase_vec.shape[0], 1, self.num_vec).reshape(erase_vec.shape[0], erase_vec.shape[1], -1)
        erase_row_stack *= weight
        self.data -= erase_row_stack
        self.data += weight*add_vec.repeat(add_vec.shape[0], add_vec.shape[1], self.num_vec).reshape(add_vec.shape[0], add_vec.shape[1], -1)
        

In [10]:
t = torch.ones([2,3])
t.repeat(1, 3, 1).reshape(t.shape[0], t.shape[1], -1)
# torch.ones(size=(3,2))
# t, t.shape


tensor([[[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]]])

In [27]:
from typing import Tuple
class NeuralNetController(torch.nn.Module):
    """Some kind of Recurrent net or feedforward net."""
    # typically LSTM
    def __init__(self, in_size:int=50, h_size: int=20, num_layers: int=1, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.hidden_size: int = h_size
        self.inp_size: int = in_size
        self.num_lstm_cells: int = num_layers
        self.lstm_cell_0 = torch.nn.LSTMCell(
            input_size=self.inp_size,
            hidden_size=self.hidden_size,
            )
        # self.memory_bank: MemoryBank = MemoryBank(10, h_size)

    
    def forward(self, x: torch.Tensor, hidden_state: torch.Tensor, cell_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """product hidden vector of size self.hidden_size from input with size self.inp_size."""
        # outputs k_t, beta_t, g_t, s_t, \gamma_t
        return self.lstm_cell_0(x, (hidden_state, cell_state))


In [28]:
class Combinator(torch.nn.Module):
    """Combine output of controller and memory access to make final prediction."""

    def __init__(self, hidden_dim: int, read_vec_dim: int=10, num_read_heads: int=1, out_dim: int=10):
        """Default out dim for copy task: 0 through 9."""
        super().__init__()
        self.hidden_dim: int = hidden_dim
        self.read_vec_dim: int = read_vec_dim
        self.num_read_heads: int = num_read_heads
        self.layer_1: torch.nn.Linear = torch.nn.Linear(
            self.hidden_dim + self.read_vec_dim * self.num_read_heads,
            out_dim
            )
    def forward(self, x: torch.Tensor):
        return torch.argmax(torch.nn.Softmax(self.layer_1(x)))



In [38]:
from typing import List, Callable
class NeuralTuringMachine(torch.nn.Module):
    """Neural Turing Machine."""

    def __init__(
            self,
            in_size: int,
            out_size: int,
            hidden_size: int,
            num_mem_vectors: int,
            mem_vect_dim: int,
            num_read_heads: int,
            num_write_heads: int,
            num_lstm_layers: int,
            sim_func: Callable
            ):
        super().__init__()
        self.inp_size: int = in_size  # input seq dim of each element
        self.out_dim: int = out_size # output seq dim of each element
        self.hidden_size: int = hidden_size
        self.num_mem_vectors: int = num_mem_vectors
        self.vect_dim: int = mem_vect_dim
        self.num_read_heads: int = num_read_heads
        self.num_write_heads: int = num_write_heads
        self.membank: MemoryBank = MemoryBank(self.num_mem_vectors, self.vect_dim)
        self.controller: NeuralNetController = NeuralNetController(
            self.inp_size,
            self.hidden_size,
            num_layers=num_lstm_layers
            )
        self.sim_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = sim_func
        self.combinator: Combinator = Combinator(
            self.hidden_size,
            self.vect_dim,
            num_read_heads=self.num_read_heads,
            out_dim=self.out_dim
            )

        self.read_heads: List[ReadHead] = [ReadHead(self.hidden_size, self.membank, sim_func=self.sim_func)]
        self.write_heads: List[WriteHead] = [WriteHead(self.hidden_size, self.membank, sim_func=self.sim_func)]


    def read(self):
        pass

    def write(self):
        pass

    def forward(self, x: torch.Tensor):
        """Tensor of t entries, each entry with dim inp_size."""

        h_t: torch.Tensor = torch.randn(x.shape[0], self.hidden_size)
        c_t: torch.Tensor = torch.randn(x.shape[0], self.hidden_size)
        read_vecs: List[torch.Tensor]
        out_vals: List = []
        # read_vecs: List[torch.Tensor] = [read_head.forward(h_t) for read_head in self.read_heads]

        for i in range(x.shape[1]):
            # always expect batch dim
            h_t, c_t = self.controller.forward(x[:, i, :], h_t, c_t)

            read_vecs = [read_head.forward(h_t) for read_head in self.read_heads]
            read_vec = torch.concat(read_vecs)
            out_vals.append(self.combinator.forward(torch.concat([h_t, read_vec])))
            # update current weights for writing
            for write_head in  self.write_heads:
                write_head.get_weight()

            # update memory 
            for write_head in  self.write_heads:
                write_head.forward(h_t)

        return torch.Tensor(out_vals)
            

    def __call__(self, x: torch.Tensor):
        return self.forward(x)


In [37]:
t.size()

torch.Size([10, 5])

In [30]:
from abc import abstractmethod
from typing import Tuple, Callable
from dataclasses import dataclass

class Head(torch.nn.Module):
    """produce key, key_strength, interpolation_gate, shift_weighting, sharpening_factor.

    key in mathbb{R}^{mem_size}
    key_strength in mathbb{R}_{+} how sharp the attention to memory locations should be
    interpolation_gate in (0,1) how much of last weight should be retained
    shift_weighting in mathbb{R}^{num_mem_locations} prob distribution over the num locations
    sharpening_factor in [1, infty)
    """

    def __init__(self, in_dim: int,
                 membank: MemoryBank,
                 sim_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
                ):

        self.membank = membank
        self.num_mem_locations = self.membank.num_vec
        self.sim_func = sim_func
        self.mem_size = self.membank.vec_dim
        self.current_weight: torch.Tensor = torch.ones(self.num_mem_locations) / self.num_mem_locations # size = num_mem_locations

    @abstractmethod
    def get_params(self):
        ...

    @abstractmethod
    def forward(self, x: torch.Tensor):
        ...

    def get_weight(
        self,
        k: torch.Tensor,
        ks: torch.Tensor,
        ig: torch.Tensor,
        sw: torch.Tensor,
        sf: torch.Tensor
        ):
        loc_similarity: torch.Tensor = self.sim_func(k, self.membank.data)
        loc_weight: torch.Tensor = torch.nn.functional.softmax(ks * loc_similarity, dim=-1)
        gated_weighting: torch.Tensor = (1 - ig) * self.current_weight + ig * loc_weight
        compound_weight: torch.Tensor = torch.zeros(self.num_mem_locations)

        # TODO implement convolution without for loops
        for i in range(self.num_mem_locations):
            for j in range(self.num_mem_locations):
                compound_weight[i] += gated_weighting[j] * sw[(i-j)%self.num_mem_locations] 

        sharpened_weight: torch.Tensor = torch.pow(compound_weight, sf)
        weight = sharpened_weight / torch.sum(sharpened_weight)
        self.current_weight = weight
        return weight


class ReadHead(Head):
    def __init__(self, in_dim, membank, sim_func):
        super().__init__(in_dim, membank, sim_func)
        self.out_size = self.mem_size + 1 + 1 + self.num_mem_locations + 1
        self.layer_1 = torch.nn.Linear(in_features=in_dim, out_features=self.out_size)

    def get_params(self, h: torch.Tensor) -> Tuple[torch.Tensor, ...]:
        out: torch.Tensor = self.layer_1(h)
        read_key: torch.Tensor = out[:, 0 : self.mem_size]
        key_strenght: torch.Tensor = torch.exp(out[:, self.mem_size])
        interpolation_gate: torch.Tensor = torch.nn.Sigmoid(out[:, self.mem_size + 1])
        shift_weighting: torch.Tensor = torch.softmax(out[:, (self.mem_size + 2): -1], dim=-1)
        sharp_factor: torch.Tensor = 1 + torch.exp(out[:, -1])
        return (read_key, key_strenght, interpolation_gate, shift_weighting, sharp_factor)

    
    def forward(self, h: torch.Tensor):
        q_key, key_strength, interpolation_gate, shift_weighting, sharp_factor = self.get_params(h)
        weight: torch.Tensor = self.get_weight(q_key, key_strength, interpolation_gate, shift_weighting, sharp_factor)
        read_vec: torch.Tensor = torch.sum(weight * self.membank.data, dim=0)
        return read_vec


class WriteHead(Head):
    def __init__(self, in_dim, membank, sim_func):
        super().__init__(in_dim, membank, sim_func)
        self.out_size: int = self.mem_size * 3 + 1 + 1 + self.num_mem_locations + 1
        self.layer_1 = torch.nn.Linear(in_features=in_dim, out_features=self.out_size)

    def forward(self, h: torch.Tensor):
        params: Tuple[torch.Tensor, ...] = self.get_params(h)
        erase_vec: torch.Tensor = params[0]
        add_vec: torch.Tensor = params[1]
        q_key: torch.Tensor = params[2]
        key_strength: torch.Tensor = params[3]
        interpolation_gate: torch.Tensor  = params[4]
        shift_weighting: torch.Tensor  = params[5]
        sharp_factor: torch.Tensor = params[6]
        # weight: torch.Tensor = self.get_weight(q_key, key_strength, interpolation_gate, shift_weighting, sharp_factor)
        self.membank.update(self.current_weight, erase_vec, add_vec)
        return None

    def get_params(self, h: torch.Tensor) -> Tuple[torch.Tensor, ...]:
        out: torch.Tensor = self.layer_1(h)
        erase_vec: torch.Tensor = torch.nn.functional.sigmoid(out[:, :self.mem_size])
        add_vec: torch.Tensor = out[:, self.mem_size: 2*self.mem_size]
        qkey: torch.Tensor = out[:, self.mem_size: 3*self.mem_size]
        key_strength: torch.Tensor = torch.exp(out[:, 3*self.mem_size])
        interpolation_gate: torch.Tensor = torch.nn.Sigmoid(out[:, 3*self.mem_size + 1])
        shift_weighting: torch.Tensor = torch.softmax(out[:, (3*self.mem_size + 2): -1], dim=-1)
        sharp_factor: torch.Tensor = 1 + torch.exp(out[:, -1])
        return (erase_vec, add_vec, qkey, key_strength, interpolation_gate, shift_weighting, sharp_factor)


    

In [15]:
# dataset
copy_dataset = None
sort_dataset = None

In [16]:
# copy dataset 
copy_vecs = torch.randint(2, (1000, 8))
copy_vecs[1, :]


tensor([1, 0, 1, 0, 0, 1, 0, 1])

In [17]:
from torch.utils.data import Dataset
class CopyDataset(Dataset):
    def __init__(self, len, delim: int=-1):
        super().__init__()
        self.delim = delim
        self.data = torch.randint(2, (len, 8))
        self.data = torch.column_stack([self.data, torch.ones(self.data.size()[0]) * self.delim])
    
    def __getitem__(self, idx: int):
        return (self.data[idx, :], self.data[idx, :-1])

In [18]:
cop_data = CopyDataset(10)
cop_data[8]

(tensor([ 0.,  0.,  1.,  0.,  1.,  1.,  0.,  0., -1.]),
 tensor([0., 0., 1., 0., 1., 1., 0., 0.]))

- for each head, be it a read or write head, the addressing mechanism is implemented
=> for each head, the controller needs to produce
  + key vector $k_t$
  + key strength $\beta_t$
  + interpolation gate $g_t$
  + shift weighting $s_t$
  + sharpening factor $\gamma_t$

In [19]:
x = torch.randint(3, (5,3))
y = torch.randint(3, (3,))
x, y, torch.cat([x] + y, dim=1)

TypeError: can only concatenate list (not "Tensor") to list

In [20]:

class A:
    def __init__(self, a,b,c):
        self.a = a
        self.b = b
        self.c = c
        self.ls = [a,b,c]

class B:
    def __init__(self, candA: A):
        self.candA = candA
    
    def print(self):
        print(self.candA.ls)

    def mod(self, v):
        self.candA.a = v
        self.candA.ls.append(self.candA.a)


In [21]:
a = A(1,2,3)
b = B(a)
c = B(a)

b.print()
c.print()

b.mod(234)
b.print()
c.print()
c.mod(888)
b.print()
c.print()

[1, 2, 3]
[1, 2, 3]
[1, 2, 3, 234]
[1, 2, 3, 234]
[1, 2, 3, 234, 888]
[1, 2, 3, 234, 888]


In [22]:
t=torch.rand(size=(10,5)) * 10
t, torch.nn.functional.softmax(t, dim=-1).sum(dim=-1)

(tensor([[5.2178, 1.9274, 1.1395, 6.4742, 5.1977],
         [2.7616, 3.0651, 1.0750, 6.6164, 2.5470],
         [3.1665, 6.1073, 1.9395, 6.3281, 9.4515],
         [7.6308, 5.6517, 2.6815, 1.7194, 8.4667],
         [0.6252, 2.7191, 8.2259, 3.9783, 2.7316],
         [1.6349, 4.1056, 2.0810, 3.2715, 4.2814],
         [7.0307, 6.4599, 4.1481, 4.1009, 9.3166],
         [2.1597, 5.4950, 1.7843, 8.0205, 2.7984],
         [1.7626, 6.8268, 3.1274, 4.3092, 9.9576],
         [2.6819, 1.9981, 8.2281, 0.5521, 2.0147]]),
 tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000]))

In [23]:
-2 % 10

torch.Tensor([2]) * 2

tensor([4.])