In [770]:
# ! curl -o input.txt https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt 

In [771]:
# tokens = nlp('and to')
# ! pip install levenshtein

from Levenshtein import distance as edit_distance

def get_proximity_with_vocab(text, tokenizer, vocab):
    def proximity_fn(word):
        if word in vocab:
            distance = 0
        else:
            distance = min(*map(
                lambda vocab_word: edit_distance(word, vocab_word),
                vocab.keys()
            ))
        
        return distance


    tokens = tokenizer(text)
    distance = reduce(lambda y,token: y+proximity_fn(str(token)), tokens, 0)

    return distance

def get_all_proximity_with_vocab(texts, tokenizer, vocab):
    return np.fromiter(
        map(
            lambda text: get_proximity_with_vocab(text, tokenizer, vocab),
            texts,
        ),
        dtype=object
    )

nlp = spacy.blank("en")

texts = [
    'hello, how are you doing?',
    'hello, how are you doing?',
]

get_all_proximity_with_vocab(texts, nlp, word_vocab)

array([14, 14], dtype=object)

In [None]:
! head input.txt

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:


In [841]:
import tensorflow as tf

from collections import Counter
from itertools import chain

data = 'helloksdhglsadjgalsdhglaksddd'
block_size = 5

# ds = tf.data.Dataset.range(len(data) - block_size - 1).map(
#     # lambda index: data[index:index+block_size], data[index+1:index+block_size+1]
#     lambda index: data[index:index+block_size]
# )

from pathlib import Path

with open('input.txt') as f:
    tiny_shakespere = list(f.readlines())

char_frequencies = Counter(chain(*tiny_shakespere))
char_vocab = dict(zip(char_frequencies.keys(), range(len(char_frequencies))))

def make_dataset(data, block_size, vocab):
    def indices(phrase):
        return list(map(char_vocab.get, phrase))

    @tf.numpy_function(Tout=(tf.int64, tf.int64))
    def get_item(index):
        X = data[index:index+block_size]
        y = data[index+1:index+block_size+1]
        return indices(X), indices(y) 
    
    return tf.data.Dataset.range(len(data) - block_size).map(get_item)

ds = make_dataset(data, block_size, char_vocab)
X, y = next(iter(ds))

print(
    f'\n\nTrain Set\n=============='
    f'\n{X.shape=} {y.shape=}'
    f'\n{X.numpy()=} {y.numpy()=}'
)



Train Set
X.shape=TensorShape([5]) y.shape=TensorShape([5])
X.numpy()=array([22,  8, 28, 28, 14]) y.numpy()=array([ 8, 28, 28, 14, 25])


In [None]:
import tensorflow as tf

from tensorflow.keras import layers

