In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [51]:
class QueryNetwork(nn.Module):
    def __init__(self, dim_in, dim_hidden, num_heads, batch_norm = True):
        super(QueryNetwork, self).__init__()
        self.dim_in = dim_in
        self.dim_hidden = dim_hidden
        self.num_heads = num_heads
        self.batch_norm = batch_norm
        self.Wq = nn.Linear(dim_in, dim_hidden * num_heads, bias = False)
        self.bn = nn.BatchNorm1d(dim_hidden * num_heads)
    def forward(self, x):
        # for memory of size less than 384*384 the paper says that batch norm gives no significant improvement
        # when using batch norm, padding tokens in the sequence can skew mean and variance estimate
        query = self.Wq(x)
        
        if self.batch_norm:
            original_shape = query.shape
            query = query.reshape(-1, self.dim_hidden * self.num_heads)
            normalized = self.bn(query)
            normalized = normalized.reshape(original_shape)
            return normalized.reshape(*normalized.shape[:-1], self.num_heads, self.dim_hidden)
        else:
            return query.reshape(*query.shape[:-1], self.num_heads, self.dim_hidden) # batch size x context length x num heads x query dim

In [18]:
qn = QueryNetwork(7,5, 4)
qn(torch.rand(100,10,7)).shape

torch.Size([100, 10, 4, 5])

In [73]:
class ProductKey(nn.Module):
    def __init__(self, dim, num_subkeys, top_k, num_heads):
        super(ProductKey, self).__init__()
        assert dim % 2 == 0, "key must be able to be split into 2"
        self.dim = dim
        self.subkey_size = dim // 2
        self.top_k = top_k
        self.num_subkeys = num_subkeys
        
        keyl = torch.empty(num_heads, num_subkeys, self.subkey_size)
        keyr = torch.empty(num_heads, num_subkeys, self.subkey_size)

        std = 1/math.sqrt(dim)
        keyl.uniform_(-std, std)
        keyr.uniform_(-std, std)

        self.keyl = nn.Parameter(keyl)
        self.keyr = nn.Parameter(keyr)

    def forward(self, query):
        # multihead query
        batch, context_length, num_heads, query_size = query.shape
        
        queryl = query[..., :self.subkey_size]
        queryr = query[..., self.subkey_size:]

        scorel = torch.einsum('bcnq,nkq->bcnk', queryl, self.keyl) # batch size x context length x num head x subquery length , num heads x num keys x subquery length
        scorer = torch.einsum('bcnq,nkq->bcnk', queryr, self.keyr)

        top_keys_l, top_idx_l = scorel.topk(self.top_k) # batch, context, heads, top k
        top_keys_r, top_idx_r = scorer.topk(self.top_k)

        #duplicate along the rows
        product_scores_l = top_keys_l.reshape(*top_keys_l.shape[:-1], top_keys_l.shape[-1], 1).expand(*top_keys_l.shape[:-1], top_keys_l.shape[-1], top_keys_l.shape[-1])
        # duplicate along the columns
        product_scores_r = top_keys_r.reshape(*top_keys_r.shape[:-1], 1, top_keys_r.shape[-1]).expand(*top_keys_r.shape[:-1], top_keys_r.shape[-1], top_keys_r.shape[-1])

        product_scores = (product_scores_l + product_scores_r) # batch, context, heads, top k, top k
        product_scores = product_scores.reshape(batch, context_length, num_heads, self.top_k * self.top_k)

        product_indices_l = top_idx_l.reshape(*top_idx_l.shape[:-1], top_idx_l.shape[-1], 1).expand(*top_idx_l.shape[:-1], top_idx_l.shape[-1], top_idx_l.shape[-1])
        product_indices_r = top_idx_r.reshape(*top_idx_r.shape[:-1], top_idx_r.shape[-1], 1).expand(*top_idx_r.shape[:-1], top_idx_r.shape[-1], top_idx_r.shape[-1]) 

        product_indices = product_indices_l * self.num_subkeys + product_indices_r
        product_indices = product_indices.reshape(batch, context_length, num_heads, self.top_k * self.top_k)

        top_product_scores, top_product_indices = product_scores.topk(self.top_k)
        selected_value_weights = F.softmax(top_product_scores, dim=-1)
        selected_value_indices = torch.gather(product_indices, -1, top_product_indices)
        #print(top_product_scores)
        #print(torch.gather(product_scores, -1, top_product_indices)) # they should be equal
        return selected_value_weights, selected_value_indices
        
        

