In [25]:
import torch as t
import torch.nn as nn

info_dim = 64
op_dim = 32
embed_dim = op_dim + info_dim

batch_size = 30

long_term_memory_size = 200
short_term_memory_size = 150
sense_size = 120

Operation Indices

In [38]:
num_ops = 7
op_ids = t.rand((num_ops, op_dim))

# Cmd Module

### This computes the row salience for the cmd module

In [30]:
cmd_row_len = 10
crl = cmd_row_len
num_cmd_rows = 20

cmd = t.rand((batch_size, num_cmd_rows, cmd_row_len, embed_dim))

# Summed salience for the entire cmd (batched)
cmd_sal = cmd.sum(dim=(-2, -3))
row_sum = t.einsum("ni,nabi->nab", cmd_sal, cmd)
row_mask = row_sum.softmax(axis=-1)
row_sal = t.einsum("nijk,nij->nik", cmd, row_mask)

# Data Transfer

### This transfers data from the Cmd Module to each submodule to be processed

In [40]:
# Get the op code salience for each cmd row
op_sal = row_sal[:, :, :op_dim] 

# Create the mask mapping data to modules
mod_sum = t.einsum("nij,kj->nik", op_sal, op_ids)
row_choose_op = mod_sum.softmax(axis=-1)
op_choose_row = (1000 * row_choose_op).softmax(axis=-2)
op_mask = row_choose_op * op_choose_row

op_data = t.einsum("nijk,nil->nljk", cmd, op_mask)
op_row_sal = t.einsum("nij,nik->nkj", row_sal, op_mask)

In [43]:
# Split op data and op row sal into the appropriate inputs for each module
def split_squeeze(t):
    return map(lambda x : x.squeeze(), t.split(1, dim=1))

In [53]:
ltm_sal_read_data, \
ltm_121_read_data, \
stm_write_data, \
stm_sal_read_data, \
std_121_read_data, \
sense_read_data, \
sense_write_data = split_squeeze(op_data)

ltm_sal_read_sal, \
ltm_121_read_sal, \
stm_write_sal, \
stm_sal_read_sal, \
std_121_read_sal, \
sense_read_sal, \
sense_write_sal = split_squeeze(op_row_sal)

### This transfers data from the result of the modules back to Cmd 

In [8]:
op_result = t.rand(op_data.shape)

In [9]:
def scale_cmd(cmd, scale):
    return t.einsum("nijk,ni->nijk", cmd, scale)

# Determine how much the module result should 
# contribute to the new value of cmd
op_mix = op_mask.sum(axis=-1)

# Compute the new mod from the mod result
new_op_cmd = t.einsum("nijk,nli->nljk", op_result, op_mask)

# Mix the new_cmd and the old_cmd
new_cmd = scale_cmd(cmd, 1 - op_mix) + scale_cmd(new_op_cmd, op_mix)

# Long-Term Memory / Instincts

### Creating LTM Structures and Helper Tensors

In [10]:
ltm_keys = t.rand((long_term_memory_size, embed_dim))
ltm_vals = t.rand((long_term_memory_size, embed_dim))

In [11]:
split_size = info_dim // cmd_row_len
sal_mask = t.zeros((cmd_row_len, embed_dim))

for i in range(cmd_row_len):
    start = op_dim + (i * split_size)
    sal_mask[i, start:start + split_size] = 1

### Salient Graph-Based Read

In [12]:
ld = ltm_sal_read_data
ls = ltm_sal_read_sal

sal_keys = t.einsum("ni,ji->nji", ls, sal_mask)
ltm_sal = t.einsum("nij,lj->nil", sal_keys, ltm_keys).mul(50).softmax(axis=-1)

# This is the output
ltm_sal_out = t.einsum("nij,jk->nik", ltm_sal, ltm_vals)

In [13]:
ld.shape

torch.Size([30, 10, 96])

### One-to-One Graph-Based Read

In [14]:
lod = ltm_121_read_data
los = ltm_121_read_sal

ltm_121_sal = t.einsum("nij,lj->nil", lod, ltm_keys).mul(5).softmax(axis=-1)
ltm_121_out = t.einsum("nij,jk->nik", ltm_121_sal, ltm_vals)

# Short-Term Memory

### Creating STM Structures 