class RotaryPositionalEncodings(layers.Layer):
    def __init__(self, dims, block_size, base=10000.):
        super(RotaryPositionalEncodings, self).__init__()
        
        self.dims = dims
        self.base = base
        self.block_size = block_size

        self.theta = 1. / (self.base ** (tf.range(0, self.dims, 2, dtype=tf.float32) / self.dims))
        self.positions = tf.range(self.block_size, dtype=tf.float32)
        self.mtheta = self.positions[..., None]*self.theta[None, ...]
        self.mtheta_paired = tf.concat([self.mtheta]*2, axis=-1)

        self.cos_mtheta = tf.math.cos(self.mtheta_paired)
        self.sin_mtheta = tf.math.sin(self.mtheta_paired)

    def call(self, x):
        B, T = x.shape[:2]

        x_real = x*self.cos_mtheta[None, :T, None, ...]
        # print(f'{x_real=}\n{self.cos_mtheta[:T]=}')
        x_img = tf.concat(
            [-x[..., self.dims//2:], x[..., :self.dims//2]],
            axis=-1
        )*self.sin_mtheta[None, :T, None, ...]
        # print(f'{x_img=}\n{self.sin_mtheta[:T]=}')
        output = x_real + x_img
        return output

B, T, H, D = 2, 8, 4, 4
B_i, T_i, H_i, D_i = 0, 1, 0, 0
x = tf.random.uniform((B, T, H, D))
l = RotaryPositionalEncodings(4, 64)
output = l(x)
# tf.print(f'{l.theta=} {l.theta.shape}')
# tf.print(f'{l.positions=} {l.positions.shape}')
tf.print(f'{x[B_i]=}')
tf.print(f'{l.cos_mtheta=}')
tf.print(f'{l.sin_mtheta=}')
tf.print(f'{output[B_i]=}\n{output[B_i]=} {output.shape}')

x[B_i]=<tf.Tensor: shape=(8, 4, 4), dtype=float32, numpy=
array([[[0.31179297, 0.8263413 , 0.6849456 , 0.0067091 ],
        [0.78749514, 0.3906511 , 0.29263055, 0.99216926],
        [0.95810425, 0.55623317, 0.16466296, 0.13445711],
        [0.13229859, 0.5348098 , 0.57090175, 0.50970507]],

       [[0.48252344, 0.15580535, 0.3703227 , 0.49210668],
        [0.567016  , 0.2077086 , 0.18223882, 0.99883735],
        [0.36950588, 0.37927854, 0.7723117 , 0.68211746],
        [0.39932835, 0.7840713 , 0.67880154, 0.73395896]],

       [[0.5520444 , 0.10948515, 0.6487982 , 0.9890779 ],
        [0.8203654 , 0.70470357, 0.9578625 , 0.02297425],
        [0.93598676, 0.6513264 , 0.31663585, 0.00111556],
        [0.9212191 , 0.3822806 , 0.77246034, 0.91514194]],

       [[0.5751133 , 0.793342  , 0.4289763 , 0.19118965],
        [0.7452506 , 0.41762006, 0.8173511 , 0.20117116],
        [0.6457157 , 0.16484237, 0.4484123 , 0.6057888 ],
        [0.816115  , 0.4129653 , 0.2632984 , 0.19087589]],

      

In [358]:
import math

import tensorflow as tf
import numpy as np

from tensorflow.keras import layers

class RotaryPositionalEncodings(layers.Layer):
    def __init__(self, dims, block_size, base=10000.):
        super(RotaryPositionalEncodings, self).__init__()
        
        self.dims = dims
        self.base = base
        self.block_size = block_size

        self.theta = 1. / (self.base ** (tf.range(0, self.dims, 2, dtype=tf.float32) / self.dims))
        self.positions = tf.range(self.block_size, dtype=tf.float32)
        self.mtheta = self.positions[..., None]*self.theta[None, ...]
        self.mtheta_paired = tf.concat([self.mtheta]*2, axis=-1)

        self.cos_mtheta = tf.math.cos(self.mtheta_paired)
        self.sin_mtheta = tf.math.sin(self.mtheta_paired)

    def call(self, x):
        B, T = x.shape[:2]

        x_real = x*self.cos_mtheta[None, :T, None, ...]
        # print(f'{x_real=}\n{self.cos_mtheta[:T]=}')
        x_img = tf.concat(
            [-x[..., self.dims//2:], x[..., :self.dims//2]],
            axis=-1
        )*self.sin_mtheta[None, :T, None, ...]
        # print(f'{x_img=}\n{self.sin_mtheta[:T]=}')
        output = x_real + x_img
        return output

class KVCache(object):
    def __init__(self, cache_size, block_size, heads, head_dims):
        super(KVCache, self).__init__()

        cache_shape = (cache_size, block_size, heads, head_dims)
        self.cache_k = np.zeros(cache_shape, dtype='float32')
        self.cache_v = np.zeros(cache_shape, dtype='float32')

    def update(self, start, xk, xv):
        batch_size, seq_len = xk.shape[:2]
        # self.cache_k[:batch_size, start:start+seq_len] = xk.numpy()
        # self.cache_v[:batch_size, start:start+seq_len] = xv.numpy()
        self.cache_k[:batch_size, start:start+seq_len] = xk
        self.cache_v[:batch_size, start:start+seq_len] = xv

    def get(self, batch_size, seq_len, start):
        keys = tf.constant(self.cache_k[:batch_size, :start+seq_len])
        values = tf.constant(self.cache_v[:batch_size, :start+seq_len])
        return keys, values

class GroupedQueryAttention(layers.Layer):
    def __init__(self, block_size, heads, kv_heads, dims, cache_size):
        super(GroupedQueryAttention, self).__init__()
        
        self.heads = heads
        self.dims = dims
        self.head_dims = dims // self.heads
        self.kv_heads = kv_heads if kv_heads else heads
        self.block_size = block_size
        self.cache_size = cache_size
        
        self.wq = layers.Dense(self.dims, use_bias=False)
        self.wk = layers.Dense(self.kv_heads * self.head_dims, use_bias=False)
        self.wv = layers.Dense(self.kv_heads * self.head_dims, use_bias=False)
        self.wo = layers.Dense(self.dims, use_bias=False)
        
        self.cache = KVCache(self.cache_size, self.block_size, self.kv_heads, self.head_dims)
        self.rope = RotaryPositionalEncodings(self.head_dims, self.block_size)
    
    def call(self, x, start):
        # print(f'{x=}\n{start=}')
        shape = tf.shape(x)
        B = shape[0]
        T = shape[1]

        # (B, T, dims)
        q, k, v = self.wq(x), self.wk(x), self.wv(x)

        # (B, T, heads/kv_heads, head_dims)
        xq = tf.reshape(q, (B, T, self.heads, self.head_dims))
        xk = tf.reshape(k, (B, T, self.kv_heads, self.head_dims))
        xv = tf.reshape(v, (B, T, self.kv_heads, self.head_dims))

        # Apply RoPE
        # (B, T, heads/kv_heads, head_dims)
        xq = self.rope(xq)
        xk = self.rope(xk)

        # Update KV cache
        self.cache.update(start, xk, xv)
        
        # Get prefix context from the cache.
        # (B, start+T, kv_heads, head_dims)
        keys, values = self.cache.get(B, T, start)

        # Expand kv_heads to heads
        # (B, start+T, heads, head_dims)
        # print(f'{keys.shape=} {values.shape=} {keys.dtype=}')
        keys = tf.tile(keys, multiples=(1, 1, self.heads//self.kv_heads, 1))
        values = tf.tile(values, multiples=(1, 1, self.heads//self.kv_heads, 1))
        
        # Transpose xq, keys and values to (B, heads, T/start+T, head_dims)
        xq = tf.transpose(xq, perm=[0, 2, 1, 3])
        xk = tf.transpose(keys, perm=[0, 2, 1, 3])
        xv = tf.transpose(values, perm=[0, 2, 1, 3])
        
        # Multiply xq and xk to compute attention matrix.
        # (B, heads, T, head_dims) @ (B, heads, head_dims, start+T) -> (B, heads, T, start+T)
        xa = xq @ tf.transpose(xk, perm=[0, 1, 3, 2]) / math.sqrt(self.head_dims)
        scores = tf.nn.softmax(xa)

        # Scale Values and compute output
        # (B, heads, T, start+T) @ (B, heads, start+T, head_dims) -> (B, heads, T, head_dims)
        output = scores @ xv
        
        # Reshape output to (B, T, dims)
        output = tf.reshape(
            tf.transpose(output, perm=[0, 2, 1, 3]),
            shape=(B, T, -1)
        )

        return self.wo(output)

l = GroupedQueryAttention(
    block_size=16,
    heads=4,
    kv_heads=2,
    dims=16,
    cache_size=4
)

B, T, C = 2, 2, 16
start = 1
x = tf.reshape(tf.range(B*T*C), (B, T, C))
output = l(x, start)

B_i, T_i, C_i = 0, 0, 3
H_i, F_i = 0, C_i

print(
    f'{x[B_i]=}'
    f'\n{l.cache.cache_k.shape=}\n{l.cache.cache_k[B_i, :start+T]=}'
    f'\n{l.cache.cache_v[B_i, :start+T]=}'
)

x[B_i]=<tf.Tensor: shape=(2, 16), dtype=int32, numpy=
array([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15],
       [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]],
      dtype=int32)>
l.cache.cache_k.shape=(4, 16, 2, 4)
l.cache.cache_k[B_i, :start+T]=array([[[  0.       ,   0.       ,   0.       ,   0.       ],
        [  0.       ,   0.       ,   0.       ,   0.       ]],

       [[ 11.398402 ,  11.38393  ,  -5.348648 ,  12.098706 ],
        [  1.0817871, -16.598928 ,  -1.4157329,  14.4951935]],

       [[ 23.552362 ,  21.765985 ,   8.511203 ,  37.022823 ],
        [  8.715925 , -41.078644 , -12.075365 ,  37.266632 ]]],
      dtype=float32)
l.cache.cache_v[B_i, :start+T]=array([[[  0.        ,   0.        ,   0.        ,   0.        ],
        [  0.        ,   0.        ,   0.        ,   0.        ]],

       [[  3.5393999 ,   2.8608413 ,  -3.3473425 ,  -0.45976973],
        [-10.767759  , -10.359337  ,   5.055772  ,  10.362978  ]],

       [[ 19.4

In [267]:
import tensorflow as tf

from tensorflow.keras import layers

class RMSNorm(layers.Layer):
    def __init__(self, eps=1e-8):
        super(RMSNorm, self).__init__()
        self.eps = eps
    
    def build(self, input_shape):
        self.w = self.add_weight(
            "kernel",
            shape=[int(input_shape[-1])],
            initializer='ones',
            trainable=True,
        )
    
    def norm(self, x):
        return x*tf.math.rsqrt(tf.math.reduce_mean(x**2, axis=-1, keepdims=True) + self.eps)
        
    def call(self, inputs):
        return self.w * self.norm(inputs)


batch, seqlen, dims = 5, 9, 7
x = tf.random.uniform((batch, seqlen, dims))
l = RMSNorm()
l.build(x.shape)
result = l(x)

print(
    f'{x.shape=} {result.shape=}'
    f'{x[:2]=} {result[:2]=}'
)

x.shape=TensorShape([5, 9, 7]) result.shape=TensorShape([5, 9, 7])x[:2]=<tf.Tensor: shape=(2, 9, 7), dtype=float32, numpy=
array([[[0.31179297, 0.8263413 , 0.6849456 , 0.0067091 , 0.78749514,
         0.3906511 , 0.29263055],
        [0.99216926, 0.95810425, 0.55623317, 0.16466296, 0.13445711,
         0.13229859, 0.5348098 ],
        [0.57090175, 0.50970507, 0.48252344, 0.15580535, 0.3703227 ,
         0.49210668, 0.567016  ],
        [0.2077086 , 0.18223882, 0.99883735, 0.36950588, 0.37927854,
         0.7723117 , 0.68211746],
        [0.39932835, 0.7840713 , 0.67880154, 0.73395896, 0.5520444 ,
         0.10948515, 0.6487982 ],
        [0.9890779 , 0.8203654 , 0.70470357, 0.9578625 , 0.02297425,
         0.93598676, 0.6513264 ],
        [0.31663585, 0.00111556, 0.9212191 , 0.3822806 , 0.77246034,
         0.91514194, 0.5751133 ],
        [0.793342  , 0.4289763 , 0.19118965, 0.7452506 , 0.41762006,
         0.8173511 , 0.20117116],
        [0.6457157 , 0.16484237, 0.4484123 , 0.605788

In [275]:
import tensorflow as tf

from tensorflow.keras import layers

class FeedForward(layers.Layer):
    def __init__(self, dims):
        super(FeedForward, self).__init__()

        self.hidden_dims = 4*dims
        self.linear_1 = layers.Dense(self.hidden_dims, use_bias=False, activation='swish')
        self.linear_2 = layers.Dense(self.hidden_dims, use_bias=False)
        self.linear_3 = layers.Dense(dims, use_bias=False)
    
    def call(self, x):
        # (..., dims) -> (..., hidden_dims)
        x = self.linear_1(x)*self.linear_2(x)
        
        # (..., hidden_dims) -> (..., dims)
        x = self.linear_3(x)
        
        return x

B, T, C = 2, 2, 16
l = FeedForward(C)
x = tf.reshape(tf.range(B*T*C, dtype='float32'), (B, T, C))
l(x)

<tf.Tensor: shape=(2, 2, 16), dtype=float32, numpy=
array([[[-2.0513205e+01,  1.0806817e+01, -9.1668015e+00, -7.4922695e+00,
         -1.3153091e+01,  5.2315617e-01,  1.2746422e+01, -3.4128265e+01,
          6.7470369e+00,  2.7175690e+01, -4.5521870e+00,  1.6494555e+01,
          1.8180428e+01,  1.5493259e+00, -1.4829735e+01,  3.4519196e+01],
        [-1.0848410e+02,  4.9710754e+01,  1.2151041e+01, -6.9968826e+01,
         -6.3353088e+01, -8.1262497e+01,  2.8162182e+01, -2.5223247e+02,
         -6.3216888e+01,  1.3744359e+02, -4.6534172e+01,  2.7248346e+02,
          2.6686646e+02, -1.2055109e+02, -2.0656288e+02,  1.3071808e+02]],

       [[-2.6512726e+02,  4.4408295e+01,  2.8791975e+01, -1.2816547e+02,
         -1.7421616e+02, -2.4323286e+02,  1.2617584e+01, -6.8494702e+02,
         -1.9408469e+02,  3.7911090e+02, -1.4601923e+02,  8.5417902e+02,
          7.5281537e+02, -4.2752676e+02, -5.7983594e+02,  2.5046817e+02],
        [-4.8799042e+02,  7.7549438e+00,  5.0849304e+01, -1.8859167

In [285]:
import tensorflow as tf

from tensorflow.keras import layers

class LlamaBlock(tf.keras.Model):
    def __init__(self, block_size, heads, kv_heads, dims, cache_size):
        super(LlamaBlock, self).__init__()
        
        self.rms_1 = RMSNorm()
        self.attention = GroupedQueryAttention(block_size, heads, kv_heads, dims, cache_size)
        self.rms_2 = RMSNorm()
        self.feed_forward = FeedForward(dims)
    
    def call(self, x, start=0):
        x += self.attention(self.rms_1(x), start=start)
        x += self.feed_forward(self.rms_2(x))
        
        return x

# l = LlamaBlock(
#     block_size=16,
#     heads=4,
#     kv_heads=2,
#     dims=16,
#     cache_size=4
# )

# B, T, C = 2, 2, 16
# start = 1
# x = tf.reshape(tf.range(B*T*C, dtype='float32'), (B, T, C))
# l(x)

<tf.Tensor: shape=(2, 2, 16), dtype=float32, numpy=
array([[[ 1.5042548 , -0.06997496, -0.17430654,  0.83612585,
          4.4509845 ,  5.4747915 ,  5.0120664 ,  5.330117  ,
          8.416934  ,  9.06529   , 10.056999  , 10.308799  ,
         11.04062   , 14.062305  , 14.283695  , 13.84579   ],
        [17.730944  , 16.037317  , 15.690351  , 17.138454  ,
         20.40212   , 21.468744  , 21.2813    , 21.50022   ,
         24.659027  , 24.916517  , 26.150116  , 26.664133  ,
         27.088156  , 30.176422  , 29.790104  , 30.220867  ]],

       [[33.831486  , 31.854721  , 31.671326  , 33.34517   ,
         36.32003   , 37.277905  , 37.55028   , 37.49979   ,
         40.585087  , 40.70484   , 42.30504   , 42.776077  ,
         43.12911   , 46.41141   , 45.518208  , 46.99089   ],
        [49.839653  , 47.869476  , 47.66023   , 49.371593  ,
         52.30976   , 53.274467  , 53.573917  , 53.524033  ,
         56.58298   , 56.68282   , 58.309742  , 58.80591   ,
         59.12888   , 62.410

In [290]:
import tensorflow as tf

from tensorflow.keras import layers
from functools import reduce

class LlamaModel(tf.keras.Model):
    def __init__(
        self,
        vocab_size, encoders, dims,
        block_size, heads, kv_heads, cache_size,
    ):
        super(LlamaModel, self).__init__()
        self.embeddings = layers.Embedding(vocab_size, dims)
        self.enc_blocks = [
            LlamaBlock(block_size, heads, kv_heads, dims, cache_size)
            for enc_id in range(encoders)
        ]
        self.rms = RMSNorm()
        self.head = layers.Dense(vocab_size, use_bias=False)
    
    def call(self, x, start=0):
        B, T = x.shape
        x_embed = self.embeddings(x)
        x = reduce(
            lambda y,enc_block: enc_block(y, start=start),
            self.enc_blocks,
            x_embed
        )
        x = self.rms(x)
        x = self.head(x)
        
        return x
    
    def train_step(self, data):
        X, y = data

        with tf.GradientTape() as tape:
            logits = self(X)
            loss = self.compute_loss(y, logits)

        grads = tape.gradient(loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        
        return compute_metrics(loss, y, logits)
    
    def test_step(self, data):
        X, y = data

        logits = self(X)
        loss = self.compute_loss(y, logits)
        
        return compute_metrics(loss, y, logits)
    
    def compute_metrics(self, loss, y, logits):
        for metric in self.metrics:
            if metric.name == "loss":
                metric.update_state(loss)
            else:
                metric.update_state(y, y_pred)

        return {m.name: m.result() for m in self.metrics}

m = LlamaModel(
    vocab_size=1000,
    encoders=2,
    block_size=16,
    heads=4,
    kv_heads=2,
    dims=16,
    cache_size=4
)

# m.compile()

B, T = 2, 2
start = 1
x = tf.reshape(tf.range(B*T), (B, T))
m(x)
# m.summary(expand_nested=True)

<tf.Tensor: shape=(2, 2, 1000), dtype=float32, numpy=
array([[[-0.1239534 , -0.0607941 , -0.26741952, ..., -0.17009711,
          0.08234538, -0.09490468],
        [-0.06256978, -0.01165929, -0.35023674, ..., -0.11410779,
         -0.01988001, -0.10182261]],

       [[-0.00436667,  0.2219006 , -0.09742606, ...,  0.06465681,
         -0.02983708, -0.02597765],
        [-0.02063455,  0.18581536, -0.16156226, ...,  0.05736311,
          0.00983013, -0.05446468]]], dtype=float32)>

In [295]:
from tensorflow.keras import losses, optimizers

loss_fn = losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = optimizers.legacy.Adam()

TODO

* [X] loss and accuracy metrics.
* [X] training, validation and inference flags.
* Match dictionary with embedding size.

In [815]:
import tensorflow as tf

from tensorflow.keras import layers, metrics
from functools import reduce

import tensorflow as tf

class KVCache(object):
    def __init__(self, cache_size, block_size, heads, head_dims):
        super(KVCache, self).__init__()
        self.block_size = block_size

        cache_shape = (cache_size, block_size, heads, head_dims)

        with tf.device('/device:CPU:0'):
            self.cache_k = tf.Variable(tf.zeros(cache_shape), shape=cache_shape, trainable=False)
            self.cache_v = tf.Variable(tf.zeros(cache_shape), shape=cache_shape, trainable=False)

    def update(self, start, xk, xv):
        shape = tf.shape(xk)
        B = shape[0]
        T = shape[1]

        # Calculate update start and end positions
        start = start%self.block_size
        end = (start + T)%(self.block_size + 1)

        # start < end: It is a single cache update.
        # end > start: It is a split cache update.
        if start < end:
            self.cache_k[:B, start:start+T].assign(xk)
            self.cache_v[:B, start:start+T].assign(xv)
        else:
            # Update cache with partial sequence that fits towards the end.
            self.cache_k[:B, start:].assign(xk[:, :-(end+1)])
            self.cache_v[:B, start:].assign(xv[:, :-(end+1)])

            # Splillover sequence is cached towards the front of the cache.
            self.cache_k[:B, :end+1].assign(xk[:, -(end+1):])
            self.cache_v[:B, :end+1].assign(xv[:, -(end+1):])

    # TODO:: Update the callers to reflect the args order change.
    def get(self, batch_size, start, seq_len):
        # Calculate update start and end positions
        start = start%self.block_size
        end = (start + seq_len)%(self.block_size + 1)

        # start < end: It is a single cache fetch.
        # end > start: It is a split cache fetch.
        if start < end:
            keys = self.cache_k[:batch_size, :start+seq_len]
            values = self.cache_v[:batch_size, :start+seq_len]
        else:
            # Fetch sequence prefix
            keys_1 = self.cache_k[:, (end+1):]
            values_1 = self.cache_k[:, (end+1):]

            # Fetch sequence suffix
            keys_2 = self.cache_k[:, :(end+1):]
            values_2 = self.cache_k[:, :(end+1):]

            # Compose the whole sequence
            keys = tf.concat([keys_1, keys_2], axis=1)
            values = tf.concat([values_1, values_2], axis=1)

        return keys, values

def update_and_show(cache, batch_size, start, seqlen, data_shape, msg, debug=False):
    xk = tf.random.uniform((batch_size, seqlen, *data_shape))
    xv = tf.random.uniform((batch_size, seqlen, *data_shape))

    if debug:
        print(
            f'\n{msg}::xk:\n{xk}'
        )

    cache.update(start, xk, xv)
    print(
        f'\n{msg}:\n{cache.cache_k}'
    )

# cache_size, block_size, heads, head_dims = 1, 8, 1, 4
# cache = KVCache(cache_size, block_size, heads, head_dims)
# data_shape = (heads, head_dims)

# # print(f'\nInitial:\n{cache.cache_k}')

# # update_and_show(cache, cache_size, 2, 4, data_shape, 'InUpdate(2, 4)')
# # update_and_show(cache, cache_size, 4, 4, data_shape, 'EndUpdate(4, 4)')
# # update_and_show(cache, cache_size, 6, 4, data_shape, 'SpilledUpdate(6, 4)', debug=True)
# update_and_show(cache, cache_size, 6, 8, data_shape, 'Fill(6, 8)')
# print(
#     # f'\nInQuery:\n{cache.get(cache_size, 0, 4)[0]}'
#     # f'\n\nEndQuery:\n{cache.get(cache_size, 4, 4)[0]}'
#     f'\n\nSpilledQuery:\n{cache.get(cache_size, 6, 2)[0]}'
# )

class GroupedQueryAttention(tf.keras.Model):
    def __init__(self, block_size, heads, kv_heads, dims, cache_size):
        super(GroupedQueryAttention, self).__init__()
        
        self.heads = heads
        self.dims = dims
        self.head_dims = dims // self.heads
        self.kv_heads = kv_heads or heads
        self.block_size = block_size
        self.cache_size = cache_size
        
        self.wq = layers.Dense(self.dims, use_bias=False)
        self.wk = layers.Dense(self.kv_heads * self.head_dims, use_bias=False)
        self.wv = layers.Dense(self.kv_heads * self.head_dims, use_bias=False)
        self.wo = layers.Dense(self.dims, use_bias=False)
        
        self.cache = KVCache(self.cache_size, self.block_size, self.kv_heads, self.head_dims)
        self.rope = RotaryPositionalEncodings(self.head_dims, self.block_size)
    
    def call(self, x, start=0, use_cache=False):
        # print(f'{x=}\n{start=}')
        shape = tf.shape(x)
        B = shape[0]
        T = shape[1]

        # (B, T, dims)
        q, k, v = self.wq(x), self.wk(x), self.wv(x)

        # (B, T, heads/kv_heads, head_dims)
        xq = tf.reshape(q, (B, T, self.heads, self.head_dims))
        xk = tf.reshape(k, (B, T, self.kv_heads, self.head_dims))
        xv = tf.reshape(v, (B, T, self.kv_heads, self.head_dims))

        # Apply RoPE
        # (B, T, heads/kv_heads, head_dims)
        xq = self.rope(xq)
        xk = self.rope(xk)

        if use_cache:
            # Update KV cache
            self.cache.update(start, xk, xv)
            
            # Get prefix context from the cache.
            # (B, start+T, kv_heads, head_dims)
            keys, values = self.cache.get(B, start, T)
        else:
            assert start == 0
            keys, values = xk, xv

        # Expand kv_heads to heads
        # (B, start+T, heads, head_dims)
        # print(f'{keys.shape=} {values.shape=} {keys.dtype=}')
        keys = tf.tile(keys, multiples=(1, 1, self.heads//self.kv_heads, 1))
        values = tf.tile(values, multiples=(1, 1, self.heads//self.kv_heads, 1))
        
        # Transpose xq, keys and values to (B, heads, T/start+T, head_dims)
        xq = tf.transpose(xq, perm=[0, 2, 1, 3])
        xk = tf.transpose(keys, perm=[0, 2, 1, 3])
        xv = tf.transpose(values, perm=[0, 2, 1, 3])
        
        # Multiply xq and xk to compute attention matrix.
        # (B, heads, T, head_dims) @ (B, heads, head_dims, start+T) -> (B, heads, T, start+T)
        xa = xq @ tf.transpose(xk, perm=[0, 1, 3, 2]) / math.sqrt(self.head_dims)

        # Compute softmax scores.
        if use_cache:
            scores = tf.nn.softmax(xa)
        else:
            # If cache is not used, apply auto-regressive mask to block forward looking
            tril = tf.linalg.band_part(tf.ones((T, T), dtype=tf.float32), -1, 0)
            scores = tf.math.softmax(tf.where(tril > 0.0, xa, float('-inf')))

        # Scale Values and compute output
        # (B, heads, T, start+T) @ (B, heads, start+T, head_dims) -> (B, heads, T, head_dims)
        output = scores @ xv
        
        # Reshape output to (B, T, dims)
        output = tf.reshape(
            tf.transpose(output, perm=[0, 2, 1, 3]),
            shape=(B, T, self.dims)
        )

        return self.wo(output)

class LlamaBlock(tf.keras.Model):
    def __init__(self, block_size, heads, kv_heads, dims, cache_size, *args, **kwargs):
        super(LlamaBlock, self).__init__(*args, **kwargs)
        
        self.rms_1 = RMSNorm()
        self.attention = GroupedQueryAttention(block_size, heads, kv_heads, dims, cache_size)
        self.rms_2 = RMSNorm()
        self.feed_forward = FeedForward(dims)
    
    def call(self, x, start=0, use_cache=False):
        x += self.attention(self.rms_1(x), start=start, use_cache=use_cache)
        x += self.feed_forward(self.rms_2(x))
        
        return x

class LlamaModel(tf.keras.Model):
    def __init__(
        self,
        vocab_size, decoders, dims,
        block_size, heads, kv_heads, cache_size,
        loss_fn,
        *args, **kwargs,
    ):
        super(LlamaModel, self).__init__(*args, **kwargs)
        # Args
        self.block_size = block_size
        self.vocab_size = vocab_size
        self.cache_size = cache_size
        
        # Model elements
        self.embeddings = layers.Embedding(vocab_size, dims)
        self.dec_blocks = [
            LlamaBlock(block_size, heads, kv_heads, dims, cache_size)
            for _ in range(decoders)
        ]
        self.rms = RMSNorm()
        self.head = layers.Dense(vocab_size, use_bias=False)

        # Loss 
        self.loss_fn = loss_fn

        # Metrics
        self.loss_tracker = metrics.Mean(name="loss")
        self.acc_tracker = metrics.SparseCategoricalAccuracy(name='accuracy')
    
    def call(self, x, start=0, use_cache=False):
        x_embed = self.embeddings(x)
        x = reduce(
            lambda y,dec_block: dec_block(y, start=start, use_cache=use_cache),
            self.dec_blocks,
            x_embed
        )
        x = self.rms(x)
        x = self.head(x)
        
        return x
    
    def train_step(self, data):
        X, y = data
        y = tf.cast(y, dtype=tf.float32)

        with tf.GradientTape() as tape:
            logits = self(X)
            loss = loss_fn(y, logits)

        grads = tape.gradient(loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        
        return self.record(loss, y, logits)
    
    def test_step(self, data):
        X, y = data

        logits = self(X)
        loss = loss_fn(y, logits)
    
        return self.record(loss, y, logits)
    
    @property
    def metrics(self):
        return [self.loss_tracker, self.acc_tracker]
    
    def record(self, loss, y, logits):
        self.loss_tracker.update_state(loss)
        self.acc_tracker.update_state(y, tf.math.softmax(logits))

        return {m.name: m.result() for m in self.metrics}

block_size = 16
m = LlamaModel(
    vocab_size=len(char_vocab),
    decoders=1,
    block_size=block_size,
    heads=4,
    kv_heads=2,
    dims=16,
    cache_size=4,
    loss_fn = loss_fn
)

m(tf.random.uniform((2, block_size), minval=0, maxval=len(char_vocab), dtype=tf.int32))
m.compile(optimizer=optimizer, loss=loss_fn)

# # m.summary(expand_nested=True)

train_ds = make_dataset(text[:100], block_size, char_vocab).batch(32)
valid_ds = make_dataset(text[100:200], block_size, char_vocab).batch(32)
# # X, y = next(iter(ds))

history = m.fit(train_ds, validation_data=valid_ds)
# decode([0].numpy())

W0000 00:00:1707970084.179407       1 op_level_cost_estimator.cc:699] Error in PredictCost() for the op: op: "Softmax" attr { key: "T" value { type: DT_FLOAT } } inputs { dtype: DT_FLOAT shape { unknown_rank: true } } device { type: "GPU" } outputs { dtype: DT_FLOAT shape { unknown_rank: true } }




W0000 00:00:1707970096.618879       1 op_level_cost_estimator.cc:699] Error in PredictCost() for the op: op: "Softmax" attr { key: "T" value { type: DT_FLOAT } } inputs { dtype: DT_FLOAT shape { unknown_rank: true } } device { type: "GPU" } outputs { dtype: DT_FLOAT shape { unknown_rank: true } }




In [655]:
def generate(model, token_idx, tokens, block_size):
    sequence = token_idx
    for start in range(tokens):
        logits = model(token_idx, start=start, use_cache=True)[:, -1, :]
        token_idx =  tf.random.categorical(logits, 1, dtype=tf.int32)
        sequence = tf.concat([sequence, token_idx], axis=-1)
    
    return sequence

idx_to_str = {i:s for i, s in enumerate(char_vocab.keys())}
decoder = lambda x: ''.join([idx_to_str[i] for i in x])

decoder(generate(m, tf.random.uniform((1, 1), maxval=m.vocab_size, dtype=tf.int32), tokens=16, block_size=16)[0].numpy())

';L3ETeThypx?WDnnq'

In [861]:
import pandas as pd
import numpy as np

def get_text_stats(text, tokenizer, fns=[]):
    def stats_fn(token):
        return np.array([token.lower_] + [fn(token) for fn in fns])
    
    tokens = tokenizer(text)
    stats = np.stack([stats_fn(token) for token in tokens])

    return stats

def generate(model, token_idx, tokens):
    sequence = token_idx
    for start in range(tokens):
        logits = model(token_idx, start=start, use_cache=True)[:, -1, :]
        token_idx =  tf.random.categorical(logits, 1, dtype=tf.int32)
        sequence = tf.concat([sequence, token_idx], axis=-1)
    
    return sequence

def generate_and_evaluate(tokenizer, model_vocab, word_vocab, generator, model, tokens=50):
    idx_to_char = {i:c for c,i in model_vocab.items()}
    decoder = lambda x: ''.join([idx_to_char[i] for i in x])

    starter = tf.constant([model_vocab[' ']], shape=(1, 1), dtype=tf.int32)
    generated_text = decoder(
        generator(
            model,
            starter,
            tokens=tokens,
        )[0].numpy()
    )

    
    stats = get_text_stats(
        generated_text, tokenizer,
        fns=[
            lambda token: token.lower_ in word_vocab,
        ]
    )

    matches = stats[:, 0][stats[:, 1] == 'True']
    mismatches = stats[:, 0][stats[:, 1] == 'False']
    char_match_count = np.char.str_len(matches).sum()
    char_mismatch_count = np.char.str_len(mismatches).sum()
    
    return {
        'generated_text': generated_text,
        'dictwords': set(matches),
        'non-dictwords': set(mismatches),
        'char_match_count': char_match_count,
        'char_mismatch_count': char_mismatch_count,
    }

word_vocab = get_word_1gram_frequencies(tiny_shakespere, word_tokenizer)
generate_and_evaluate(word_tokenizer, char_vocab, word_vocab, generate, m)

{'generated_text': " K$:EE3Tyy-HriUzsFPQ!kr.c.mah\n:W'k . qv'zxWr$Lrs'WC",
 'dictwords': {'\n', ' ', '-', '.', ':'},
 'non-dictwords': {'hriuzsfpq!kr.c.mah', 'k$:ee3tyy', "qv'zxwr$lrs'wc", "w'k"},
 'char_match_count': 5,
 'char_mismatch_count': 44}

In [860]:
get_text_stats(
    'hello and welcome', word_tokenizer,
    fns=[
        lambda token: token.lower_ in word_vocab,
    ]
)

# word_vocab = get_word_1gram_frequencies(tiny_shakespere, word_tokenizer)
# list(word_vocab.keys())[:5]

array([['hello', 'False'],
       ['and', 'True'],
       ['welcome', 'True']], dtype='<U7')

In [758]:
%%timeit

def generate_many(model_vocab, generator, model, examples=4, tokens=50):
    starters = tf.random.uniform((examples, 1), maxval=model.vocab_size, dtype=tf.int32)
    generator(
        model,
        starters,
        tokens=tokens,
    )

generate_many(char_vocab, generate, m, examples=m.cache_size, tokens=100)

1.73 s ± 21.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [759]:
%%timeit

def generate_no_cache(model, token_idx, tokens):
    sequence = token_idx
    for _ in range(tokens):
        logits = model(sequence[:, -model.block_size:])[:, -1, :]
        token_idx =  tf.random.categorical(logits, 1, dtype=tf.int32)
        sequence = tf.concat([sequence, token_idx], axis=-1)
    
    return sequence

generate_many(char_vocab, generate_no_cache, m, examples=m.cache_size, tokens=100)

2.14 s ± 74.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [684]:
import pandas as pd

from tensorflow.keras import callbacks
from functools import partial

class ProgressEvaluation(callbacks.Callback):
    def __init__(self, steps, eval_fn, verbose=False):
        super(ProgressEvaluation, self).__init__()
        self.steps = steps
        self.eval_fn = eval_fn
        self.stats = pd.DataFrame(
            columns=[
                'epoch', 'step', 'loss',
                'dictwords', 'non-dictwords',
                'char_match_count', 'char_mismatch_count',
                'generated_text',
            ]
        )
        self.verbose = verbose
    
    def on_train_begin(self, logs=None):
        self.step = 0
    
    def on_epoch_begin(self, epoch, logs=None):
        self.epoch = epoch

    def on_train_batch_end(self, batch, logs={}):
        self.step += 1

        # Log results every 'steps' steps
        if self.step % self.steps == 0:
            results = self.eval_fn(self.model)
            results.update({
                'epoch': self.epoch,
                'step': self.step,
                'loss': logs.get("loss"),
            })

            self.stats.loc[len(self.stats)] = pd.Series(results)

            if self.verbose:
                char_match_rate = results['char_match_count'] / (results['char_match_count'] + results['char_mismatch_count'])

                print(
                    f"\nEpoch: {results['epoch']} Step: {results['step']} Loss={results['loss']:.3}"
                    f"\nGenerated: {results['generated_text']}"
                    f"\n\nStatistics\n==============="
                    f"\nDictionary Words: {results['dictwords']}"
                    f"\nNon-dictionary Words: {results['non-dictwords']}"
                    f"\nChar Matches: {results['char_match_count']}"
                    f"\nChar Mismatches: {results['char_mismatch_count']}"
                    f"\nChar Match Rate: {char_match_rate:.2%}"
                )

eval_fn = partial(generate_and_evaluate, word_tokenizer, word_vocab, generate, decoder)
# eval_cb = ProgressEvaluation(steps=2000, eval_fn=eval_fn)
eval_cb = ProgressEvaluation(steps=1, eval_fn=eval_fn)

train_ds = make_dataset(text[:100], block_size, char_vocab).batch(32)

m.compile(optimizer=optimizer, loss=loss_fn)
history = m.fit(train_ds, callbacks=[eval_cb], epochs=2)

Epoch 1/2


W0000 00:00:1707885216.674606       1 op_level_cost_estimator.cc:699] Error in PredictCost() for the op: op: "Softmax" attr { key: "T" value { type: DT_FLOAT } } inputs { dtype: DT_FLOAT shape { unknown_rank: true } } device { type: "GPU" } outputs { dtype: DT_FLOAT shape { unknown_rank: true } }


Epoch 2/2


In [696]:
# eval_cb.stats.to_pickle('train.progress.stats.pickle')
m.save_weights('llama.tf')

In [341]:
import tensorflow as tf

from collections import Counter
from itertools import chain

data = 'helloksdhglsadjgalsdhglaksddd'
block_size = 5

# ds = tf.data.Dataset.range(len(data) - block_size - 1).map(
#     # lambda index: data[index:index+block_size], data[index+1:index+block_size+1]
#     lambda index: data[index:index+block_size]
# )
with open('input.txt') as f:
    tiny_shakespere = ''.join(f.readlines())

char_frequencies = Counter(chain(*tiny_shakespere))
char_vocab = dict(zip(char_frequencies.keys(), range(len(char_frequencies))))

def make_dataset(data, block_size, vocab):
    def indices(phrase):
        return list(map(vocab.get, phrase))

    @tf.numpy_function(Tout=(tf.int64, tf.int64))
    def get_numpy_item(index):
        X, y = data[index:index+block_size], data[index+1:index+block_size+1]
        return indices(X), indices(y)

    def get_item(index):
        X, y = get_numpy_item(index)
        return tf.reshape(X, [block_size]), tf.reshape(y, [block_size])

ds = make_dataset(data, block_size, char_vocab).batch(2)
X, y = next(iter(ds))

print(
    f'\n\nTrain Set\n=============='
    f'\n{X.shape=} {y.shape=} {X} {y}'
    f'\n{X.numpy()=} {y.numpy()=}'
)



Train Set
X.shape=TensorShape([2, 5]) y.shape=TensorShape([2, 5]) [[22  8 28 28 14]
 [ 8 28 28 14 25]] [[ 8 28 28 14 25]
 [28 28 14 25  3]]
X.numpy()=array([[22,  8, 28, 28, 14],
       [ 8, 28, 28, 14, 25]]) y.numpy()=array([[ 8, 28, 28, 14, 25],
       [28, 28, 14, 25,  3]])


In [455]:
shape = (4, 16, 4, 4)
cache = tf.zeros(shape)

filled = tf.tensor_scatter_nd_update(
    cache, tf.stack(tf.unravel_index(range(4**3*16), dims=shape), axis=-1)[2*16*4*4:3*16*4*4], tf.reshape(tf.random.uniform(shape[1:]), -1),
)

In [515]:
# tf.unravel_index(range(4**3*16), dims=shape)
# filled[2]
# tf.stack(tf.unravel_index(range(4**3*16), dims=shape), axis=-1)[2*16*4*4:3*16*4*4]
start, seqlen = 2, 8

indices = tf.transpose(tf.unravel_index(tf.reshape(tf.reshape(tf.range(4*16*4*4), shape)[:2, start:start+seqlen], (-1)), dims=shape))

filled = tf.tensor_scatter_nd_update(
    cache, indices, tf.reshape(tf.random.uniform((2*seqlen, 4, 4)), (-1)),
)

(256, 4)


In [567]:
from collections import OrderedDict

char_vocab = OrderedDict(zip(
    char_frequencies.keys(),
    range(len(char_frequencies)),
))

In [858]:
import spacy
import itertools

import numpy as np

from collections import Counter

def get_word_1gram_frequencies(data, tokenizer):
    words = map(lambda x:x.lower_, itertools.chain(*map(tokenizer, data)))
    return Counter(words)

word_tokenizer = spacy.blank("en")
word_vocab = get_word_1gram_frequencies(tiny_shakespere, word_tokenizer)

print(
    f'{list(word_vocab.keys())[:5]}'
)

['first', 'citizen', ':', '\n', 'before']


In [855]:
import itertools

word_tokenizer = spacy.blank("en")

# words = list(map(lambda x:x.lower_, map(word_tokenizer, tiny_shakespere)))
# words[:5]
# word_tokenizer(tiny_shakespere[0])
tokens = list(itertools.chain(*map(word_tokenizer, tiny_shakespere[:2])))
len(tokens), tokens[:2]

(15, [First, Citizen])

In [604]:
block_size = 16
m = LlamaModel(
    vocab_size=len(char_vocab),
    encoders=2,
    block_size=block_size,
    heads=4,
    kv_heads=2,
    dims=16,
    cache_size=4,
    loss_fn = loss_fn
)

In [615]:
@tf.function
def generate(sequence, tokens, block_size):
    for start in range(tokens):
        sequence = sequence[:, -block_size:]
        logits = m(sequence, start=start, use_cache=True)[:, -1, :]
        prediction =  tf.random.categorical(logits, 1, dtype=tf.int32)

        sequence = tf.concat([sequence, prediction], axis=-1)
    
    return sequence

idx_to_str = {i:s for i, s in enumerate(char_vocab.keys())}
decode = lambda x: ''.join([idx_to_str[i] for i in x])

decode(generate(tf.zeros((1, 1), dtype=tf.int32), tokens=block_size, block_size=16)[0].numpy())

2024-02-13 15:37:30.381841: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 2514055488347275762
2024-02-13 15:37:30.381913: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 6864595417707083459
2024-02-13 15:37:30.382219: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 15030648477679010439
2024-02-13 15:37:30.382254: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 7802382919658029917
2024-02-13 15:37:30.382278: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 12903173487718679601
2024-02-13 15:37:30.382297: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 11760740079638168607
2024-02-13 15:37:30.382310: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv 

InvalidArgumentError: Graph execution error:

Detected at node llama_model_168/llama_block_349/grouped_query_attention_409/llama_model_168/llama_block_349/grouped_query_attention_409/strided_slice_69/_assign defined at (most recent call last):
  File "/opt/homebrew/Cellar/python@3.10/3.10.13_1/Frameworks/Python.framework/Versions/3.10/lib/python3.10/runpy.py", line 196, in _run_module_as_main

  File "/opt/homebrew/Cellar/python@3.10/3.10.13_1/Frameworks/Python.framework/Versions/3.10/lib/python3.10/runpy.py", line 86, in _run_code

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/ipykernel_launcher.py", line 17, in <module>

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/traitlets/config/application.py", line 1075, in launch_instance

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 739, in start

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 205, in start

  File "/opt/homebrew/Cellar/python@3.10/3.10.13_1/Frameworks/Python.framework/Versions/3.10/lib/python3.10/asyncio/base_events.py", line 603, in run_forever

  File "/opt/homebrew/Cellar/python@3.10/3.10.13_1/Frameworks/Python.framework/Versions/3.10/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once

  File "/opt/homebrew/Cellar/python@3.10/3.10.13_1/Frameworks/Python.framework/Versions/3.10/lib/python3.10/asyncio/events.py", line 80, in _run

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 542, in dispatch_queue

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 531, in process_one

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 437, in dispatch_shell

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 359, in execute_request

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 775, in execute_request

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 446, in do_execute

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/ipykernel/zmqshell.py", line 549, in run_cell

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3051, in run_cell

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3106, in _run_cell

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3311, in run_cell_async

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3493, in run_ast_nodes

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3553, in run_code

  File "/var/folders/8y/5694n0_n42j4fmg6d5j411480000gn/T/ipykernel_13484/851780790.py", line 15, in <module>

  File "/var/folders/8y/5694n0_n42j4fmg6d5j411480000gn/T/ipykernel_13484/851780790.py", line 3, in generate

  File "/var/folders/8y/5694n0_n42j4fmg6d5j411480000gn/T/ipykernel_13484/851780790.py", line 5, in generate

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/keras/src/engine/training.py", line 590, in __call__

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/keras/src/engine/base_layer.py", line 1149, in __call__

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 96, in error_handler

  File "/var/folders/8y/5694n0_n42j4fmg6d5j411480000gn/T/ipykernel_13484/1265064723.py", line 155, in call

  File "/var/folders/8y/5694n0_n42j4fmg6d5j411480000gn/T/ipykernel_13484/1265064723.py", line 155, in call

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/keras/src/engine/training.py", line 590, in __call__

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/keras/src/engine/base_layer.py", line 1149, in __call__

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 96, in error_handler

  File "/var/folders/8y/5694n0_n42j4fmg6d5j411480000gn/T/ipykernel_13484/1265064723.py", line 123, in call

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/keras/src/engine/training.py", line 590, in __call__

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/keras/src/engine/base_layer.py", line 1149, in __call__

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 96, in error_handler

  File "/var/folders/8y/5694n0_n42j4fmg6d5j411480000gn/T/ipykernel_13484/200258366.py", line 67, in call

  File "/var/folders/8y/5694n0_n42j4fmg6d5j411480000gn/T/ipykernel_13484/200258366.py", line 69, in call

  File "/var/folders/8y/5694n0_n42j4fmg6d5j411480000gn/T/ipykernel_13484/2638763864.py", line 22, in update

Detected at node llama_model_168/llama_block_349/grouped_query_attention_409/llama_model_168/llama_block_349/grouped_query_attention_409/strided_slice_69/_assign defined at (most recent call last):
  File "/opt/homebrew/Cellar/python@3.10/3.10.13_1/Frameworks/Python.framework/Versions/3.10/lib/python3.10/runpy.py", line 196, in _run_module_as_main

  File "/opt/homebrew/Cellar/python@3.10/3.10.13_1/Frameworks/Python.framework/Versions/3.10/lib/python3.10/runpy.py", line 86, in _run_code

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/ipykernel_launcher.py", line 17, in <module>

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/traitlets/config/application.py", line 1075, in launch_instance

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 739, in start

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 205, in start

  File "/opt/homebrew/Cellar/python@3.10/3.10.13_1/Frameworks/Python.framework/Versions/3.10/lib/python3.10/asyncio/base_events.py", line 603, in run_forever

  File "/opt/homebrew/Cellar/python@3.10/3.10.13_1/Frameworks/Python.framework/Versions/3.10/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once

  File "/opt/homebrew/Cellar/python@3.10/3.10.13_1/Frameworks/Python.framework/Versions/3.10/lib/python3.10/asyncio/events.py", line 80, in _run

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 542, in dispatch_queue

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 531, in process_one

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 437, in dispatch_shell

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 359, in execute_request

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 775, in execute_request

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 446, in do_execute

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/ipykernel/zmqshell.py", line 549, in run_cell

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3051, in run_cell

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3106, in _run_cell

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3311, in run_cell_async

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3493, in run_ast_nodes

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3553, in run_code

  File "/var/folders/8y/5694n0_n42j4fmg6d5j411480000gn/T/ipykernel_13484/851780790.py", line 15, in <module>

  File "/var/folders/8y/5694n0_n42j4fmg6d5j411480000gn/T/ipykernel_13484/851780790.py", line 3, in generate

  File "/var/folders/8y/5694n0_n42j4fmg6d5j411480000gn/T/ipykernel_13484/851780790.py", line 5, in generate

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/keras/src/engine/training.py", line 590, in __call__

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/keras/src/engine/base_layer.py", line 1149, in __call__

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 96, in error_handler

  File "/var/folders/8y/5694n0_n42j4fmg6d5j411480000gn/T/ipykernel_13484/1265064723.py", line 155, in call

  File "/var/folders/8y/5694n0_n42j4fmg6d5j411480000gn/T/ipykernel_13484/1265064723.py", line 155, in call

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/keras/src/engine/training.py", line 590, in __call__

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/keras/src/engine/base_layer.py", line 1149, in __call__

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 96, in error_handler

  File "/var/folders/8y/5694n0_n42j4fmg6d5j411480000gn/T/ipykernel_13484/1265064723.py", line 123, in call

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/keras/src/engine/training.py", line 590, in __call__

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/keras/src/engine/base_layer.py", line 1149, in __call__

  File "/Users/broxoli/.venv-tensorflow/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 96, in error_handler

  File "/var/folders/8y/5694n0_n42j4fmg6d5j411480000gn/T/ipykernel_13484/200258366.py", line 67, in call

  File "/var/folders/8y/5694n0_n42j4fmg6d5j411480000gn/T/ipykernel_13484/200258366.py", line 69, in call

  File "/var/folders/8y/5694n0_n42j4fmg6d5j411480000gn/T/ipykernel_13484/2638763864.py", line 22, in update

2 root error(s) found.
  (0) INVALID_ARGUMENT:  Cannot broadcast input shape [1,9,2,4] into final shape [1,8,2,4]
	 [[{{node llama_model_168/llama_block_349/grouped_query_attention_409/llama_model_168/llama_block_349/grouped_query_attention_409/strided_slice_69/_assign}}]]
	 [[llama_model_168/llama_block_350/grouped_query_attention_410/ReadVariableOp_34/_240]]
  (1) INVALID_ARGUMENT:  Cannot broadcast input shape [1,9,2,4] into final shape [1,8,2,4]
	 [[{{node llama_model_168/llama_block_349/grouped_query_attention_409/llama_model_168/llama_block_349/grouped_query_attention_409/strided_slice_69/_assign}}]]
0 successful operations.
0 derived errors ignored. [Op:__inference_generate_306524]

In [813]:
import tensorflow as tf

class KVCache(object):
    def __init__(self, cache_size, block_size, heads, head_dims):
        super(KVCache, self).__init__()
        self.block_size = block_size

        cache_shape = (cache_size, block_size, heads, head_dims)

        with tf.device('/device:CPU:0'):
            self.cache_k = tf.Variable(tf.zeros(cache_shape), shape=cache_shape, trainable=False)
            self.cache_v = tf.Variable(tf.zeros(cache_shape), shape=cache_shape, trainable=False)

    def update(self, start, xk, xv):
        shape = tf.shape(xk)
        B = shape[0]
        T = shape[1]

        # Calculate update start and end positions
        start = start%self.block_size
        end = (start + T)%(self.block_size + 1)

        # start < end: It is a single cache update.
        # end > start: It is a split cache update.
        if start < end:
            self.cache_k[:B, start:start+T].assign(xk)
            self.cache_v[:B, start:start+T].assign(xv)
        else:
            # Update cache with partial sequence that fits towards the end.
            self.cache_k[:B, start:].assign(xk[:, :-(end+1)])
            self.cache_v[:B, start:].assign(xv[:, :-(end+1)])

            # Splillover sequence is cached towards the front of the cache.
            self.cache_k[:B, :end+1].assign(xk[:, -(end+1):])
            self.cache_v[:B, :end+1].assign(xv[:, -(end+1):])

    # TODO:: Update the callers to reflect the args order change.
    def get(self, batch_size, start, seq_len):
        # Calculate update start and end positions
        start = start%self.block_size
        end = (start + seq_len)%(self.block_size + 1)

        # start < end: It is a single cache fetch.
        # end > start: It is a split cache fetch.
        if start < end:
            keys = self.cache_k[:batch_size, :start+seq_len]
            values = self.cache_v[:batch_size, :start+seq_len]
        else:
            # Fetch sequence prefix
            keys_1 = self.cache_k[:, (end+1):]
            values_1 = self.cache_k[:, (end+1):]

            # Fetch sequence suffix
            keys_2 = self.cache_k[:, :(end+1):]
            values_2 = self.cache_k[:, :(end+1):]

            # Compose the whole sequence
            keys = tf.concat([keys_1, keys_2], axis=1)
            values = tf.concat([values_1, values_2], axis=1)

        return keys, values

def update_and_show(cache, batch_size, start, seqlen, data_shape, msg, debug=False):
    xk = tf.random.uniform((batch_size, seqlen, *data_shape))
    xv = tf.random.uniform((batch_size, seqlen, *data_shape))

    if debug:
        print(
            f'\n{msg}::xk:\n{xk}'
        )

    cache.update(start, xk, xv)
    print(
        f'\n{msg}:\n{cache.cache_k}'
    )

cache_size, block_size, heads, head_dims = 1, 8, 1, 4
cache = KVCache(cache_size, block_size, heads, head_dims)
data_shape = (heads, head_dims)

# print(f'\nInitial:\n{cache.cache_k}')

# update_and_show(cache, cache_size, 2, 4, data_shape, 'InUpdate(2, 4)')
# update_and_show(cache, cache_size, 4, 4, data_shape, 'EndUpdate(4, 4)')
# update_and_show(cache, cache_size, 6, 4, data_shape, 'SpilledUpdate(6, 4)', debug=True)
update_and_show(cache, cache_size, 6, 8, data_shape, 'Fill(6, 8)')
print(
    # f'\nInQuery:\n{cache.get(cache_size, 0, 4)[0]}'
    # f'\n\nEndQuery:\n{cache.get(cache_size, 4, 4)[0]}'
    f'\n\nSpilledQuery:\n{cache.get(cache_size, 6, 2)[0]}'
)


Fill(6, 8):
<tf.Variable 'Variable:0' shape=(1, 8, 1, 4) dtype=float32, numpy=
array([[[[0.95810425, 0.55623317, 0.16466296, 0.13445711]],

        [[0.13229859, 0.5348098 , 0.57090175, 0.50970507]],

        [[0.48252344, 0.15580535, 0.3703227 , 0.49210668]],

        [[0.567016  , 0.2077086 , 0.18223882, 0.99883735]],

        [[0.36950588, 0.37927854, 0.7723117 , 0.68211746]],

        [[0.39932835, 0.7840713 , 0.67880154, 0.73395896]],

        [[0.31179297, 0.8263413 , 0.6849456 , 0.0067091 ]],

        [[0.78749514, 0.3906511 , 0.29263055, 0.99216926]]]],
      dtype=float32)>


SpilledQuery:
[[[[0.95810425 0.55623317 0.16466296 0.13445711]]

  [[0.13229859 0.5348098  0.57090175 0.50970507]]

  [[0.48252344 0.15580535 0.3703227  0.49210668]]

  [[0.567016   0.2077086  0.18223882 0.99883735]]

  [[0.36950588 0.37927854 0.7723117  0.68211746]]

  [[0.39932835 0.7840713  0.67880154 0.73395896]]

  [[0.31179297 0.8263413  0.6849456  0.0067091 ]]

  [[0.78749514 0.3906511  0.29263055

In [865]:
import tensorflow as tf

class SelfAttentionLayer(tf.keras.layers.Layer):
  def __init__(self, head_size):
    super().__init__()
    self.key = tf.keras.layers.Dense(head_size, use_bias=False)
    self.query = tf.keras.layers.Dense(head_size, use_bias=False)
    self.value = tf.keras.layers.Dense(head_size, use_bias=False)
    self.head_size = head_size


  def call(self, x):
    B, T, C = x.shape

    k = self.key(x)
    q = self.query(x)
    v = self.query(x)

    wei = k @ tf.transpose(q, perm=[0, 2, 1]) # (B, T, 16) @ (B, 16, T) --> (B, T, T)
    wei /= tf.cast(tf.math.sqrt(self.head_size * 1.), dtype=x.dtype)

    tril = tf.linalg.band_part(
        tf.ones((T, T)), -1, 0
    )

    wei = tf.nn.softmax(tf.where(tril > 0.0, wei, float('-inf')))
    out = wei @ v
    return out

class MultiHeadAttentionLayer(tf.keras.layers.Layer):
  def __init__(self, num_heads, head_size):
    super().__init__()
    self.attn_layers = [SelfAttentionLayer(head_size) for i in range(num_heads)]


  def call(self, x):
    return tf.concat([
        attn_layer(x) for attn_layer in self.attn_layers
    ], axis=-1)


B, T, C = 2, 2, 16
l = MultiHeadAttentionLayer(8, 16)
l(tf.random.uniform((B, T, 8*16)))

<tf.Tensor: shape=(2, 2, 128), dtype=float32, numpy=
array([[[-2.33013868e-01, -2.07698846e+00,  3.48630995e-01,
         -5.10134697e-01, -2.54128367e-01,  5.40200472e-01,
          3.85206193e-02,  7.73041844e-01, -1.09596026e+00,
          1.01041520e+00,  9.91183937e-01, -7.66317129e-01,
         -5.28275371e-02, -1.04835615e-01, -3.37260664e-01,
         -7.06953108e-01,  6.96793646e-02, -5.24373591e-01,
         -1.54647708e+00, -1.07588410e-01, -1.45081893e-01,
         -4.36975539e-01,  2.54979879e-01,  7.04379976e-02,
         -6.10021055e-01,  2.36156583e-02,  1.18972093e-01,
         -4.18260843e-01,  1.39139616e+00,  7.75396824e-01,
         -7.95781732e-01,  2.60917783e-01,  1.12052810e+00,
         -4.42397982e-01,  9.55890119e-02, -1.09111333e+00,
         -4.14958864e-01, -3.14540088e-01,  1.01840663e+00,
          9.63069439e-01,  3.75590533e-01,  1.06295928e-01,
          9.52532589e-01, -6.47454500e-01, -9.87813771e-02,
          3.03195119e-02, -8.26753795e-01,  2.4