# PositionEmbedding

In [None]:
class PositionEmbedding(layers.Layer):
    def __init__(self, initializer="glorot_uniform", **kwargs ):
        super().__init__(**kwargs)
        self._initializer = tf.keras.initializers.get(initializer)

    def get_config(self):
        config = {
            "initializer": tf.keras.initializers.serialize(self._initializer),
        }
        base_config = super(PositionEmbedding, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def build(self, input_shape):
        sequence_length = input_shape[-2]
        width = input_shape[-1]

        self._position_embeddings = self.add_weight(
            "position_embeddings",
            shape=[sequence_length, width],
            initializer=self._initializer)

        super().build(input_shape)

    def call(self, inputs):
        return inputs + self._position_embeddings

xx = tf.zeros((2, 4, 12))
l = PositionEmbedding()
output = l(xx)

print('PositionEmbedding Layer')
print('Input: {} --> {}'.format(xx.shape, output.shape))
print(f'Embedding Size: {l.weights[0].shape}')
print(f'Verify Embeddings: {tf.reduce_all(tf.math.equal(output, l.weights))}')

# StochasticDepth

In [None]:
"""It is sourced from tensorflow-models package

Source: https://github.com/tensorflow/models/blob/v2.12.0/official/vision/modeling/layers/nn_layers.py#L227-L262
"""
class StochasticDepth(layers.Layer):
    """Creates a stochastic depth layer."""

    def __init__(self, stochastic_depth_drop_rate, **kwargs):
        """Initializes a stochastic depth layer.

        Args:
          stochastic_depth_drop_rate: A `float` of drop rate.
          **kwargs: Additional keyword arguments to be passed.

        Returns:
          A output `tf.Tensor` of which should have the same shape as input.
        """
        super(StochasticDepth, self).__init__(**kwargs)
        self._drop_rate = stochastic_depth_drop_rate

    def get_config(self):
        config = {'stochastic_depth_drop_rate': self._drop_rate}
        base_config = super(StochasticDepth, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def call(self, inputs, training=None):
        if training is None:
            training = tf.keras.backend.learning_phase()
        if not training or self._drop_rate is None or self._drop_rate == 0:
            return inputs

        keep_prob = 1.0 - self._drop_rate
        batch_size = tf.shape(inputs)[0]
        random_tensor = keep_prob
        random_tensor += tf.random.uniform(
            [batch_size] + [1] * (inputs.shape.rank - 1), dtype=inputs.dtype)
        binary_tensor = tf.floor(random_tensor)
        output = tf.math.divide(inputs, keep_prob) * binary_tensor
        return output

batch_size, num_patches, dims = 2, 256, 768
drop_prob = .5

l = StochasticDepth(drop_prob)
xx = tf.random.normal((batch_size, num_patches, dims))
output = l(xx, training=True)
non_zeros = 1 - (tf.math.reduce_sum(tf.cast(output == 0, tf.int64))/(batch_size * num_patches * dims))

print('StochasticDepth Layer')
print('---------------------')
print(f'batch_size: {batch_size}, num_patches: {num_patches}, dims: {dims}')
print(f'\ndrop_prob: {drop_prob}')
print(f'\nInput: {xx.shape} --> {output.shape}')
print(f'\nOutput Drop Rate: {non_zeros}')

# MultiHeadAttention

In [None]:
import tensorflow as tf

from tensorflow.keras import layers

class SelfAttention(layers.Layer):
    def __init__(self, dims):
        super(SelfAttentionLayer, self).__init__()
        
        self.key = layers.Dense(dims, use_bias=False)
        self.query = layers.Dense(dims, use_bias=False)
        self.value = layers.Dense(dims, use_bias=False)
        self.dims = dims

    def call(self, x):
        B, T, C = x.shape

        k, q, v = self.key(x), self.query(x), self.query(x)

        kq = k @ tf.transpose(q, perm=[0, 2, 1]) # (B, T, 16) @ (B, 16, T) --> (B, T, T)
        kq /= tf.math.sqrt(self.dims)

        # Lower triangular matrix
        causal_mask = tf.linalg.band_part(
            tf.ones((T, T)), -1, 0
        )

        kq = tf.nn.softmax(tf.where(causal_mask > 0.0, kq, float('-inf')))
        output = kq @ v
        return output

class MultiHeadAttention(tf.keras.layers.Layer):
    def __init__(self, num_heads, dims):
        super().__init__()
        self.attn_layers = [SelfAttention(dims) for i in range(num_heads)]

    def call(self, x):
        return tf.concat([
            attn_layer(x) for attn_layer in self.attn_layers
        ], axis=-1)

# RotaryPositionalEncodings

In [None]:
import tensorflow as tf

from tensorflow.keras import layers

class RotaryPositionalEncodings(layers.Layer):
    def __init__(self, dims, block_size, base=10000.):
        super(RotaryPositionalEncodings, self).__init__()
        
        self.dims = dims
        self.base = base
        self.block_size = block_size

        self.theta = 1. / (self.base ** (tf.range(0, self.dims, 2, dtype=tf.float32) / self.dims))
        self.positions = tf.range(self.block_size, dtype=tf.float32)
        self.mtheta = self.positions[..., None]*self.theta[None, ...]
        self.mtheta_paired = tf.concat([self.mtheta]*2, axis=-1)

        self.cos_mtheta = tf.math.cos(self.mtheta_paired)
        self.sin_mtheta = tf.math.sin(self.mtheta_paired)

    def call(self, x):
        B, T = x.shape[:2]

        x_real = x*self.cos_mtheta[None, :T, None, ...]
        # print(f'{x_real=}\n{self.cos_mtheta[:T]=}')
        x_img = tf.concat(
            [-x[..., self.dims//2:], x[..., :self.dims//2]],
            axis=-1
        )*self.sin_mtheta[None, :T, None, ...]
        # print(f'{x_img=}\n{self.sin_mtheta[:T]=}')
        output = x_real + x_img
        return output

# B, T, H, D = 2, 8, 4, 4
# B_i, T_i, H_i, D_i = 0, 1, 0, 0
# x = tf.random.uniform((B, T, H, D))
# l = RotaryPositionalEncodings(4, 64)
# output = l(x)
# # tf.print(f'{l.theta=} {l.theta.shape}')
# # tf.print(f'{l.positions=} {l.positions.shape}')
# tf.print(f'{x[B_i]=}')
# tf.print(f'{l.cos_mtheta=}')
# tf.print(f'{l.sin_mtheta=}')
# tf.print(f'{output[B_i]=}\n{output[B_i]=} {output.shape}')

# KVCache

In [None]:
import tensorflow as tf

class KVCache(object):
    def __init__(self, cache_size, block_size, heads, head_dims):
        super(KVCache, self).__init__()
        self.block_size = block_size

        cache_shape = (cache_size, block_size, heads, head_dims)

        with tf.device('/device:CPU:0'):
            self.cache_k = tf.Variable(tf.zeros(cache_shape), shape=cache_shape, trainable=False)
            self.cache_v = tf.Variable(tf.zeros(cache_shape), shape=cache_shape, trainable=False)

    def update(self, start, xk, xv):
        shape = tf.shape(xk)
        B = shape[0]
        T = shape[1]

        # Calculate update start and end positions
        start = start%self.block_size
        end = (start + T)%(self.block_size + 1)

        # start < end: It is a single cache update.
        # end > start: It is a split cache update.
        if start < end:
            self.cache_k[:B, start:start+T].assign(xk)
            self.cache_v[:B, start:start+T].assign(xv)
        else:
            # Update cache with partial sequence that fits towards the end.
            self.cache_k[:B, start:].assign(xk[:, :-(end+1)])
            self.cache_v[:B, start:].assign(xv[:, :-(end+1)])

            # Splillover sequence is cached towards the front of the cache.
            self.cache_k[:B, :end+1].assign(xk[:, -(end+1):])
            self.cache_v[:B, :end+1].assign(xv[:, -(end+1):])

    def get(self, batch_size, start, seq_len):
        # Calculate update start and end positions
        start = start%self.block_size
        end = (start + seq_len)%(self.block_size + 1)

        # start < end: It is a single cache fetch.
        # end > start: It is a split cache fetch.
        if start < end:
            keys = self.cache_k[:batch_size, :start+seq_len]
            values = self.cache_v[:batch_size, :start+seq_len]
        else:
            # Fetch sequence prefix
            keys_1 = self.cache_k[:, (end+1):]
            values_1 = self.cache_k[:, (end+1):]

            # Fetch sequence suffix
            keys_2 = self.cache_k[:, :(end+1):]
            values_2 = self.cache_k[:, :(end+1):]

            # Compose the whole sequence
            keys = tf.concat([keys_1, keys_2], axis=1)
            values = tf.concat([values_1, values_2], axis=1)

        return keys, values

def update_and_show(cache, batch_size, start, seqlen, data_shape, msg, debug=False):
    xk = tf.random.uniform((batch_size, seqlen, *data_shape))
    xv = tf.random.uniform((batch_size, seqlen, *data_shape))

    if debug:
        print(
            f'\n{msg}::xk:\n{xk}'
        )

    cache.update(start, xk, xv)
    print(
        f'\n{msg}:\n{cache.cache_k}'
    )

# cache_size, block_size, heads, head_dims = 1, 8, 1, 4
# cache = KVCache(cache_size, block_size, heads, head_dims)
# data_shape = (heads, head_dims)

# # print(f'\nInitial:\n{cache.cache_k}')

# # update_and_show(cache, cache_size, 2, 4, data_shape, 'InUpdate(2, 4)')
# # update_and_show(cache, cache_size, 4, 4, data_shape, 'EndUpdate(4, 4)')
# # update_and_show(cache, cache_size, 6, 4, data_shape, 'SpilledUpdate(6, 4)', debug=True)
# update_and_show(cache, cache_size, 6, 8, data_shape, 'Fill(6, 8)')
# print(
#     # f'\nInQuery:\n{cache.get(cache_size, 0, 4)[0]}'
#     # f'\n\nEndQuery:\n{cache.get(cache_size, 4, 4)[0]}'
#     f'\n\nSpilledQuery:\n{cache.get(cache_size, 6, 2)[0]}'
# )

# GroupedQueryAttention

In [None]:
import math

import tensorflow as tf

from tensorflow.keras import layers

class GroupedQueryAttention(tf.keras.Model):
    def __init__(self, block_size, heads, kv_heads, dims, cache_size):
        super(GroupedQueryAttention, self).__init__()
        
        self.heads = heads
        self.dims = dims
        self.head_dims = dims // self.heads
        self.kv_heads = kv_heads or heads
        self.block_size = block_size
        self.cache_size = cache_size
        
        self.wq = layers.Dense(self.dims, use_bias=False)
        self.wk = layers.Dense(self.kv_heads * self.head_dims, use_bias=False)
        self.wv = layers.Dense(self.kv_heads * self.head_dims, use_bias=False)
        self.wo = layers.Dense(self.dims, use_bias=False)
        
        self.cache = KVCache(self.cache_size, self.block_size, self.kv_heads, self.head_dims)
        self.rope = RotaryPositionalEncodings(self.head_dims, self.block_size)
    
    def call(self, x, start, use_cache):
        # print(f'{x=}\n{start=}')
        shape = tf.shape(x)
        B = shape[0]
        T = shape[1]

        # (B, T, dims)
        q, k, v = self.wq(x), self.wk(x), self.wv(x)

        # (B, T, heads/kv_heads, head_dims)
        xq = tf.reshape(q, (B, T, self.heads, self.head_dims))
        xk = tf.reshape(k, (B, T, self.kv_heads, self.head_dims))
        xv = tf.reshape(v, (B, T, self.kv_heads, self.head_dims))

        # Apply RoPE
        # (B, T, heads/kv_heads, head_dims)
        xq = self.rope(xq)
        xk = self.rope(xk)

        if use_cache:
            # Update KV cache
            self.cache.update(start, xk, xv)
            
            # Get prefix context from the cache.
            # (B, start+T, kv_heads, head_dims)
            keys, values = self.cache.get(B, start, T)
        else:
            assert start == 0
            keys, values = xk, xv

        # Expand kv_heads to heads
        # (B, start+T, heads, head_dims)
        # print(f'{keys.shape=} {values.shape=} {keys.dtype=}')
        keys = tf.tile(keys, multiples=(1, 1, self.heads//self.kv_heads, 1))
        values = tf.tile(values, multiples=(1, 1, self.heads//self.kv_heads, 1))
        
        # Transpose xq, keys and values to (B, heads, T/start+T, head_dims)
        xq = tf.transpose(xq, perm=[0, 2, 1, 3])
        xk = tf.transpose(keys, perm=[0, 2, 1, 3])
        xv = tf.transpose(values, perm=[0, 2, 1, 3])
        
        # Multiply xq and xk to compute attention matrix.
        # (B, heads, T, head_dims) @ (B, heads, head_dims, start+T) -> (B, heads, T, start+T)
        xa = xq @ tf.transpose(xk, perm=[0, 1, 3, 2]) / math.sqrt(self.head_dims*1.)

        # Compute softmax scores.
        if use_cache:
            scores = tf.nn.softmax(xa)
        else:
            # If cache is not used, apply auto-regressive mask to block forward looking
            tril = tf.linalg.band_part(tf.ones((T, T), dtype=tf.float32), -1, 0)
            scores = tf.math.softmax(tf.where(tril > 0.0, xa, float('-inf')))

        # Scale Values and compute output
        # (B, heads, T, start+T) @ (B, heads, start+T, head_dims) -> (B, heads, T, head_dims)
        output = scores @ xv
        
        # Reshape output to (B, T, dims)
        output = tf.reshape(
            tf.transpose(output, perm=[0, 2, 1, 3]),
            shape=(B, T, self.dims)
        )

        return self.wo(output)

# l = GroupedQueryAttention(
#     block_size=16,
#     heads=4,
#     kv_heads=2,
#     dims=16,
#     cache_size=4
# )

# B, T, C = 2, 2, 16
# start = 1
# x = tf.reshape(tf.range(B*T*C), (B, T, C))
# output = l(x, start)

# B_i, T_i, C_i = 0, 0, 3
# H_i, F_i = 0, C_i

# print(
#     f'{x[B_i]=}'
#     f'\n{l.cache.cache_k.shape=}\n{l.cache.cache_k[B_i, :start+T]=}'
#     f'\n{l.cache.cache_v[B_i, :start+T]=}'
# )