In [26]:
# These are fixed or learned
stm_locs = t.rand((short_term_memory_size, embed_dim))

# These are transient, namely they are set while running the processor
stm_keys = t.rand((batch_size, short_term_memory_size, embed_dim))
stm_vals = t.rand((batch_size, short_term_memory_size, embed_dim))

In [17]:
# These masks are useful for writing key value pairs
num_lkv_triples = crl // 3
nlt = num_lkv_triples

l_mask = t.zeros((nlt, crl))
k_mask = t.zeros((nlt, crl))
v_mask = t.zeros((nlt, crl))
            
for i in range(num_lkv_triples * 3):
    i3 = i % 3
    n3 = i // 3
    
    if i3 == 0:
        l_mask[n3, i] = 1
    elif i3 == 1:
        k_mask[n3, i] = 1
    elif i3 == 2:
        v_mask[n3, i] = 1

### Write Key Value Pairs To Locations

In [18]:
swd = stm_write_data
sws = stm_write_sal

e_sum = "nik,ji->njk"

locs = t.einsum(e_sum, swd, l_mask)
keys = t.einsum(e_sum, swd, k_mask)
vals = t.einsum(e_sum, swd, v_mask)

stm_sal = t.einsum("njk,lk->njl", locs, stm_locs).mul(20).softmax(axis=-1)
mask = t.einsum("nik->nk", stm_sal)

e_mask = "njk,njl->nkl"
new_keys = t.einsum(e_mask, stm_sal, keys)
new_vals = t.einsum(e_mask, stm_sal, vals)

def s_mask(mask, vals):
    return t.einsum("ni,nij->nij", mask, vals)

stm_keys = s_mask(mask, new_keys) + s_mask((1 - mask), stm_keys)
stm_vals = s_mask(mask, new_vals) + s_mask((1 - mask), stm_vals)

### Short-Term Memory Salient Graph-Based Read

In [19]:
sgd = stm_sal_read_data
sgs = stm_sal_read_sal

st_sal_keys = t.einsum("ni,ji->nji", sgs, sal_mask)
stm_sal = t.einsum(
    "nij,nlj->nil", sal_keys, stm_keys).mul(
    50).softmax(axis=-1)

stm_sal_out = t.einsum("nij,njk->nik", stm_sal, stm_vals)

### Short-Term Memory One-to-One Graph based read

In [46]:
sod = std_121_read_data
sos = std_121_read_sal


stm_121_sal = t.einsum("nij,nlj->nil", sod, stm_keys).mul(5).softmax(axis=-1)
stm_121_out = t.einsum("nij,njk->nik", stm_121_sal, stm_vals)


# Senses

### Initialize Sense Locations

In [28]:
# These are fixed or learned
sense_locs = t.rand((sense_size, embed_dim))

# These are transient, namely they are
# set while running the processor
sense_vals = t.rand((batch_size, sense_size, embed_dim))

In [51]:
num_sense_doubles = crl // 2
nsd = num_sense_doubles

sense_k_mask = t.zeros((nsd, crl))
sense_v_mask = t.zeros((nsd, crl))
            
for i in range(nsd * 2):
    i2 = i % 2
    n2 = i // 2
    
    if i2 == 0:
        sense_k_mask[n2, i] = 1
    elif i3 == 1:
        sense_v_mask[n2, i] = 1

### Read Data

In [47]:
srd = sense_read_data
srs = sense_read_sal

sod = std_121_read_data
sos = std_121_read_sal


sense_read_sal = t.einsum("nij,nlj->nil", srd, stm_keys).mul(5).softmax(axis=-1)
sense_read_out = t.einsum("nij,njk->nik", stm_121_sal, stm_vals)


### Write Data

In [54]:
sewd = sense_write_data
sews = sense_write_sal

e_sum = "nik,ji->njk"

keys = t.einsum(e_sum, sewd, sense_k_mask)
vals = t.einsum(e_sum, sewd, sense_v_mask)

sense_sal = t.einsum("njk,lk->njl", keys, sense_locs).mul(20).softmax(axis=-1)
mask = t.einsum("nik->nk", sense_sal)

e_mask = "njk,njl->nkl"
new_vals = t.einsum(e_mask, sense_sal, vals)

sense_vals = s_mask(mask, new_vals) + s_mask((1 - mask), sense_vals)