In [75]:
import torch
from torch import nn
import torch.nn.functional as F

from mogrifier import Mogrifier

import math
from collections import namedtuple
from functools import partial
from inspect import isfunction


from compressive_transformer_pytorch import CompressiveTransformer

In [82]:
model = CompressiveTransformer(
    num_tokens = 20000,
    emb_dim = 128,                 # embedding dimensions, embedding factorization from Albert paper
    dim = 512,
    depth = 12,
    seq_len = 1024,
    mem_len = 1024,                # memory length
    cmem_len = 1024 // 4,          # compressed memory buffer length
    cmem_ratio = 4,                # compressed memory ratio, 4 was recommended in paper
    reconstruction_loss_weight = 1,# weight to place on compressed memory reconstruction loss
    attn_dropout = 0.1,            # dropout post-attention
    ff_dropout = 0.1,              # dropout in feedforward
    attn_layer_dropout = 0.1,      # dropout for attention layer output
    gru_gated_residual = True,     # whether to gate the residual intersection, from 'Stabilizing Transformer for RL' paper
    mogrify_gru = False,           # experimental feature that adds a mogrifier for the update and residual before gating by the GRU
    memory_layers = range(6, 13),  # specify which layers to use long-range memory, from 'Do Transformers Need LR Memory' paper
    one_kv_head = False,            # share one key/value head for all queries, from Shazeers 'One Write-Head is All You Need'
    ff_glu = True                  # use GLU variant for feedforward
)


In [83]:
model