In [74]:
block = torch.rand(1,2,3,4)
print(block)
print(block[...,:2])
print(block[...,2:])

tensor([[[[0.2253, 0.5491, 0.1194, 0.7783],
          [0.2550, 0.9119, 0.3249, 0.9291],
          [0.4815, 0.4903, 0.6554, 0.1917]],

         [[0.5868, 0.9642, 0.2588, 0.4964],
          [0.3499, 0.1581, 0.0349, 0.5953],
          [0.7301, 0.7644, 0.8572, 0.3248]]]])
tensor([[[[0.2253, 0.5491],
          [0.2550, 0.9119],
          [0.4815, 0.4903]],

         [[0.5868, 0.9642],
          [0.3499, 0.1581],
          [0.7301, 0.7644]]]])
tensor([[[[0.1194, 0.7783],
          [0.3249, 0.9291],
          [0.6554, 0.1917]],

         [[0.2588, 0.4964],
          [0.0349, 0.5953],
          [0.8572, 0.3248]]]])


In [29]:
a = torch.rand(2,4,1)
print(a)
print(a.expand(2,4,4))
b = a.reshape(2,1,4)
print(b.expand(2,4,4))

tensor([[[0.3561],
         [0.9040],
         [0.2236],
         [0.9060]],

        [[0.8523],
         [0.1579],
         [0.4793],
         [0.3675]]])
tensor([[[0.3561, 0.3561, 0.3561, 0.3561],
         [0.9040, 0.9040, 0.9040, 0.9040],
         [0.2236, 0.2236, 0.2236, 0.2236],
         [0.9060, 0.9060, 0.9060, 0.9060]],

        [[0.8523, 0.8523, 0.8523, 0.8523],
         [0.1579, 0.1579, 0.1579, 0.1579],
         [0.4793, 0.4793, 0.4793, 0.4793],
         [0.3675, 0.3675, 0.3675, 0.3675]]])
tensor([[[0.3561, 0.9040, 0.2236, 0.9060],
         [0.3561, 0.9040, 0.2236, 0.9060],
         [0.3561, 0.9040, 0.2236, 0.9060],
         [0.3561, 0.9040, 0.2236, 0.9060]],

        [[0.8523, 0.1579, 0.4793, 0.3675],
         [0.8523, 0.1579, 0.4793, 0.3675],
         [0.8523, 0.1579, 0.4793, 0.3675],
         [0.8523, 0.1579, 0.4793, 0.3675]]])


In [59]:
qn = QueryNetwork(7,6, 4)
query = qn(torch.rand(10,2,7))
#print(query.shape)
key = ProductKey(6, 3, 2, 4)
key(query)

