In [261]:
# ! curl -o input.txt https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt 

In [279]:
import spacy
import itertools

import numpy as np

from collections import Counter

def get_word_1gram_frequencies(data, tokenizer):
    words = map(lambda x:x.lower_, itertools.chain(*map(tokenizer, data)))
    return Counter(words)

word_tokenizer = spacy.blank("en")
word_vocab = get_word_1gram_frequencies(tiny_shakespere, word_tokenizer)

print(
    f'{list(word_vocab.keys())[:5]}'
)

['first', 'citizen', ':', '\n', 'before']


In [278]:
# tokens = nlp('and to')
# ! pip install levenshtein

from Levenshtein import distance as edit_distance

def get_proximity_with_vocab(text, tokenizer, vocab):
    def proximity_fn(word):
        if word in vocab:
            distance = 0
        else:
            distance = min(*map(
                lambda vocab_word: edit_distance(word, vocab_word),
                vocab.keys()
            ))
        
        return distance


    tokens = tokenizer(text)
    distance = reduce(lambda y,token: y+proximity_fn(str(token)), tokens, 0)

    return distance

def get_all_proximity_with_vocab(texts, tokenizer, vocab):
    return np.fromiter(
        map(
            lambda text: get_proximity_with_vocab(text, tokenizer, vocab),
            texts,
        ),
        dtype=object
    )

nlp = spacy.blank("en")

texts = [
    'hello, how are you doing?',
    'hello, how are you doing?',
]

get_all_proximity_with_vocab(texts, nlp, word_vocab)

array([1, 1], dtype=object)

In [None]:
! head input.txt

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:


In [264]:
from tensorflow.keras import losses, optimizers

loss_fn = losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = optimizers.legacy.Adam()

TODO

* [X] loss and accuracy metrics.
* [X] training, validation and inference flags.
* Match dictionary with embedding size.

In [273]:
from collections import OrderedDict

model_vocab = OrderedDict([('F', 0), ('i', 1), ('r', 2), ('s', 3), ('t', 4), (' ', 5), ('C', 6), ('z', 7), ('e', 8), ('n', 9), (':', 10), ('\n', 11), ('B', 12), ('f', 13), ('o', 14), ('w', 15), ('p', 16), ('c', 17), ('d', 18), ('a', 19), ('y', 20), ('u', 21), ('h', 22), (',', 23), ('m', 24), ('k', 25), ('.', 26), ('A', 27), ('l', 28), ('S', 29), ('Y', 30), ('v', 31), ('?', 32), ('R', 33), ('M', 34), ('W', 35), ("'", 36), ('L', 37), ('I', 38), ('N', 39), ('g', 40), (';', 41), ('b', 42), ('!', 43), ('O', 44), ('j', 45), ('V', 46), ('-', 47), ('T', 48), ('H', 49), ('E', 50), ('U', 51), ('D', 52), ('P', 53), ('q', 54), ('x', 55), ('J', 56), ('G', 57), ('K', 58), ('Q', 59), ('&', 60), ('Z', 61), ('X', 62), ('3', 63), ('$', 64)])

In [269]:
from pathlib import Path

datapath = Path('./input.txt')

with open(datapath) as f:
    tiny_shakespere = list(f.readlines())

In [270]:
all_data = ''.join(tiny_shakespere)

all_data_length = len(all_data)
train_data_length = int(all_data_length*.9)
valid_data_length = all_data_length - train_data_length

train_data = all_data[:train_data_length]
valid_data = all_data[-valid_data_length:]

In [274]:
import tensorflow as tf

from collections import OrderedDict

def make_dataset(data, block_size, vocab):
    def indices(phrase):
        return list(map(vocab.get, phrase))

    @tf.numpy_function(Tout=(tf.int64, tf.int64))
    def get_numpy_item(index):
        X, y = data[index:index+block_size], data[index+1:index+block_size+1]
        return indices(X), indices(y)

    def get_item(index):
        X, y = get_numpy_item(index)
        return tf.reshape(X, [block_size]), tf.reshape(y, [block_size])
    
    return tf.data.Dataset.range(len(data) - block_size).map(get_item)

ds = make_dataset(train_data, block_size, model_vocab).batch(2)
X, y = next(iter(ds))

print(
    f'\n\nTrain Set\n=============='
    f'\n{X.shape=} {y.shape=} {X} {y}'
    f'\n{X.numpy()=} {y.numpy()=}'
)



