### Data

In [1]:
from tqdm.auto import tqdm
from datasets import load_dataset
import numpy as  np
from transformers import GPT2Tokenizer

In [2]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
dataset = load_dataset('wikitext', 'wikitext-103-v1')
split = dataset['train']

In [None]:
# compute text with token length >= 512
batch = 500
min_tok_length = 512
filtered = []
lengths = []
max_length = 0
for i in tqdm(range(0, len(split), batch)):
    lengths += [len(tok) for tok in tokenizer.batch_encode_plus([split[min(i+j, len(split) - 1)]['text'] for j in range(batch)])['input_ids']]

In [None]:
np.save('token_length.npy', np.array(lengths))

In [3]:
l = np.load('token_length.npy')

In [4]:
import matplotlib.pyplot as plt
%matplotlib inline

In [5]:
all([len(tok) > 512 for tok in tokenizer.batch_encode_plus([split[int(i)]['text'] for i in np.where(l > 512)[0][:20]])['input_ids']])

True

### Apply to MHA (GPT2 WikiText)

In [6]:
import torch
import torch.nn as nn
import numpy as np
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import transformers_drop_in as drop_in
import tensor_util as tu
from config import CONFIG
from performer_pytorch.performer_pytorch import causal_linear_attention_noncuda

In [7]:
CONFIG.do_consolidate = True
CONFIG.consolidate_ratio = 0.5
CONFIG.context_length = 400
CONFIG.consolidate_length = 200
CONFIG.temperature = 0.1
CONFIG.fix_prune_rate = True

In [8]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2').to(CONFIG.device)
model.eval()

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2AttentionDropIn(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [9]:
batch_size = 1
indices, = np.where(l > CONFIG.context_length)
batch_iter = iter(np.array_split(np.random.choice(indices, len(indices), replace=False), len(indices) // batch_size))

In [10]:
n_layer = 12
drop_in.GLOBALS.outputs = {
    'unnormalized': [[] for _ in range(n_layer)],
    'eig': [[] for _ in range(n_layer)],
    'final': [[] for _ in range(n_layer)],
    'value': [[] for _ in range(n_layer)],
    'query': [[] for _ in range(n_layer)],
    'key': [[] for _ in range(n_layer)],
    'out': [[] for _ in range(n_layer)],
    'mask': [[] for _ in range(n_layer)],
}
def record_attn(layer_idx, query, key, value, unnormalized_attn, final_attn, attn_output, attn_mask):
    drop_in.GLOBALS.outputs['unnormalized'][layer_idx] += [unnormalized_attn.cpu()]
    drop_in.GLOBALS.outputs['final'][layer_idx] += [final_attn.cpu()]
    drop_in.GLOBALS.outputs['value'][layer_idx] += [value.cpu()]
    drop_in.GLOBALS.outputs['query'][layer_idx] += [value.cpu()]
    drop_in.GLOBALS.outputs['key'][layer_idx] += [value.cpu()]
    drop_in.GLOBALS.outputs['out'][layer_idx] += [attn_output.cpu()]
    drop_in.GLOBALS.outputs['mask'][layer_idx] += [attn_mask.cpu()]

def no_op(query, key, value, attn_weights):
    pass

drop_in.record_attn_vars = record_attn

In [15]:
drop_in.GLOBALS.scaling_enabled = False
n_sample = 1
n_layer = 12
cols = 4
rows = n_layer // cols
rank_by_layer = [[] for _ in range(n_layer)]
with torch.no_grad():
    for i in tqdm(range(n_sample)):
        batch = next(batch_iter)
        model_input = {name: t.to(CONFIG.device) for name, t in tokenizer.batch_encode_plus(split[batch]['text'],
                                                                                         return_tensors="pt",
                                                                                         truncation=True,
                                                                                         max_length=CONFIG.context_length).items()}
        model(**model_input)

  0%|          | 0/1 [00:00<?, ?it/s]

In [16]:
data = {}
for m in ['key', 'query', 'value', 'out', 'mask']:
    data[m] = torch.stack([torch.cat(t, dim=0) for t in drop_in.GLOBALS.outputs[m]], dim=1)[0]

In [17]:
def compute_linear_attn(q, k, v):
    return causal_linear_attention_noncuda(
        torch.softmax(q, dim=-1),
        torch.exp(k),
        v,
    )

In [30]:
idx = slice(None), slice(None)
q = data['query'][idx]
k = data['key'][idx]
v = data['value'][idx]
raw = q @ k.transpose(-2, -1)
mask = torch.tril(torch.ones_like(raw, dtype=torch.bool), diagonal=0).transpose(-2, -1)
raw[~mask] = -float("inf")
attn = torch.softmax(raw, dim=-1)
out = attn @ v
out_p = compute_linear_attn(q, k, v)

In [31]:
(((out - out_p)**2).sum(dim=-1)**0.5).numpy().mean()

5.460476