tensor([[[[ 0.9406,  0.8643],
          [ 0.1486,  0.1485],
          [ 1.3653,  1.2828],
          [ 0.5725,  0.5647]],

         [[ 1.1675,  0.9089],
          [ 0.0824, -0.0272],
          [ 0.3099,  0.1507],
          [ 0.2130,  0.1989]]],


        [[[ 0.4970,  0.3907],
          [ 0.3884,  0.3348],
          [-0.1141, -0.1393],
          [ 0.4365,  0.1453]],

         [[ 0.3047,  0.1988],
          [ 0.7850,  0.7605],
          [ 0.9413,  0.8459],
          [ 0.4884,  0.4194]]],


        [[[ 0.8606,  0.8399],
          [ 0.3752,  0.2896],
          [ 2.5514,  2.4396],
          [ 0.2540,  0.1578]],

         [[ 1.0256,  0.6873],
          [ 1.2209,  0.9314],
          [ 1.2635,  1.1965],
          [ 0.4429,  0.2562]]],


        [[[ 0.4367,  0.3870],
          [ 0.8675,  0.8348],
          [ 1.7378,  1.5173],
          [ 0.5041,  0.4577]],

         [[ 0.6319,  0.4495],
          [ 0.1568,  0.1509],
          [-0.0641, -0.1277],
          [ 0.6859,  0.2412]]],


        [[[ 0.01

In [97]:
class PKM(nn.Module):
    def __init__(self, dim_in, dim_hidden, num_subkeys, top_k, num_heads, batch_norm = True):
        super(PKM, self).__init__()
        self.dim_in, self.dim_hidden, self.num_subkeys, self.top_k, self.num_heads = dim_in, dim_hidden, num_subkeys, top_k, num_heads
        self.query_network = QueryNetwork(dim_in, dim_hidden, num_heads, batch_norm)
        self.key_table = ProductKey(dim_hidden, num_subkeys, top_k, num_heads)
        self.value_table = nn.Embedding(num_subkeys * num_subkeys, dim_in)
    def forward(self, x):
        queries = self.query_network(x)
        weights, indices = self.key_table(queries) # shape is batch, context length, num heads, top k
        original_shape = weights.shape
        
        weights, indices = weights.reshape(-1, self.top_k), indices.reshape(-1, self.top_k)
        values = self.value_table(indices)
        weights = weights.reshape(original_shape)
        values = values.reshape(*original_shape, self.dim_in)

        weighted_values = torch.einsum('bcnk,bcnkd->bcd', weights, values) # take linear combination of weights & values and sum over all heads

        return weighted_values
        
        
        
        

# Validation

## Query

In [55]:
query_network = QueryNetwork(6,4,2,False)
dumb_query_network = [nn.Linear(10, 5, bias=False), nn.Linear(10, 5, bias=False)]
dumb_query_network[0].weight.data = query_network.Wq.weight.data[:4,...]
dumb_query_network[1].weight.data = query_network.Wq.weight.data[4:,...]
data = torch.rand(1,1,6)
query_out = query_network(data)
head_0_out = dumb_query_network[0](data)
head_1_out = dumb_query_network[1](data)

assert torch.equal(query_out[...,0,:], head_0_out)
assert torch.equal(query_out[...,1,:], head_1_out)

## Product key

In [68]:
class DumbProductKey(nn.Module):
    def __init__(self, dim, num_subkeys, top_k):
        super(DumbProductKey, self).__init__()
        self.dim = dim
        self.subkey_size = dim // 2
        self.num_subkeys = num_subkeys
        self.top_k = top_k
        self.keyl = nn.Linear(self.subkey_size, num_subkeys, bias = False)
        self.keyr = nn.Linear(self.subkey_size, num_subkeys, bias = False)
    def forward(self, x):
        batch_size, context_length, query_size = x.shape
        query_l = x[..., :self.subkey_size]
        query_r = x[..., self.subkey_size:]
        top_left_keys, top_left_indices = self.keyl(query_l).topk(self.top_k)
        top_right_keys, top_right_indices = self.keyr(query_r).topk(self.top_k)
        product_keys_left = top_left_keys.reshape(batch_size, context_length, self.top_k, 1).expand(batch_size, context_length, self.top_k, self.top_k)
        product_keys_right = top_right_keys.reshape(batch_size, context_length, 1, self.top_k).expand(batch_size, context_length, self.top_k, self.top_k)
        product_keys = product_keys_left + product_keys_right
        product_indices_left = top_left_indices.reshape(batch_size, context_length, self.top_k, 1).expand(batch_size, context_length, self.top_k, self.top_k)
        product_keys_right = top_right_indices.reshape(batch_size, context_length, self.top_k, 1).expand(batch_size, context_length, self.top_k, self.top_k)
        product_indices = product_indices_left * self.num_subkeys + product_keys_right
        product_keys = product_keys.reshape(batch_size, context_length, self.top_k * self.top_k)
        product_indices = product_indices.reshape(batch_size, context_length, self.top_k * self.top_k)
        top_keys, top_indices = product_keys.topk(self.top_k)
        top_indices = torch.gather(product_indices, -1, top_indices)
        top_weights = F.softmax(top_keys, dim=-1)
        return top_weights, top_indices
        
        

In [79]:
dpk0 = DumbProductKey(10, 3, 2)
dpk1 = DumbProductKey(10, 3, 2)
pk = ProductKey(10, 3, top_k=2, num_heads=2)
print(pk.keyl.data[0].shape)
pk.keyl.data[0], pk.keyl.data[1]  = dpk0.keyl.weight.data, dpk1.keyl.weight.data
pk.keyr.data[0], pk.keyr.data[1] = dpk0.keyr.weight.data, dpk1.keyr.weight.data
data = torch.rand(1,1,2,10)
pk_out = pk(data)
dpk0_out = dpk0(data[:,:,0,:])
dpk1_out = dpk1(data[:,:,1,:])

assert torch.equal(pk_out[0][:,:,0,:], dpk0_out[0])
assert torch.equal(pk_out[0][:,:,1,:], dpk1_out[0])
assert torch.equal(pk_out[1][:,:,0,:], dpk0_out[1])
assert torch.equal(pk_out[1][:,:,1,:], dpk1_out[1])

torch.Size([3, 5])


## PKM

In [125]:
class DumbPKM(nn.Module): 
    # inefficient implementation to verify correctness
    def __init__(self, dim_in, dim_hidden, num_subkeys, top_k, num_heads, batch_norm = True):
        super(DumbPKM, self).__init__()
        self.dim_in, self.dim_hidden, self.num_subkeys, self.top_k, self.num_heads = dim_in, dim_hidden, num_subkeys, top_k, num_heads
        self.query_network = QueryNetwork(dim_in, dim_hidden, num_heads, batch_norm)
        self.key_table = ProductKey(dim_hidden, num_subkeys, top_k, num_heads)
        self.value_table = nn.Embedding(num_subkeys * num_subkeys, dim_in)
    def forward(self, x):
        queries = self.query_network(x)
        weights, indices = self.key_table(queries) # shape is batch, context length, num heads, top k
        
        #weights, indices = weights.reshape(-1, self.top_k), indices.reshape(-1, self.top_k)
        values = self.value_table(indices)
        #weights = weights.reshape(original_shape)
        #values = values.reshape(*original_shape, self.dim_in)
        # print(weights)
        # print(values)
        # print((weights.unsqueeze(-1) * values).shape )
        weighted_values = (weights.unsqueeze(-1) * values).sum(dim=2).sum(dim=2)# take linear combination of weights & values and sum over all heads
        
        return weighted_values
        
        
        

In [124]:
torch.manual_seed(0)
dpkm = DumbPKM(10,6,5,3,2, False)
torch.manual_seed(0)
pkm = PKM(10,6,5,3,2, False)
data = torch.rand(1,1,10)
dpkm_out = dpkm(data)
pkm_out = pkm(data)
assert torch.allclose(dpkm_out, pkm_out)

tensor([[[[0.3421, 0.3416, 0.3163],
          [0.3505, 0.3462, 0.3033]]]], grad_fn=<SoftmaxBackward0>)
tensor([[[[[-0.2813, -1.3299, -0.6538,  1.7198, -0.9610, -0.6375, -0.8870,
             0.8388,  1.1529, -1.7611],
           [-0.2813, -1.3299, -0.6538,  1.7198, -0.9610, -0.6375, -0.8870,
             0.8388,  1.1529, -1.7611],
           [-0.2813, -1.3299, -0.6538,  1.7198, -0.9610, -0.6375, -0.8870,
             0.8388,  1.1529, -1.7611]],

          [[-0.9069, -0.5918,  0.1508, -1.0411, -1.1559, -0.3167,  0.9403,
            -1.1470,  0.7928,  0.0832],
           [-0.9069, -0.5918,  0.1508, -1.0411, -1.1559, -0.3167,  0.9403,
            -1.1470,  0.7928,  0.0832],
           [-1.1070, -1.7174,  1.5346, -0.0032,  1.4403, -0.1106,  0.5769,
            -0.1692,  1.1887, -0.1575]]]]], grad_fn=<EmbeddingBackward0>)
torch.Size([1, 1, 2, 3, 10])
