Looks at the key-value implementation in the `CausalSelfAttention` class. Illustrates how it works and the outputs of the self-attention mechanism.

In [1]:
from mingpt.model import GPT
from mingpt.attention import CausalSelfAttention
from mingpt.utils import set_seed
import torch

We look at a small example with batch of 2 sequences of length 8 with dimension 6. For now, we will just use 1 head for simplicity.

We generate a random input tensor - dim 1 of the sequence, this is like the token embeddings (with positional embeddings) that we input to the self-attention mechanism.

In [2]:
set_seed(42)
batch_size = 2
seq_len = 8
dim = 6

input = torch.rand((batch_size, seq_len, dim))
input.shape

torch.Size([2, 8, 6])

In [3]:
config = GPT.get_default_config()

config.n_layer = 1
config.n_head = 1
config.n_embd = dim
config.block_size = seq_len

We will just pass this through a forward pass of the `CausalSelfAttention` class and look at the outputs when we do not use any key-value caching. Here, we just input the full sequence (as if we were going to process the entire sequence to predict the next token).

In [4]:
set_seed(42)
mha = CausalSelfAttention(config)
mha.eval()

without_kv_cache = mha.forward(input)

In [5]:
without_kv_cache

tensor([[[-0.2179,  0.4357,  0.3333,  0.0294, -0.0310, -0.0278],
         [-0.1889,  0.5065,  0.3372, -0.0225, -0.0488, -0.0271],
         [-0.1682,  0.4691,  0.3606, -0.0076, -0.0312,  0.0076],
         [-0.1536,  0.4177,  0.3679, -0.0322, -0.0192,  0.0309],
         [-0.1518,  0.3811,  0.3394, -0.0629, -0.0441,  0.0369],
         [-0.1075,  0.3083,  0.3525, -0.0363, -0.0332,  0.0558],
         [-0.1208,  0.3341,  0.3497, -0.0366, -0.0378,  0.0537],
         [-0.1245,  0.3554,  0.3567, -0.0293, -0.0346,  0.0549]],

        [[ 0.0404, -0.1563,  0.5403, -0.0059,  0.1507,  0.3894],
         [-0.0704,  0.1210,  0.4321, -0.0240,  0.0440,  0.2168],
         [-0.0642,  0.1390,  0.4008, -0.0369,  0.0128,  0.1795],
         [-0.0706,  0.1430,  0.4408, -0.0333,  0.0486,  0.2177],
         [-0.0364,  0.0500,  0.4691, -0.0181,  0.0731,  0.2813],
         [-0.0682,  0.1153,  0.4586, -0.0224,  0.0610,  0.2626],
         [-0.0420,  0.0875,  0.4597, -0.0096,  0.0579,  0.2689],
         [-0.0447,  0.0

Now we use the key-value caching mechanism (kind of). If we just set `use_kv_cache` to `True`, the key-value tensors are computed for the first time and then cached. In a decoding setting, we would be able to re-use these cached key-value tensors for the later token generation steps.

In [6]:
set_seed(42)
mha = CausalSelfAttention(config)
mha.eval()

with_kv_cache = mha.forward(input, use_kv_cache=True)

In [7]:
with_kv_cache

tensor([[[-0.2179,  0.4357,  0.3333,  0.0294, -0.0310, -0.0278],
         [-0.1889,  0.5065,  0.3372, -0.0225, -0.0488, -0.0271],
         [-0.1682,  0.4691,  0.3606, -0.0076, -0.0312,  0.0076],
         [-0.1536,  0.4177,  0.3679, -0.0322, -0.0192,  0.0309],
         [-0.1518,  0.3811,  0.3394, -0.0629, -0.0441,  0.0369],
         [-0.1075,  0.3083,  0.3525, -0.0363, -0.0332,  0.0558],
         [-0.1208,  0.3341,  0.3497, -0.0366, -0.0378,  0.0537],
         [-0.1245,  0.3554,  0.3567, -0.0293, -0.0346,  0.0549]],

        [[ 0.0404, -0.1563,  0.5403, -0.0059,  0.1507,  0.3894],
         [-0.0704,  0.1210,  0.4321, -0.0240,  0.0440,  0.2168],
         [-0.0642,  0.1390,  0.4008, -0.0369,  0.0128,  0.1795],
         [-0.0706,  0.1430,  0.4408, -0.0333,  0.0486,  0.2177],
         [-0.0364,  0.0500,  0.4691, -0.0181,  0.0731,  0.2813],
         [-0.0682,  0.1153,  0.4586, -0.0224,  0.0610,  0.2626],
         [-0.0420,  0.0875,  0.4597, -0.0096,  0.0579,  0.2689],
         [-0.0447,  0.0

We can see the outputs of the self-attention mechanism in both cases, but this is expected to be the same for the first pass. All we have done differently is to cache the key-value tensors for the later steps.

In [8]:
torch.equal(without_kv_cache, with_kv_cache)

True

We can look into the key-value cache here now too. From inspecting the shape, we have actually created a cache which is a bit larger than the amount of things in the cache (we only have two batches rather than 64 which we have just set as a `max_batch_size` argument above in the config).

In [9]:
mha.cache_k.shape

torch.Size([64, 1, 8, 6])

In [10]:
mha.cache_k[:3]

tensor([[[[-0.0403,  0.4043, -0.2079, -0.1077, -0.4971, -0.2688],
          [ 0.0716,  0.6593, -0.5755,  0.0521, -0.2243, -0.1871],
          [-0.0142,  0.5693, -0.3494,  0.0481, -0.2169,  0.0308],
          [-0.1485,  0.2542, -0.0300, -0.3365, -0.4056, -0.2569],
          [-0.1228,  0.1589, -0.3271, -0.0833, -0.1632, -0.2081],
          [-0.2708,  0.2987, -0.0424, -0.2511, -0.1694, -0.0641],
          [ 0.0767,  0.5388, -0.3960,  0.0277, -0.3429, -0.2034],
          [ 0.1095,  0.6960, -0.4371,  0.0652, -0.3131, -0.1237]]],


        [[[-0.2830,  0.1052,  0.3896, -0.4659, -0.3163,  0.0543],
          [ 0.0539,  0.5208, -0.5527,  0.1603, -0.2434, -0.1303],
          [-0.1671,  0.2920, -0.2535, -0.0939, -0.1222, -0.0482],
          [-0.1294,  0.3796,  0.0153, -0.0825, -0.1500,  0.3572],
          [-0.2505,  0.1359,  0.4108, -0.3808, -0.2291,  0.1948],
          [ 0.1183,  0.5452, -0.2858,  0.0636, -0.3351, -0.0121],
          [-0.1156,  0.4128, -0.0587, -0.2097, -0.2028, -0.1035],
      

So how would we actually use key-value caching in practice?

So initially when we have an input sequence, we will feed all of that into the model to generate the next token. This is the `pre-fill` stage as we are able to process several input tokens in one pass. Subsequently, we will only have one new token generated at a time. At each stage, we compute the key-value tensors for the new token and add them to the cache. This means that we don't input the full sequence each time (like we do without key-value caching) but just the new token. We only output the attention output for the new token as that is all is needed for generation.

For the `forward` method, there's actually another argument `start_pos` which tells the model where to start processing the sequence from. In the initial (called `pre-fill` stage when we process an input sequence for the first time) generation for the first new token, we set `start_pos` to 0 to fill the cache. From there, we would set `start_pos` to the next position in the sequence to generate the next token and then increment it for the next token.

In [11]:
set_seed(42)
mha = CausalSelfAttention(config)
mha.eval()

with_cache_sequential = []

# prefill the cache with the first 3 tokens
print("prefill input")
print(input[:, :3])
output = mha.forward(input[:, :3], use_kv_cache=True, start_pos=0)
print("current key-cache:")
print(mha.cache_k[:batch_size])
print("prefill output")
print(output)
with_cache_sequential.append(output)

# now we can process the rest of the sequence sequentially and update the cache
for i in range(3, seq_len):
    print("i: ", i)
    print("input")
    print(input[:, i : (i + 1)])
    output = mha.forward(input[:, i : (i + 1)], use_kv_cache=True, start_pos=i)
    with_cache_sequential.append(output)
    print("current key-cache:")
    print(mha.cache_k[:batch_size])
    print("output")
    print(output)

prefill input
tensor([[[0.8823, 0.9150, 0.3829, 0.9593, 0.3904, 0.6009],
         [0.2566, 0.7936, 0.9408, 0.1332, 0.9346, 0.5936],
         [0.8694, 0.5677, 0.7411, 0.4294, 0.8854, 0.5739]],

        [[0.7539, 0.1952, 0.0050, 0.3068, 0.1165, 0.9103],
         [0.6440, 0.7071, 0.6581, 0.4913, 0.8913, 0.1447],
         [0.5315, 0.1587, 0.6542, 0.3278, 0.6532, 0.3958]]])
current key-cache:
tensor([[[[-0.0403,  0.4043, -0.2079, -0.1077, -0.4971, -0.2688],
          [ 0.0716,  0.6593, -0.5755,  0.0521, -0.2243, -0.1871],
          [-0.0142,  0.5693, -0.3494,  0.0481, -0.2169,  0.0308],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]]],


        [[[-0.2830,  0.1052,  0.3896, -0.4659, -0.3163,  0.0543],
          [

We can see in the outputs above, we just output the attention output for the new token. If we concatenate all of our outputs, we would get the same output as if we had not used key-value caching but on the whole sequence (not exactly though due to some numerical precision issues, but it should be very close):

In [12]:
torch.concatenate(with_cache_sequential, dim=1)

tensor([[[-0.2179,  0.4357,  0.3333,  0.0294, -0.0310, -0.0278],
         [-0.1889,  0.5065,  0.3372, -0.0225, -0.0488, -0.0271],
         [-0.1682,  0.4691,  0.3606, -0.0076, -0.0312,  0.0076],
         [-0.1536,  0.4177,  0.3679, -0.0322, -0.0192,  0.0309],
         [-0.1518,  0.3811,  0.3394, -0.0629, -0.0441,  0.0369],
         [-0.1075,  0.3083,  0.3525, -0.0363, -0.0332,  0.0558],
         [-0.1208,  0.3341,  0.3497, -0.0366, -0.0378,  0.0537],
         [-0.1245,  0.3554,  0.3567, -0.0293, -0.0346,  0.0549]],

        [[ 0.0404, -0.1563,  0.5403, -0.0059,  0.1507,  0.3894],
         [-0.0704,  0.1210,  0.4321, -0.0240,  0.0440,  0.2168],
         [-0.0642,  0.1390,  0.4008, -0.0369,  0.0128,  0.1795],
         [-0.0706,  0.1430,  0.4408, -0.0333,  0.0486,  0.2177],
         [-0.0364,  0.0500,  0.4691, -0.0181,  0.0731,  0.2813],
         [-0.0682,  0.1153,  0.4586, -0.0224,  0.0610,  0.2626],
         [-0.0420,  0.0875,  0.4597, -0.0096,  0.0579,  0.2689],
         [-0.0447,  0.0

In [13]:
torch.equal(torch.concatenate(with_cache_sequential, dim=1), with_kv_cache)

False

In [14]:
torch.concatenate(with_cache_sequential, dim=1) == with_kv_cache

tensor([[[ True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True],
         [False, False,  True, False,  True, False],
         [ True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True]],

        [[ True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True],
         [False,  True, False, False, False,  True],
         [False, False, False, False, False,  True],
         [ True,  True,  True,  True,  True,  True]]])

In [15]:
torch.allclose(torch.concatenate(with_cache_sequential, dim=1), with_kv_cache)

True

Here, we just have a reminder of what we did without key-value caching.

Here, at each step, we'd input all the tokens up to that point to generate the next token. This is not efficient as we are re-computing the key-value tensors for the entire sequence each time. We also are saving computation since we only compute the attention output for the new token.

In [16]:
set_seed(42)
mha = CausalSelfAttention(config)
mha.eval()

print("input:")
print(input[:, :3])
without_cache = mha.forward(input[:, :3])
print("output:")
print(without_cache)

# now we can process the rest of the sequence sequentially and update the cache
for i in range(4, seq_len + 1):
    print("i: ", i)
    print("input:")
    print(input[:, :i])
    without_cache = mha.forward(input[:, :i])
    print("output:")
    print(without_cache)

input:
tensor([[[0.8823, 0.9150, 0.3829, 0.9593, 0.3904, 0.6009],
         [0.2566, 0.7936, 0.9408, 0.1332, 0.9346, 0.5936],
         [0.8694, 0.5677, 0.7411, 0.4294, 0.8854, 0.5739]],

        [[0.7539, 0.1952, 0.0050, 0.3068, 0.1165, 0.9103],
         [0.6440, 0.7071, 0.6581, 0.4913, 0.8913, 0.1447],
         [0.5315, 0.1587, 0.6542, 0.3278, 0.6532, 0.3958]]])
output:
tensor([[[-0.2179,  0.4357,  0.3333,  0.0294, -0.0310, -0.0278],
         [-0.1889,  0.5065,  0.3372, -0.0225, -0.0488, -0.0271],
         [-0.1682,  0.4691,  0.3606, -0.0076, -0.0312,  0.0076]],

        [[ 0.0404, -0.1563,  0.5403, -0.0059,  0.1507,  0.3894],
         [-0.0704,  0.1210,  0.4321, -0.0240,  0.0440,  0.2168],
         [-0.0642,  0.1390,  0.4008, -0.0369,  0.0128,  0.1795]]],
       grad_fn=<ViewBackward0>)
i:  4
input:
tensor([[[0.8823, 0.9150, 0.3829, 0.9593, 0.3904, 0.6009],
         [0.2566, 0.7936, 0.9408, 0.1332, 0.9346, 0.5936],
         [0.8694, 0.5677, 0.7411, 0.4294, 0.8854, 0.5739],
         [0

Let's now do a quick time comparison to see how much faster key-value caching is. We will use a larger batch size, sequence length, dimension and number of heads now:

In [17]:
set_seed(42)
batch_size = 8
seq_len = 64
dim = 128

input = torch.rand((batch_size, seq_len, dim))
input.shape

torch.Size([8, 64, 128])

In [18]:
config = GPT.get_default_config()

config.n_head = 8
config.n_embd = dim
config.block_size = seq_len

In [19]:
%%timeit -n 100

set_seed(42)
mha = CausalSelfAttention(config)
mha.eval()

# prefill the cache with the first 3 tokens
mha.forward(input[:, :3], use_kv_cache=True, start_pos=0)

# now we can process the rest of the sequence sequentially and update the cache
for i in range(3, seq_len):
    with_cache_sequential = mha.forward(
        input[:, i : (i + 1)], use_kv_cache=True, start_pos=i
    )

19.6 ms ± 1.3 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [20]:
%%timeit -n 100

set_seed(42)
mha = CausalSelfAttention(config)
mha.eval()

mha.forward(input[:, :3])

# now we can process the rest of the sequence sequentially and update the cache
for i in range(4, seq_len + 1):
    without_cache = mha.forward(input[:, :i])

43.9 ms ± 3.06 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
