In [1]:
import math

import torch
import torch.nn.functional as F


n_ctx = 10
seqlen = 3
start_pos = 2
T = start_pos + seqlen
tokens = torch.randint(0, 100, (1, seqlen))

mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)
print(mask)
mask = torch.triu(mask, diagonal=1)
print(mask)
# When performing key-value caching, we compute the attention scores
# only for the new sequence. Thus, the matrix of scores is of size
# (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
# j > cache_len + i, since row i corresponds to token cache_len + i.
mask = torch.hstack([torch.zeros((seqlen, start_pos), device=tokens.device), mask])
print("mask 1")
print(mask)
mask_1 = mask

full_mask = torch.full((n_ctx, n_ctx), float("-inf"), device=tokens.device)
full_mask = torch.triu(full_mask, diagonal=1)
print(full_mask)
mask_2 = full_mask[start_pos : start_pos + seqlen, : start_pos + seqlen]
print("mask 2")
print(mask_2)

assert mask_1.all() == mask_2.all()


tensor([[-inf, -inf, -inf],
        [-inf, -inf, -inf],
        [-inf, -inf, -inf]])
tensor([[0., -inf, -inf],
        [0., 0., -inf],
        [0., 0., 0.]])
mask 1
tensor([[0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0.]])
tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
mask 2
tensor([[0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0.]])


In [2]:
d_head = 4
n_head = 2

# (B, n_head, T, d_head)
K = torch.randn(1, n_head, T, d_head)
V = torch.randn(1, n_head, T, d_head)
Q = torch.randn(1, n_head, seqlen, d_head)
print(f"{K.shape=}")
print(f"{V.shape=}")
print(f"{Q.shape=}")
attn_scores = Q @ K.transpose(-2, -1)
attn_scores = attn_scores / math.sqrt(d_head)
print(f"{attn_scores.shape=}")
attn_scores = attn_scores + mask_2
print(f"{attn_scores.shape=}")
attn = attn_scores.softmax(dim=-1)
print(f"{attn.shape=}")
print(attn)
z_og = attn @ V
print(f"{z_og.shape=}")
print(z_og)

z_flash = F.scaled_dot_product_attention(Q, K, V, is_causal=False, attn_mask=mask_2)
print(f"{z_flash.shape=}")
print(z_flash)


K.shape=torch.Size([1, 2, 5, 4])
V.shape=torch.Size([1, 2, 5, 4])
Q.shape=torch.Size([1, 2, 3, 4])
attn_scores.shape=torch.Size([1, 2, 3, 5])
attn_scores.shape=torch.Size([1, 2, 3, 5])
attn.shape=torch.Size([1, 2, 3, 5])
tensor([[[[0.3098, 0.3718, 0.3185, 0.0000, 0.0000],
          [0.4430, 0.3402, 0.1989, 0.0179, 0.0000],
          [0.1740, 0.3748, 0.1429, 0.0801, 0.2282]],

         [[0.1503, 0.7668, 0.0828, 0.0000, 0.0000],
          [0.1756, 0.3616, 0.3000, 0.1628, 0.0000],
          [0.0159, 0.0421, 0.7838, 0.0119, 0.1464]]]])
z_og.shape=torch.Size([1, 2, 3, 4])
tensor([[[[-0.4766,  0.2880,  0.1352, -0.0508],
          [-0.4505,  0.2423,  0.0181,  0.0174],
          [-0.0904,  0.2590, -0.2730, -0.4101]],

         [[ 0.2620, -0.2075,  0.5307, -1.4057],
          [-0.0495, -0.0957, -0.3282, -0.5466],
          [-0.6198,  0.3243, -0.9575,  0.7876]]]])
z_flash.shape=torch.Size([1, 2, 3, 4])
tensor([[[[-0.4766,  0.2880,  0.1352, -0.0508],
          [-0.4505,  0.2423,  0.0181,  0.0174]

In [13]:
import time

from gollem.models.gpt2.config import GPT2Config
from gollem.models.gpt2.model import GPT


def generate_text(model, cfg, x):
    time_start = time.time()
    B = x.size(0)
    output = torch.zeros((B, cfg.n_ctx), device=x.device)
    for i in range(cfg.n_ctx):
        logits = model.sample(x, start_pos=i)
        x = torch.argmax(logits, dim=-1)
        output[:, i] = x.squeeze(-1)
    time_taken = time.time() - time_start
    print(f"Time taken: {time_taken} seconds")
    return time_taken, output


def compute_expected_flops(cfg):
    B = cfg.max_sample_batch_size
    T = cfg.n_ctx
    d_model = cfg.d_model
    L = cfg.n_layer
    # (B*d_model)**2 per layer L per step T
    total_flops_caching = T * L * (B * d_model) ** 2
    # (B*d_model*i)**2 per layer L per step T
    # i = 1, 2, ..., T
    total_flops_no_caching = 0
    for i in range(1, T + 1):
        total_flops_no_caching += L * (B * d_model * i) ** 2
    print(f"{total_flops_caching=}")
    print(f"{total_flops_no_caching=}")
    print(f"speedup: {total_flops_no_caching / total_flops_caching}")


In [16]:
vocab_size = 100
max_sample_batch_size = 1
x = torch.randint(0, vocab_size, (max_sample_batch_size, 1))

cfg = GPT2Config(
    vocab_size=vocab_size,
    n_ctx=512,
    n_layer=8,
    n_head=4,
    d_model=128,
    flash=True,
    max_sample_batch_size=max_sample_batch_size,
    use_kv_caching=False,
)
model = GPT(cfg)
assert model.transformer.h[0].attn.cache_x is not None
assert model.transformer.h[0].attn.cache_k is None
assert model.transformer.h[0].attn.cache_v is None
no_caching_time, no_caching_output = generate_text(model, cfg, x)


cfg = GPT2Config.override(cfg, use_kv_caching=True)
model = GPT(cfg)
assert model.transformer.h[0].attn.cache_x is None
assert model.transformer.h[0].attn.cache_k is not None
assert model.transformer.h[0].attn.cache_v is not None
caching_time, caching_output = generate_text(model, cfg, x)

assert no_caching_output.allclose(caching_output)
print(f"speedup: {no_caching_time / caching_time}")

compute_expected_flops(cfg)

Time taken: 0.7336409091949463 seconds
Time taken: 0.572786808013916 seconds
speedup: 1.2808271750160878
total_flops_caching=67108864
total_flops_no_caching=5881253068800
speedup: 87637.5


In [17]:
from math import factorial


factorial(512)

3477289793132605363283045917545604711992250655643514570342474831551610412066352543473209850339502253644322433110213945452950017020700690132641531132609379413587118640447161868610408995574973614275882823562549684250124803968552397251205625120655558221217087864436207992465509591872320268380814151785881725352800207863134700768597399809657208738499042913738268415847127986184303873380423297718017247676910950195457589869427325150335515295950098769992795539310703785929170990023970619071471434241132521175859508178508966184339941402328233164321874103563412623863324969543199731304073425672820273985793825430484568768008623499281404119054312761974356746032818425307441775273658857216295122538723866131188215408478974931073983819560817636952364227958802962043017708088094771476324286392990388330462645858348881588473877378418434136648928335862091963669797757488958218269240400578451402875222386750821375703159545267274370949049147967826410007407778979191340933935304227609551402113871736500473583473533792