In [1]:

import torch

In [6]:
t = torch.Tensor([1,2,2,3])
t

tensor([1., 2., 2., 3.])

In [None]:
nnc = None 

membank = None

In [12]:
n_loc: int = 10  # N
dim_loc: int = 20  # M 

memmat = torch.rand(size=(n_loc, dim_loc))
memmat.size()

torch.Size([10, 20])

In [13]:
# read_head -> w_t vec weights with N entries 
# read_vector r_t = convex combination of mem elements with weights = w_t

In [37]:
w = torch.arange(0.1,.6,0.1)
w = w[:, None]
print(w.size())
memmat = torch.randint(10, (5,2))
print(memmat)
print(w)
print(w * memmat)
torch.sum(w * memmat, dim=0)




torch.Size([5, 1])
tensor([[0, 0],
        [6, 8],
        [6, 5],
        [9, 2],
        [7, 5]])
tensor([[0.1000],
        [0.2000],
        [0.3000],
        [0.4000],
        [0.5000]])
tensor([[0.0000, 0.0000],
        [1.2000, 1.6000],
        [1.8000, 1.5000],
        [3.6000, 0.8000],
        [3.5000, 2.5000]])


tensor([10.1000,  6.4000])

In [38]:
class ReadHead(torch.nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    
    def forward(self, mem_mat: torch.Tensor):
        """Emit read vector r_t"""
        
        # product weights w_t
        # emit reaad vector r_t
        weights = torch.rand(mem_mat.size()[0])
        weights = weights[:, None]
        return torch.sum(weights * mem_mat, dim=0)

In [39]:
rh = ReadHead()

In [41]:
memmat

tensor([[0, 0],
        [6, 8],
        [6, 5],
        [9, 2],
        [7, 5]])

In [44]:

rh.forward(memmat).gradient()

AttributeError: 'Tensor' object has no attribute 'gradient'

In [None]:
class WriteHead(torch.nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    def forward(self, memmat: torch.Tensor):
        """Write head adjusts memory matrix."""

        weights = torch.rand(memmat.size()[0])
        v_erase =  torch.rand(memmat.size()[1])
        v_add =  torch.rand(memmat.size()[1])
        weights = weights[:, None]
        ones = torch.ones(memmat.size()[1])
        for i in range(len(memmat.size()[0])):
            memmat[i, :] = memmat[i, :] * (ones - weights[i] * v_erase)
            memmat[i, :] = memmat[i, :] * (ones - weights[i] * v_add)

        return None




In [52]:
e_mat = e.repeat(4,1)
e_mat


tensor([[0.2607, 0.6224, 0.4122],
        [0.2607, 0.6224, 0.4122],
        [0.2607, 0.6224, 0.4122],
        [0.2607, 0.6224, 0.4122]])

In [None]:
# free params controller
SIZE_OF_MEM = 10
NUM_READ_HEADS = 5
NUM_WRITE_HEADS = 5
LOC_SHIFT_RANGE = list(range(1,6))


In [80]:
from typing import Optional
class MemoryBank(torch.nn.Module):
    def __init__(self, num_vectors: int, vec_dim: int):
        super(MemoryBank, self).__init__()
        self.num_vec = num_vectors
        self.vec_dim = vec_dim
        self.batch_size: Optional[int] = None
        self.data: torch.Tensor | None = None

    def init_state(self, batch_size, device):
        self.batch_size = batch_size
        self.data = torch.zeros(batch_size, self.num_vec, self.vec_dim).to(device)

    def update(self, weight: torch.Tensor, erase_vec: torch.Tensor, add_vec: torch.Tensor):
        # make sure that batch_dim of tensor is indeed self.batch_dim
        # TODO check if dims are ok and batch dim is considered correctly
        erase_row_stack = erase_vec.repeat(erase_vec.shape[0], 1, self.num_vec).reshape(erase_vec.shape[0], erase_vec.shape[1], -1)
        erase_row_stack *= weight
        self.data -= erase_row_stack
        self.data += weight*add_vec.repeat(add_vec.shape[0], add_vec.shape[1], self.num_vec).reshape(add_vec.shape[0], add_vec.shape[1], -1)
        

In [78]:
t = torch.ones([2,3])
t.repeat(1, 3, 1).reshape(t.shape[0], t.shape[1], -1)
# torch.ones(size=(3,2))
# t, t.shape


tensor([[[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]]])

In [25]:
class NeuralNetController(torch.nn.Module):
    """Some kind of Recurrent net or feedforward net."""
    # typically LSTM
    def __init__(self, h_size=20, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.hidden_size = h_size
        self.lstm_cell = torch.nn.LSTM(input_size=1, hidden_size=h_size, batch_first=True)
        self.memory_bank: MemoryBank = MemoryBank(10, h_size)

    
    def forward(self):
        # outputs k_t, beta_t, g_t, s_t, \gamma_t
        return


In [26]:
class Combinator(torch.nn.Module):
    """Combine output of controller and memory access to make final prediction."""

    def __init__(self, in_dim: int, out_dim: int=10):
        """Default out dim for copy task: 0 through 9."""
        super.__init__()
        self.input_dim: int = in_dim
        layer_1: torch.nn.Linear = torch.nn.Linear(in_dim, out_dim)
    def forward(self, x: torch.Tensor):
        return torch.argmax(torch.nn.Softmax(self.layer_1(x)))



In [38]:
from abc import abstractmethod
from typing import Tuple, Callable
from dataclasses import dataclass

class Head(torch.nn.Module):
    """produce key, key_strength, interpolation_gate, shift_weighting, sharpening_factor.

    key in mathbb{R}^{mem_size}
    key_strength in mathbb{R}_{+} how sharp the attention to memory locations should be
    interpolation_gate in (0,1) how much of last weight should be retained
    shift_weighting in mathbb{R}^{num_mem_locations} prob distribution over the num locations
    sharpening_factor in [1, infty)
    """

    def __init__(self, in_dim: int,
                 membank: MemoryBank,
                 sim_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
                ):

        self.membank = membank
        self.num_mem_locations = self.membank.num_vec
        self.sim_func = sim_func
        self.mem_size = self.membank.vec_dim
        self.current_weight: torch.Tensor = torch.ones(self.num_mem_locations) / self.num_mem_locations # size = num_mem_locations

    @abstractmethod
    def get_params(self):
        ...

    @abstractmethod
    def forward(self, x: torch.Tensor):
        ...

    def get_weight(
        self,
        k: torch.Tensor,
        ks: torch.Tensor,
        ig: torch.Tensor,
        sw: torch.Tensor,
        sf: torch.Tensor
        ):
        loc_similarity: torch.Tensor = self.sim_func(k, self.membank.data)
        loc_weight: torch.Tensor = torch.nn.functional.softmax(ks * loc_similarity, dim=-1)
        gated_weighting: torch.Tensor = (1 - ig) * self.current_weight + ig * loc_weight
        compound_weight: torch.Tensor = torch.zeros(self.num_mem_locations)

        # TODO implement convolution without for loops
        for i in range(self.num_mem_locations):
            for j in range(self.num_mem_locations):
                compound_weight[i] += gated_weighting[j] * sw[(i-j)%self.num_mem_locations] 

        sharpened_weight: torch.Tensor = torch.pow(compound_weight, sf)
        weight = sharpened_weight / torch.sum(sharpened_weight)
        self.current_weight = weight
        return weight


class ReadHead(Head):
    def __init__(self, in_dim, membank, sim_func):
        super().__init__(in_dim, membank, sim_func)
        self.out_size = self.mem_size + 1 + 1 + self.num_mem_locations + 1
        self.layer_1 = torch.nn.Linear(in_features=in_dim, out_features=self.out_size)

    def get_params(self, h: torch.Tensor) -> Tuple[torch.Tensor, ...]:
        out: torch.Tensor = self.layer_1(h)
        read_key: torch.Tensor = out[:, 0 : self.mem_size]
        key_strenght: torch.Tensor = torch.exp(out[:, self.mem_size])
        interpolation_gate: torch.Tensor = torch.nn.Sigmoid(out[:, self.mem_size + 1])
        shift_weighting: torch.Tensor = torch.softmax(out[:, (self.mem_size + 2): -1], dim=-1)
        sharp_factor: torch.Tensor = 1 + torch.exp(out[:, -1])
        return (read_key, key_strenght, interpolation_gate, shift_weighting, sharp_factor)

    
    def forward(self, h: torch.Tensor):
        q_key, key_strength, interpolation_gate, shift_weighting, sharp_factor = self.get_params(h)
        weight: torch.Tensor = self.get_weight(q_key, key_strength, interpolation_gate, shift_weighting, sharp_factor)
        read_vec: torch.Tensor = torch.sum(weight * self.membank.data, dim=0)
        return read_vec


class WriteHead(Head):
    def __init__(self, in_dim, membank, sim_func):
        super().__init__(in_dim, membank, sim_func)
        self.out_size: int = self.mem_size * 3 + 1 + 1 + self.num_mem_locations + 1
        self.layer_1 = torch.nn.Linear(in_features=in_dim, out_features=self.out_size)

    def forward(self, h: torch.Tensor):
        params: Tuple[torch.Tensor, ...] = self.get_params(h)
        erase_vec: torch.Tensor = params[0]
        add_vec: torch.Tensor = params[1]
        q_key: torch.Tensor = params[2]
        key_strength: torch.Tensor = params[3]
        interpolation_gate: torch.Tensor  = params[4]
        shift_weighting: torch.Tensor  = params[5]
        sharp_factor: torch.Tensor = params[6]
        weight: torch.Tensor = self.get_weight(q_key, key_strength, interpolation_gate, shift_weighting, sharp_factor)
        self.membank.update(weight, erase_vec, add_vec)
        return None

    def get_params(self, h: torch.Tensor) -> Tuple[torch.Tensor, ...]:
        out: torch.Tensor = self.layer_1(h)
        erase_vec: torch.Tensor = torch.nn.functional.sigmoid(out[:, :self.mem_size])
        add_vec: torch.Tensor = out[:, self.mem_size: 2*self.mem_size]
        qkey: torch.Tensor = out[:, self.mem_size: 3*self.mem_size]
        key_strength: torch.Tensor = torch.exp(out[:, 3*self.mem_size])
        interpolation_gate: torch.Tensor = torch.nn.Sigmoid(out[:, 3*self.mem_size + 1])
        shift_weighting: torch.Tensor = torch.softmax(out[:, (3*self.mem_size + 2): -1], dim=-1)
        sharp_factor: torch.Tensor = 1 + torch.exp(out[:, -1])
        return (erase_vec, add_vec, qkey, key_strength, interpolation_gate, shift_weighting, sharp_factor)


    

In [54]:
# dataset
copy_dataset = None
sort_dataset = None

In [73]:
# copy dataset 
copy_vecs = torch.randint(2, (1000, 8))
copy_vecs[1, :]


tensor([1, 1, 1, 0, 1, 1, 0, 0])

In [77]:
from torch.utils.data import Dataset
class CopyDataset(Dataset):
    def __init__(self, len, delim: int=-1):
        super().__init__()
        self.delim = delim
        self.data = torch.randint(2, (len, 8))
        self.data = torch.column_stack([self.data, torch.ones(self.data.size()[0]) * self.delim])
    
    def __getitem__(self, idx: int):
        return (self.data[idx, :], self.data[idx, :-1])

In [89]:
cop_data = CopyDataset(10)
cop_data[8]

(tensor([ 0.,  0.,  0.,  0.,  1.,  1.,  1.,  0., -1.]),
 tensor([0., 0., 0., 0., 1., 1., 1., 0.]))

- for each head, be it a read or write head, the addressing mechanism is implemented
=> for each head, the controller needs to produce
  + key vector $k_t$
  + key strength $\beta_t$
  + interpolation gate $g_t$
  + shift weighting $s_t$
  + sharpening factor $\gamma_t$

In [10]:
x = torch.randint(3, (5,3))
y = torch.randint(3, (3,))
x, y, torch.cat([x] + y, dim=1)

TypeError: can only concatenate list (not "Tensor") to list

In [16]:

class A:
    def __init__(self, a,b,c):
        self.a = a
        self.b = b
        self.c = c
        self.ls = [a,b,c]

class B:
    def __init__(self, candA: A):
        self.candA = candA
    
    def print(self):
        print(self.candA.ls)

    def mod(self, v):
        self.candA.a = v
        self.candA.ls.append(self.candA.a)


In [17]:
a = A(1,2,3)
b = B(a)
c = B(a)

b.print()
c.print()

b.mod(234)
b.print()
c.print()
c.mod(888)
b.print()
c.print()

[1, 2, 3]
[1, 2, 3]
[1, 2, 3, 234]
[1, 2, 3, 234]
[1, 2, 3, 234, 888]
[1, 2, 3, 234, 888]


In [41]:
t=torch.rand(size=(10,5)) * 10
t, torch.nn.functional.softmax(t, dim=-1).sum(dim=-1)

(tensor([[6.1253, 8.4934, 8.1074, 7.6174, 6.4826],
         [1.5601, 2.4506, 4.0441, 3.1744, 2.9900],
         [6.9845, 0.4665, 0.5360, 4.0491, 7.3489],
         [2.4286, 5.8532, 2.7234, 1.9773, 5.1179],
         [0.7154, 6.6040, 1.2617, 8.1936, 1.4594],
         [9.3102, 0.7271, 2.4330, 6.2105, 4.5922],
         [4.8239, 6.0186, 2.6772, 4.6056, 7.0927],
         [9.7810, 4.5565, 0.1035, 1.3065, 8.4828],
         [4.5577, 1.2025, 5.8060, 4.7937, 5.3212],
         [2.6435, 1.2668, 8.5653, 2.6699, 6.4882]]),
 tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000]))

In [48]:
-2 % 10

torch.Tensor([2]) * 2

tensor([4.])