In [1]:
class Props(dict):
    def __init__(self, *args, **kwargs):
        super(Props, self).__init__(*args, **kwargs)
        self.__dict__ = self
    
    def __getattribute__(self, name):
        try:
            return super(Props, self).__getattribute__(name)
        except AttributeError:
            return None

In [2]:
# ! curl -o input.txt https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt 

from pathlib import Path

datapath = Path('./input.txt')

with open(datapath) as f:
    tiny_shakespere = list(f.readlines())

In [3]:
import spacy
import itertools

import numpy as np

from collections import Counter

def get_word_1gram_frequencies(data, tokenizer):
    words = map(lambda x:x.lower_, itertools.chain(*map(tokenizer, data)))
    return Counter(words)

word_tokenizer = spacy.blank("en")
word_vocab = get_word_1gram_frequencies(tiny_shakespere, word_tokenizer)

print(
    f'{list(word_vocab.keys())[:5]}'
)

['first', 'citizen', ':', '\n', 'before']


In [4]:
# tokens = nlp('and to')
# ! pip install levenshtein
from functools import reduce

from Levenshtein import distance as edit_distance

def get_proximity_with_vocab(text, tokenizer, vocab):
    def proximity_fn(word):
        if word in vocab:
            distance = 0
        else:
            distance = min(*map(
                lambda vocab_word: edit_distance(word, vocab_word),
                vocab.keys()
            ))
        
        return distance


    tokens = tokenizer(text)
    distance = reduce(lambda y,token: y+proximity_fn(str(token)), tokens, 0)

    return distance

def get_all_proximity_with_vocab(texts, tokenizer, vocab):
    return np.fromiter(
        map(
            lambda text: get_proximity_with_vocab(text, tokenizer, vocab),
            texts,
        ),
        dtype=object
    )

nlp = spacy.blank("en")

texts = [
    'hello, how are you doing?',
    'hello, how are you doing?',
]

get_all_proximity_with_vocab(texts, nlp, word_vocab)

array([1, 1], dtype=object)

In [5]:
! head input.txt

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:


In [6]:
all_data = ''.join(tiny_shakespere)

all_data_length = len(all_data)
train_data_length = int(all_data_length*.9)
valid_data_length = all_data_length - train_data_length

train_data = all_data[:train_data_length]
valid_data = all_data[-valid_data_length:]

In [7]:
def get_vectorizer(config):
    c_idx_mapping = dict(map(
        lambda item: (item[1], item[0]),
        enumerate(config.vocab)
    ))
    encoder = lambda text: np.fromiter(map(c_idx_mapping.get, text), dtype=int)
    decoder = lambda ids: ''.join([config.vocab[idx] for idx in ids])

    return encoder, decoder

config = Props(
    valid_split=0.1,
    vocab="\n !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz",
    model=Props(block_size=256),
)

encoder, decoder = get_vectorizer(config)

train_data_vectorized = encoder(train_data)
valid_data_vectorized = encoder(valid_data)
log_size = config.model.block_size

print(
    f'\n{len(train_data)=} {train_data_vectorized.shape=}'
    f'\n{len(valid_data)=} {valid_data_vectorized.shape=}'
)
print(
    f'\n{train_data[:log_size]=}\n{train_data_vectorized[:log_size]=}'
    f'\n\n{valid_data[:log_size]=}\n{valid_data_vectorized[:log_size]=}'
)


len(train_data)=1003854 train_data_vectorized.shape=(1003854,)
len(valid_data)=111540 valid_data_vectorized.shape=(111540,)

train_data[:log_size]='First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou are all resolved rather to die than to famish?\n\nAll:\nResolved. resolved.\n\nFirst Citizen:\nFirst, you know Caius Marcius is chief enemy to the people.\n\nAll:\n'
train_data_vectorized[:log_size]=array([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43,
       44, 53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39,
       52, 63,  1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1,
       51, 43,  1, 57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31,
       54, 43, 39, 49,  6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56,
       57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39,
       56, 43,  1, 39, 50, 50,  1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56,
       39, 58, 46, 

In [8]:
import tensorflow as tf

def make_dataset(data, config):
    t_data = tf.constant(data)
    offsets = tf.range(config.model.block_size + 1, dtype=tf.int64)

    @tf.function
    def to_item(idx):
        item = tf.gather(t_data, offsets + idx)
        return item[:-1], item[1:]

    return tf.data.Dataset.range(data.shape[0] - config.model.block_size - 1).map(
        to_item
    )

ds = make_dataset(train_data_vectorized, config).shuffle(
    train_data_vectorized.shape[0] - config.model.block_size - 1,
).batch(2)
X, y = next(iter(ds))

print(
    f'\n\nTrain Set\n=============='
    f'\n{X.shape=} {y.shape=}'
    f'\n{X.numpy()=}\n{y.numpy()=}'
)

2024-05-02 08:27:05.041971: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M1
2024-05-02 08:27:05.041997: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 8.00 GB
2024-05-02 08:27:05.042004: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 2.67 GB
2024-05-02 08:27:05.042068: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2024-05-02 08:27:05.042115: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:272] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)
2024-05-02 08:27:15.248655: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:422] ShuffleDatasetV3:2: Filling up shuffle buffer (this may take a while): 618993 of 1003597
2024-05-02 0



