In [1]:
import os
import gc
import time
import json
import pickle
import numpy as np

import torch
import jax
import equinox as eqx
import jax.numpy as jnp
import jax.tree_util as jtu

from functools import partial
from equinox._misc import default_floating_dtype
from jaxtyping import Array, Float, Scalar
from typing import Optional, Tuple, List, NamedTuple

from sentencepiece import SentencePieceProcessor

# 1. Tokenizer

In [2]:
class Tokenizer:
    def __init__(self, model_path: str):
        self._model = SentencePieceProcessor(model_file=model_path)

    @property
    def eos_id(self) -> int:
        return self._model.eos_id()

    @property
    def pad_id(self) -> int:
        return self._model.pad_id()

    def encode(self, s: str) -> List[int]:
        return [self._model.bos_id(), *self._model.encode(s)]

    def decode(self, t: List[int]) -> str:
        return self._model.decode(t)

# 2. RoPE

In [3]:
def precompute_frequencies(dim, max_pos, theta=10000.0):
    inv_freq = 1.0 / (
        theta ** (jnp.arange(0, dim, 2, dtype=jnp.float32)[: (dim // 2)] / dim)
    )
    t = jnp.arange(0, max_pos, dtype=jnp.float32)
    freqs = jnp.outer(t, inv_freq)
    return jnp.cos(freqs), jnp.sin(freqs)

def calculate_rope(x, cos_freq, sin_freq):
    # x shape  is [seqlen, num_heads, heads_dim]
    sin = jax.lax.expand_dims(sin_freq, (1,))
    cos = jax.lax.expand_dims(cos_freq, (1,))

    # Get the even-odd positions from the inputs
    x1 = x[..., 0::2]
    x2 = x[..., 1::2]

    # Matmul with the rotation matrix
    # [cos_nθ, -sin_nθ] [x1]
    # [sin_nθ,  cos_nθ] [x2]
    # => [x1 * cos_nθ - x2 * sin_nθ, x1 * sin_nθ + x2 * cos_nθ]
    pos_embed = jnp.stack([x1 * cos - x2 * sin, x1 * sin + x2 * cos], axis=-1)
    pos_embed = jax.lax.collapse(pos_embed, -2)
    return pos_embed.astype(x.dtype)

# 3. RMSNorm

In [4]:
class RMSNorm(eqx.Module):
    eps: float
    weight: Float[Array, "*shape"]

    def __init__(self, dim, eps, dtype=jnp.bfloat16):
        dtype = default_floating_dtype if dtype is None else dtype
        self.eps = eps
        self.weight = jnp.ones(shape=dim, dtype=dtype)

    def _norm(self, x):
        return x * jax.lax.rsqrt(jnp.mean(x **2 , keepdims=True) + self.eps)

    def __call__(self, x):
        output = self._norm(x.astype(jnp.float32)).astype(x.dtype)
        return output * self.weight

# 4. FeedForward

In [5]:
class FeedForward(eqx.Module):
    w1: eqx.nn.Linear
    w2: eqx.nn.Linear
    w3: eqx.nn.Linear

    def __init__(self, args, key, dtype=jnp.bfloat16):
        dtype = default_floating_dtype if dtype is None else dtype
        key1, key2, key3 = jax.random.split(key, 3)

        self.w1 = eqx.nn.Linear(args.dim, args.hidden_dim, use_bias=False, key=key1, dtype=dtype)
        self.w2 = eqx.nn.Linear(args.hidden_dim, args.dim, use_bias=False, key=key2, dtype=dtype)
        self.w3 = eqx.nn.Linear(args.dim, args.hidden_dim, use_bias=False, key=key3, dtype=dtype)

    def __call__(self, x):
        h = jax.nn.silu(self.w1(x).astype(jnp.float32)).astype(x.dtype)
        return self.w2(h * self.w3(x))

# 5. Attention

In [6]:
class Attention(eqx.Module):
    dim: int
    n_heads: int
    head_dim: int
    n_kv_heads: int
    kv_repeats: int
    sliding_window: int
    scale: float
    split_sizes: Tuple
    wqkv: eqx.nn.Linear
    wo: eqx.nn.Linear

    def __init__(self, args, key, dtype=jnp.bfloat16):
        dtype = default_floating_dtype if dtype is None else dtype
        key1, key2 = jax.random.split(key, 2)

        self.n_heads = args.n_heads
        self.head_dim = args.head_dim
        self.n_kv_heads = args.n_kv_heads
        self.dim = args.dim
        self.kv_repeats = self.n_heads // self.n_kv_heads
        self.sliding_window = args.sliding_window
        self.scale = args.head_dim**-0.5
        total_head_dim = (args.n_heads + 2 * args.n_kv_heads) * args.head_dim
        self.split_sizes = (args.n_heads * args.head_dim, (args.n_heads * args.head_dim) + (args.n_kv_heads * args.head_dim))
        self.wqkv = eqx.nn.Linear(args.dim, total_head_dim, use_bias=False, key=key1, dtype=dtype)
        self.wo = eqx.nn.Linear(args.n_heads * args.head_dim, args.dim, use_bias=False, key=key2, dtype=dtype)

    def compute_scores_and_output(self, xq, key, value, mask, seqlen, pos_mask):
        query = jnp.transpose(xq, (1, 0, 2))
        key = jnp.transpose(key, (1, 0, 2))
        value = jnp.transpose(value, (1, 0, 2))

        # # # scores : [n_heads, seqlen | 1, seqlen]
        scores = jnp.matmul(query, jnp.transpose(key, (0, 2, 1))) * self.scale
        if pos_mask is not None:
            scores = jnp.where(pos_mask, -jnp.inf, scores)

        if mask is not None:
            # Mask will of shape [seqlen, seqlen] but our scores
            # have shape [num_heads, seqlen, seqlen], hence we need
            # to introduce another dimension in the mask
            mask = mask[jnp.newaxis, ...]
            scores = scores + mask

        scores = jax.nn.softmax(scores.astype(jnp.float32), axis=-1).astype(query.dtype)
        output = jnp.matmul(scores, value)
        output = jnp.reshape(jnp.transpose(output, (1, 0, 2)), (seqlen, -1))
        output = jax.vmap(self.wo)(output)
        return output

    def __call__(self,  x, cos_freq, sin_freq, positions, mask=None, cache_k=None, cache_v=None):
        # x shape: [seqlen, embed_dim]
        seqlen = x.shape[0]

        xqkv = jax.vmap(self.wqkv)(x)
        xq, xk, xv = jnp.split(xqkv, self.split_sizes, axis=-1)

        xq = jnp.reshape(xq, (seqlen, self.n_heads, self.head_dim))
        xk = jnp.reshape(xk, (seqlen, self.n_kv_heads, self.head_dim))
        xv = jnp.reshape(xv, (seqlen, self.n_kv_heads, self.head_dim))

        xq = calculate_rope(xq, cos_freq, sin_freq)
        xk = calculate_rope(xk, cos_freq, sin_freq)

        if positions.shape[0] > 1:
            # prefill
            cache_k = cache_k.at[positions, :, :].set(xk[positions, :, :], mode="drop")
            cache_v = cache_v.at[positions, :, :].set(xv[positions, :, :], mode="drop")
            key = jnp.repeat(xk, self.kv_repeats, axis=1)
            value = jnp.repeat(xv, self.kv_repeats, axis=1)
            output = self.compute_scores_and_output(xq, key, value, mask, seqlen, None)
        else:
            # single-token generation
            one_hot_indices = jax.nn.one_hot(positions, self.sliding_window, dtype=cache_k.dtype).reshape(self.sliding_window, 1, 1)
            # the `where` update is only necessary if you are calling the cache multiple times with the same prompt
            # Ideally, we expect that you flush out the cache with the new prompt, and start over.
            # What does this do? It ensures that we are not adding any values updated earlier 
            # with the new updates, meaning we are always replacing the value not updating it.
            # For example, if prompt had a length of 6, and you want to generate 7th token, this
            # ensures that we are not adding the old value of 7th token to the updated value as
            # it would lead to wrong results. 
            # In case, you are flushing the cache after every prompt, remove the `jnp.where()` condition
            # and pass the updates directly to cache_k, and cache_v respectively 
            # i.e. cache_k = cache_k + xk * one_hot_indices
            # and cache_v = cache_v + xv * one_hot_indices
            k_updates = cache_k + xk * one_hot_indices
            v_updates = cache_v + xv * one_hot_indices
            cache_k = jnp.where(cache_k, cache_k, k_updates)
            cache_v = jnp.where(cache_v, cache_v, v_updates)
        
            cur_pos = positions[-1] + 1
            causal_mask = jnp.broadcast_to(jnp.arange(self.sliding_window) >= cur_pos,(1, 1, self.sliding_window)).reshape(self.sliding_window,1,1)
            key = jnp.repeat(jnp.where(causal_mask, 0, cache_k), axis=1, repeats=self.kv_repeats)
            value = jnp.repeat(jnp.where(causal_mask, 0, cache_v), axis=1, repeats=self.kv_repeats)
            output = self.compute_scores_and_output(xq, key, value, mask, seqlen, causal_mask.reshape((1, 1, self.sliding_window)))

        return output, cache_k, cache_v

# 6. TransformerBlock

In [7]:
class TransformerBlock(eqx.Module):
    dim: int
    n_heads: int
    attention: Attention
    attention_norm: RMSNorm
    feed_forward: FeedForward
    ffn_norm: RMSNorm

    def __init__(self, args, key, dtype=jnp.bfloat16):
        key1, key2 = jax.random.split(key, 2)
        self.n_heads = args.n_heads
        self.dim = args.dim

        self.attention = Attention(args, key=key1, dtype=dtype)
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps, dtype=dtype)

        self.feed_forward = FeedForward(args, key=key2, dtype=dtype)
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps, dtype=dtype)

    def __call__(self, x, cos_freq, sin_freq, positions, mask, cache_k, cache_v):
        normed_x = jax.vmap(self.attention_norm)(x)
        r, cache_k, cache_v = self.attention(normed_x, cos_freq, sin_freq, positions, mask, cache_k, cache_v)
        h = x + r
        r = jax.vmap(self.feed_forward)(jax.vmap(self.ffn_norm)(h))
        out = h +r
        return out, cache_k, cache_v

# 7. Transformer

In [8]:
class Transformer(eqx.Module):
    tok_embeddings: eqx.nn.Embedding
    layers: TransformerBlock
    norm: RMSNorm
    output: eqx.nn.Linear
    vocab_size: int
    n_layers: int
    sliding_window: int

    def __init__(self, args, key, dtype=jnp.bfloat16):
        self.vocab_size = args.vocab_size
        self.n_layers = args.n_layers
        self.sliding_window = args.sliding_window
        keys = jax.random.split(key, args.n_layers + 2)
        embed_key, linear_key, tf_layers_keys = keys[0], keys[1], keys[2:]

        self.tok_embeddings = eqx.nn.Embedding(args.vocab_size, args.dim, key=embed_key, dtype=dtype)
        self.norm = RMSNorm(dim=args.dim, eps=args.norm_eps, dtype=dtype)
        self.output = eqx.nn.Linear(args.dim, args.vocab_size, use_bias=False, key=linear_key, dtype=dtype)
        make_layers = lambda k: TransformerBlock(args, key=k, dtype=dtype)
        self.layers = eqx.filter_vmap(make_layers)(tf_layers_keys)
        del make_layers

    def compute_mask(self, seqlen):
        t = jnp.full((seqlen, seqlen), dtype=jnp.bfloat16, fill_value=1)
        mask = jnp.tril(t, k=0)
        # make the mask banded to account for sliding window
        mask = jnp.log(jnp.triu(mask, k=-self.sliding_window))
        return mask


    def __call__(self, x, cos_freq, sin_freq, positions, mask, cache_k, cache_v):
        # x is of shape (seqlen, )
        h = jax.vmap(self.tok_embeddings)(x)
        
        if x.shape[-1] > 1:
            seqlen = x.shape[-1]
            mask = self.compute_mask(seqlen)
        else:
            mask = None

        dynamic_layers, static_layers = eqx.partition(self.layers, eqx.is_array)
        
        def f(_x, _dynamic_l):
            layer = eqx.combine(_dynamic_l, static_layers)
            h, cache_k, cache_v, layer_idx = _x
            h, cache_ki, cache_vi = layer(
                h,
                cos_freq,
                sin_freq,
                positions,
                mask,
                cache_k[layer_idx, ...],
                cache_v[layer_idx, ...],
            )
            cache_k = cache_k.at[layer_idx, :, :, :].set(cache_ki)
            cache_v = cache_v.at[layer_idx, :, :, :].set(cache_vi)
            return (h, cache_k, cache_v, layer_idx + 1), None

        layer_idx = 0
        (h, cache_k, cache_v, layer_idx), _ = jax.lax.scan(f, (h, cache_k, cache_v, layer_idx), dynamic_layers)

        h = jax.vmap(self.norm)(h)
        h = jax.vmap(self.output)(h).astype(jnp.float32)
        return h, cache_k, cache_v

In [9]:
class ModelArgs(NamedTuple):
    dim: int
    n_layers: int
    n_heads: int
    n_kv_heads: int
    head_dim: int
    hidden_dim: int
    vocab_size: int
    sliding_window: int
    norm_eps: float
    max_batch_size: int = 1


with open('./mistral-7B-v0.1/params.json', 'r') as f:
    args = ModelArgs(**json.loads(f.read()))

In [10]:
model = Transformer(args, key=jax.random.PRNGKey(1), dtype=jnp.bfloat16)
model = eqx.tree_deserialise_leaves("mistral7B_jax_port_fast.eqx", model)
model = eqx.filter_vmap(eqx.filter_jit(model), in_axes=(0, None, None, None, None, 0, 0))
print("Model weights loaded successfully!")

Model weights loaded successfully!


In [11]:
cache_k = jnp.zeros((args.max_batch_size, args.n_layers, args.sliding_window, args.n_kv_heads, args.head_dim), dtype=jnp.bfloat16)
cache_v = jnp.zeros((args.max_batch_size, args.n_layers, args.sliding_window, args.n_kv_heads, args.head_dim), dtype=jnp.bfloat16)
cos_freq, sin_freq = precompute_frequencies(args.head_dim, 128000)

In [12]:
fake_pos = jnp.array([0, 1, 2, 3, 4], dtype=jnp.int32)
fake_inp = jnp.asarray([[1,  832,  349,  265, 1369]], dtype=jnp.int32)
fake_mask = None
fake_pos_padded = jnp.pad(fake_pos, (0, args.sliding_window - len(fake_pos)), constant_values=-1)

# warmup for prefilling
_ = model(fake_inp, cos_freq[fake_pos], sin_freq[fake_pos], fake_pos_padded, fake_mask, cache_k, cache_v)

# warmup for generation
fake_pos = jnp.array([5], dtype=jnp.int32)
fake_inp = jnp.asarray([[1369]], dtype=jnp.int32)
fake_mask = None
_ = model(fake_inp, cos_freq[fake_pos], sin_freq[fake_pos], fake_pos_padded, fake_mask, cache_k, cache_v)

In [13]:
tokenizer = Tokenizer("mistral-7B-v0.1/tokenizer.model")
print("Tokenizer loaded successfully!")

Tokenizer loaded successfully!


In [14]:
def generate(prompts, model, tokenizer, max_tokens=36):
    """Generates completion of length `max_tokens` for a list of given prompts."""

    cache_shape = (args.max_batch_size, args.n_layers, args.sliding_window, args.n_kv_heads, args.head_dim)
    
    for prompt in prompts:
        # 1. Encode the prompt
        encoded = tokenizer.encode(prompt)
        cur_pos = len(encoded)

        # 2. We need to flush the cache with every prompt. 
        # Need a better way to do this but for now it's okay!
        cache_k = jnp.zeros(cache_shape, dtype=jnp.bfloat16)
        cache_v = jnp.zeros(cache_shape, dtype=jnp.bfloat16)

        # 3. pre-fill
        positions = jnp.arange(0, cur_pos)
        positions_padded = jnp.pad(positions, (0, args.sliding_window - len(positions)), constant_values=-1)
        print("Prefilling...", end="   ")
        start = time.time()
        logits, cache_k, cache_v = model(
            jnp.asarray([encoded]),
            cos_freq[positions],
            sin_freq[positions],
            positions_padded,
            None,
            cache_k,
            cache_v
        )
        print(f"Time taken : {time.time()- start :.2f} seconds")
        logprobs = jax.nn.log_softmax(logits, axis=-1)
        next_token = jnp.argmax(logprobs[:, -1,:], axis=-1)

        # 4. Generation
        generated = [next_token[0].item()]
        print("Generating...", end="   ")
        overall_start = time.time()
        for t in range(max_tokens):
            cur_pos+=1
            pos = jnp.array([cur_pos])
            logits, cache_k, cache_v = logits, cache_k, cache_v = model(
                jax.lax.expand_dims(next_token, (1,)),
                cos_freq[pos],
                sin_freq[pos],
                pos,
                None,
                cache_k,
                cache_v
            )
            logprobs = jax.nn.log_softmax(logits[:, -1, :], axis=-1)
            next_token = jnp.argmax(logprobs, axis=-1)
            generated.append(next_token[0].item())
    
        end = time.time()
        print(f"Time taken to generate {max_tokens} tokens: {end- overall_start:.2f} seconds")
        print("\nPrompt     : ", prompt)
        print("Completion :", end=" ")
        res = prompt + " " + "".join(tokenizer.decode(generated))
        print(repr(res))
        print("\n", "="*75)

In [16]:
prompts = [
    "This is a test",
    "Hello, I am a language model,",
    "I am a helpful assistant"
]
generate(prompts, model, tokenizer, max_tokens=64)

Prefilling...   Time taken : 0.01 seconds
Generating...   Time taken to generate 64 tokens: 1.66 seconds

Prompt     :  This is a test
Completion : 'This is a test of the emergency broadcast system.\n\nThis is only a test.\n\nIf this had been an actual emergency, you would have been instructed where to go and what to do.\n\nThis is only a test.\n\nThis is a test of the emergency broadcast system.\n\nThis is only a test'

Prefilling...   Time taken : 0.01 seconds
Generating...   Time taken to generate 64 tokens: 1.63 seconds

Prompt     :  Hello, I am a language model,
Completion : 'Hello, I am a language model, and I am here to help you with your writing. I can provide you with a variety of writing prompts to get your creative juices flowing.\n\n## Introduction\n\nWriting prompts are a great way to get your creative juices flowing. They can help you come up with new ideas for stories, poems'

Prefilling...   Time taken : 0.01 seconds
Generating...   Time taken to generate 64 tokens: 1.