In [3]:
import torch
t = torch.Tensor([1,2,2,3])
t

tensor([1., 2., 2., 3.])

In [13]:
# read_head -> w_t vec weights with N entries 
# read_vector r_t = convex combination of mem elements with weights = w_t

In [38]:
# class ReadHead(torch.nn.Module):
#     def __init__(self, *args, **kwargs):
#         super().__init__(*args, **kwargs)
    
#     def forward(self, mem_mat: torch.Tensor):
#         """Emit read vector r_t"""
        
#         # product weights w_t
#         # emit reaad vector r_t
#         weights = torch.rand(mem_mat.size()[0])
#         weights = weights[:, None]
#         return torch.sum(weights * mem_mat, dim=0)

In [None]:
# class WriteHead(torch.nn.Module):
#     def __init__(self, *args, **kwargs):
#         super().__init__(*args, **kwargs)
#     def forward(self, memmat: torch.Tensor):
#         """Write head adjusts memory matrix."""

#         weights = torch.rand(memmat.size()[0])
#         v_erase =  torch.rand(memmat.size()[1])
#         v_add =  torch.rand(memmat.size()[1])
#         weights = weights[:, None]
#         ones = torch.ones(memmat.size()[1])
#         for i in range(len(memmat.size()[0])):
#             memmat[i, :] = memmat[i, :] * (ones - weights[i] * v_erase)
#             memmat[i, :] = memmat[i, :] * (ones - weights[i] * v_add)

#         return None




In [None]:
# free params controller
SIZE_OF_MEM = 10
NUM_READ_HEADS = 5
NUM_WRITE_HEADS = 5
LOC_SHIFT_RANGE = list(range(1,6))


In [9]:
from typing import Optional
class MemoryBank(torch.nn.Module):
    def __init__(self, num_vectors: int, vec_dim: int):
        super(MemoryBank, self).__init__()
        self.num_vec = num_vectors
        self.vec_dim = vec_dim
        self.batch_size: Optional[int] = None
        self.data: torch.Tensor | None = None

    def init_state(self, batch_size, device):
        self.batch_size = batch_size
        self.data = torch.zeros(batch_size, self.num_vec, self.vec_dim).to(device)

    def update(self, weight: torch.Tensor, erase_vec: torch.Tensor, add_vec: torch.Tensor):
        # make sure that batch_dim of tensor is indeed self.batch_dim
        # TODO check if dims are ok and batch dim is considered correctly
        erase_row_stack = erase_vec.repeat(erase_vec.shape[0], 1, self.num_vec).reshape(erase_vec.shape[0], erase_vec.shape[1], -1)
        erase_row_stack *= weight
        self.data -= erase_row_stack
        self.data += weight*add_vec.repeat(add_vec.shape[0], add_vec.shape[1], self.num_vec).reshape(add_vec.shape[0], add_vec.shape[1], -1)
        

In [None]:
# 4 20 10
# 4 20

20

20 10





In [385]:
b = torch.randn(1, 20, 10)
r = torch.abs(torch.randn(1, 20)).unsqueeze(-1)
r = r / torch.sum(r)
r.shape
(r * b).sum(1).shape



torch.Size([1, 10])

In [10]:
t = torch.ones([2,3])
t.repeat(1, 3, 1).reshape(t.shape[0], t.shape[1], -1)
# torch.ones(size=(3,2))
# t, t.shape


tensor([[[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]]])