Train Set
X.shape=TensorShape([2, 256]) y.shape=TensorShape([2, 256])
X.numpy()=array([[50, 50,  1, 57, 47, 52, 45,  8,  0,  0, 29, 33, 17, 17, 26, 10,
         0,  5, 32, 47, 57,  1, 61, 43, 50, 50,  1, 58, 46, 39, 58,  1,
        58, 46, 53, 59,  1, 46, 39, 57, 58,  1, 41, 39, 59, 57, 43,  0,
        14, 59, 58,  1, 58, 46, 53, 59,  1, 57, 46, 53, 59, 50, 42, 57,
        58,  1, 54, 50, 43, 39, 57, 43,  1, 51, 43,  1, 40, 43, 58, 58,
        43, 56,  6,  1, 61, 53, 59, 50, 42, 57, 58,  1, 58, 46, 53, 59,
         1, 61, 43, 43, 54,  8,  0,  0, 24, 39, 42, 63, 10,  0, 21,  1,
        41, 53, 59, 50, 42,  1, 61, 43, 43, 54,  6,  1, 51, 39, 42, 39,
        51,  6,  1, 61, 53, 59, 50, 42,  1, 47, 58,  1, 42, 53,  1, 63,
        53, 59,  1, 45, 53, 53, 42,  8,  0,  0, 29, 33, 17, 17, 26, 10,
         0, 13, 52, 42,  1, 21,  1, 41, 53, 59, 50, 42,  1, 57, 47, 52,
        45,  6,  1, 61, 53, 59, 50, 42,  1, 61, 43, 43, 54, 47, 52, 45,
         1, 42, 53,  1, 51, 43,  1, 45, 53, 53, 42,  6

In [9]:
# config = Props(
#     valid_split=0.1,
#     vocab="\n !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz",
#     model=Props(
#         layers=2,
#         heads=2,
#         dims=8,
#         dropout=0.2,
#         block_size=4,
#         bias=False,
#         vocab_size=65,
#         pos_encoding='rope',
#         use_cache=True,
#     ),
# )

# Configuration for NanoGPT is sourced from
# https://github.com/karpathy/nanoGPT/blob/master/config/train_shakespeare_char.py

def is_interactive():
    return False

config = Props(
    # Dataset
    dataset_size=300 if is_interactive()  else None,
    valid_split=0.1,

    # Model
    model=Props(
        layers=6,
        heads=6,
        dims=384,
        dropout=0.2,
        block_size=32 if is_interactive() else 256,
        bias=False,
        vocab_size=65,
        pos_encoding='rope',
        use_cache=True,
    ),
    
    # Vocab
    vocab="\n !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz",

    # Training
    train=Props(
        epochs=1,
        log_steps=1 if is_interactive() else 200,
#         max_steps=2 if is_interactive() else 5000,
        max_steps=2 if is_interactive() else 600,
        batch_size=2 if is_interactive() else 64,
    ),
    
    # Optimizer
    lr=1e-3,
    beta1=0.9,
    beta2=0.99,
)

print(
    f'{config=}'
)

config={'dataset_size': None, 'valid_split': 0.1, 'model': {'layers': 6, 'heads': 6, 'dims': 384, 'dropout': 0.2, 'block_size': 256, 'bias': False, 'vocab_size': 65, 'pos_encoding': 'rope', 'use_cache': True}, 'vocab': "\n !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz", 'train': {'epochs': 1, 'log_steps': 200, 'max_steps': 600, 'batch_size': 64}, 'lr': 0.001, 'beta1': 0.9, 'beta2': 0.99}


In [15]:
import math

import tensorflow as tf

from tensorflow.keras import layers, initializers

def dense_init(config, size, activation=None, kernel_initializer=None):
    kernel_initializer = kernel_initializer or initializers.RandomNormal(
        mean=0.,
        stddev=0.02,
    )
    return layers.Dense(
        size,
        use_bias=config.bias,
        activation=activation,
        kernel_initializer=kernel_initializer,
        bias_initializer=initializers.Zeros(),
    )

def embed_init(config, size):
    return layers.Embedding(
        size,
        config.dims,
        embeddings_initializer=initializers.RandomNormal(
            mean=0.,
            stddev=0.02,
        ),
    )

import tensorflow as tf

from tensorflow.keras import layers

class RotaryPositionalEncodings(layers.Layer):
    def __init__(self, block_size, dims, base=10000.):
        super(RotaryPositionalEncodings, self).__init__()
        
        self.dims = dims
        self.base = base
        self.block_size = block_size

        self.theta = 1. / (self.base ** (tf.range(0, self.dims, 2, dtype=tf.float32) / self.dims))
        self.positions = tf.range(self.block_size, dtype=tf.float32)
        self.mtheta = self.positions[..., None]*self.theta[None, ...]
        self.mtheta_paired = tf.concat([self.mtheta]*2, axis=-1)

        self.cos_mtheta = tf.math.cos(self.mtheta_paired)
        self.sin_mtheta = tf.math.sin(self.mtheta_paired)

    def call(self, x, index=0):
        shape = tf.shape(x)
        T = shape[-2]

        x_real = x*self.cos_mtheta[index:index+T, ...]

        x_img = tf.concat(
            [-x[..., self.dims//2:], x[..., :self.dims//2]],
            axis=-1
        )*self.sin_mtheta[index:index+T, ...]

        output = x_real + x_img
        return output

# B, H, T, D = 2, 4, 8, 4
# B_i = 0
# x = tf.random.uniform((B, H, T, D))
# l = RotaryPositionalEncodings(64, D)
# output = l(x)
# # tf.print(f'{l.theta=} {l.theta.shape}')
# # tf.print(f'{l.positions=} {l.positions.shape}')
# # tf.print(f'{x[B_i]=}')
# # tf.print(f'{l.cos_mtheta=}')
# # tf.print(f'{l.sin_mtheta=}')
# tf.print(f'{output[B_i]=}\n{output[B_i]=} {output.shape}')

import tensorflow as tf

from tensorflow.keras import layers

class SelfAttention(layers.Layer):
    def __init__(self, config):
        super().__init__()
        self.dims = config.dims
        self.head_size = config.dims // config.heads
        self.heads = config.heads
        
        # Key, query and value projection layers
        self.key = dense_init(config, config.dims)
        self.query = dense_init(config, config.dims)
        self.value = dense_init(config, config.dims)
        
        # Dropout layers
        self.attn_dropout = layers.Dropout(config.dropout)
        self.residual_dropout = layers.Dropout(config.dropout)
        
        # RoPE
        if config.pos_encoding == 'rope':
            self.rope = RotaryPositionalEncodings(config.block_size, self.head_size)
    
    def update_cache(self, k, v, cache, token_axis=1):
        def glue_fn(x, cache_x):
            # Glue the cached sequence and the input on token dimension.
            return tf.concat([cache_x, x], axis=token_axis)

        # 1. Compose full key and value sequences
        k, v = glue_fn(k, cache['key']), glue_fn(v, cache['value'])

        # 2. Update cache
        cache['key'] = k
        cache['value'] = v

        return k, v, cache

    def call(self, x, training=None, cache=None):
        shape = tf.shape(x)
        B = shape[0]
        T = shape[1]
        C = shape[2]
        
        # 1. Compute keys, queries and values for all heads
        # (B, T, dims) -> (B, T, heads, head_size)
        k = tf.reshape(self.key(x), [B, T, self.heads, self.head_size])
        q = tf.reshape(self.query(x), [B, T, self.heads, self.head_size])
        v = tf.reshape(self.value(x), [B, T, self.heads, self.head_size])

        # 1.1. Interact with KVCache
        if cache:
            # 1.2. Compose full key and value sequences, and update cache.
            k, v, cache = self.update_cache(k, v, cache, token_axis=1)

            # 1.3. Initialize RoPE encoding index for query
            seqlen = tf.shape(k)[1]
            rope_q_index = seqlen - 1
        else:
            rope_q_index = 0
        
        # 2. Transpose keys, queries and values to facilitate matrix multiplication
        # (B, T, heads, head_size) -> (B, heads, T, head_size)
        k = tf.transpose(k, perm=[0, 2, 1, 3])
        q = tf.transpose(q, perm=[0, 2, 1, 3])
        v = tf.transpose(v, perm=[0, 2, 1, 3])
        
        # 2.1. Apply RoPE if it is enabled
        if hasattr(self, 'rope'):
            q = self.rope(q, index=rope_q_index)
            k = self.rope(k)
        
        # 3. Compute QK^T
        # (B, heads, T, head_size) @ (B, heads, head_size, T)
        # -> (B, heads, T, T)
        attention = q @ tf.transpose(k, perm=[0, 1, 3, 2])
        attention /= tf.cast(tf.math.sqrt(self.head_size * 1.), dtype=x.dtype)

        # 4. Conpute masked attention scores
        tril = 1. if cache else tf.linalg.band_part(tf.ones((T, T)), -1, 0)
        scores = tf.nn.softmax(tf.where(tril > 0.0, attention, float('-inf')))
        
        # 5. Apply attention dropout
        scores = self.attn_dropout(scores, training=training)
        
        # 6. Attend values
        # (B, heads, T, T) @ (B, heads, T, head_size)
        # -> (B, heads, T, head_size)
        x = scores @ v
        
        # 7. Format output
        # (B, heads, T, head_size) -> (B, T, dims)
        x = tf.reshape(tf.transpose(x, perm=[0, 2, 1, 3]), [B, T, self.dims])
        
        # 8. Apply residual dropout
        x = self.residual_dropout(x, training=training)
        
        return x

# B, T, C = 2, 2, config.model.dims
# l = SelfAttention(config.model)
# l(tf.random.uniform((B, T, C)))

# head_size = config.dims // config.heads
# cache = dict(key=tf.zeros([B, 0, config.heads, head_size]), value=tf.zeros([B, 0, config.heads, head_size]))
# l(tf.random.uniform((B, 1, C)), cache=cache)

import tensorflow as tf

from tensorflow.keras import layers, initializers

class FeedForward(layers.Layer):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.fc = dense_init(config, 4*config.dims, activation='gelu')
        self.projection = dense_init(
            config, config.dims,
            kernel_initializer=initializers.RandomNormal(
                mean=0.,
                stddev=0.02/math.sqrt(2 * config.layers),
            )
        )
        self.dropout = layers.Dropout(config.dropout)
    
    def call(self, x, training=None):
        # (..., dims) -> (..., fc.dims)
        x = self.fc(x)
        
        # (..., fc.dims) -> (..., dims)
        x = self.projection(x)
        
        x = self.dropout(x, training=training)
        
        return x

# B, T, C = 2, 2, config.model.dims
# l = FeedForward(config.model)
# x = tf.reshape(tf.range(B*T*C, dtype='float32'), (B, T, C))
# l(x)

import tensorflow as tf

from tensorflow.keras import layers

class Block(tf.keras.Model):
    def __init__(self, config):
        super().__init__()
        
        self.norm_1 = layers.LayerNormalization()
        self.attention = SelfAttention(config)
        self.norm_2 = layers.LayerNormalization()
        self.feed_forward = FeedForward(config)
    
    def call(self, x, training=None, cache=None):
        x = x + self.attention(self.norm_1(x), training=training, cache=cache)
        x = x + self.feed_forward(self.norm_2(x), training=training)
        
        return x

# l = Block(config.model)
# B, T, C = 2, 2, config.model.dims
# start = 1
# x = tf.reshape(tf.range(B*T*C, dtype='float32'), (B, T, C))
# l(x)

import tensorflow as tf

from tensorflow.keras import layers, metrics
from functools import reduce

class NanoGPT(tf.keras.Model):
    def __init__(self, config):
        super().__init__()
        
        # Input Args
        self.use_cache = config.use_cache
        self.heads = config.heads
        self.head_size = config.dims // config.heads
        self.block_size = config.block_size
        self.num_layers = config.layers
        
        # Model elements
        self.token_embed = embed_init(config, config.vocab_size)
        self.dropout = layers.Dropout(config.dropout)
        self.blocks = list(map(
            lambda _: Block(config),
            range(config.layers),
        ))
        self.norm = layers.LayerNormalization()
        self.head = dense_init(config, config.vocab_size)
        
        # Conditional model elements
        if config.pos_encoding == 'embed':
            self.pos_embed = embed_init(config, config.block_size)
        
        # # Metrics
        # self.trackers = {
        #     'loss': metrics.Mean(name="loss"),
        #     'val_loss': metrics.Mean(name="val_loss"),
        # }
    
    def call(self, x, training=None, cache=None):
        B, T = x.shape
        
        # 1. Get embeddings for input tokens (B, T) -> (B, T, dims)
        x_token_embed = self.token_embed(x)
        
        # 2. Get position embeddings for input tokens (B, T) -> (B, T, dims)
        if hasattr(self, 'pos_embed'):
            x_pos_embed = self.pos_embed(tf.range(T))
        else:
            x_pos_embed = 0.
        
        # 3. Combine token and position embeddings
        x = self.dropout(x_token_embed + x_pos_embed, training=training)
        
        # 4. Apply blocks
        x = reduce(
            lambda y,item: item[0](y, training=training, cache=item[1]),
            zip(self.blocks, cache if cache else [None]*self.num_layers),
            x,
        )
        
        # 5. Apply layer norm
        x = self.norm(x)
        
        # 6. Prediction head (B, T, dims) -> (B, T, vocab_size)
        x = self.head(x)
        
        return x
    
    @tf.function
    def train_step(self, data):
        # 1. Separate input and target
        X, y = data
        y = tf.cast(y, dtype=tf.float32)
        
        # 2. Compute loss (B, T) -> (B, T, vocab_size)
        with tf.GradientTape() as tape:
            logits = self(X, training=True)
            loss = self.compute_loss(y=y, y_pred=logits)

        # 3. Compute and apply gradients
        grads = tape.gradient(loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        
        return self.record('loss', loss)
    
    @tf.function
    def test_step(self, data):
        # 1. Separate input and target
        X, y = data

        # 2. Compute loss (B, T) -> (B, T, vocab_size)
        logits = self(X, training=False)
        loss = self.compute_loss(y=y, y_pred=logits)
    
        return self.record('val_loss', loss)

    def init_cache(self, sequence):
        # 1. Get the batch size to initialize the cache
        B = sequence.shape[0]
        
        # 2. Initialize cache shape
        shape = (B, 0, self.heads, self.head_size)
        
        # 2. Initialize a cache for each attention block
        return list(map(
            # 3. Keys and values are initialized with zeros
            lambda _: dict(
                key=tf.zeros(shape),
                value=tf.zeros(shape),
            ),
            self.blocks,
        ))
    
    @tf.function
    def generate(self, token_idx, num_tokens):
        # 1. Initial the first token, the cache, and the input sequence length
        sequence = token_idx
        cache = self.init_cache(sequence) if self.use_cache else None
        seq_len = 1 if self.use_cache else self.block_size
        
        # 2. Token generation loop
        for index in range(num_tokens):
            logits = self(sequence[:, -seq_len:], training=False, cache=cache)[:, -1, :]
            token_idx =  tf.random.categorical(logits, 1, dtype=tf.int32)
            sequence = tf.concat([sequence, token_idx], axis=-1)

        return sequence
    
    @property
    def metrics(self):
        return list(self.trackers.values())

    def record(self, name, loss):
        self.trackers[name].update_state(loss)
        return {m.name: m.result() for m in self.metrics}

# block_size = 16
# m = NanoGPT(config)

# m(tf.random.uniform((2, block_size)))
# m.compile(optimizer=optimizer, loss=loss_fn)


# B, T = 2, 2
# start = 1
# x = tf.reshape(tf.range(B*T), (B, T))
# m(x)
# m.summary(expand_nested=True)

m = NanoGPT(config.model)
# m.load_weights('nodegpt.weights.h5')

In [16]:
@tf.function
def generate(self, token_idx, num_tokens):
    # 1. Initial the first token, the cache, and the input sequence length
    sequence = token_idx
    cache = self.init_cache(sequence) if self.use_cache else None
    seq_len = 1 if self.use_cache else self.block_size
    
    # 2. Token generation loop
    for index in range(num_tokens):
        logits = self(sequence[:, -seq_len:], training=False, cache=cache)[:, -1, :]
        token_idx =  tf.random.categorical(logits, 1, dtype=tf.int32)
        sequence = tf.concat([sequence, token_idx], axis=-1)

    return sequence

# 1. Pick start character
NEWLINE_ID = config.vocab.index('\n')
sequence = tf.constant([NEWLINE_ID], shape=(1, 1), dtype=tf.int32)
# num_tokens = config.model.block_size
num_tokens = 16

for _ in range(num_tokens):
    if _ % 10 == 0: print(f'Token: {_}')
    logits = m(sequence[:, -config.model.block_size:], training=False)[:, -1, :]
    token_idx =  tf.random.categorical(logits, 1, dtype=tf.int32)
    sequence = tf.concat([sequence, token_idx], axis=-1)

generated_text = decoder(sequence[0].numpy())
print(f'{generated_text=}')

Token: 0
Token: 10
generated_text='\n&&KeYTCeroKQRncI'


In [81]:
import copy

caches = []

def dump_cache(cache):
    caches.append(copy.deepcopy(cache))

# 1. Pick start character
NEWLINE_ID = config.vocab.index('\n')
sequence = tf.constant([NEWLINE_ID], shape=(1, 1), dtype=tf.int32)

cache = m.init_cache(sequence) if m.use_cache else None
seq_len = 1 if m.use_cache else m.block_size

# 2. Token generation loop
for index in range(4):
    logits = m(sequence[:, -seq_len:], training=False, cache=cache)[:, -1, :]
    token_idx =  tf.random.categorical(logits, 1, dtype=tf.int32)
    sequence = tf.concat([sequence, token_idx], axis=-1)
    dump_cache(cache)

tril=1.0
scores=<tf.Tensor: shape=(1, 2, 1, 1), dtype=float32, numpy=
array([[[[1.]],

        [[1.]]]], dtype=float32)>
tril=1.0
scores=<tf.Tensor: shape=(1, 2, 1, 1), dtype=float32, numpy=
array([[[[1.]],

        [[1.]]]], dtype=float32)>
tril=1.0
scores=<tf.Tensor: shape=(1, 2, 1, 2), dtype=float32, numpy=
array([[[[0.49945444, 0.50054556]],

        [[0.49998766, 0.50001234]]]], dtype=float32)>
tril=1.0
scores=<tf.Tensor: shape=(1, 2, 1, 2), dtype=float32, numpy=
array([[[[0.4997439 , 0.50025606]],

        [[0.5000082 , 0.4999918 ]]]], dtype=float32)>
tril=1.0
scores=<tf.Tensor: shape=(1, 2, 1, 3), dtype=float32, numpy=
array([[[[0.33344847, 0.33319923, 0.3333523 ]],

        [[0.3331369 , 0.33317405, 0.33368906]]]], dtype=float32)>
tril=1.0
scores=<tf.Tensor: shape=(1, 2, 1, 3), dtype=float32, numpy=
array([[[[0.33318198, 0.3332842 , 0.33353382]],

        [[0.3331056 , 0.33318123, 0.33371317]]]], dtype=float32)>
tril=1.0
scores=<tf.Tensor: shape=(1, 2, 1, 4), dtype=float32, num

In [288]:
import math

import tensorflow as tf

from tensorflow.keras import layers

class GroupedQueryAttention(tf.keras.Model):
    def __init__(self, cache_size, block_size, heads, kv_heads, dims):
        super(GroupedQueryAttention, self).__init__()
        
        self.heads = heads
        self.head_size = dims // heads
        self.kv_heads = kv_heads or heads
        
        self.query = tf.keras.layers.Dense(dims, use_bias=False)
        self.key = tf.keras.layers.Dense(kv_heads*self.head_size, use_bias=False)
        self.value = tf.keras.layers.Dense(kv_heads*self.head_size, use_bias=False)
        
        self.cache = KVCache(cache_size, block_size, self.kv_heads, self.head_size)
        self.rope = RotaryPositionalEncodings(block_size, self.head_size)
        
    def call(self, x, start=0, inference=False):
        shape = tf.shape(x)
        B, T = shape[0], shape[1]

        q = self.query(x)
        k = self.key(x)
        v = self.value(x)
        
        # (B, T, dims) -> (B, T, heads/kv_heads, head_dims)
        q = tf.reshape(q, [B, T, self.heads, self.head_size])
        k = tf.reshape(k, [B, T, self.kv_heads, self.head_size])
        v = tf.reshape(v, [B, T, self.kv_heads, self.head_size])
        
        # RoPE expects inputs in (B, heads/kv_heads, T, head_size) format.
        # Transpose(B, T, heads/kv_heads, head_size) -> (B, heads/kv_heads, T, head_size)
        q = self.rope(tf.transpose(q, perm=[0, 2, 1, 3]), start=start)
        k = self.rope(tf.transpose(k, perm=[0, 2, 1, 3]), start=start)
        
        if inference:
            # Update KV cache
            # KV cache expects inputs in (B, T, heads/kv_heads, head_size) format.
            self.cache.update(
                start,
                tf.transpose(k, perm=[0, 2, 1, 3]),
                v,
            )
            
            # Get prefix context from the cache.
            # (B, start+T, ...)
            k, v = self.cache.get(B, start, T)
            
            ## Replicate KV heads to match query heads.
            # (B, start+T, heads/kv_heads, head_size) -> (B, start+T, heads, head_size)
            k = tf.tile(k, multiples=(1, 1, self.heads//self.kv_heads, 1))
            
            # (B, T, heads/kv_heads, head_dims) -> (B, T, heads, head_dims)
            v = tf.tile(v, multiples=(1, 1, self.heads//self.kv_heads, 1))
            
            # (B, heads/kv_heads, T, head_size) @ (B, heads/kv_heads, head_size, start+T)
            # -> (B, heads/kv_heads, T, start+T)
            wei = q @ tf.transpose(k, perm=[0, 2, 3, 1])
            wei /= tf.cast(tf.math.sqrt(self.head_size * 1.), dtype=x.dtype)
            wei = tf.nn.softmax(wei)
        else:
            assert start == 0
            
            ## Replicate KV heads to match query heads.
            # (B, heads/kv_heads, T, head_size) -> (B, heads, T, head_size)
            k = tf.tile(k, multiples=(1, self.heads//self.kv_heads, 1, 1))
            
            # (B, T, heads/kv_heads, head_dims) -> (B, T, heads, head_dims)
            v = tf.tile(v, multiples=(1, 1, self.heads//self.kv_heads, 1))
            
            # (B, heads/kv_heads, T, head_size) @ (B, heads/kv_heads, head_size, start+T)
            # -> (B, heads/kv_heads, T, start+T)
            wei = q @ tf.transpose(k, perm=[0, 1, 3, 2])
            wei /= tf.cast(tf.math.sqrt(self.head_size * 1.), dtype=x.dtype)

            tril = tf.linalg.band_part(tf.ones((T, T)), -1, 0)
            wei = tf.nn.softmax(tf.where(tril > 0.0, wei, float('-inf')))

        # (B, heads/kv_heads, T, start+T) @ (B, heads/kv_heads, start+T, head_dims)
        # -> (B, heads/kv_heads, T, head_dims)
        out = wei @ tf.transpose(v, perm=[0, 2, 1, 3])
        
        # (B, heads/kv_heads, T, head_dims) -> (B, T, heads/kv_heads, head_dims)
        out = tf.transpose(out, perm=[0, 2, 1, 3])
        
        return tf.reshape(out, shape=(B, T, -1))

# l = GroupedQueryAttention(
#     block_size=16,
#     heads=4,
#     kv_heads=2,
#     dims=16,
#     cache_size=4
# )

# B, T, C = 2, 2, 16
# start = 1
# x = tf.reshape(tf.range(B*T*C), (B, T, C))
# output = l(x, start)

# B_i, T_i, C_i = 0, 0, 3
# H_i, F_i = 0, C_i

# print(
#     f'{x[B_i]=}'
#     f'\n{l.cache.cache_k.shape=}\n{l.cache.cache_k[B_i, :start+T]=}'
#     f'\n{l.cache.cache_v[B_i, :start+T]=}'
# )

import math

import tensorflow as tf

from tensorflow.keras import layers

class SlidingWindowAttention(tf.keras.Model):
    def __init__(self, config):
        super(SlidingWindowAttention, self).__init__()

        self.config = config
        self.head_size = config.dims // config.heads
        
        self.query = tf.keras.layers.Dense(config.dims, use_bias=False)
        self.key = tf.keras.layers.Dense(config.kv_heads*self.head_size, use_bias=False)
        self.value = tf.keras.layers.Dense(config.kv_heads*self.head_size, use_bias=False)
        
        self.rope = RotaryPositionalEncodings(config.block_size, self.head_size)
    
    def as_strided(self, x):
        shape = tf.shape(x)
        B, T = shape[0], shape[2]

        # (B, heads/kv_heads, T, head_size) -> (B, heads/kv_heads, T+2*window_size, head_size)
        padded_x = tf.pad(x, [[0, 0], [0, 0], [self.config.window_size]*2, [0, 0]])

        # indices.shape = (T, window_size + 1)
        indices = tf.tile(tf.reshape(tf.range(T), (-1, 1)), [1, self.config.window_size + 1])
        indices = indices + tf.expand_dims(tf.range(self.config.window_size + 1), axis=0)

        # (B, heads/kv_heads, T, window_size + 1)
        strided_x = tf.gather(padded_x, indices, axis=2)

#         print(
#             f'{x.shape=}'
#             f'\n{indices=}'
#             f'\n{padded_x.shape=}'
#             f'{strided_x=}'
#             f'\n{x=}'
#         )

        return strided_x
    
    def as_grid(self, x):
        # (B, heads, T, window_size+1) -> (B, heads, T, T)
        diagonals = tf.transpose(x, perm=[0, 1, 3, 2])[..., ::-1, :]
        grid_x = tf.linalg.diag(diagonals, k=(-self.config.window_size, 0), align='RIGHT_RIGHT')

        # print(
        #     f'\n{x.shape=} {diagonals.shape=}'
        #     f'\n{grid_x.shape=}'
        #     f'\n{x[0, 0]=}'
        #     # f'\n{x[0, 0, -2:]=}'
        #     f'\n{grid_x[0, 0]=}'
        #     # f'\n{grid_x[0, 0, -2:]=}'
        #     f'\n{tf.math.reduce_sum(x[..., 0, 0])=}'
        #     f'\n{tf.reduce_all(x[..., 1, :] == grid_x[..., 1, :2])=}'
        #     f'\n{tf.reduce_all(x[..., -1, :] == grid_x[..., -1, -2:])=}'
        # )

        return grid_x
        
    def call(self, x, start=0, inference=False):
        shape = tf.shape(x)
        B, T = shape[0], shape[1]

        q = self.query(x)
        k = self.key(x)
        v = self.value(x)
        
        # (B, T, dims) -> (B, T, heads/kv_heads, head_dims)
        q = tf.reshape(q, [B, T, self.config.heads, self.head_size])
        k = tf.reshape(k, [B, T, self.config.kv_heads, self.head_size])
        v = tf.reshape(v, [B, T, self.config.kv_heads, self.head_size])
        
        # RoPE expects inputs in (B, heads/kv_heads, T, head_size) format.
        # Transpose(B, T, heads/kv_heads, head_size) -> (B, heads/kv_heads, T, head_size)
        q = self.rope(tf.transpose(q, perm=[0, 2, 1, 3]), start=start)
        k = self.rope(tf.transpose(k, perm=[0, 2, 1, 3]), start=start)
            
        ## Replicate KV heads to match query heads.
        # (B, heads/kv_heads, T, head_size) -> (B, heads, T, head_size)
        k = tf.tile(k, multiples=(1, self.config.heads//self.config.kv_heads, 1, 1))
        
        # (B, T, heads/kv_heads, head_dims) -> (B, T, heads, head_dims)
        v = tf.tile(v, multiples=(1, 1, self.config.heads//self.config.kv_heads, 1))

        # Compose strided keys to mimic the sliding window. 
        # (B, heads, T, head_size) -> (B, heads, T, window_size+1, head_size)
        strided_k = self.as_strided(k)

        # (B, heads, T, head_size) @ (B, heads, T, window_size+1, head_size)
        # -> (B, heads, T, window_size+1)
        wei = tf.einsum('bhtd,bhtxd->bhtx', q, strided_k)

        # (B, heads, T, window_size+1) -> (B, heads, T, T)
        wei = self.as_grid(wei)
        wei /= tf.cast(tf.math.sqrt(self.head_size * 1.), dtype=x.dtype)

        tril = tf.linalg.band_part(tf.ones((T, T)), -1, 0)
        wei = tf.nn.softmax(tf.where(tril > 0.0, wei, float('-inf')))

        # (B, heads, T, T) @ (B, heads, T, head_dims)
        # -> (B, heads, T, head_dims)
        out = wei @ tf.transpose(v, perm=[0, 2, 1, 3])
        
        # (B, heads, T, head_dims) -> (B, T, heads, head_dims)
        out = tf.transpose(out, perm=[0, 2, 1, 3])
        
        return tf.reshape(out, shape=(B, T, -1))

# l = SlidingWindowAttention(
#     block_size=16,
#     heads=4,
#     kv_heads=2,
#     dims=32,
#     window_size=2
# )

# B, T, C = 2, 2, 32
# x = tf.cast(tf.reshape(tf.range(B*T*C), (B, T, C)), dtype=tf.float32)
# output = l(x)

# B_i, T_i, C_i = 0, 0, 3
# H_i, F_i = 0, C_i

# print(
#     f'{x[B_i]=}'
#     f'{output.shape=}'
# )

import tensorflow as tf

class SelfAttentionLayer(tf.keras.layers.Layer):
    def __init__(self, cache_size, block_size, head_size):
        
        super().__init__()
        # Input args
        self.head_size = head_size
        
        self.key = tf.keras.layers.Dense(head_size, use_bias=False)
        self.query = tf.keras.layers.Dense(head_size, use_bias=False)
        self.value = tf.keras.layers.Dense(head_size, use_bias=False)
        
        self.cache = KVCache(cache_size, block_size, head_size)
        self.rope = RotaryPositionalEncodings(block_size, head_size)

    def call(self, x, start=0, inference=False):
        B, T, C = x.shape

        k = self.key(x)
        q = self.query(x)
        v = self.value(x)
        
        # Apply RoPE
        q = self.rope(q, start=start)
        k = self.rope(k, start=start)
        
        if inference:
            # Update KV cache
            self.cache.update(start, k, v)
            
            # Get prefix context from the cache.
            # (B, start+T, head_dims)
            k, v = self.cache.get(B, start, T)
            
            # (B, T, head_size) @ (B, head_size, start+T) --> (B, T, start+T)
            wei = q @ tf.transpose(k, perm=[0, 2, 1])
            wei /= tf.cast(tf.math.sqrt(self.head_size * 1.), dtype=x.dtype)
            wei = tf.nn.softmax(wei)
        else:
            assert start == 0
            
            # (B, T, head_size) @ (B, head_size, T) --> (B, T, T)
            wei = q @ tf.transpose(k, perm=[0, 2, 1])
            wei /= tf.cast(tf.math.sqrt(self.head_size * 1.), dtype=x.dtype)

            tril = tf.linalg.band_part(tf.ones((T, T)), -1, 0)
            wei = tf.nn.softmax(tf.where(tril > 0.0, wei, float('-inf')))
            
        out = wei @ v
        return out

class MultiHeadAttentionLayer(tf.keras.layers.Layer):
    def __init__(self, cache_size, block_size, num_heads, head_size):
        super().__init__()
        self.attn_layers = [SelfAttentionLayer(cache_size, block_size, head_size) for i in range(num_heads)]

    def call(self, x, start=0, inference=False):
        return tf.concat([
            attn_layer(x, start=start, inference=inference) for attn_layer in self.attn_layers
        ], axis=-1)

# cache_size, block_size, num_heads, head_size = 2, 16, 1, 16
# B, T, C = 2, 2, num_heads*head_size

# l = MultiHeadAttentionLayer(cache_size, block_size, num_heads, head_size)

# for start in range(T):
#     print(
#         f'\n{start=}::{l(tf.random.uniform((B, 1, C)), start=start, inference=True).shape}\n'
#     )

import tensorflow as tf

from tensorflow.keras import layers

class RMSNorm(layers.Layer):
    def __init__(self, eps=1e-8):
        super(RMSNorm, self).__init__()
        self.eps = eps
    
    def build(self, shape):
        self.w = self.add_weight(
            "kernel",
            shape=(1, shape[-1]),
            initializer='ones',
            trainable=True,
        )
    
    def norm(self, x):
        return x*tf.math.rsqrt(tf.math.reduce_mean(x**2, axis=-1, keepdims=True) + self.eps)
        
    def call(self, x):
        return self.w * self.norm(x)

# data = tf.constant(np.arange(20).reshape(5, 2, 2) * 10, dtype=tf.float32)

# l=RMSNorm()
# l.build(data.shape)
# l(data)

import tensorflow as tf

from tensorflow.keras import layers

class FeedForward(layers.Layer):
    def __init__(self, dims):
        super(FeedForward, self).__init__()

        self.hidden_dims = 4*dims
        self.linear_1 = layers.Dense(self.hidden_dims, use_bias=False, activation='swish')
        self.linear_2 = layers.Dense(self.hidden_dims, use_bias=False)
        self.linear_3 = layers.Dense(dims, use_bias=False)
    
    def call(self, x):
        # (..., dims) -> (..., hidden_dims)
        x = self.linear_1(x)*self.linear_2(x)
        
        # (..., hidden_dims) -> (..., dims)
        x = self.linear_3(x)
        
        return x

# B, T, C = 2, 2, 16
# l = FeedForward(C)
# x = tf.reshape(tf.range(B*T*C, dtype='float32'), (B, T, C))
# l(x)

import tensorflow as tf

from tensorflow.keras import layers

class AttentionSelector(object):
    def __init__(self, config, *args, **kwargs):
        super(AttentionSelector, self).__init__(*args, **kwargs)
        
        self.config = config
    
    @property
    def choice(self):
        return (self.config.attention or 'msa').lower()
    
    def select(self):
        if self.choice in ['slidingwindow', 'slidingwindowattention', 'swa']:
            return SlidingWindowAttention(self.config)
        elif self.choice in ['groupedquery', 'groupedqueryattention', 'gqa']:
            return GroupedQueryAttention(self.config)
        else:
            return MultiHeadAttentionLayer(self.config)
    
    def call(self, l, x, start=0, inference=False):
        if self.choice in ['slidingwindow', 'slidingwindowattention', 'swa']:
            return l(x)
        elif self.choice in ['groupedquery', 'groupedqueryattention', 'gqa']:
            return l(x, start=start, inference=inference)
        else:
            return l(x, start=start, inference=inference)

class NormSelector(object):
    def __init__(self, config, *args, **kwargs):
        super(NormSelector, self).__init__(*args, **kwargs)
        self.config = config
    
    def select(self):
        choice = (self.config.norm or 'rms').lower()
        if choice in ['batchnorm', 'bn', 'batchnormalization']:
            return layers.BatchNormalization()
        else:
            return RMSNorm()

class LlamaBlock(tf.keras.Model):
    def __init__(self, config, *args, **kwargs):
        super(LlamaBlock, self).__init__(*args, **kwargs)
        
        self.norm_selector = NormSelector(config)
        self.attention_selector = AttentionSelector(config)
        
        self.norm_1 = self.norm_selector.select()
        self.attention = self.attention_selector.select()
        self.norm_2 = self.norm_selector.select()
    
    def call(self, x, start, inference):
        x += self.attention_selector.call(
            self.attention,
            self.norm_1(x), start=start, inference=inference
        )
        x = self.norm_2(x)
#         x += self.feed_forward(self.norm_2(x))
        
        return x

# l = LlamaBlock(
#     block_size=16,
#     heads=4,
#     kv_heads=2,
#     dims=16,
#     cache_size=4
# )

# B, T, C = 2, 2, 16
# start = 1
# x = tf.reshape(tf.range(B*T*C, dtype='float32'), (B, T, C))
# l(x)

import tensorflow as tf

from tensorflow.keras import layers, metrics
from functools import reduce

class LlamaModel(tf.keras.Model):
    def __init__(self, config, loss_fn, *args, **kwargs):
        super(LlamaModel, self).__init__(*args, **kwargs)
        # Args
        self.config = config

        # Model elements
        self.embeddings = layers.Embedding(config.vocab_size, config.embed_dims)
        self.pos_embeddings = layers.Embedding(config.block_size, config.embed_dims) if config.pos_embeddings else None
        self.dec_blocks = [LlamaBlock(config) for _ in range(config.decoders)]

        self.head = layers.Dense(config.vocab_size, use_bias=False)

        # Loss 
        self.loss_fn = loss_fn

        # Metrics
        self.loss_tracker = metrics.Mean(name="loss")
    
    def alternating_reduce(self, fn, blocks, x):
        for b_index, block in enumerate(blocks):
            # Pick the sub-block for transformer branch
            pivot = b_index % self.config.alternate_blocks

            # Split input into subblocks
            xs = tf.split(x, self.config.alternate_blocks, axis=-1)

            # Process picked block
            x = fn(xs[pivot], block)

            # Re-compose
            x = tf.concat([*xs[:pivot], x, *xs[(pivot + 1):]], axis=-1)

        return x
    
    def call(self, x, start=0, inference=False):
        B, T = x.shape
        
        x_pos_embed = self.pos_embeddings(tf.range(T)) if self.pos_embeddings else None
        x_embed = self.embeddings(x)
        
        x = x_embed + x_pos_embed if self.pos_embeddings else x_embed
        
        reduce_fn = self.alternating_reduce if self.config.alternate_blocks else reduce
        
        x = reduce_fn(
            lambda y,dec_block: dec_block(y, start=start, inference=inference),
            self.dec_blocks,
            x
        )
        x = self.head(x)
        
        return x
    
    def train_step(self, data):
        X, y = data
        y = tf.cast(y, dtype=tf.float32)

        with tf.GradientTape() as tape:
            logits = self(X)
            loss = self.loss_fn(y, logits)

        grads = tape.gradient(loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        
        return self.record(loss, y, logits)
    
    def test_step(self, data):
        X, y = data

        logits = self(X)
        loss = self.loss_fn(y, logits)
    
        return self.record(loss, y, logits)
    
    @property
    def metrics(self):
        return [self.loss_tracker]
    
    def record(self, loss, y, logits):
        self.loss_tracker.update_state(loss)

        return {m.name: m.result() for m in self.metrics}

model_config = ModelConfig(
    vocab_size=len(model_vocab),
    
    # Model Size and Prediction Capacity
    decoders=4,
    block_size=64,
    embed_dims=512,
    dims=128,
    
    # Attention Params
    attention='swa',
    heads=8,
    kv_heads=8,
    
    # KV Cache
    cache_size=2,
    
    # Sliding Window Attention
    window_size=3,
    
    # Features
    pos_embeddings=False,
    alternate_blocks=4,
)

llama = LlamaModel(config=model_config, loss_fn=loss_fn)
llama(tf.random.uniform((2, llama.config.block_size), minval=0, maxval=llama.config.vocab_size, dtype=tf.int32))
llama.compile(optimizer=optimizer, loss=loss_fn)
llama.summary(expand_nested=True)

Model: "llama_model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 embedding_2 (Embedding)     multiple                  33280     
                                                                 
 llama_block_1 (LlamaBlock)  multiple                  49408     
|¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯|
| rms_norm_2 (RMSNorm)       multiple                  128      |
|                                                               |
| sliding_window_attention_  multiple                  49152    |
| 16 (SlidingWindowAttentio                                     |
| n)                                                            |
||¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯||
|| dense_99 (Dense)          multiple                  16384   ||
||                                                             ||
|| dense_100 (Dense)         multiple                

In [290]:
# llama = LlamaModel(config=model_config, loss_fn=loss_fn)

# config_data = model_config._asdict()
# config_data['kv_heads'] = config_data['kv_heads']//grouping_factor

# grouped_config = ModelConfig(**config_data)

# llama_grouped.compile(optimizer=optimizer, loss=loss_fn)
# llama_grouped(tf.random.uniform((2, 5), minval=0, maxval=llama_grouped.vocab_size, dtype=tf.int32))
# llama_grouped.summary(expand_nested=True)

In [291]:
def get_text_stats(text, tokenizer, fns=[]):
    def stats_fn(token):
        return np.array([token.lower_] + [fn(token) for fn in fns])
    
    tokens = tokenizer(text)
    stats = np.stack([stats_fn(token) for token in tokens])

    return stats

def random_token_from_logits(logits, samples=1):
    return tf.random.categorical(logits, samples, dtype=tf.int32)

def argmax_token(logits, samples=1):
    return tf.expand_dims(
        tf.math.argmax(tf.math.softmax(logits), axis=-1, output_type=tf.int32),
        axis=-1
    )

def generate(model, token_idx, tokens, randomize=True):
    sequence = token_idx
    for start in range(tokens - 1):
        logits = model(token_idx, start=start, inference=True)[:, -1, :]
#         token_idx =  tf.random.categorical(logits, 1, dtype=tf.int32)
        token_idx = random_token_from_logits(logits) if randomize else argmax_token(logits)
        sequence = tf.concat([sequence, token_idx], axis=-1)
    
    return sequence

def generate_no_cache(model, token_idx, tokens, randomize=True):
    sequence = token_idx
    for _ in range(tokens - 1):
        logits = model(sequence[:, -model.block_size:])[:, -1, :]
#         token_idx =  tf.random.categorical(logits, 1, dtype=tf.int32)
        token_idx = random_token_from_logits(logits) if randomize else argmax_token(logits)
        sequence = tf.concat([sequence, token_idx], axis=-1)
    
    return sequence

def generate_and_evaluate(model_vocab, generator, model, tokens=50, randomize=True):
    idx_to_char = {i:c for c,i in model_vocab.items()}
    decoder = lambda x: ''.join([idx_to_char[i] for i in x])

    starter = tf.constant([model_vocab[' ']], shape=(1, 1), dtype=tf.int32)
    generated_text = decoder(
        generator(
            model,
            starter,
            tokens=tokens,
            randomize=randomize,
        )[0].numpy()
    )

    print(f'{generator.__name__}:: {generated_text=}')

# word_vocab = get_word_1gram_frequencies(tiny_shakespere, word_tokenizer)
# print(
#     f'\nVocab: {list(word_vocab.keys())[:5]}'
# )
# generate_and_evaluate(model_vocab, generate, llama, randomize=False, tokens=min(50, llama.block_size))
# generate_and_evaluate(model_vocab, generate_no_cache, llama, randomize=False, tokens=min(50, llama.block_size))

In [None]:
import itertools

def get_layers(model, exceptions=[]):
    def fn(l):
        if hasattr(l, 'layers'):
            return list(itertools.chain(*map(fn, l.layers)))
        elif l.name in exceptions:
            return [None]
        else:
            return [l]
    
    return list(filter(lambda x: x is not None, fn(model)))


def apply_groupings(src_model, target_model, factor=2):
    def grouped_apply(l1, l2):
        grouped_w1s = []

        for w1, w2 in zip(l1.get_weights(), l2.get_weights()):
            w1_groups = tf.split(w1, factor, axis=-1)
            print(f'\nSplit {src_model.name}.{l1.name}({w1.shape}) --> {factor}x{w1_groups[0].shape}')

            w1_grouped = tf.math.add(*w1_groups) / factor
            print(f'Average {factor}x{w1_groups[0].shape} --> {w1_grouped.shape}')

            grouped_w1s.append(w1_grouped)
        
        # Apply grouped weights to the target
        l2.set_weights(grouped_w1s)
        print(f'Copied? {np.array_equal(grouped_w1s, l2.get_weights())}')

    src_attention = src_model.dec_blocks[0].attention
    target_attention = target_model.dec_blocks[0].attention

    src_other_layers = get_layers(
        src_model,
        exceptions=list(map(lambda x: x.name, [src_attention.key, src_attention.value]))
    )

    target_other_layers = get_layers(
        target_model,
        exceptions=list(map(lambda x: x.name, [target_attention.key, target_attention.value]))
    )

    for src_layer, target_layer in zip(src_other_layers, target_other_layers):
        target_layer.set_weights(src_layer.get_weights())
    
    # Group keys and values.
    print(f'\n Grouping Layers ({factor=})\n============================')
    grouped_apply(src_attention.key, target_attention.key)
    grouped_apply(src_attention.value, target_attention.value)

def compare_weights(one, two):
    one_layers = get_layers(one)
    two_layers = get_layers(two)

    print('\n Comparison Status\n===================')

    for ol, tl in filter(lambda x: x[0].get_weights(),  zip(one_layers, two_layers)):
        if np.array_equal(ol.get_weights(), tl.get_weights()):
            print(f'{ol.name} == {tl.name}')
        else:
            print(f'\t{ol.name} != {tl.name}')

# apply_groupings(llama, llama_grouped)
# compare_weights(llama, llama_grouped)

In [None]:
def get_grouped_generator(grouped_model, grouping_factor, base_fn=generate_no_cache):
    def fn(model, *args, **kwargs):
        apply_groupings(model, grouped_model)
        return base_fn(grouped_model, *args, **kwargs)
    
    return fn

generator_grouped = get_grouped_generator(llama_grouped, grouping_factor=grouping_factor)
generator_grouped_cached = get_grouped_generator(llama_grouped, grouping_factor=grouping_factor, base_fn=generate)

generate_and_evaluate(model_vocab, generator, llama, randomize=False, tokens=min(50, llama.block_size))


 Grouping Layers (factor=2)

Split llama_model_37.dense_384((128, 128)) --> 2x(128, 64)
Average 2x(128, 64) --> (128, 64)
Copied? True

Split llama_model_37.dense_385((128, 128)) --> 2x(128, 64)
Average 2x(128, 64) --> (128, 64)
Copied? True
fn:: generated_text=' sssawaksedreerederedoubexfoctleereere:rere:reree:'


In [None]:
# llama.layers[2].layers[1].attn_layers[0].logs['cache']['k']
cache_logs = llama.layers[2].layers[1].attn_layers[0].logs['cache']
no_cache_logs = llama.layers[2].layers[1].attn_layers[0].logs['no-cache']

log_items = len(cache_logs['k'])

# for key in ['in_x', 'k', 'q', 'v', 'rope_k', 'rope_q', 'wei', 'out']:
for key in ['wei']:
    for idx in [1]:
        print(
            f'{key}({idx})::{tf.math.reduce_all(cache_logs[key][idx] == no_cache_logs[key][idx+1][..., idx:, :])}'
            f' {cache_logs[key][idx]} {no_cache_logs[key][idx+1][..., idx:, :]}'
            f' {cache_logs[key][idx].shape} {no_cache_logs[key][idx+1][..., idx:, :].shape}'
        )

wei(1)::False [[[21.616919 43.966854]]] [[[21.616917 43.966858]]] (1, 1, 2) (1, 1, 2)


In [None]:
# tf.math.reduce_all(cache_logs['cache_k'][1] == no_cache_logs['rope_k'][2])
index = 1
print(
    f"{tf.math.reduce_all(cache_logs['wei'][index] == no_cache_logs['wei'][index+1])}"
    # f"\n{tf.math.reduce_all(cache_logs['cache_v'][1] == no_cache_logs['v'][2])}"
    f"\n{tf.math.reduce_all(cache_logs['cache_k'][index] == no_cache_logs['rope_k'][index+1])}",
    # f"\n{tf.math.reduce_all(cache_logs['cache_v'][index] == no_cache_logs['v'][index+1][:, 1:, :])}",
    f"\n{tf.math.reduce_all(cache_logs['rope_q'][index] == no_cache_logs['rope_q'][index+1][:, 1:, :])}",
    # f"\n{cache_logs['rope_q'][index]=}"
    # f"\n{no_cache_logs['rope_q'][index+1]=}"
    # f"\n{cache_logs['cache_k'][index]=}"
    # f"\n{no_cache_logs['rope_k'][index+1]=}"
    f"\n{cache_logs['wei'][index]=}"
    f"\n{no_cache_logs['wei'][index+1]=}"
    f"\n{cache_logs['wei'][index]=}"
    f"\n{no_cache_logs['wei'][index+1]=}"
)
# cache_logs['out'][1], no_cache_logs['out'][2][..., 1:, :]

False
True 
True 
cache_logs['wei'][index]=<tf.Tensor: shape=(1, 1, 2), dtype=float32, numpy=array([[[21.616919, 43.966854]]], dtype=float32)>
no_cache_logs['wei'][index+1]=<tf.Tensor: shape=(1, 2, 2), dtype=float32, numpy=
array([[[29.863405 ,  7.2788296],
        [21.616917 , 43.966858 ]]], dtype=float32)>
cache_logs['wei'][index]=<tf.Tensor: shape=(1, 1, 2), dtype=float32, numpy=array([[[21.616919, 43.966854]]], dtype=float32)>
no_cache_logs['wei'][index+1]=<tf.Tensor: shape=(1, 2, 2), dtype=float32, numpy=
array([[[29.863405 ,  7.2788296],
        [21.616917 , 43.966858 ]]], dtype=float32)>


In [None]:
c_q = cache_logs['rope_q'][index]
c_k = cache_logs['cache_k'][index]

c_q, c_k, c_q @ tf.transpose(c_k, perm=[0, 2, 1])

(<tf.Tensor: shape=(1, 1, 16), dtype=float32, numpy=
 array([[[ 2.697421  , -2.4562492 , -1.5515298 ,  0.14776081,
           0.37695628, -1.5206785 ,  0.8845265 ,  0.17617644,
          -2.1475816 ,  1.1095126 , -0.10803113, -4.174893  ,
           0.12987702,  0.36463377, -4.113693  , -0.7377639 ]]],
       dtype=float32)>,
 <tf.Tensor: shape=(1, 2, 16), dtype=float32, numpy=
 array([[[ 4.474342  , -1.4330535 , -4.977524  , -4.9114738 ,
          -0.23111431,  1.3314574 ,  0.03826439,  1.3454266 ,
           1.7946535 ,  0.05863512, -1.4367882 , -2.3322222 ,
           0.93547916, -2.7148142 ,  0.902055  ,  0.88373613],
         [ 1.8138299 , -2.927385  , -3.3529627 , -4.1700516 ,
          -0.463976  , -2.7048278 , -1.8990085 ,  1.7266546 ,
           1.0458257 , -0.78157973, -0.0542576 , -2.8697302 ,
           2.9209538 , -0.50924057, -3.9430103 ,  0.7489088 ]]],
       dtype=float32)>,
 <tf.Tensor: shape=(1, 1, 2), dtype=float32, numpy=array([[[21.616919, 43.966854]]], dtype=floa

In [None]:
nc_q = no_cache_logs['rope_q'][index + 1]
nc_k = no_cache_logs['rope_k'][index + 1]

nc_q, nc_k, nc_q @ tf.transpose(nc_k, perm=[0, 2, 1])

(<tf.Tensor: shape=(1, 2, 16), dtype=float32, numpy=
 array([[[-0.23348725, -1.2534161 , -3.0547209 , -0.6355625 ,
          -0.08479708,  1.8937602 ,  1.8187456 , -1.0931925 ,
          -0.7623807 ,  2.3252077 , -0.7686442 , -1.8549696 ,
           0.79256994, -1.0165763 ,  1.2038559 ,  0.9727511 ],
         [ 2.697421  , -2.4562492 , -1.5515298 ,  0.14776081,
           0.37695628, -1.5206785 ,  0.8845265 ,  0.17617644,
          -2.1475816 ,  1.1095126 , -0.10803113, -4.174893  ,
           0.12987702,  0.36463377, -4.113693  , -0.7377639 ]]],
       dtype=float32)>,
 <tf.Tensor: shape=(1, 2, 16), dtype=float32, numpy=
 array([[[ 4.474342  , -1.4330535 , -4.977524  , -4.9114738 ,
          -0.23111431,  1.3314574 ,  0.03826439,  1.3454266 ,
           1.7946535 ,  0.05863512, -1.4367882 , -2.3322222 ,
           0.93547916, -2.7148142 ,  0.902055  ,  0.88373613],
         [ 1.8138299 , -2.927385  , -3.3529627 , -4.1700516 ,
          -0.463976  , -2.7048278 , -1.8990085 ,  1.7266546

In [None]:
from tensorflow.keras import layers

K = 2

def alternating_reduce(fn, blocks, x):
    for b_index, block in enumerate(blocks):
        # Pick the sub-block for transformer branch
        pivot = b_index % K

        # Split input into subblocks
        xs = tf.split(x, K, axis=-1)

        # Process picked block
        x = block(xs[pivot])

        # Re-compose
        x = tf.concat([*xs[:pivot], x, *xs[(pivot + 1):]], axis=-1)
    
    return x

x = tf.reshape(tf.cast(tf.range(2*100), dtype=tf.float32), (2, 100))

print(f'{x=}')
alternating_reduce(
    lambda y,dec_block: dec_block(y),
    [
        # layers.Dense(100, use_bias=False),
        layers.Dense(50, use_bias=False, kernel_initializer='ones'),
        layers.Dense(50, use_bias=False, kernel_initializer='ones'),
        layers.Dense(50, use_bias=False, kernel_initializer='ones'),
        layers.Dense(50, use_bias=False, kernel_initializer='ones'),
    ],
    x
)

x=<tf.Tensor: shape=(2, 100), dtype=float32, numpy=
array([[  0.,   1.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.,  10.,
         11.,  12.,  13.,  14.,  15.,  16.,  17.,  18.,  19.,  20.,  21.,
         22.,  23.,  24.,  25.,  26.,  27.,  28.,  29.,  30.,  31.,  32.,
         33.,  34.,  35.,  36.,  37.,  38.,  39.,  40.,  41.,  42.,  43.,
         44.,  45.,  46.,  47.,  48.,  49.,  50.,  51.,  52.,  53.,  54.,
         55.,  56.,  57.,  58.,  59.,  60.,  61.,  62.,  63.,  64.,  65.,
         66.,  67.,  68.,  69.,  70.,  71.,  72.,  73.,  74.,  75.,  76.,
         77.,  78.,  79.,  80.,  81.,  82.,  83.,  84.,  85.,  86.,  87.,
         88.,  89.,  90.,  91.,  92.,  93.,  94.,  95.,  96.,  97.,  98.,
         99.],
       [100., 101., 102., 103., 104., 105., 106., 107., 108., 109., 110.,
        111., 112., 113., 114., 115., 116., 117., 118., 119., 120., 121.,
        122., 123., 124., 125., 126., 127., 128., 129., 130., 131., 132.,
        133., 134., 135., 136., 137., 138., 1

<tf.Tensor: shape=(2, 100), dtype=float32, numpy=
array([[ 61250.,  61250.,  61250.,  61250.,  61250.,  61250.,  61250.,
         61250.,  61250.,  61250.,  61250.,  61250.,  61250.,  61250.,
         61250.,  61250.,  61250.,  61250.,  61250.,  61250.,  61250.,
         61250.,  61250.,  61250.,  61250.,  61250.,  61250.,  61250.,
         61250.,  61250.,  61250.,  61250.,  61250.,  61250.,  61250.,
         61250.,  61250.,  61250.,  61250.,  61250.,  61250.,  61250.,
         61250.,  61250.,  61250.,  61250.,  61250.,  61250.,  61250.,
         61250., 186250., 186250., 186250., 186250., 186250., 186250.,
        186250., 186250., 186250., 186250., 186250., 186250., 186250.,
        186250., 186250., 186250., 186250., 186250., 186250., 186250.,
        186250., 186250., 186250., 186250., 186250., 186250., 186250.,
        186250., 186250., 186250., 186250., 186250., 186250., 186250.,
        186250., 186250., 186250., 186250., 186250., 186250., 186250.,
        186250., 186250., 1

In [338]:
import tensorflow as tf

class AltUpBlockSequence(tf.keras.Model):
    def __init__(self, config, *args, **kwargs):
        super(AltUpBlockSequence, self).__init__(*args, **kwargs)
        self.config = config

        self.blocks = [LlamaBlock(config) for _ in range(config.decoders)]
    
    def build(self, shape):
        self.prediction_scalars = self.add_weight(
            "prediction_scalars",
            shape=(self.config.alternate_blocks,)*2,
            initializer='ones',
            trainable=True,
        )
        self.correction_scalars = self.add_weight(
            "correction_scalars",
            shape=(self.config.alternate_blocks,),
            initializer='ones',
            trainable=True,
        )
    
    def call(self, x, start=0, inference=False):
        for b_index, block in enumerate(self.blocks):
            # Pick the sub-block for transformer branch
            pivot = b_index % self.config.alternate_blocks

            # Split input into subblocks
            xs = tf.split(x, self.config.alternate_blocks, axis=-1)

            # Prediction with a linear map
            xs_hat = tf.unstack(
                tf.einsum('aa,axyz->axyz', self.prediction_scalars, tf.stack(xs, axis=0)),
                axis=0,
            )

            # Process picked block
            x = block(xs[pivot], start=start, inference=inference)

            # Correction
            xs = list(map(
                lambda idx: xs_hat[idx] + self.correction_scalars[idx]*(x - xs_hat[pivot]),
                range(self.config.alternate_blocks),
            ))

            # Re-compose
            # x = tf.concat([*xs[:pivot], x, *xs[(pivot + 1):]], axis=-1)
            x = tf.concat(xs, axis=-1)

        return x


l = AltUpBlockSequence(model_config)
x = tf.random.uniform((2, model_config.block_size, model_config.embed_dims))
y = l(x)

In [88]:
import tensorflow as tf

from tensorflow.keras import layers, metrics, losses, optimizers

class CustomModel(tf.keras.Model):
    def __init__(self):
        super().__init__()

        self.dense = layers.Dense(10)

        self.trackers = {
            'loss': metrics.Mean(name="loss"),
            'val_loss': metrics.Mean(name="val_loss"),
        }
    
    def call(self, x):
        return self.dense(x)
    
    @tf.function
    def train_step(self, data):
        X, y = data
        with tf.GradientTape() as tape:
            logits = self(X, training=True)
            loss = self.compute_loss(y=y, y_pred=logits)

        grads = tape.gradient(loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))

        return self.record('loss', loss)

    @tf.function
    def test_step(self, data):
        logits = self(X, training=True)
        loss = self.compute_loss(y=y, y_pred=logits)

        return self.record('val_loss', loss)
    
    @property
    def metrics(self):
        return list(self.trackers.values())

    def record(self, name, loss):
        self.trackers[name].update_state(loss)

        return {m.name: m.result() for m in self.metrics}

loss_fn = losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = optimizers.AdamW()
model = CustomModel()
model.compile(optimizer=optimizer, loss=loss_fn)

B, C = 8, 32
classes = 10
X = tf.random.normal((B, C))
y = tf.random.uniform((B,), minval=0, maxval=classes, dtype=tf.int32)
print(f'{X.shape=} {y.shape=}')

model.train_step((X, y)), model.reset_metrics(), model.get_metrics_result()



X.shape=TensorShape([8, 32]) y.shape=TensorShape([8])


({'loss': <tf.Tensor: shape=(), dtype=float32, numpy=2.7486548>,
  'val_loss': <tf.Tensor: shape=(), dtype=float32, numpy=0.0>},
 None,
 {'loss': <tf.Tensor: shape=(), dtype=float32, numpy=0.0>,
  'val_loss': <tf.Tensor: shape=(), dtype=float32, numpy=0.0>})