Train Set
X.shape=TensorShape([2, 128]) y.shape=TensorShape([2, 128]) [[ 0  1  2  3  4  5  6  1  4  1  7  8  9 10 11 12  8 13 14  2  8  5 15  8
   5 16  2 14 17  8  8 18  5 19  9 20  5 13 21  2  4 22  8  2 23  5 22  8
  19  2  5 24  8  5  3 16  8 19 25 26 11 11 27 28 28 10 11 29 16  8 19 25
  23  5  3 16  8 19 25 26 11 11  0  1  2  3  4  5  6  1  4  1  7  8  9 10
  11 30 14 21  5 19  2  8  5 19 28 28  5  2  8  3 14 28 31  8 18  5  2 19
   4 22  8  2  5  4 14  5]
 [ 1  2  3  4  5  6  1  4  1  7  8  9 10 11 12  8 13 14  2  8  5 15  8  5
  16  2 14 17  8  8 18  5 19  9 20  5 13 21  2  4 22  8  2 23  5 22  8 19
   2  5 24  8  5  3 16  8 19 25 26 11 11 27 28 28 10 11 29 16  8 19 25 23
   5  3 16  8 19 25 26 11 11  0  1  2  3  4  5  6  1  4  1  7  8  9 10 11
  30 14 21  5 19  2  8  5 19 28 28  5  2  8  3 14 28 31  8 18  5  2 19  4
  22  8  2  5  4 14  5 18]] [[ 1  2  3  4  5  6  1  4  1  7  8  9 10 11 12  8 13 14  2  8  5 15  8  5
  16  2 14 17  8  8 18  5 19  9 20  5 13 21  2  4 22  8  2 

In [284]:
from collections import namedtuple

fields = (
    'vocab_size', 
    
    # Model Size and Prediction Capacity
    'decoders', 'block_size', 'embed_dims', 'dims',
    
    # Attention Params
    'attention', 'heads', 'kv_heads',
    
    # KV Cache
    'cache_size',
    
    # Sliding Window Attention
    'window_size', 
    
    # Features
    'pos_embeddings', 'norm',
    
    # Alternating Updates
    'alternate_blocks'
)
ModelConfig = namedtuple('ModelConfig', fields, defaults=(None, )*len(fields))

In [288]:
import tensorflow as tf

from tensorflow.keras import layers

class RotaryPositionalEncodings(layers.Layer):
    def __init__(self, block_size, dims, base=10000.):
        super(RotaryPositionalEncodings, self).__init__()
        
        self.dims = dims

        # θ = 1 / (base^(2(i - 1) / dims)), i = [1, dims/2]
        self.theta = 1. / (base ** (tf.range(0, dims, 2, dtype=tf.float32) / dims))
        self.positions = tf.range(block_size, dtype=tf.float32)
        
        self.mtheta = tf.concat([tf.einsum('n,d->nd', self.positions, self.theta)]*2, axis=-1)

        self.cos_mtheta = tf.reshape(
            tf.math.cos(self.mtheta),
            [1, 1, *self.mtheta.shape],
        )
        self.sin_mtheta =  tf.reshape(
            tf.math.sin(self.mtheta),
            [1, 1, *self.mtheta.shape],
        )

    def call(self, x, start):
        """Compute positional encodings

        Arguments:
            x: A tensor of shape (B, _, T, _)
            start: Token start position
        
        Returns:
            An position-encoded tensor of shape (B, _, T, _)
        """
        B, _, T, _ = x.shape

        # (B, _, T, dims)*(1, 1, T, dims) -> (B, _, T, dims)
        x_real = x*self.cos_mtheta[..., start:start+T, :]

        # (B, _, T, dims)*(1, 1, T, dims) -> (B, _, T, dims)
        x_img = tf.concat(
            [-x[..., self.dims//2:], x[..., :self.dims//2]],
            axis=-1
        )*self.sin_mtheta[..., start:start+T, :]

        # (B, _, T, dims) + (B, _, T, dims) -> (B, _, T, dims)
        output = x_real + x_img
        return output

# B, T, H, D = 2, 8, 4, 4
# B_i, T_i, H_i, D_i = 0, 1, 0, 0
# x = tf.random.uniform((B, H, T, D))
# l = RotaryPositionalEncodings(T, D)
# output = l(x, start=0)
# # tf.print(f'{l.theta=} {l.theta.shape}')
# # tf.print(f'{l.positions=} {l.positions.shape}')
# tf.print(f'{x[B_i]=}')
# tf.print(f'{l.cos_mtheta=}')
# tf.print(f'{l.sin_mtheta=}')
# tf.print(f'{output[B_i]=}\n{output[B_i]=} {output.shape}')

import tensorflow as tf
import numpy as np

class KVCache(object):
    def __init__(self, *shape):
        super(KVCache, self).__init__()

        with tf.device('/device:CPU:0'):
            self.cache_k = tf.Variable(tf.zeros(shape), trainable=False)
            self.cache_v = tf.Variable(tf.zeros(shape), trainable=False)

    def update(self, start, xk, xv):
        B, T, _, _ = xk.shape

        # print(
        #     f'KVCache.update({B}, {T}):: {xk.shape=} {xv.shape=}'
        # )

        self.cache_k[:B, start:start+T].assign(xk)
        self.cache_v[:B, start:start+T].assign(xv)

    def get(self, batch_size, start, seqlen):
        # print(
        #     f'KVCache.get({batch_size}, {start}, {seqlen})'
        #     f'{self.cache_k.shape=} {self.cache_v.shape=}'
        # )

        keys = self.cache_k[:batch_size, :start+seqlen]
        values = self.cache_v[:batch_size, :start+seqlen]

        return keys, values

def update_and_show(cache, batch_size, start, data_shape, msg, debug=False):
    xk = np.random.rand(batch_size, 1, *data_shape)
    xv = np.random.rand(batch_size, 1, *data_shape)

    if debug:
        print(
            f'\n{msg}::xk:\n{xk}'
            f'\n{msg}::xv:\n{xv}'
        )

    cache.update(start, xk, xv)
    print(
        f'\n{msg}::key:\n{cache.cache_k}'
        f'\n{msg}::value:\n{cache.cache_v}'
    )

# cache_size, block_size, heads, head_dims = 1, 8, 2, 4
# cache = KVCache(cache_size, block_size, heads, head_dims)
# data_shape = (head_dims,)

# print(f'\nInitial:\n{cache.cache_k}')

# # update_and_show(cache, cache_size, 2, 4, data_shape, 'InUpdate(2, 4)')
# # update_and_show(cache, cache_size, 4, 4, data_shape, 'EndUpdate(4, 4)')
# # update_and_show(cache, cache_size, 6, 4, data_shape, 'SpilledUpdate(6, 4)', debug=True)
# update_and_show(cache, cache_size, 0, data_shape, 'Update(0)', debug=True)
# print(
#     # f'\nInQuery:\n{cache.get(cache_size, 0, 4)[0]}'
#     # f'\n\nEndQuery:\n{cache.get(cache_size, 4, 4)[0]}'
#     f'\n\nQuery::key:\n{cache.get(cache_size, 0, 2)[0]}'
#     f'\n\nQuery::value:\n{cache.get(cache_size, 0, 2)[1]}'
# )

import math

import tensorflow as tf

from tensorflow.keras import layers

class GroupedQueryAttention(tf.keras.Model):
    def __init__(self, cache_size, block_size, heads, kv_heads, dims):
        super(GroupedQueryAttention, self).__init__()
        
        self.heads = heads
        self.head_size = dims // heads
        self.kv_heads = kv_heads or heads
        
        self.query = tf.keras.layers.Dense(dims, use_bias=False)
        self.key = tf.keras.layers.Dense(kv_heads*self.head_size, use_bias=False)
        self.value = tf.keras.layers.Dense(kv_heads*self.head_size, use_bias=False)
        
        self.cache = KVCache(cache_size, block_size, self.kv_heads, self.head_size)
        self.rope = RotaryPositionalEncodings(block_size, self.head_size)
        
    def call(self, x, start=0, inference=False):
        shape = tf.shape(x)
        B, T = shape[0], shape[1]

        q = self.query(x)
        k = self.key(x)
        v = self.value(x)
        
        # (B, T, dims) -> (B, T, heads/kv_heads, head_dims)
        q = tf.reshape(q, [B, T, self.heads, self.head_size])
        k = tf.reshape(k, [B, T, self.kv_heads, self.head_size])
        v = tf.reshape(v, [B, T, self.kv_heads, self.head_size])
        
        # RoPE expects inputs in (B, heads/kv_heads, T, head_size) format.
        # Transpose(B, T, heads/kv_heads, head_size) -> (B, heads/kv_heads, T, head_size)
        q = self.rope(tf.transpose(q, perm=[0, 2, 1, 3]), start=start)
        k = self.rope(tf.transpose(k, perm=[0, 2, 1, 3]), start=start)
        
        if inference:
            # Update KV cache
            # KV cache expects inputs in (B, T, heads/kv_heads, head_size) format.
            self.cache.update(
                start,
                tf.transpose(k, perm=[0, 2, 1, 3]),
                v,
            )
            
            # Get prefix context from the cache.
            # (B, start+T, ...)
            k, v = self.cache.get(B, start, T)
            
            ## Replicate KV heads to match query heads.
            # (B, start+T, heads/kv_heads, head_size) -> (B, start+T, heads, head_size)
            k = tf.tile(k, multiples=(1, 1, self.heads//self.kv_heads, 1))
            
            # (B, T, heads/kv_heads, head_dims) -> (B, T, heads, head_dims)
            v = tf.tile(v, multiples=(1, 1, self.heads//self.kv_heads, 1))
            
            # (B, heads/kv_heads, T, head_size) @ (B, heads/kv_heads, head_size, start+T)
            # -> (B, heads/kv_heads, T, start+T)
            wei = q @ tf.transpose(k, perm=[0, 2, 3, 1])
            wei /= tf.cast(tf.math.sqrt(self.head_size * 1.), dtype=x.dtype)
            wei = tf.nn.softmax(wei)
        else:
            assert start == 0
            
            ## Replicate KV heads to match query heads.
            # (B, heads/kv_heads, T, head_size) -> (B, heads, T, head_size)
            k = tf.tile(k, multiples=(1, self.heads//self.kv_heads, 1, 1))
            
            # (B, T, heads/kv_heads, head_dims) -> (B, T, heads, head_dims)
            v = tf.tile(v, multiples=(1, 1, self.heads//self.kv_heads, 1))
            
            # (B, heads/kv_heads, T, head_size) @ (B, heads/kv_heads, head_size, start+T)
            # -> (B, heads/kv_heads, T, start+T)
            wei = q @ tf.transpose(k, perm=[0, 1, 3, 2])
            wei /= tf.cast(tf.math.sqrt(self.head_size * 1.), dtype=x.dtype)

            tril = tf.linalg.band_part(tf.ones((T, T)), -1, 0)
            wei = tf.nn.softmax(tf.where(tril > 0.0, wei, float('-inf')))

        # (B, heads/kv_heads, T, start+T) @ (B, heads/kv_heads, start+T, head_dims)
        # -> (B, heads/kv_heads, T, head_dims)
        out = wei @ tf.transpose(v, perm=[0, 2, 1, 3])
        
        # (B, heads/kv_heads, T, head_dims) -> (B, T, heads/kv_heads, head_dims)
        out = tf.transpose(out, perm=[0, 2, 1, 3])
        
        return tf.reshape(out, shape=(B, T, -1))

# l = GroupedQueryAttention(
#     block_size=16,
#     heads=4,
#     kv_heads=2,
#     dims=16,
#     cache_size=4
# )

# B, T, C = 2, 2, 16
# start = 1
# x = tf.reshape(tf.range(B*T*C), (B, T, C))
# output = l(x, start)

# B_i, T_i, C_i = 0, 0, 3
# H_i, F_i = 0, C_i

# print(
#     f'{x[B_i]=}'
#     f'\n{l.cache.cache_k.shape=}\n{l.cache.cache_k[B_i, :start+T]=}'
#     f'\n{l.cache.cache_v[B_i, :start+T]=}'
# )

import math

import tensorflow as tf

from tensorflow.keras import layers

class SlidingWindowAttention(tf.keras.Model):
    def __init__(self, config):
        super(SlidingWindowAttention, self).__init__()

        self.config = config
        self.head_size = config.dims // config.heads
        
        self.query = tf.keras.layers.Dense(config.dims, use_bias=False)
        self.key = tf.keras.layers.Dense(config.kv_heads*self.head_size, use_bias=False)
        self.value = tf.keras.layers.Dense(config.kv_heads*self.head_size, use_bias=False)
        
        self.rope = RotaryPositionalEncodings(config.block_size, self.head_size)
    
    def as_strided(self, x):
        shape = tf.shape(x)
        B, T = shape[0], shape[2]

        # (B, heads/kv_heads, T, head_size) -> (B, heads/kv_heads, T+2*window_size, head_size)
        padded_x = tf.pad(x, [[0, 0], [0, 0], [self.config.window_size]*2, [0, 0]])

        # indices.shape = (T, window_size + 1)
        indices = tf.tile(tf.reshape(tf.range(T), (-1, 1)), [1, self.config.window_size + 1])
        indices = indices + tf.expand_dims(tf.range(self.config.window_size + 1), axis=0)

        # (B, heads/kv_heads, T, window_size + 1)
        strided_x = tf.gather(padded_x, indices, axis=2)

#         print(
#             f'{x.shape=}'
#             f'\n{indices=}'
#             f'\n{padded_x.shape=}'
#             f'{strided_x=}'
#             f'\n{x=}'
#         )

        return strided_x
    
    def as_grid(self, x):
        # (B, heads, T, window_size+1) -> (B, heads, T, T)
        diagonals = tf.transpose(x, perm=[0, 1, 3, 2])[..., ::-1, :]
        grid_x = tf.linalg.diag(diagonals, k=(-self.config.window_size, 0), align='RIGHT_RIGHT')

        # print(
        #     f'\n{x.shape=} {diagonals.shape=}'
        #     f'\n{grid_x.shape=}'
        #     f'\n{x[0, 0]=}'
        #     # f'\n{x[0, 0, -2:]=}'
        #     f'\n{grid_x[0, 0]=}'
        #     # f'\n{grid_x[0, 0, -2:]=}'
        #     f'\n{tf.math.reduce_sum(x[..., 0, 0])=}'
        #     f'\n{tf.reduce_all(x[..., 1, :] == grid_x[..., 1, :2])=}'
        #     f'\n{tf.reduce_all(x[..., -1, :] == grid_x[..., -1, -2:])=}'
        # )

        return grid_x
        
    def call(self, x, start=0, inference=False):
        shape = tf.shape(x)
        B, T = shape[0], shape[1]

        q = self.query(x)
        k = self.key(x)
        v = self.value(x)
        
        # (B, T, dims) -> (B, T, heads/kv_heads, head_dims)
        q = tf.reshape(q, [B, T, self.config.heads, self.head_size])
        k = tf.reshape(k, [B, T, self.config.kv_heads, self.head_size])
        v = tf.reshape(v, [B, T, self.config.kv_heads, self.head_size])
        
        # RoPE expects inputs in (B, heads/kv_heads, T, head_size) format.
        # Transpose(B, T, heads/kv_heads, head_size) -> (B, heads/kv_heads, T, head_size)
        q = self.rope(tf.transpose(q, perm=[0, 2, 1, 3]), start=start)
        k = self.rope(tf.transpose(k, perm=[0, 2, 1, 3]), start=start)
            
        ## Replicate KV heads to match query heads.
        # (B, heads/kv_heads, T, head_size) -> (B, heads, T, head_size)
        k = tf.tile(k, multiples=(1, self.config.heads//self.config.kv_heads, 1, 1))
        
        # (B, T, heads/kv_heads, head_dims) -> (B, T, heads, head_dims)
        v = tf.tile(v, multiples=(1, 1, self.config.heads//self.config.kv_heads, 1))

        # Compose strided keys to mimic the sliding window. 
        # (B, heads, T, head_size) -> (B, heads, T, window_size+1, head_size)
        strided_k = self.as_strided(k)

        # (B, heads, T, head_size) @ (B, heads, T, window_size+1, head_size)
        # -> (B, heads, T, window_size+1)
        wei = tf.einsum('bhtd,bhtxd->bhtx', q, strided_k)

        # (B, heads, T, window_size+1) -> (B, heads, T, T)
        wei = self.as_grid(wei)
        wei /= tf.cast(tf.math.sqrt(self.head_size * 1.), dtype=x.dtype)

        tril = tf.linalg.band_part(tf.ones((T, T)), -1, 0)
        wei = tf.nn.softmax(tf.where(tril > 0.0, wei, float('-inf')))

        # (B, heads, T, T) @ (B, heads, T, head_dims)
        # -> (B, heads, T, head_dims)
        out = wei @ tf.transpose(v, perm=[0, 2, 1, 3])
        
        # (B, heads, T, head_dims) -> (B, T, heads, head_dims)
        out = tf.transpose(out, perm=[0, 2, 1, 3])
        
        return tf.reshape(out, shape=(B, T, -1))

# l = SlidingWindowAttention(
#     block_size=16,
#     heads=4,
#     kv_heads=2,
#     dims=32,
#     window_size=2
# )

# B, T, C = 2, 2, 32
# x = tf.cast(tf.reshape(tf.range(B*T*C), (B, T, C)), dtype=tf.float32)
# output = l(x)

# B_i, T_i, C_i = 0, 0, 3
# H_i, F_i = 0, C_i

# print(
#     f'{x[B_i]=}'
#     f'{output.shape=}'
# )

import tensorflow as tf

class SelfAttentionLayer(tf.keras.layers.Layer):
    def __init__(self, cache_size, block_size, head_size):
        
        super().__init__()
        # Input args
        self.head_size = head_size
        
        self.key = tf.keras.layers.Dense(head_size, use_bias=False)
        self.query = tf.keras.layers.Dense(head_size, use_bias=False)
        self.value = tf.keras.layers.Dense(head_size, use_bias=False)
        
        self.cache = KVCache(cache_size, block_size, head_size)
        self.rope = RotaryPositionalEncodings(block_size, head_size)

    def call(self, x, start=0, inference=False):
        B, T, C = x.shape

        k = self.key(x)
        q = self.query(x)
        v = self.value(x)
        
        # Apply RoPE
        q = self.rope(q, start=start)
        k = self.rope(k, start=start)
        
        if inference:
            # Update KV cache
            self.cache.update(start, k, v)
            
            # Get prefix context from the cache.
            # (B, start+T, head_dims)
            k, v = self.cache.get(B, start, T)
            
            # (B, T, head_size) @ (B, head_size, start+T) --> (B, T, start+T)
            wei = q @ tf.transpose(k, perm=[0, 2, 1])
            wei /= tf.cast(tf.math.sqrt(self.head_size * 1.), dtype=x.dtype)
            wei = tf.nn.softmax(wei)
        else:
            assert start == 0
            
            # (B, T, head_size) @ (B, head_size, T) --> (B, T, T)
            wei = q @ tf.transpose(k, perm=[0, 2, 1])
            wei /= tf.cast(tf.math.sqrt(self.head_size * 1.), dtype=x.dtype)

            tril = tf.linalg.band_part(tf.ones((T, T)), -1, 0)
            wei = tf.nn.softmax(tf.where(tril > 0.0, wei, float('-inf')))
            
        out = wei @ v
        return out

class MultiHeadAttentionLayer(tf.keras.layers.Layer):
    def __init__(self, cache_size, block_size, num_heads, head_size):
        super().__init__()
        self.attn_layers = [SelfAttentionLayer(cache_size, block_size, head_size) for i in range(num_heads)]

    def call(self, x, start=0, inference=False):
        return tf.concat([
            attn_layer(x, start=start, inference=inference) for attn_layer in self.attn_layers
        ], axis=-1)

# cache_size, block_size, num_heads, head_size = 2, 16, 1, 16
# B, T, C = 2, 2, num_heads*head_size

# l = MultiHeadAttentionLayer(cache_size, block_size, num_heads, head_size)

# for start in range(T):
#     print(
#         f'\n{start=}::{l(tf.random.uniform((B, 1, C)), start=start, inference=True).shape}\n'
#     )

import tensorflow as tf

from tensorflow.keras import layers

class RMSNorm(layers.Layer):
    def __init__(self, eps=1e-8):
        super(RMSNorm, self).__init__()
        self.eps = eps
    
    def build(self, shape):
        self.w = self.add_weight(
            "kernel",
            shape=(1, shape[-1]),
            initializer='ones',
            trainable=True,
        )
    
    def norm(self, x):
        return x*tf.math.rsqrt(tf.math.reduce_mean(x**2, axis=-1, keepdims=True) + self.eps)
        
    def call(self, x):
        return self.w * self.norm(x)

# data = tf.constant(np.arange(20).reshape(5, 2, 2) * 10, dtype=tf.float32)

# l=RMSNorm()
# l.build(data.shape)
# l(data)

import tensorflow as tf

from tensorflow.keras import layers

class FeedForward(layers.Layer):
    def __init__(self, dims):
        super(FeedForward, self).__init__()

        self.hidden_dims = 4*dims
        self.linear_1 = layers.Dense(self.hidden_dims, use_bias=False, activation='swish')
        self.linear_2 = layers.Dense(self.hidden_dims, use_bias=False)
        self.linear_3 = layers.Dense(dims, use_bias=False)
    
    def call(self, x):
        # (..., dims) -> (..., hidden_dims)
        x = self.linear_1(x)*self.linear_2(x)
        
        # (..., hidden_dims) -> (..., dims)
        x = self.linear_3(x)
        
        return x

# B, T, C = 2, 2, 16
# l = FeedForward(C)
# x = tf.reshape(tf.range(B*T*C, dtype='float32'), (B, T, C))
# l(x)

import tensorflow as tf

from tensorflow.keras import layers

class AttentionSelector(object):
    def __init__(self, config, *args, **kwargs):
        super(AttentionSelector, self).__init__(*args, **kwargs)
        
        self.config = config
    
    @property
    def choice(self):
        return (self.config.attention or 'msa').lower()
    
    def select(self):
        if self.choice in ['slidingwindow', 'slidingwindowattention', 'swa']:
            return SlidingWindowAttention(self.config)
        elif self.choice in ['groupedquery', 'groupedqueryattention', 'gqa']:
            return GroupedQueryAttention(self.config)
        else:
            return MultiHeadAttentionLayer(self.config)
    
    def call(self, l, x, start=0, inference=False):
        if self.choice in ['slidingwindow', 'slidingwindowattention', 'swa']:
            return l(x)
        elif self.choice in ['groupedquery', 'groupedqueryattention', 'gqa']:
            return l(x, start=start, inference=inference)
        else:
            return l(x, start=start, inference=inference)

class NormSelector(object):
    def __init__(self, config, *args, **kwargs):
        super(NormSelector, self).__init__(*args, **kwargs)
        self.config = config
    
    def select(self):
        choice = (self.config.norm or 'rms').lower()
        if choice in ['batchnorm', 'bn', 'batchnormalization']:
            return layers.BatchNormalization()
        else:
            return RMSNorm()

class LlamaBlock(tf.keras.Model):
    def __init__(self, config, *args, **kwargs):
        super(LlamaBlock, self).__init__(*args, **kwargs)
        
        self.norm_selector = NormSelector(config)
        self.attention_selector = AttentionSelector(config)
        
        self.norm_1 = self.norm_selector.select()
        self.attention = self.attention_selector.select()
        self.norm_2 = self.norm_selector.select()
    
    def call(self, x, start, inference):
        x += self.attention_selector.call(
            self.attention,
            self.norm_1(x), start=start, inference=inference
        )
        x = self.norm_2(x)
#         x += self.feed_forward(self.norm_2(x))
        
        return x

# l = LlamaBlock(
#     block_size=16,
#     heads=4,
#     kv_heads=2,
#     dims=16,
#     cache_size=4
# )

# B, T, C = 2, 2, 16
# start = 1
# x = tf.reshape(tf.range(B*T*C, dtype='float32'), (B, T, C))
# l(x)

import tensorflow as tf

from tensorflow.keras import layers, metrics
from functools import reduce

class LlamaModel(tf.keras.Model):
    def __init__(self, config, loss_fn, *args, **kwargs):
        super(LlamaModel, self).__init__(*args, **kwargs)
        # Args
        self.config = config

        # Model elements
        self.embeddings = layers.Embedding(config.vocab_size, config.embed_dims)
        self.pos_embeddings = layers.Embedding(config.block_size, config.embed_dims) if config.pos_embeddings else None
        self.dec_blocks = [LlamaBlock(config) for _ in range(config.decoders)]

        self.head = layers.Dense(config.vocab_size, use_bias=False)

        # Loss 
        self.loss_fn = loss_fn

        # Metrics
        self.loss_tracker = metrics.Mean(name="loss")
    
    def alternating_reduce(self, fn, blocks, x):
        for b_index, block in enumerate(blocks):
            # Pick the sub-block for transformer branch
            pivot = b_index % self.config.alternate_blocks

            # Split input into subblocks
            xs = tf.split(x, self.config.alternate_blocks, axis=-1)

            # Process picked block
            x = fn(xs[pivot], block)

            # Re-compose
            x = tf.concat([*xs[:pivot], x, *xs[(pivot + 1):]], axis=-1)

        return x
    
    def call(self, x, start=0, inference=False):
        B, T = x.shape
        
        x_pos_embed = self.pos_embeddings(tf.range(T)) if self.pos_embeddings else None
        x_embed = self.embeddings(x)
        
        x = x_embed + x_pos_embed if self.pos_embeddings else x_embed
        
        reduce_fn = self.alternating_reduce if self.config.alternate_blocks else reduce
        
        x = reduce_fn(
            lambda y,dec_block: dec_block(y, start=start, inference=inference),
            self.dec_blocks,
            x
        )
        x = self.head(x)
        
        return x
    
    def train_step(self, data):
        X, y = data
        y = tf.cast(y, dtype=tf.float32)

        with tf.GradientTape() as tape:
            logits = self(X)
            loss = self.loss_fn(y, logits)

        grads = tape.gradient(loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        
        return self.record(loss, y, logits)
    
    def test_step(self, data):
        X, y = data

        logits = self(X)
        loss = self.loss_fn(y, logits)
    
        return self.record(loss, y, logits)
    
    @property
    def metrics(self):
        return [self.loss_tracker]
    
    def record(self, loss, y, logits):
        self.loss_tracker.update_state(loss)

        return {m.name: m.result() for m in self.metrics}

model_config = ModelConfig(
    vocab_size=len(model_vocab),
    
    # Model Size and Prediction Capacity
    decoders=4,
    block_size=64,
    embed_dims=512,
    dims=128,
    
    # Attention Params
    attention='swa',
    heads=8,
    kv_heads=8,
    
    # KV Cache
    cache_size=2,
    
    # Sliding Window Attention
    window_size=3,
    
    # Features
    pos_embeddings=False,
    alternate_blocks=4,
)

llama = LlamaModel(config=model_config, loss_fn=loss_fn)
llama(tf.random.uniform((2, llama.config.block_size), minval=0, maxval=llama.config.vocab_size, dtype=tf.int32))
llama.compile(optimizer=optimizer, loss=loss_fn)
llama.summary(expand_nested=True)

Model: "llama_model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 embedding_2 (Embedding)     multiple                  33280     
                                                                 
 llama_block_1 (LlamaBlock)  multiple                  49408     
|¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯|
| rms_norm_2 (RMSNorm)       multiple                  128      |
|                                                               |
| sliding_window_attention_  multiple                  49152    |
| 16 (SlidingWindowAttentio                                     |
| n)                                                            |
||¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯||
|| dense_99 (Dense)          multiple                  16384   ||
||                                                             ||
|| dense_100 (Dense)         multiple                

In [290]:
# llama = LlamaModel(config=model_config, loss_fn=loss_fn)

# config_data = model_config._asdict()
# config_data['kv_heads'] = config_data['kv_heads']//grouping_factor

# grouped_config = ModelConfig(**config_data)

# llama_grouped.compile(optimizer=optimizer, loss=loss_fn)
# llama_grouped(tf.random.uniform((2, 5), minval=0, maxval=llama_grouped.vocab_size, dtype=tf.int32))
# llama_grouped.summary(expand_nested=True)

In [291]:
def get_text_stats(text, tokenizer, fns=[]):
    def stats_fn(token):
        return np.array([token.lower_] + [fn(token) for fn in fns])
    
    tokens = tokenizer(text)
    stats = np.stack([stats_fn(token) for token in tokens])

    return stats

def random_token_from_logits(logits, samples=1):
    return tf.random.categorical(logits, samples, dtype=tf.int32)

def argmax_token(logits, samples=1):
    return tf.expand_dims(
        tf.math.argmax(tf.math.softmax(logits), axis=-1, output_type=tf.int32),
        axis=-1
    )

def generate(model, token_idx, tokens, randomize=True):
    sequence = token_idx
    for start in range(tokens - 1):
        logits = model(token_idx, start=start, inference=True)[:, -1, :]
#         token_idx =  tf.random.categorical(logits, 1, dtype=tf.int32)
        token_idx = random_token_from_logits(logits) if randomize else argmax_token(logits)
        sequence = tf.concat([sequence, token_idx], axis=-1)
    
    return sequence

def generate_no_cache(model, token_idx, tokens, randomize=True):
    sequence = token_idx
    for _ in range(tokens - 1):
        logits = model(sequence[:, -model.block_size:])[:, -1, :]
#         token_idx =  tf.random.categorical(logits, 1, dtype=tf.int32)
        token_idx = random_token_from_logits(logits) if randomize else argmax_token(logits)
        sequence = tf.concat([sequence, token_idx], axis=-1)
    
    return sequence

def generate_and_evaluate(model_vocab, generator, model, tokens=50, randomize=True):
    idx_to_char = {i:c for c,i in model_vocab.items()}
    decoder = lambda x: ''.join([idx_to_char[i] for i in x])

    starter = tf.constant([model_vocab[' ']], shape=(1, 1), dtype=tf.int32)
    generated_text = decoder(
        generator(
            model,
            starter,
            tokens=tokens,
            randomize=randomize,
        )[0].numpy()
    )

    print(f'{generator.__name__}:: {generated_text=}')

# word_vocab = get_word_1gram_frequencies(tiny_shakespere, word_tokenizer)
# print(
#     f'\nVocab: {list(word_vocab.keys())[:5]}'
# )
# generate_and_evaluate(model_vocab, generate, llama, randomize=False, tokens=min(50, llama.block_size))
# generate_and_evaluate(model_vocab, generate_no_cache, llama, randomize=False, tokens=min(50, llama.block_size))

In [None]:
import itertools

def get_layers(model, exceptions=[]):
    def fn(l):
        if hasattr(l, 'layers'):
            return list(itertools.chain(*map(fn, l.layers)))
        elif l.name in exceptions:
            return [None]
        else:
            return [l]
    
    return list(filter(lambda x: x is not None, fn(model)))


def apply_groupings(src_model, target_model, factor=2):
    def grouped_apply(l1, l2):
        grouped_w1s = []

        for w1, w2 in zip(l1.get_weights(), l2.get_weights()):
            w1_groups = tf.split(w1, factor, axis=-1)
            print(f'\nSplit {src_model.name}.{l1.name}({w1.shape}) --> {factor}x{w1_groups[0].shape}')

            w1_grouped = tf.math.add(*w1_groups) / factor
            print(f'Average {factor}x{w1_groups[0].shape} --> {w1_grouped.shape}')

            grouped_w1s.append(w1_grouped)
        
        # Apply grouped weights to the target
        l2.set_weights(grouped_w1s)
        print(f'Copied? {np.array_equal(grouped_w1s, l2.get_weights())}')

    src_attention = src_model.dec_blocks[0].attention
    target_attention = target_model.dec_blocks[0].attention

    src_other_layers = get_layers(
        src_model,
        exceptions=list(map(lambda x: x.name, [src_attention.key, src_attention.value]))
    )

    target_other_layers = get_layers(
        target_model,
        exceptions=list(map(lambda x: x.name, [target_attention.key, target_attention.value]))
    )

    for src_layer, target_layer in zip(src_other_layers, target_other_layers):
        target_layer.set_weights(src_layer.get_weights())
    
    # Group keys and values.
    print(f'\n Grouping Layers ({factor=})\n============================')
    grouped_apply(src_attention.key, target_attention.key)
    grouped_apply(src_attention.value, target_attention.value)

def compare_weights(one, two):
    one_layers = get_layers(one)
    two_layers = get_layers(two)

    print('\n Comparison Status\n===================')

    for ol, tl in filter(lambda x: x[0].get_weights(),  zip(one_layers, two_layers)):
        if np.array_equal(ol.get_weights(), tl.get_weights()):
            print(f'{ol.name} == {tl.name}')
        else:
            print(f'\t{ol.name} != {tl.name}')

# apply_groupings(llama, llama_grouped)
# compare_weights(llama, llama_grouped)

In [None]:
def get_grouped_generator(grouped_model, grouping_factor, base_fn=generate_no_cache):
    def fn(model, *args, **kwargs):
        apply_groupings(model, grouped_model)
        return base_fn(grouped_model, *args, **kwargs)
    
    return fn

generator_grouped = get_grouped_generator(llama_grouped, grouping_factor=grouping_factor)
generator_grouped_cached = get_grouped_generator(llama_grouped, grouping_factor=grouping_factor, base_fn=generate)

generate_and_evaluate(model_vocab, generator, llama, randomize=False, tokens=min(50, llama.block_size))


 Grouping Layers (factor=2)

Split llama_model_37.dense_384((128, 128)) --> 2x(128, 64)
Average 2x(128, 64) --> (128, 64)
Copied? True

Split llama_model_37.dense_385((128, 128)) --> 2x(128, 64)
Average 2x(128, 64) --> (128, 64)
Copied? True
fn:: generated_text=' sssawaksedreerederedoubexfoctleereere:rere:reree:'


In [None]:
# llama.layers[2].layers[1].attn_layers[0].logs['cache']['k']
cache_logs = llama.layers[2].layers[1].attn_layers[0].logs['cache']
no_cache_logs = llama.layers[2].layers[1].attn_layers[0].logs['no-cache']

log_items = len(cache_logs['k'])

# for key in ['in_x', 'k', 'q', 'v', 'rope_k', 'rope_q', 'wei', 'out']:
for key in ['wei']:
    for idx in [1]:
        print(
            f'{key}({idx})::{tf.math.reduce_all(cache_logs[key][idx] == no_cache_logs[key][idx+1][..., idx:, :])}'
            f' {cache_logs[key][idx]} {no_cache_logs[key][idx+1][..., idx:, :]}'
            f' {cache_logs[key][idx].shape} {no_cache_logs[key][idx+1][..., idx:, :].shape}'
        )

wei(1)::False [[[21.616919 43.966854]]] [[[21.616917 43.966858]]] (1, 1, 2) (1, 1, 2)


In [None]:
# tf.math.reduce_all(cache_logs['cache_k'][1] == no_cache_logs['rope_k'][2])
index = 1
print(
    f"{tf.math.reduce_all(cache_logs['wei'][index] == no_cache_logs['wei'][index+1])}"
    # f"\n{tf.math.reduce_all(cache_logs['cache_v'][1] == no_cache_logs['v'][2])}"
    f"\n{tf.math.reduce_all(cache_logs['cache_k'][index] == no_cache_logs['rope_k'][index+1])}",
    # f"\n{tf.math.reduce_all(cache_logs['cache_v'][index] == no_cache_logs['v'][index+1][:, 1:, :])}",
    f"\n{tf.math.reduce_all(cache_logs['rope_q'][index] == no_cache_logs['rope_q'][index+1][:, 1:, :])}",
    # f"\n{cache_logs['rope_q'][index]=}"
    # f"\n{no_cache_logs['rope_q'][index+1]=}"
    # f"\n{cache_logs['cache_k'][index]=}"
    # f"\n{no_cache_logs['rope_k'][index+1]=}"
    f"\n{cache_logs['wei'][index]=}"
    f"\n{no_cache_logs['wei'][index+1]=}"
    f"\n{cache_logs['wei'][index]=}"
    f"\n{no_cache_logs['wei'][index+1]=}"
)
# cache_logs['out'][1], no_cache_logs['out'][2][..., 1:, :]

False
True 
True 
cache_logs['wei'][index]=<tf.Tensor: shape=(1, 1, 2), dtype=float32, numpy=array([[[21.616919, 43.966854]]], dtype=float32)>
no_cache_logs['wei'][index+1]=<tf.Tensor: shape=(1, 2, 2), dtype=float32, numpy=
array([[[29.863405 ,  7.2788296],
        [21.616917 , 43.966858 ]]], dtype=float32)>
cache_logs['wei'][index]=<tf.Tensor: shape=(1, 1, 2), dtype=float32, numpy=array([[[21.616919, 43.966854]]], dtype=float32)>
no_cache_logs['wei'][index+1]=<tf.Tensor: shape=(1, 2, 2), dtype=float32, numpy=
array([[[29.863405 ,  7.2788296],
        [21.616917 , 43.966858 ]]], dtype=float32)>


In [None]:
c_q = cache_logs['rope_q'][index]
c_k = cache_logs['cache_k'][index]

c_q, c_k, c_q @ tf.transpose(c_k, perm=[0, 2, 1])

(<tf.Tensor: shape=(1, 1, 16), dtype=float32, numpy=
 array([[[ 2.697421  , -2.4562492 , -1.5515298 ,  0.14776081,
           0.37695628, -1.5206785 ,  0.8845265 ,  0.17617644,
          -2.1475816 ,  1.1095126 , -0.10803113, -4.174893  ,
           0.12987702,  0.36463377, -4.113693  , -0.7377639 ]]],
       dtype=float32)>,
 <tf.Tensor: shape=(1, 2, 16), dtype=float32, numpy=
 array([[[ 4.474342  , -1.4330535 , -4.977524  , -4.9114738 ,
          -0.23111431,  1.3314574 ,  0.03826439,  1.3454266 ,
           1.7946535 ,  0.05863512, -1.4367882 , -2.3322222 ,
           0.93547916, -2.7148142 ,  0.902055  ,  0.88373613],
         [ 1.8138299 , -2.927385  , -3.3529627 , -4.1700516 ,
          -0.463976  , -2.7048278 , -1.8990085 ,  1.7266546 ,
           1.0458257 , -0.78157973, -0.0542576 , -2.8697302 ,
           2.9209538 , -0.50924057, -3.9430103 ,  0.7489088 ]]],
       dtype=float32)>,
 <tf.Tensor: shape=(1, 1, 2), dtype=float32, numpy=array([[[21.616919, 43.966854]]], dtype=floa

In [None]:
nc_q = no_cache_logs['rope_q'][index + 1]
nc_k = no_cache_logs['rope_k'][index + 1]

nc_q, nc_k, nc_q @ tf.transpose(nc_k, perm=[0, 2, 1])

(<tf.Tensor: shape=(1, 2, 16), dtype=float32, numpy=
 array([[[-0.23348725, -1.2534161 , -3.0547209 , -0.6355625 ,
          -0.08479708,  1.8937602 ,  1.8187456 , -1.0931925 ,
          -0.7623807 ,  2.3252077 , -0.7686442 , -1.8549696 ,
           0.79256994, -1.0165763 ,  1.2038559 ,  0.9727511 ],
         [ 2.697421  , -2.4562492 , -1.5515298 ,  0.14776081,
           0.37695628, -1.5206785 ,  0.8845265 ,  0.17617644,
          -2.1475816 ,  1.1095126 , -0.10803113, -4.174893  ,
           0.12987702,  0.36463377, -4.113693  , -0.7377639 ]]],
       dtype=float32)>,
 <tf.Tensor: shape=(1, 2, 16), dtype=float32, numpy=
 array([[[ 4.474342  , -1.4330535 , -4.977524  , -4.9114738 ,
          -0.23111431,  1.3314574 ,  0.03826439,  1.3454266 ,
           1.7946535 ,  0.05863512, -1.4367882 , -2.3322222 ,
           0.93547916, -2.7148142 ,  0.902055  ,  0.88373613],
         [ 1.8138299 , -2.927385  , -3.3529627 , -4.1700516 ,
          -0.463976  , -2.7048278 , -1.8990085 ,  1.7266546

In [None]:
from tensorflow.keras import layers

K = 2

def alternating_reduce(fn, blocks, x):
    for b_index, block in enumerate(blocks):
        # Pick the sub-block for transformer branch
        pivot = b_index % K

        # Split input into subblocks
        xs = tf.split(x, K, axis=-1)

        # Process picked block
        x = block(xs[pivot])

        # Re-compose
        x = tf.concat([*xs[:pivot], x, *xs[(pivot + 1):]], axis=-1)
    
    return x

x = tf.reshape(tf.cast(tf.range(2*100), dtype=tf.float32), (2, 100))

print(f'{x=}')
alternating_reduce(
    lambda y,dec_block: dec_block(y),
    [
        # layers.Dense(100, use_bias=False),
        layers.Dense(50, use_bias=False, kernel_initializer='ones'),
        layers.Dense(50, use_bias=False, kernel_initializer='ones'),
        layers.Dense(50, use_bias=False, kernel_initializer='ones'),
        layers.Dense(50, use_bias=False, kernel_initializer='ones'),
    ],
    x
)

x=<tf.Tensor: shape=(2, 100), dtype=float32, numpy=
array([[  0.,   1.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.,  10.,
         11.,  12.,  13.,  14.,  15.,  16.,  17.,  18.,  19.,  20.,  21.,
         22.,  23.,  24.,  25.,  26.,  27.,  28.,  29.,  30.,  31.,  32.,
         33.,  34.,  35.,  36.,  37.,  38.,  39.,  40.,  41.,  42.,  43.,
         44.,  45.,  46.,  47.,  48.,  49.,  50.,  51.,  52.,  53.,  54.,
         55.,  56.,  57.,  58.,  59.,  60.,  61.,  62.,  63.,  64.,  65.,
         66.,  67.,  68.,  69.,  70.,  71.,  72.,  73.,  74.,  75.,  76.,
         77.,  78.,  79.,  80.,  81.,  82.,  83.,  84.,  85.,  86.,  87.,
         88.,  89.,  90.,  91.,  92.,  93.,  94.,  95.,  96.,  97.,  98.,
         99.],
       [100., 101., 102., 103., 104., 105., 106., 107., 108., 109., 110.,
        111., 112., 113., 114., 115., 116., 117., 118., 119., 120., 121.,
        122., 123., 124., 125., 126., 127., 128., 129., 130., 131., 132.,
        133., 134., 135., 136., 137., 138., 1

<tf.Tensor: shape=(2, 100), dtype=float32, numpy=
array([[ 61250.,  61250.,  61250.,  61250.,  61250.,  61250.,  61250.,
         61250.,  61250.,  61250.,  61250.,  61250.,  61250.,  61250.,
         61250.,  61250.,  61250.,  61250.,  61250.,  61250.,  61250.,
         61250.,  61250.,  61250.,  61250.,  61250.,  61250.,  61250.,
         61250.,  61250.,  61250.,  61250.,  61250.,  61250.,  61250.,
         61250.,  61250.,  61250.,  61250.,  61250.,  61250.,  61250.,
         61250.,  61250.,  61250.,  61250.,  61250.,  61250.,  61250.,
         61250., 186250., 186250., 186250., 186250., 186250., 186250.,
        186250., 186250., 186250., 186250., 186250., 186250., 186250.,
        186250., 186250., 186250., 186250., 186250., 186250., 186250.,
        186250., 186250., 186250., 186250., 186250., 186250., 186250.,
        186250., 186250., 186250., 186250., 186250., 186250., 186250.,
        186250., 186250., 186250., 186250., 186250., 186250., 186250.,
        186250., 186250., 1

In [338]:
import tensorflow as tf

class AltUpBlockSequence(tf.keras.Model):
    def __init__(self, config, *args, **kwargs):
        super(AltUpBlockSequence, self).__init__(*args, **kwargs)
        self.config = config

        self.blocks = [LlamaBlock(config) for _ in range(config.decoders)]
    
    def build(self, shape):
        self.prediction_scalars = self.add_weight(
            "prediction_scalars",
            shape=(self.config.alternate_blocks,)*2,
            initializer='ones',
            trainable=True,
        )
        self.correction_scalars = self.add_weight(
            "correction_scalars",
            shape=(self.config.alternate_blocks,),
            initializer='ones',
            trainable=True,
        )
    
    def call(self, x, start=0, inference=False):
        for b_index, block in enumerate(self.blocks):
            # Pick the sub-block for transformer branch
            pivot = b_index % self.config.alternate_blocks

            # Split input into subblocks
            xs = tf.split(x, self.config.alternate_blocks, axis=-1)

            # Prediction with a linear map
            xs_hat = tf.unstack(
                tf.einsum('aa,axyz->axyz', self.prediction_scalars, tf.stack(xs, axis=0)),
                axis=0,
            )

            # Process picked block
            x = block(xs[pivot], start=start, inference=inference)

            # Correction
            xs = list(map(
                lambda idx: xs_hat[idx] + self.correction_scalars[idx]*(x - xs_hat[pivot]),
                range(self.config.alternate_blocks),
            ))

            # Re-compose
            # x = tf.concat([*xs[:pivot], x, *xs[(pivot + 1):]], axis=-1)
            x = tf.concat(xs, axis=-1)

        return x


l = AltUpBlockSequence(model_config)
x = tf.random.uniform((2, model_config.block_size, model_config.embed_dims))
y = l(x)

In [341]:
x[0], y[0]

(<tf.Tensor: shape=(64, 512), dtype=float32, numpy=
 array([[0.31179297, 0.8263413 , 0.6849456 , ..., 0.79634094, 0.16588712,
         0.7652644 ],
        [0.9043931 , 0.09195149, 0.25456262, ..., 0.10521233, 0.29894042,
         0.96700084],
        [0.18097997, 0.90534747, 0.6530881 , ..., 0.810333  , 0.8480791 ,
         0.68864477],
        ...,
        [0.9133718 , 0.3668282 , 0.8784882 , ..., 0.57490706, 0.8577874 ,
         0.02953506],
        [0.06336856, 0.5252453 , 0.8964504 , ..., 0.12895262, 0.16139305,
         0.5827892 ],
        [0.5203084 , 0.2430867 , 0.10670376, ..., 0.82163894, 0.05010164,
         0.4061122 ]], dtype=float32)>,
 <tf.Tensor: shape=(64, 512), dtype=float32, numpy=
 array([[-1.1757655 ,  0.22772452,  0.39361706, ...,  2.4383278 ,
          0.5172272 , -0.91794693],
        [-0.616339  , -0.24007252, -0.29521257, ...,  2.067838  ,
          0.48718598, -0.9977868 ],
        [-1.4486234 ,  0.33871815,  0.46358114, ...,  2.3470387 ,
          0.4813824

In [304]:
import tensorflow as tf

class BlockSequence(tf.keras.Model):
    def __init__(self, config, *args, **kwargs):
        super(BlockSequence, self).__init__(*args, **kwargs)
        self.config = config

        self.blocks = [LlamaBlock(config) for _ in range(config.decoders)]
    
    def call(self, x, start=0, inference=False):
        return reduce(
            lambda y,block: block(y, start=start, inference=inference),
            self.blocks,
            x
        )

# config_data = model_config._asdict()
# config_data['']

# altup_config = ModelConfig(**config_data)

l = BlockSequence(model_config)
l(tf.random.uniform((2, model_config.block_size, model_config.dims)))


<tf.Tensor: shape=(2, 64, 128), dtype=float32, numpy=
array([[[-0.36719808,  0.97809   , -0.03038533, ..., -0.30581978,
         -0.1215362 ,  3.1138353 ],
        [-0.8910546 ,  0.52530986,  0.37287   , ..., -0.5833475 ,
          0.08942268,  3.0151086 ],
        [-0.7777609 ,  0.37908828,  0.09426907, ..., -0.5609133 ,
          0.11016361,  2.9316347 ],
        ...,
        [-0.71178305,  1.202099  ,  0.05243267, ..., -1.1422229 ,
          0.4896912 ,  2.347373  ],
        [-0.8180739 ,  0.9634706 ,  0.15713139, ..., -0.98911667,
          0.5352649 ,  2.310747  ],
        [-0.7443944 ,  0.99957556,  0.16659923, ..., -1.024197  ,
          0.6366435 ,  2.0893376 ]],

       [[-0.2841137 ,  1.5303981 ,  0.10613061, ..., -0.74273384,
         -0.68495226,  2.1344197 ],
        [-0.3547677 ,  1.3389255 ,  0.5335099 , ..., -0.65516317,
         -0.28303394,  2.4078557 ],
        [-0.3474855 ,  1.1259456 ,  0.2917909 , ..., -0.86325204,
         -0.06728553,  2.332008  ],
        ...,