CompressiveTransformer(
  (token_emb): Embedding(20000, 128)
  (to_model_dim): Linear(in_features=128, out_features=512, bias=True)
  (to_logits): Sequential(
    (0): Linear(in_features=512, out_features=128, bias=True)
    (1): Linear(in_features=128, out_features=20000, bias=True)
  )
  (attn_layers): ModuleList(
    (0): GRUGating(
      (fn): PreNorm(
        (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (fn): SelfAttention(
          (compress_mem_fn): ConvCompress(
            (conv): Conv1d(512, 512, kernel_size=(4,), stride=(4,))
          )
          (to_q): Linear(in_features=512, out_features=512, bias=False)
          (to_kv): Linear(in_features=512, out_features=1024, bias=False)
          (to_out): Linear(in_features=512, out_features=512, bias=True)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (dropout): Dropout(p=0.1, inplace=False)
          (reconstruction_attn_dropout): Dropout(p=0.0, inplace=False)
        )
      )
    

In [84]:
inputs = torch.randint(0, 256, (1, 2048))  # [b, 2n]
masks = torch.ones_like(inputs).bool()   # [b, nd]

segments = inputs.reshape(1, -1, 1024).transpose(0, 1)  # [b, 2, n] -> [2,b,n]
masks = masks.reshape(1, -1, 1024).transpose(0, 1)   # [b, 2, n]  -> [2,b,n]


# seg[0] [b, n]
# logits [b,1024,d]    mem [7, b, 1024, 512]
logits, memories, aux_loss = model(segments[0], mask = masks[0])
logits.shape

torch.Size([1, 1024, 20000])

In [37]:
# compressed_mem[7, b, 256, 512]
logits,  memories, aux_loss = model(segments[1], mask = masks[1], memories = memories)

# memories is a named tuple that contains the memory (mem) and the compressed memory (cmem)

In [45]:
from inspect import isfunction

def default(x, val):
    if x is not None:
        return x
    return val if not isfunction(val) else val()

In [48]:
mem, cmem = memories

In [50]:
mem.shape, cmem.shape

(torch.Size([7, 1, 1024, 512]), torch.Size([7, 1, 256, 512]))

In [59]:
x = torch.randn(1,16)
b = 2
d = 512

In [64]:
torch.randn(2,0)

tensor([], size=(2, 0))

In [65]:
def iterate_tensor(t):
    length = t.shape[0]
    for ind in range(length):
        yield t[ind]


In [72]:
m = torch.randn(2,3,4)
m

tensor([[[ 1.7537, -0.8728, -1.2249,  0.0646],
         [ 0.7961,  0.2608, -0.9098, -1.7165],
         [ 0.7452, -1.1558,  1.6637, -1.4666]],

        [[-0.9516,  2.1680, -1.0025,  0.8389],
         [ 0.1069, -1.4778,  1.0530, -1.1011],
         [ 0.0072,  0.6764, -0.2917, -1.2895]]])

In [73]:
m.transpose(1,2).shape

torch.Size([2, 4, 3])

In [76]:
class ConvCompress(nn.Module):
    def __init__(self, dim, ratio = 4):
        super().__init__()
        self.conv = nn.Conv1d(dim, dim, ratio, stride = ratio)

    def forward(self, mem):
        mem = mem.transpose(1, 2)
        compressed_mem = self.conv(mem)
        return compressed_mem.transpose(1, 2)

In [77]:
def reshape_dim(t, dim, split_dims):
    shape = list(t.shape)
    num_dims = len(shape)
    dim = (dim + num_dims) % num_dims
    shape[dim:dim+1] = split_dims
    return t.reshape(shape)

In [88]:
dim_h = 32
h = 4
merge_heads = lambda x: reshape_dim(x, -1, (-1, dim_h)).transpose(1, 2)


q = k = v = torch.randn(2,16,128)
q, k, v = map(merge_heads, (q, k, v))   # [b,n,d] -> [b,h,n,h_d]

q.shape, k.shape, v.shape

(torch.Size([2, 4, 16, 32]),
 torch.Size([2, 4, 16, 32]),
 torch.Size([2, 4, 16, 32]))

In [89]:
k, v = map(lambda x: x.expand(-1, h, -1, -1), (k, v))  # k,v:[b,h,len_cmem+mem+x,h_d]
k.shape, v.shape

(torch.Size([2, 4, 16, 32]), torch.Size([2, 4, 16, 32]))

In [94]:
tmp = torch.randn(2,1,4)

In [95]:
tmp.expand(2,3,4).shape

torch.Size([2, 3, 4])

In [96]:
t = 4
total_mem_len = 6
mask = torch.ones(t, t + total_mem_len,).triu_(diagonal = 1 + total_mem_len)

In [97]:
mask

tensor([[0., 0., 0., 0., 0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [98]:
def split_at_index(dim, index, t):
    pre_slices = (slice(None),) * dim
    l = (*pre_slices, slice(None, index))
    r = (*pre_slices, slice(index, None))
    return t[l], t[r]


def queue_fifo(*args, length, dim=-2):
    queue = torch.cat(args, dim=dim)
    if length > 0:
        return split_at_index(dim, -length, queue)

    device = queue.device
    shape = list(queue.shape)
    shape[dim] = 0
    return queue, torch.empty(shape, device = device)

In [99]:
7%4

3

In [100]:
class ConvCompress(nn.Module):
    def __init__(self, dim, ratio = 4):
        super().__init__()
        self.conv = nn.Conv1d(dim, dim, ratio, stride = ratio)

    def forward(self, mem):
        mem = mem.transpose(1, 2)
        compressed_mem = self.conv(mem)
        return compressed_mem.transpose(1, 2)

In [101]:
Cc = ConvCompress(dim=64)

In [102]:
mem = torch.randn(2, 16, 64)

In [103]:
Cc(mem).shape

torch.Size([2, 4, 64])

In [None]:
slice(- min(mem_len, self.mem_len) - self.seq_len, -self.seq_len)

In [123]:
li = [i for i in range(10)]

In [125]:
s = slice(5)  # indexes 0,1,2,3,4
li[s]

[0, 1, 2, 3, 4]

In [136]:
s = slice(0, 6)  # indexes 9,7,5,3,1
li[s]

[0, 1, 2, 3, 4, 5]

In [137]:
s = slice(1, 10, 2)  # indexes 1,3,5,7,9
li[s]

[1, 3, 5, 7, 9]

In [138]:
s = slice(9, 0, -2)  # indexes 9,7,5,3,1
li[s]

[9, 7, 5, 3, 1]

In [141]:
li

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

In [139]:
s = slice(-6, -1)  # indexes 9,7,5,3,1
li[s]

[4, 5, 6, 7, 8]

In [143]:
x = torch.randn(1,10,5)
x1 = x.split(2, dim=1)

In [144]:
len(x1)

5

In [145]:
x1[0].shape

torch.Size([1, 2, 5])

In [146]:
def fun1():
    for i in range(10):
        yield i

In [147]:
res = fun1()

In [153]:
next(res)

3

In [154]:
for i in res:
    print(i, 'ok')

4 ok
5 ok
6 ok
7 ok
8 ok
9 ok


#   mem cache

In [159]:
from collections import namedtuple

In [160]:
mem_to_max_patch_idx = namedtuple('mem_to_max_patch_idx', ['cur_idx', 'max_attach_idx'])

In [162]:
k = mem_to_max_patch_idx(cur_idx=0, max_attach_idx=5)
k.cur_idx, k.max_attach_idx

In [157]:
m = torch.Tensor(list(range(4*4))).reshape(4,4)
m

tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [12., 13., 14., 15.]])

In [280]:
mem_mapping = dict()  # mapping cur_idx: max_attach_idx
mem_expand_size = 1
patch_height = 4
patch_width = 4
for cur_patch_idx in range(4*4):
    cur_row = cur_patch_idx // patch_width
    cur_col = cur_patch_idx % patch_width
    max_row = min(patch_height-1, cur_row+mem_expand_size)
    max_col = min(patch_width-1, cur_col+mem_expand_size)
    max_idx = max_row * patch_width + max_col
    mem_mapping[cur_patch_idx] = max_idx
    
mem_mapping

{0: 5,
 1: 6,
 2: 7,
 3: 7,
 4: 9,
 5: 10,
 6: 11,
 7: 11,
 8: 13,
 9: 14,
 10: 15,
 11: 15,
 12: 13,
 13: 14,
 14: 15,
 15: 15}

In [281]:
def update_mem(mem_pool, mem_mapping, mem, cur_patch_idx):
    # add new_mem
    mem_pool[cur_patch_idx] = mem

    # filter mem
    del_idxs = []
    for k in mem_pool.keys():
        if mem_mapping[k] > cur_patch_idx:
            continue
        else:
            del_idxs.append(k)
    for del_idx in del_idxs:
        del mem_pool[del_idx]

In [282]:
mem_pool_list = []
mem_pool = dict()
for i in range(16):
    update_mem(mem_pool, mem_mapping, 'mem'+str(i), i)
    
    mem_pool_list.append(mem_pool.copy())
    print('cur:', i)
    print(mem_pool)
    
    
    print('*'*16)

cur: 0
{0: 'mem0'}
****************
cur: 1
{0: 'mem0', 1: 'mem1'}
****************
cur: 2
{0: 'mem0', 1: 'mem1', 2: 'mem2'}
****************
cur: 3
{0: 'mem0', 1: 'mem1', 2: 'mem2', 3: 'mem3'}
****************
cur: 4
{0: 'mem0', 1: 'mem1', 2: 'mem2', 3: 'mem3', 4: 'mem4'}
****************
cur: 5
{1: 'mem1', 2: 'mem2', 3: 'mem3', 4: 'mem4', 5: 'mem5'}
****************
cur: 6
{2: 'mem2', 3: 'mem3', 4: 'mem4', 5: 'mem5', 6: 'mem6'}
****************
cur: 7
{4: 'mem4', 5: 'mem5', 6: 'mem6', 7: 'mem7'}
****************
cur: 8
{4: 'mem4', 5: 'mem5', 6: 'mem6', 7: 'mem7', 8: 'mem8'}
****************
cur: 9
{5: 'mem5', 6: 'mem6', 7: 'mem7', 8: 'mem8', 9: 'mem9'}
****************
cur: 10
{6: 'mem6', 7: 'mem7', 8: 'mem8', 9: 'mem9', 10: 'mem10'}
****************
cur: 11
{8: 'mem8', 9: 'mem9', 10: 'mem10', 11: 'mem11'}
****************
cur: 12
{8: 'mem8', 9: 'mem9', 10: 'mem10', 11: 'mem11', 12: 'mem12'}
****************
cur: 13
{9: 'mem9', 10: 'mem10', 11: 'mem11', 13: 'mem13'}
****************
c

##   concat mem+x

In [318]:
def get_mem_from_pool(mem_pool, cur_patch_idx, mem_expand_size=1, patch_height=4, patch_width=4):
    cur_patch_row = cur_patch_idx // patch_width
    cur_patch_col = cur_patch_idx % patch_width

    rows = torch.arange(max(0, cur_patch_row - mem_expand_size), min(patch_height, cur_patch_row + 1))
    cols = torch.arange(max(0, cur_patch_col - mem_expand_size), min(patch_width, cur_patch_col + mem_expand_size + 1))

    coords_init = torch.stack(torch.meshgrid([rows, cols]), dim=-1).reshape(-1, 2)
    coords_filtered = [coord for coord in coords_init if coord[0] < cur_patch_row or coord[1] < cur_patch_col]
    mem_idxs = [int(coord[0] * patch_width + coord[1]) for coord in coords_filtered]

    return [mem_pool[idx] for idx in mem_idxs]


In [319]:
mem_pool_list[6]

{2: 'mem2', 3: 'mem3', 4: 'mem4', 5: 'mem5', 6: 'mem6'}

In [320]:
get_mem_from_pool(mem_pool_list[5], 6,)

['mem1', 'mem2', 'mem3', 'mem5']

In [258]:
mem_list = []
for i in range(4*4):
    mem = torch.full((2,2*2), i)
    mem_list.append(mem)

In [259]:
mem_list[0].shape

torch.Size([2, 4])

In [262]:
m1, m2, m3, m4 = mem_list[0],mem_list[1],mem_list[2],mem_list[4]
ms = [m1,m2,m3,m4]

In [263]:
x = mem_list[5]
ms += [x]

In [265]:
x2 = torch.concat(ms, dim=1)

# rel pos emb

In [204]:
# get pair-wise relative position index for each token inside the window
window_size = (7,7)

coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww

In [212]:
coords

tensor([[[0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1],
         [2, 2, 2, 2, 2, 2, 2],
         [3, 3, 3, 3, 3, 3, 3],
         [4, 4, 4, 4, 4, 4, 4],
         [5, 5, 5, 5, 5, 5, 5],
         [6, 6, 6, 6, 6, 6, 6]],

        [[0, 1, 2, 3, 4, 5, 6],
         [0, 1, 2, 3, 4, 5, 6],
         [0, 1, 2, 3, 4, 5, 6],
         [0, 1, 2, 3, 4, 5, 6],
         [0, 1, 2, 3, 4, 5, 6],
         [0, 1, 2, 3, 4, 5, 6],
         [0, 1, 2, 3, 4, 5, 6]]])

In [211]:
coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
coords_flatten.shape, coords_flatten[:, 8]

(torch.Size([2, 49]), tensor([1, 1]))

In [213]:
# 2, Wh*Ww, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  
relative_coords.shape, relative_coords

(torch.Size([2, 49, 49]),
 tensor([[[ 0,  0,  0,  ..., -6, -6, -6],
          [ 0,  0,  0,  ..., -6, -6, -6],
          [ 0,  0,  0,  ..., -6, -6, -6],
          ...,
          [ 6,  6,  6,  ...,  0,  0,  0],
          [ 6,  6,  6,  ...,  0,  0,  0],
          [ 6,  6,  6,  ...,  0,  0,  0]],
 
         [[ 0, -1, -2,  ..., -4, -5, -6],
          [ 1,  0, -1,  ..., -3, -4, -5],
          [ 2,  1,  0,  ..., -2, -3, -4],
          ...,
          [ 4,  3,  2,  ...,  0, -1, -2],
          [ 5,  4,  3,  ...,  1,  0, -1],
          [ 6,  5,  4,  ...,  2,  1,  0]]]))

In [217]:
# Wh*Ww, Wh*Ww, 2, [i,j,:]表示窗口内第i个patch相对于第j个patch的坐标
relative_coords = relative_coords.permute(1, 2, 0).contiguous() 

In [218]:
relative_coords.shape

torch.Size([49, 49, 2])

In [230]:
relative_coords[7,:,0]

tensor([ 1,  1,  1,  1,  1,  1,  1,  0,  0,  0,  0,  0,  0,  0, -1, -1, -1, -1,
        -1, -1, -1, -2, -2, -2, -2, -2, -2, -2, -3, -3, -3, -3, -3, -3, -3, -4,
        -4, -4, -4, -4, -4, -4, -5, -5, -5, -5, -5, -5, -5])

In [231]:
relative_coords[7,:,1]

tensor([ 0, -1, -2, -3, -4, -5, -6,  0, -1, -2, -3, -4, -5, -6,  0, -1, -2, -3,
        -4, -5, -6,  0, -1, -2, -3, -4, -5, -6,  0, -1, -2, -3, -4, -5, -6,  0,
        -1, -2, -3, -4, -5, -6,  0, -1, -2, -3, -4, -5, -6])

In [232]:
relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1

In [233]:
relative_coords[7,:,0]

tensor([7, 7, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 5, 5, 4, 4, 4,
        4, 4, 4, 4, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1,
        1])

In [234]:
relative_coords[7,:,1]

tensor([6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4,
        3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1,
        0])

In [236]:
# define a parameter table of relative position bias  # shape : 2*Wh-1 * 2*Ww-1, nH
num_heads = 4
relative_position_bias_table = nn.Parameter(
    torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) 
relative_position_bias_table.shape  # 2*7-1=13  13*13=169

torch.Size([169, 4])

In [239]:
relative_coords[:, :, 0] *= 2 * window_size[1] - 1

In [241]:
relative_coords[7,:,0]

tensor([91, 91, 91, 91, 91, 91, 91, 78, 78, 78, 78, 78, 78, 78, 65, 65, 65, 65,
        65, 65, 65, 52, 52, 52, 52, 52, 52, 52, 39, 39, 39, 39, 39, 39, 39, 26,
        26, 26, 26, 26, 26, 26, 13, 13, 13, 13, 13, 13, 13])

In [242]:
relative_coords.shape

torch.Size([49, 49, 2])

In [243]:
relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
relative_position_index.shape

torch.Size([49, 49])

In [244]:
relative_position_index

tensor([[ 84,  83,  82,  ...,   2,   1,   0],
        [ 85,  84,  83,  ...,   3,   2,   1],
        [ 86,  85,  84,  ...,   4,   3,   2],
        ...,
        [166, 165, 164,  ...,  84,  83,  82],
        [167, 166, 165,  ...,  85,  84,  83],
        [168, 167, 166,  ...,  86,  85,  84]])

In [245]:
relative_position_bias = relative_position_bias_table[relative_position_index.view(-1)].view(
    window_size[0] * window_size[1], window_size[0] * window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
relative_position_bias.shape

torch.Size([49, 49, 4])

In [246]:
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
relative_position_bias.shape

torch.Size([4, 49, 49])