In [109]:
from typing import Tuple
class NeuralNetController(torch.nn.Module):
    """Some kind of Recurrent net or feedforward net."""
    # typically LSTM
    def __init__(self, in_size:int=50, h_size: int=20, num_layers: int=1, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.hidden_size: int = h_size
        self.inp_size: int = in_size
        self.num_lstm_cells: int = num_layers
        self.lstm_cell_0 = torch.nn.LSTMCell(
            input_size=self.inp_size,
            hidden_size=self.hidden_size,
            )
        # self.memory_bank: MemoryBank = MemoryBank(10, h_size)

    
    def forward(self, x: torch.Tensor, hidden_state: torch.Tensor, cell_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """product hidden vector of size self.hidden_size from input with size self.inp_size."""
        # outputs k_t, beta_t, g_t, s_t, \gamma_t
        print(f"x size: {x.shape},\t h size: {hidden_state.shape},\t c size: {cell_state.shape}")
        return self.lstm_cell_0(x, (hidden_state, cell_state))


In [427]:
class Combinator(torch.nn.Module):
    """Combine output of controller and memory access to make final prediction."""

    def __init__(self, hidden_dim: int, read_vec_dim: int=10, num_read_heads: int=1, out_dim: int=10):
        """Default out dim for copy task: 0 through 9."""
        super().__init__()
        self.hidden_dim: int = hidden_dim
        self.read_vec_dim: int = read_vec_dim
        self.num_read_heads: int = num_read_heads
        self.layer_1: torch.nn.Linear = torch.nn.Linear(
            self.hidden_dim + self.read_vec_dim * self.num_read_heads,
            out_dim
            )
    def forward(self, x: torch.Tensor):
        return torch.argmax(torch.nn.functional.softmax(self.layer_1(x)))



In [428]:
from typing import List, Callable
class NeuralTuringMachine(torch.nn.Module):
    """Neural Turing Machine."""

    def __init__(
            self,
            in_size: int,
            out_size: int,
            hidden_size: int,
            num_mem_vectors: int,
            mem_vect_dim: int,
            num_read_heads: int,
            num_write_heads: int,
            num_lstm_layers: int,
            sim_func: Callable
            ):
        super().__init__()
        self.inp_size: int = in_size  # input seq dim of each element
        self.out_dim: int = out_size # output seq dim of each element
        self.hidden_size: int = hidden_size
        self.num_mem_vectors: int = num_mem_vectors
        self.vect_dim: int = mem_vect_dim
        self.num_read_heads: int = num_read_heads
        self.num_write_heads: int = num_write_heads
        self.membank: MemoryBank = MemoryBank(self.num_mem_vectors, self.vect_dim)
        self.controller: NeuralNetController = NeuralNetController(
            self.inp_size,
            self.hidden_size,
            num_layers=num_lstm_layers
            )
        self.sim_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = sim_func
        self.combinator: Combinator = Combinator(
            self.hidden_size,
            self.vect_dim,
            num_read_heads=self.num_read_heads,
            out_dim=self.out_dim
            )

        self.read_heads: List[ReadHead] = [ReadHead(self.hidden_size, self.membank, sim_func=self.sim_func)]
        self.write_heads: List[WriteHead] = [WriteHead(self.hidden_size, self.membank, sim_func=self.sim_func)]


    def read(self):
        pass

    def write(self):
        pass

    def forward(self, x: torch.Tensor):
        """Tensor of t entries, each entry with dim inp_size."""

        h_t: torch.Tensor = torch.randn(x.shape[0], self.hidden_size)
        c_t: torch.Tensor = torch.randn(x.shape[0], self.hidden_size)
        read_vecs: List[torch.Tensor]
        out_vals: List = []
        # read_vecs: List[torch.Tensor] = [read_head.forward(h_t) for read_head in self.read_heads]

        for i in range(x.shape[1]):
            # always expect batch dim
            print(f"dim h: {h_t.shape}, dim c: {c_t.shape}")
            h_t, c_t = self.controller.forward(x[:, i:i+1], h_t, c_t)

            read_vecs = [read_head.forward(h_t) for read_head in self.read_heads]
            read_vec = torch.concat(read_vecs)
            out_vals.append(self.combinator.forward(torch.concat([h_t, read_vec], dim=1)))
            # update current weights for writing
            for write_head in  self.write_heads:
                write_head.get_weight()

            # update memory 
            for write_head in  self.write_heads:
                write_head.forward(h_t)

        return torch.Tensor(out_vals)
            

    def __call__(self, x: torch.Tensor):
        return self.forward(x)


In [429]:
t.size()

torch.Size([10, 5])

In [430]:
from abc import abstractmethod
from typing import Tuple, Callable
from dataclasses import dataclass

class Head(torch.nn.Module):
    """produce key, key_strength, interpolation_gate, shift_weighting, sharpening_factor.

    key in mathbb{R}^{mem_size}
    key_strength in mathbb{R}_{+} how sharp the attention to memory locations should be
    interpolation_gate in (0,1) how much of last weight should be retained
    shift_weighting in mathbb{R}^{num_mem_locations} prob distribution over the num locations
    sharpening_factor in [1, infty)
    """

    def __init__(self, in_dim: int,
                 membank: MemoryBank,
                 sim_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
                ):

        super().__init__()
        self.membank = membank
        self.num_mem_locations = self.membank.num_vec
        self.sim_func = sim_func
        self.mem_size = self.membank.vec_dim
        self.current_weight: torch.Tensor = torch.ones(self.num_mem_locations) / self.num_mem_locations # size = num_mem_locations
        self.bs: int = 0



    @abstractmethod
    def get_params(self):
        ...

    @abstractmethod
    def forward(self, x: torch.Tensor):
        ...

    def batch_size_infer(self, bs: int):
        self.current_weight = self.current_weight.unsqueeze(0)
        self.current_weight = self.current_weight.repeat(bs, 1)
        self.bs = bs

    def get_weight(
        self,
        k: torch.Tensor,
        ks: torch.Tensor,
        ig: torch.Tensor,
        sw: torch.Tensor,
        sf: torch.Tensor
        ):
        
        # print(f"key.shape: {k.shape}, membank data shape: {self.membank.data.shape}")
        loc_similarity: torch.Tensor = self.sim_func(k, self.membank.data)
        loc_similarity = loc_similarity.squeeze(dim=-1)
        ks = ks.unsqueeze(1).repeat(1, loc_similarity.shape[1])
        # print(f"loc similarity.shape: {loc_similarity.shape}")
        # print(f"key strenght shape: {ks.shape}\t ks * loc_sim shape {(ks*loc_similarity).shape}")
        loc_weight: torch.Tensor = torch.nn.functional.softmax(ks * loc_similarity, dim=1)
        # print(f"loc weight shape {loc_weight.shape}\t self.current_weight.shape {self.current_weight.shape}",
        #       f"ig.shape {ig.shape}")
        gated_weighting: torch.Tensor = (1 - ig).unsqueeze(1) * self.current_weight + ig.unsqueeze(1) * loc_weight
        compound_weight: torch.Tensor = torch.zeros(self.bs, self.num_mem_locations)

        # TODO implement convolution without for loops
        # print(f"sw shape: {sw.shape}.")
        # print(f"compound weight shape: {compound_weight.shape}.")

        for i in range(self.num_mem_locations):
            for j in range(self.num_mem_locations):
                compound_weight[:, i] += gated_weighting[:, j] * sw[:, (i-j)%self.num_mem_locations] 

        print(compound_weight.shape, sf.shape)
        sharpened_weight: torch.Tensor = torch.pow(compound_weight, sf.unsqueeze(1).repeat(1, compound_weight.shape[1]))
        weight = sharpened_weight / torch.sum(sharpened_weight)
        self.current_weight = weight
        return weight


class ReadHead(Head):
    def __init__(self, in_dim, membank, sim_func):
        print(membank)
        super().__init__(in_dim, membank, sim_func)
        self.out_size = self.mem_size + 1 + 1 + self.num_mem_locations + 1
        self.layer_1 = torch.nn.Linear(in_features=in_dim, out_features=self.out_size)

    def get_params(self, h: torch.Tensor) -> Tuple[torch.Tensor, ...]:
        out: torch.Tensor = self.layer_1(h)
        read_key: torch.Tensor = out[:, 0 : self.mem_size]
        key_strenght: torch.Tensor = torch.exp(out[:, self.mem_size])
        # interpolation_gate: torch.Tensor = torch.nn.Sigmoid(out[:, self.mem_size + 1])
        interpolation_gate: torch.Tensor = torch.nn.functional.sigmoid(out[:, self.mem_size + 1])
        shift_weighting: torch.Tensor = torch.softmax(out[:, (self.mem_size + 2): -1], dim=-1)
        sharp_factor: torch.Tensor = 1 + torch.exp(out[:, -1])
        return (read_key, key_strenght, interpolation_gate, shift_weighting, sharp_factor)

    
    def forward(self, h: torch.Tensor):
        q_key, key_strength, interpolation_gate, shift_weighting, sharp_factor = self.get_params(h)
        weight: torch.Tensor = self.get_weight(q_key, key_strength, interpolation_gate, shift_weighting, sharp_factor)
        print(f"weight shape: {weight.shape}")
        read_vec: torch.Tensor = torch.sum(weight.unsqueeze(-1) * self.membank.data, dim=1)
        print(f"read vec shape: {read_vec.shape}")
        return read_vec


class WriteHead(Head):
    def __init__(self, in_dim, membank, sim_func):
        super().__init__(in_dim, membank, sim_func)
        self.out_size: int = self.mem_size * 3 + 1 + 1 + self.num_mem_locations + 1
        self.layer_1 = torch.nn.Linear(in_features=in_dim, out_features=self.out_size)

    def forward(self, h: torch.Tensor):
        params: Tuple[torch.Tensor, ...] = self.get_params(h)
        erase_vec: torch.Tensor = params[0]
        add_vec: torch.Tensor = params[1]
        q_key: torch.Tensor = params[2]
        key_strength: torch.Tensor = params[3]
        interpolation_gate: torch.Tensor  = params[4]
        shift_weighting: torch.Tensor  = params[5]
        sharp_factor: torch.Tensor = params[6]
        # weight: torch.Tensor = self.get_weight(q_key, key_strength, interpolation_gate, shift_weighting, sharp_factor)
        self.membank.update(self.current_weight, erase_vec, add_vec)
        return None

    def get_params(self, h: torch.Tensor) -> Tuple[torch.Tensor, ...]:
        out: torch.Tensor = self.layer_1(h)
        erase_vec: torch.Tensor = torch.nn.functional.sigmoid(out[:, :self.mem_size])
        add_vec: torch.Tensor = out[:, self.mem_size: 2*self.mem_size]
        qkey: torch.Tensor = out[:, self.mem_size: 3*self.mem_size]
        key_strength: torch.Tensor = torch.exp(out[:, 3*self.mem_size])
        interpolation_gate: torch.Tensor = torch.nn.sigmoid(out[:, 3*self.mem_size + 1])
        shift_weighting: torch.Tensor = torch.softmax(out[:, (3*self.mem_size + 2): -1], dim=-1)
        sharp_factor: torch.Tensor = 1 + torch.exp(out[:, -1])
        return (erase_vec, add_vec, qkey, key_strength, interpolation_gate, shift_weighting, sharp_factor)


    

In [431]:
# dataset
copy_dataset = None
sort_dataset = None

In [432]:
# copy dataset 
copy_vecs = torch.randint(2, (1000, 8))
copy_vecs[1, :]


tensor([0, 0, 0, 0, 1, 1, 0, 0])

In [433]:
from torch.utils.data import Dataset
class CopyDataset(Dataset):
    def __init__(self, len, delim: int=-1):
        super().__init__()
        self.delim = delim
        self.length = len
        self.data = torch.randint(2, (len, 8))
        self.data = torch.column_stack([self.data, torch.ones(self.data.size()[0]) * self.delim])
    
    def __getitem__(self, idx: int):
        return (self.data[idx, :], self.data[idx, :-1])
    def __len__(self) -> int:
        return self.length


In [434]:
cop_data = CopyDataset(10)

cop_data[8]

(tensor([ 0.,  0.,  1.,  0.,  1.,  1.,  0.,  0., -1.]),
 tensor([0., 0., 1., 0., 1., 1., 0., 0.]))

- for each head, be it a read or write head, the addressing mechanism is implemented
=> for each head, the controller needs to produce
  + key vector $k_t$
  + key strength $\beta_t$
  + interpolation gate $g_t$
  + shift weighting $s_t$
  + sharpening factor $\gamma_t$

In [435]:
from torch.utils.data import DataLoader

ds = DataLoader(cop_data, batch_size=2)

In [436]:
for i, obj in enumerate(ds):
    if i < 1:
        print(obj)
        print(obj[0].shape, obj[0])
        print(obj[1].shape, obj[1])


    

[tensor([[ 0.,  1.,  0.,  0.,  0.,  1.,  1.,  0., -1.],
        [ 0.,  1.,  0.,  1.,  1.,  1.,  0.,  1., -1.]]), tensor([[0., 1., 0., 0., 0., 1., 1., 0.],
        [0., 1., 0., 1., 1., 1., 0., 1.]])]
torch.Size([2, 9]) tensor([[ 0.,  1.,  0.,  0.,  0.,  1.,  1.,  0., -1.],
        [ 0.,  1.,  0.,  1.,  1.,  1.,  0.,  1., -1.]])
torch.Size([2, 8]) tensor([[0., 1., 0., 0., 0., 1., 1., 0.],
        [0., 1., 0., 1., 1., 1., 0., 1.]])


In [437]:
def train(ntm: NeuralTuringMachine, ds: CopyDataset, loss_func: Callable = None, epochs: int=10, bs: int = 4):
    dataloader = DataLoader(ds, batch_size=bs)
    adam = torch.optim.Adam(ntm.parameters())
    ntm.membank.init_state(bs, torch.device("cpu"))
    for rh in ntm.read_heads:
        rh.batch_size_infer(bs)


    for i in range(epochs):
        for ep, (x,y) in enumerate(dataloader):
            adam.zero_grad()
            print(f"x: {x}, shape: {x.shape}")
            pred = ntm(x)
            loss = loss_func(pred,y)
            loss.backward()
            adam.step()
            print(loss)

        print(f"Finished epoch no {ep}")

In [438]:
def dot_product(t1: torch.Tensor, t2: torch.Tensor) -> torch.Tensor:
    return torch.sum(t1 * t2, dim=1)
t1 = torch.randn(2, 10)
t2 = torch.randn(2, 10)

dot_product(t1, t2)


tensor([ 0.6453, -1.9508])

In [439]:
def new_dot_prod(k: torch.Tensor, mb: torch.Tensor):
    # key.shape = (batch_size, size_of_mem_vector)
    # mb.shape = (batch_size, num_mem_vectors, size_of_mem_vector)
    k = k.unsqueeze(1)
    k = k.repeat(1,20,1)
    # print(k.shape)
    # print(k[0, :, :])
    # print(mb.shape)

    prod = k * mb
    return torch.sum(prod, dim=2, keepdim=True)




t1 = torch.randn((4, 10))
t2 = torch.randn((4, 20, 10))


dp = new_dot_prod(t1, t2)

torch.softmax(dp, dim=1)[0], dp.shape


(tensor([[2.0068e-01],
         [1.5779e-03],
         [8.6358e-03],
         [2.9689e-03],
         [5.3717e-02],
         [2.4518e-04],
         [1.6183e-03],
         [2.0148e-01],
         [3.3282e-04],
         [2.2563e-02],
         [8.6456e-03],
         [1.8847e-02],
         [3.0825e-01],
         [5.0610e-02],
         [6.2993e-03],
         [1.0062e-01],
         [1.1395e-04],
         [6.8084e-03],
         [5.2486e-03],
         [7.4083e-04]]),
 torch.Size([4, 20, 1]))

In [440]:
ntm = NeuralTuringMachine(1, 10, 20, 20, 10, 1, 1, 1, new_dot_prod)

MemoryBank()


In [441]:
ra = torch.arange(1,21)
ra = ra.unsqueeze(0).repeat(4,1)
fac = torch.ones(4)*0.5
fac.unsqueeze(1) * ra

tensor([[ 0.5000,  1.0000,  1.5000,  2.0000,  2.5000,  3.0000,  3.5000,  4.0000,
          4.5000,  5.0000,  5.5000,  6.0000,  6.5000,  7.0000,  7.5000,  8.0000,
          8.5000,  9.0000,  9.5000, 10.0000],
        [ 0.5000,  1.0000,  1.5000,  2.0000,  2.5000,  3.0000,  3.5000,  4.0000,
          4.5000,  5.0000,  5.5000,  6.0000,  6.5000,  7.0000,  7.5000,  8.0000,
          8.5000,  9.0000,  9.5000, 10.0000],
        [ 0.5000,  1.0000,  1.5000,  2.0000,  2.5000,  3.0000,  3.5000,  4.0000,
          4.5000,  5.0000,  5.5000,  6.0000,  6.5000,  7.0000,  7.5000,  8.0000,
          8.5000,  9.0000,  9.5000, 10.0000],
        [ 0.5000,  1.0000,  1.5000,  2.0000,  2.5000,  3.0000,  3.5000,  4.0000,
          4.5000,  5.0000,  5.5000,  6.0000,  6.5000,  7.0000,  7.5000,  8.0000,
          8.5000,  9.0000,  9.5000, 10.0000]])

In [442]:
train(ntm , cop_data, None)

x: tensor([[ 0.,  1.,  0.,  0.,  0.,  1.,  1.,  0., -1.],
        [ 0.,  1.,  0.,  1.,  1.,  1.,  0.,  1., -1.],
        [ 0.,  0.,  0.,  1.,  1.,  0.,  0.,  0., -1.],
        [ 0.,  0.,  1.,  0.,  0.,  0.,  1.,  1., -1.]]), shape: torch.Size([4, 9])
dim h: torch.Size([4, 20]), dim c: torch.Size([4, 20])
x size: torch.Size([4, 1]),	 h size: torch.Size([4, 20]),	 c size: torch.Size([4, 20])
torch.Size([4, 20]) torch.Size([4])
weight shape: torch.Size([4, 20])
read vec shape: torch.Size([4, 10])


  return torch.argmax(torch.nn.functional.softmax(self.layer_1(x)))


TypeError: Head.get_weight() missing 5 required positional arguments: 'k', 'ks', 'ig', 'sw', and 'sf'