# Hyperdimension

In [612]:
import torchhd as hd
import torch

In [613]:
d = 10_000

In [614]:
person = {
    'name': 'Bedi',
    'sex': 'Male',
    'date of birth': '24th december',
}

In [615]:
vocab = {}

for key, value in person.items():
    vocab[key] = hd.random(1, d)
    vocab[value] = hd.random(1, d)

In [616]:
vocab

{'name': MAPTensor([[ 1.,  1., -1.,  ...,  1., -1., -1.]]),
 'Bedi': MAPTensor([[ 1., -1., -1.,  ..., -1.,  1.,  1.]]),
 'sex': MAPTensor([[-1.,  1.,  1.,  ..., -1., -1., -1.]]),
 'Male': MAPTensor([[-1., -1., -1.,  ...,  1.,  1., -1.]]),
 'date of birth': MAPTensor([[ 1.,  1., -1.,  ...,  1., -1.,  1.]]),
 '24th december': MAPTensor([[-1., -1.,  1.,  ..., -1., -1.,  1.]])}

In [631]:
memory = hd.multibundle(torch.concat([
    hd.bind(vocab['name'], vocab['Bedi']),
    hd.bind(vocab['sex'], vocab['Male']),
    hd.bind(vocab['date of birth'], vocab['24th december'])
]))

In [632]:
retrieved_name = hd.bind(memory, hd.inverse(vocab['name']))
hd.cosine_similarity(retrieved_name, vocab['Bedi'])

MAPTensor([[0.5783]])

# Infini-attention

In [619]:
import torch
import torch.nn.functional as F

In [620]:
dkey = 64
dvalue = 64
N = 10

def sigma(x):
    # sigma(x) = ELU(x) + 1
    return F.elu(x) + 1

Ms_prev = torch.randn(dkey, dvalue) * 0.01
zs_prev = torch.randn(dkey) * 0.01

In [621]:
Q = torch.randn(N, dkey)
K = torch.randn(N, dkey)
V = torch.randn(N, dvalue)

In [622]:
def read(Ms, zs, Q):
    sigma_Q = sigma(Q)
    
    numerator = sigma_Q @ Ms
    denominator = sigma_Q @ zs
    denominator = denominator.unsqueeze(-1) + 1e-5
    
    return numerator / denominator


def save(Ms, zs, K, V):
    sigma_K = sigma(K)
    
    delta_Ms = sigma_K.transpose(-2, -1) @ V
    Ms_new = Ms + delta_Ms
    
    delta_zs = torch.sum(sigma_K, dim=0)
    zs_new = zs + delta_zs
    
    return Ms_new, zs_new

In [623]:
MS_prev, zs_prev = save(Ms_prev, zs_prev, K, V)

In [624]:
Amem = read(Ms_prev, zs_prev, Q)
print(Amem[0, :5])

tensor([ 2.1462e-04,  5.8591e-05,  1.3485e-04, -7.8125e-05, -9.7636e-05])


In [625]:
for _ in range(100_000):
    new_K = torch.randn(N, dkey)
    new_V = torch.randn(N, dvalue)
    MS_prev, zs_prev = save(Ms_prev, zs_prev, new_K, new_V)

In [626]:
Amem_new = read(Ms_prev, zs_prev, Q)
print(Amem[0, :5])

tensor([ 2.1462e-04,  5.8591e-05,  1.3485e-04, -7.8125e-05, -9.7636e-05])


In [627]:
F.cosine_similarity(Amem, Amem_new)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000])