In [None]:
!pip install sentencepiece
!wget https://files.mistral-7b-v0-1.mistral.ai/mistral-7B-v0.1.tar
!tar -xf mistral-7B-v0.1.tar

In [None]:
from pathlib import Path
import json
import math
from dataclasses import dataclass
import torch
from torch import nn
from sentencepiece import SentencePieceProcessor

# 1) Mistral LLM
[Mistral Github](https://github.com/mistralai/mistral-src/blob/main/one_file_ref.py)

TODO: Explore original cache, prefill chunking (curently using one file ref)

## 1.1 Config

In [None]:
@dataclass
class ModelArgs:
    dim: int
    n_layers: int
    head_dim: int
    hidden_dim: int
    n_heads: int
    n_kv_heads: int
    sliding_window: int
    norm_eps: float
    vocab_size: int
    max_batch_size: int = 0

model_path = './mistral-7B-v0.1'
data_set = [
    "Lucky is a dog from my neighbour.",
    "She likes to play ball with Lucy.",
    "Lucy is her best friend.",
]
num_pipeline_ranks = 2

In [None]:
# if num_pipeline_ranks > 1:
#     torch.distributed.init_process_group()
#     torch.cuda.set_device(torch.distributed.get_rank())
#     should_print = torch.distributed.get_rank() == 0

## 1.2 Tokenization

In [None]:
class Tokenizer:
    def __init__(self, model_path: str):
        assert Path(model_path).exists(), model_path
        self._model = SentencePieceProcessor(model_file=model_path)
        assert self._model.vocab_size() == self._model.get_piece_size()

    @property
    def n_words(self):
        return self._model.vocab_size()

    @property
    def bos_id(self) -> int:
        return self._model.bos_id()

    @property
    def eos_id(self):
        return self._model.eos_id()

    @property
    def pad_id(self):
        return self._model.pad_id()

    def encode(self, s: str, bos: bool=True):
        t = self._model.encode(s)
        if bos:
            t = [self.bos_id, *t]
        return t

    def decode(self, t):
        return self._model.decode(t)

In [None]:
# encode each sentence
tokenizer = Tokenizer(f'{model_path}/tokenizer.model')
encoded_prompts = [tokenizer.encode(prompt) for prompt in data_set]
print(encoded_prompts)

# creating an empty tensor
prompt_lens = [len(x) for x in encoded_prompts]
min_prompt_len = min(prompt_lens)
max_prompt_len = max(prompt_lens)
print('Length: ', min_prompt_len, max_prompt_len)
input_tokens = torch.full(
    (len(data_set), max_prompt_len),
    tokenizer.pad_id,
    dtype=torch.long,
    device="cuda"
)
print(f"Empty Tensor: {input_tokens}")

# inserting tokens into the tensor
for i, encoded in enumerate(encoded_prompts):
    input_tokens[i, :len(encoded)] = torch.tensor(encoded).to(input_tokens)
print(f"Tokenized Tensor: {input_tokens}")
input_mask = input_tokens != tokenizer.pad_id

# position for rotary embedding
positions = torch.arange(0, min_prompt_len).to("cuda")
print(f"Position: {positions}")

[[1, 393, 11791, 349, 264, 3914, 477, 586, 18583, 28723], [1, 985, 12672, 298, 1156, 4374, 395, 18010, 28723], [1, 18010, 349, 559, 1489, 1832, 28723]]
Length:  7 10
Empty Tensor: tensor([[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1]], device='cuda:0')
Tokenized Tensor: tensor([[    1,   393, 11791,   349,   264,  3914,   477,   586, 18583, 28723],
        [    1,   985, 12672,   298,  1156,  4374,   395, 18010, 28723,    -1],
        [    1, 18010,   349,   559,  1489,  1832, 28723,    -1,    -1,    -1]],
       device='cuda:0')
Position: tensor([0, 1, 2, 3, 4, 5, 6], device='cuda:0')


## 1.3 Model

In [None]:
args = ModelArgs(
    dim=512,
    n_layers=1,
    head_dim=128,
    hidden_dim=2048,
    n_heads=4,
    n_kv_heads=2,
    sliding_window=3,
    norm_eps=1e-5,
    vocab_size=32_000,
    max_batch_size=3,
)

# official parameter
# {
#     "dim": 4096,
#     "n_layers": 32,
#     "head_dim": 128,
#     "hidden_dim": 14336,
#     "n_heads": 32,
#     "n_kv_heads": 8,
#     "norm_eps": 1e-05,
#     "sliding_window": 4096,
#     "vocab_size": 32000
# }

In [None]:
def _reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):

    # 4 dimension
    ndim = x.ndim

    # (m, seq_len, n_head, head_dim // 2) --> (1, seq_len, 1, head_dim // 2)
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]

    # (seq_len, head_dim // 2) --> (1, seq_len, 1, head_dim // 2)
    return freqs_cis.view(*shape)

def apply_rotary_emb(xq, xk, freqs_cis):

    # input: (m, seq_len, n_head, head_dim)
    # (m, seq_len, n_head, head_dim // 2, 2) --> (m, seq_len, n_head, head_dim // 2)
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))

    # (1, seq_len, 1, head_dim // 2)
    freqs_cis = _reshape_for_broadcast(freqs_cis, xq_)

    # (m, seq_len, n_heads, head_dim // 2) --> (m, seq_len, n_heads, head_dim // 2, 2)
    xq_out = torch.view_as_real(xq_ * freqs_cis)
    xk_out = torch.view_as_real(xk_ * freqs_cis)

    # (m, seq_len, n_heads, head_dim)
    xq_out = xq_out.reshape(*xq.shape)
    xk_out = xk_out.reshape(*xk.shape)

    return xq_out.type_as(xq), xk_out.type_as(xk)

def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int):
    """
    Part of Grouped Query Attention
    Repeat heads of key and values to match the dimension of query
    """

    # (m, seq_len, n_kv_heads, head_dim) --> (m, seq_len, n_kv_heads * repeat, head_dim)
    keys = torch.repeat_interleave(keys, repeats=repeats, dim=2)
    values = torch.repeat_interleave(values, repeats=repeats, dim=2)
    return keys, values

**Sliding window attention**

![](https://miro.medium.com/v2/resize:fit:1400/0*uJ9qfE3Ik92XnEdz)

Information Flow from Input to Upper Layer.
Note that tokens outside the sliding window still influence next word prediction. At each attention layer, information can move forward by W tokens at most: after two attention layers, information can move forward by 2W tokens, etc.

For instance in a sequence of length 16K and a sliding window of 4K, after 4 layers, information has propagated to the full sequence length.

**Rolling buffer cache**

![](https://github.com/mistralai/mistral-src/raw/main/assets/rolling_cache.png)

The cache has a fixed size of W, and we store the (key, value) for position i in cache position i % W. When the position i is larger than W, past values in the cache are overwritten.

In [None]:
class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.args = args

        self.n_heads = args.n_heads
        self.n_kv_heads = args.n_kv_heads
        self.repeats = self.n_heads // self.n_kv_heads
        self.sliding_window = self.args.sliding_window
        self.scale = self.args.head_dim**-0.5

        self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False)
        self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
        self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
        self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)

        # only cache the sliding window instead of max seq_len
        self.cache_k = torch.empty(
            (
                args.max_batch_size,
                args.sliding_window,
                self.n_kv_heads,
                self.args.head_dim,
            ), dtype=torch.float32).cuda()

        self.cache_v = torch.empty(
            (
                args.max_batch_size,
                args.sliding_window,
                self.n_kv_heads,
                self.args.head_dim,
            ), dtype=torch.float32).cuda()

    def forward(self, x, freqs_cis, positions, mask):
        """
        x: torch Tensor(m, seq_len, dim)
        freqs_cis: torch Tensor(seq_len, head_dim // 2)
        positions: torch Tensor(seq_len)
        mask: torch Tensor(seq_len, seq_len)
        """

        # (m, seq_len, dim)
        bsz, seqlen, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

        ### grouped query attention ###
        # (m, seq_len, n_heads, head_dim)
        xq = xq.view(bsz, seqlen, self.n_heads, self.args.head_dim)
        # (m, seq_len, n_kv_heads, head_dim)
        xk = xk.view(bsz, seqlen, self.n_kv_heads, self.args.head_dim)
        xv = xv.view(bsz, seqlen, self.n_kv_heads, self.args.head_dim)

        # (m, seq_len, n_heads, head_dim), "n_kv_heads" for xk
        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

        # The cache is a rotating buffer, the last element will become first after exceeding length
        # [0, 1, 2, 3, 4, 5, 6] --> [0, 1, 2, 0, 1, 2, 0] --> [1, 2, 0] --> (1, sliding_window, 1, 1)
        scatter_pos = (positions[-self.sliding_window:] % self.sliding_window)[None, :, None, None]
        # (m, sliding_window, n_kv_heads, head_dim)
        scatter_pos = scatter_pos.repeat(bsz, 1, self.n_kv_heads, self.args.head_dim)

        # (m, sliding_window, n_kv_heads, head_dim)
        # it rotate according to the index, ['on', 'the', 'cat', 'is'] -> ['the', 'cat', 'is', 'on']
        self.cache_k[:bsz].scatter_(dim=1, index=scatter_pos, src=xk[:, -self.sliding_window:])
        self.cache_v[:bsz].scatter_(dim=1, index=scatter_pos, src=xv[:, -self.sliding_window:])

        if positions.shape[0] > 1:
            key, value = repeat_kv(xk, xv, self.repeats)
        # use cache if seq_len is 1
        else:
            cur_pos = positions[-1].item() + 1
            key, value = repeat_kv(
                self.cache_k[:bsz, :cur_pos, ...], self.cache_v[:bsz, :cur_pos, ...], self.repeats
            )

        ### original ###
        # xformers requires (B=1, S, H, D)
        # xq, key, val = xq[None, ...], key[None, ...], val[None, ...]
        # output = memory_efficient_attention(
        #     xq, key, val, None if cache is None else cache.mask
        # )
        # return self.wo(output.view(seqlen_sum, self.n_heads * self.head_dim))

        # (m, n_heads, seq_len, head_dim)
        query = xq.transpose(1, 2)
        key = key.transpose(1, 2)
        value = value.transpose(1, 2)

        # (m, n_heads, seqlen | 1, seqlen)
        scores = torch.matmul(query, key.transpose(2, 3)) * self.scale

        if mask is not None:
            # mask (1, 1, seq_len, seq_len)
            scores += mask[None, None, ...]
        scores = scores.float()
        scores = nn.functional.softmax(scores, dim=-1).type_as(query)

        # (m, n_heads, seqlen, head_dim)
        output = torch.matmul(scores, value)
        # (m, seqlen, dim)
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
        return self.wo(output)

In [None]:
### kv cache in the original model ###
'''
from xformers.ops.fmha.attn_bias import (
    AttentionBias,
    BlockDiagonalCausalMask,
    BlockDiagonalCausalWithOffsetPaddedKeysMask,
    BlockDiagonalMask,
)


@dataclass
class RotatingCacheInputMetadata:
    # rope absolute positions
    positions: torch.Tensor
    # which elements in the sequences need to be cached
    to_cache_mask: torch.Tensor
    # how many elements are cached per sequence
    cached_elements: torch.Tensor
    # where tokens should go in the cache
    cache_positions: torch.Tensor

    # if prefill, use block diagonal causal mask
    # else use causal with padded key mask
    prefill: bool
    mask: AttentionBias
    seqlens: list[int]


def interleave_list(l1, l2):
    assert len(l1) == len(l2)
    return [v for pair in zip(l1, l2) for v in pair]

def unrotate(cache: torch.Tensor, seqlen: int):
    assert cache.ndim == 3  # (W, H, D)
    position = seqlen % cache.shape[0]
    if seqlen < cache.shape[0]:
        return cache[:seqlen]
    elif position == 0:
        return cache
    else:
        return torch.cat([cache[position:], cache[:position]], dim=0)

class CacheView:
    def __init__(self, cache_k, cache_v, metadata, kv_seqlens):
        self.cache_k = cache_k
        self.cache_v = cache_v
        self.kv_seqlens = kv_seqlens
        self.metadata = metadata

    def update(self, xk, xv):
        """
        to_cache_mask masks the last [sliding_window] tokens in each sequence
        """
        n_kv_heads, head_dim = self.cache_k.shape[-2:]
        flat_cache_k = self.cache_k.view(-1, n_kv_heads, head_dim)
        flat_cache_v = self.cache_v.view(-1, n_kv_heads, head_dim)

        flat_cache_k.index_copy_(0, self.metadata.cache_positions, xk[self.metadata.to_cache_mask])
        flat_cache_v.index_copy_(0, self.metadata.cache_positions, xv[self.metadata.to_cache_mask])

    def interleave_kv(self, xk, xv):
        """
        This is a naive implementation and not optimized for speed.
        """
        assert xk.ndim == xv.ndim == 3 # (B * T, H, D)
        assert xk.shape == xv.shape

        if all([s == 0 for s in self.metadata.seqlens]):
            # No cache to interleave
            return xk, xv

        # Make it a list of [(T, H, D)]
        xk = torch.split(xk, self.metadata.seqlens)
        xv = torch.split(xv, self.metadata.seqlens)

        # Order elements in cache by position by unrotating
        cache_k = [unrotate(t, s) for t, s in zip(self.cache_k, self.kv_seqlens)]
        cache_v = [unrotate(t, s) for t, s in zip(self.cache_v, self.kv_seqlens)]

        interleaved_k = interleave_list(cache_k, xk)
        interleaved_v = interleave_list(cache_v, xv)

        return torch.cat(interleaved_k, dim=0), torch.cat(interleaved_v, dim=0)

    @property
    def sliding_window(self):
        return self.cache_k.shape[1]

    @property
    def key(self):
        return self.cache_k[:len(self.kv_seqlens)]

    @property
    def value(self):
        return self.cache_v[:len(self.kv_seqlens)]

    @property
    def prefill(self):
        return self.metadata.prefill

    @property
    def mask(self):
        return self.metadata.mask


class RotatingBufferCache:
    """
    This is an example that implements a less naive rotating buffer cache, allowing for variable length sequences.
    Allocated cache is rectangular which is wasteful (see PagedAttention for better mechanisms)
    """
    def __init__(self, n_layers, max_batch_size, sliding_window, n_kv_heads, head_dim):

        self.sliding_window = sliding_window
        self.n_kv_heads = n_kv_heads
        self.head_dim = head_dim

        self.cache_k = torch.empty((
            n_layers,
            max_batch_size,
            sliding_window,
            n_kv_heads,
            head_dim
        ))
        self.cache_v = torch.empty((
            n_layers,
            max_batch_size,
            sliding_window,
            n_kv_heads,
            head_dim
        ))
        # holds the valid length for each batch element in the cache
        self.kv_seqlens = None

    def get_view(self, layer_id: int, metadata: RotatingCacheInputMetadata) -> CacheView:
        return CacheView(self.cache_k[layer_id], self.cache_v[layer_id], metadata, self.kv_seqlens)

    def reset(self):
        self.kv_seqlens = None

    def init_kvseqlens(self, batch_size: int):
        self.kv_seqlens = torch.zeros((batch_size,), device=self.device, dtype=torch.long)

    @property
    def device(self):
        return self.cache_k.device

    def to(self, device: torch.device, dtype: torch.dtype):
        self.cache_k = self.cache_k.to(device=device, dtype=dtype)
        self.cache_v = self.cache_v.to(device=device, dtype=dtype)

        return self

    def update_seqlens(self, seqlens: List[int]):
        self.kv_seqlens += torch.tensor(seqlens, device=self.device, dtype=torch.long)

    def get_input_metadata(self, seqlens: List[int]):

        if self.kv_seqlens is None:
            self.init_kvseqlens(len(seqlens))
        seqpos = self.kv_seqlens.tolist()
        masks = [
            [x >= seqlen - self.sliding_window for x in range(seqlen)]
                for seqlen in seqlens
        ]
        to_cache_mask = torch.tensor(sum(masks, []), device=self.device, dtype=torch.bool)
        cached_elements = torch.tensor([sum(mask) for mask in masks], device=self.device, dtype=torch.long)
        positions = [torch.arange(pos, pos + seqlen) for pos, seqlen in zip(seqpos, seqlens)]
        positions = torch.cat(positions).to(device=self.device, dtype=torch.long)
        batch_idx = torch.tensor(
            sum([[i]*seqlen for i, seqlen in enumerate(seqlens)], []),
            device=self.device,
            dtype=torch.long,
        )
        cache_positions = positions % self.sliding_window + batch_idx * self.sliding_window

        first_prefill = seqpos[0] == 0
        subsequent_prefill = any(seqlen > 1 for seqlen in seqlens)
        if first_prefill:
            mask = BlockDiagonalCausalMask.from_seqlens(seqlens).make_local_attention(self.sliding_window)
        elif subsequent_prefill:
            mask = BlockDiagonalMask.from_seqlens(
                q_seqlen=seqlens,
                kv_seqlen=[
                    s + cached_s.clamp(max=self.sliding_window).item()
                        for (s, cached_s) in zip(seqlens, self.kv_seqlens)
                ]
            ).make_local_attention_from_bottomright(self.sliding_window)
        else:
            mask = BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
                q_seqlen=seqlens,
                kv_padding=self.sliding_window,
                kv_seqlen=(self.kv_seqlens + cached_elements).clamp(max=self.sliding_window).tolist()
            )

        return RotatingCacheInputMetadata(
            positions=positions,
            to_cache_mask=to_cache_mask,
            cached_elements=cached_elements,
            cache_positions=cache_positions[to_cache_mask],
            prefill=first_prefill or subsequent_prefill,
            mask=mask,
            seqlens=seqlens,
        )

class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.args = args

        self.n_heads: int = args.n_heads
        self.head_dim: int = args.head_dim
        self.n_kv_heads: int = args.n_kv_heads
        self.repeats = self.n_heads // self.n_kv_heads
        self.scale = self.args.head_dim**-0.5

        self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False)
        self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
        self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
        self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)

    def forward(self, x, freqs_cis, cache):

        seqlen_sum, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
        xq = xq.view(seqlen_sum, self.n_heads, self.head_dim)
        xk = xk.view(seqlen_sum, self.n_kv_heads, self.head_dim)
        xv = xv.view(seqlen_sum, self.n_kv_heads, self.head_dim)
        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

        if cache is None:
            key, val = xk, xv
        elif cache.prefill:
            key, val = cache.interleave_kv(xk, xv)
            cache.update(xk, xv)
        else:
            cache.update(xk, xv)
            key, val = cache.key, cache.value
            key = key.view(seqlen_sum * cache.sliding_window, self.n_kv_heads, self.head_dim)
            val = val.view(seqlen_sum * cache.sliding_window, self.n_kv_heads, self.head_dim)

        # Repeat keys and values to match number of query heads
        key, val = repeat_kv(key, val, self.repeats, dim=1)

        # xformers requires (B=1, S, H, D)
        xq, key, val = xq[None, ...], key[None, ...], val[None, ...]
        output = memory_efficient_attention(xq, key, val, None if cache is None else cache.mask)
        return self.wo(output.view(seqlen_sum, self.n_heads * self.head_dim))

'''

'\nfrom xformers.ops.fmha.attn_bias import (\n    AttentionBias,\n    BlockDiagonalCausalMask,\n    BlockDiagonalCausalWithOffsetPaddedKeysMask,\n    BlockDiagonalMask,\n)\n\n\n@dataclass\nclass RotatingCacheInputMetadata:\n    # rope absolute positions\n    positions: torch.Tensor\n    # which elements in the sequences need to be cached\n    to_cache_mask: torch.Tensor\n    # how many elements are cached per sequence\n    cached_elements: torch.Tensor\n    # where tokens should go in the cache\n    cache_positions: torch.Tensor\n\n    # if prefill, use block diagonal causal mask\n    # else use causal with padded key mask\n    prefill: bool\n    mask: AttentionBias\n    seqlens: list[int]\n\n\ndef interleave_list(l1, l2):\n    assert len(l1) == len(l2)\n    return [v for pair in zip(l1, l2) for v in pair]\n\ndef unrotate(cache: torch.Tensor, seqlen: int):\n    assert cache.ndim == 3  # (W, H, D)\n    position = seqlen % cache.shape[0]\n    if seqlen < cache.shape[0]:\n        return c

In [None]:
class FeedForward(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False)
        self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False)
        self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False)

    def forward(self, x) -> torch.Tensor:
        return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))

In [None]:
class RMSNorm(torch.nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

class TransformerBlock(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.attention = Attention(args)
        self.feed_forward = FeedForward(args=args)
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.args = args

    def forward(self, x, freqs_cis, positions, mask):

        # (m, seq_len, dim)
        r = self.attention.forward(self.attention_norm(x), freqs_cis, positions, mask)

        # skip connection and feedforward
        h = x + r
        r = self.feed_forward.forward(self.ffn_norm(h))
        out = h + r
        return out

In [None]:
def precompute_freqs_cis(head_dim, end=128_000, theta=10000.0):
    # (head_dim // 2)
    freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))
    # 12800
    t = torch.arange(end, device=freqs.device)
    # (12800, head_dim // 2)
    freqs = torch.outer(t, freqs).float()
    return torch.polar(torch.ones_like(freqs), freqs)

class Transformer(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.args = args
        self.vocab_size = args.vocab_size
        self.n_layers = args.n_layers
        self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
        self.layers = torch.nn.ModuleList(
            [TransformerBlock(args=args) for _ in range(self.n_layers)]
        )
        self.norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
        self.freqs_cis = precompute_freqs_cis(self.args.head_dim, 128_000).to("cuda")

    def forward(self, input_ids, positions):

        # [m, seq_len] --> [m, seq_len, dim]
        h = self.tok_embeddings(input_ids)

        # position (0, 1, 2 ... seq_len)
        # (seq_len, head_dim // 2)
        freqs_cis = self.freqs_cis[positions]

        mask = None
        if input_ids.shape[1] > 1:
            seqlen = input_ids.shape[1]
            # (seq_len, seq_len)
            tensor = torch.full(
                (seqlen, seqlen),
                dtype=h.dtype,
                fill_value=1,
                device=h.device,
            )

            # triangular mask
            # make the mask banded to account for sliding window
            mask = torch.tril(tensor, diagonal=0).to(h.dtype)
            mask = torch.triu(mask, diagonal=-self.args.sliding_window)
            # 1 become 0, 0 become -inf
            mask = torch.log(mask)

        # (m, seq_len, dim)
        for layer in self.layers:
            h = layer(h, freqs_cis, positions, mask)
        return self.output(self.norm(h)).float()

    ### to load it straight from pretrained ###
    @staticmethod
    def from_folder(folder, max_batch_size=1, device="cuda", dtype=torch.float16):
        with open(f'{folder}/params.json', 'r') as f:
            model_args = ModelArgs(**json.loads(f.read()))
        model_args.max_batch_size = max_batch_size
        model = Transformer(model_args).to(device=device, dtype=dtype)
        loaded = torch.load(f'{folder}/consolidated.00.pth')
        model.load_state_dict(loaded)
        return model

In [None]:
model = Transformer(args).to("cuda", dtype=torch.float32)
logits = model.forward(input_tokens[:, :min_prompt_len], positions)
logprobs = nn.functional.log_softmax(logits, dim=-1)
print(logits.size(), logprobs.size())

torch.Size([3, 7, 32000]) torch.Size([3, 7, 32000])


## 1.4 Generation

In [None]:
# model = model.from_folder("./mistral-7B-v0.1")
# model.eval()

In [None]:
model_path = './mistral-7B-v0.1'
prompts = [
    "Lucky is a dog from my neighbour.",
    "She likes to play ball with Lucy.",
    "Lucy is her best friend.",
]

# tokenize every word
tokenizer = Tokenizer(f'{model_path}/tokenizer.model')
encoded_prompts = [tokenizer.encode(prompt) for prompt in prompts]

In [None]:
min_prompt_len = min(prompt_lens)
max_prompt_len = max(prompt_lens)

# create temporary tensor filled with pad_id
input_tokens = torch.full(
    (len(prompts), max_prompt_len),
    tokenizer.pad_id,
    dtype=torch.long,
    device="cuda"
)

# replace the token
for i, encoded in enumerate(encoded_prompts):
    input_tokens[i, :len(encoded)] = torch.tensor(encoded).to(input_tokens)
input_mask = input_tokens != tokenizer.pad_id
print(input_tokens)

tensor([[    1,   393, 11791,   349,   264,  3914,   477,   586, 18583, 28723,
            -1],
        [    1,   985, 12672,   298,  1156,  4374,   395, 18010, 28723,    -1,
            -1],
        [    1, 18010,   349,   559,  1489,  1832, 28723,    -1,    -1,    -1,
            -1]], device='cuda:0')


In [None]:
positions = torch.arange(0, min_prompt_len).to("cuda")
logits = model.forward(input_tokens[:, :min_prompt_len], positions)
logprobs = nn.functional.log_softmax(logits, dim=-1)
print(logprobs.size())

# get probability for all tokens in the prompt
all_logprobs = [
    logprobs[:, :-1, :].gather(2, input_tokens[:, 1:min_prompt_len, None]).squeeze(-1),
]
print(all_logprobs)

torch.Size([3, 7, 32000])
[tensor([[-11.2580, -10.4556, -10.4530, -11.7142,  -9.8308, -10.6482],
        [-11.0530, -10.4912, -10.6849, -10.0249, -10.4921, -11.3246],
        [-10.3404, -11.1509, -10.4690, -11.3796,  -9.9112, -11.8490]],
       device='cuda:0', grad_fn=<SqueezeBackward1>)]


In [None]:
def sample_top_p(probs: torch.Tensor, p: float):
    assert 0 <= p <= 1

    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
    probs_sum = torch.cumsum(probs_sort, dim=-1)
    mask = probs_sum - probs_sort > p
    probs_sort[mask] = 0.0
    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))

    # sampled from the multinomial probability distribution
    next_token = torch.multinomial(probs_sort, num_samples=1)

    # torch gather is indexing
    return torch.gather(probs_idx, -1, next_token)

def sample(logits: torch.Tensor, temperature: float, top_p: float):

    # logits size: (m, vocab_size)
    if temperature > 0:
        # logits increase, probs more extreme, making the output stable
        probs = torch.softmax(logits / temperature, dim=-1)
        next_token = sample_top_p(probs, top_p)
    else:
        next_token = torch.argmax(logits, dim=-1).unsqueeze(0)

    return next_token.reshape(-1)

In [None]:
# get probability for each token from input_tokens
max_tokens = 10
temperature = 0.8
generated = []

# start from the min_prompt_len (for batch >1)
# loop until getting the max tokens

for cur_pos in range(min_prompt_len, max_tokens):

    # token id for max prob (m,)
    # next_token = torch.argmax(logprobs[:, -1, :], dim=-1)
    next_token = sample(logprobs[:, -1, :], temperature=temperature, top_p=0.8)

    if cur_pos < input_mask.shape[1]:
        # if not pad, return original token else predicted token
        next_token = torch.where(
            input_mask[:, cur_pos],
            input_tokens[:, cur_pos],
            next_token
        )
    all_logprobs.append(logprobs[:, -1, :].gather(1, next_token[:, None]))

    # (max_tokens, m, 1)
    generated.append(next_token[:, None])

    # feed the last token and position and get prediction
    logits = model.forward(next_token[:, None], torch.LongTensor([cur_pos]).to(next_token))
    logprobs = nn.functional.log_softmax(logits, dim=-1)

all_logprobs = torch.cat(all_logprobs, 1)
print('All Prob Size: ', all_logprobs.size())
generated = torch.cat(generated, 1)
print('Generated Tokens: ', generated)

res = []
for i, x in enumerate(encoded_prompts):
    res.append(tokenizer.decode(x[:min_prompt_len] + generated[i].tolist()))

for x in res:
    print(x)
    print("=====================")

All Prob Size:  torch.Size([3, 16])
Generated Tokens:  tensor([[  586, 18583, 28723, 29919, 31342,  6654,  4228,   399, 31841, 15854],
        [18010, 28723,   948, 16018, 28367,   689, 15099, 20251,  3445,  6302],
        [30021, 30327, 17565, 17159,  1603,  3185, 29532, 24422,  4040, 10668]],
       device='cuda:0')
Lucky is a dog from my neighbour.费ইrior School R❒ gri
She likes to play ball with Lucy. end dés (+ Chistes fucked command college
Lucy is her best friend.技颜 ALL cultiv////CON效itungbig Wild


# 2) Mixtral Mix of Expert (8x7b MoE)

1. Ensemble technique with multiple expert (e.g. some experts specialized for different languages or tasks).  
2. The output of experts are combined (weighted sum or averaging)
3. Only 2 out of 8 experts are used for each token
4. Before going into experts network, the gate produces logits to select topK experts.

[MoE One File Ref](https://github.com/mistralai/mistral-src/blob/main/moe_one_file_ref.py)


In [None]:
@dataclass
class MoeArgs():
    num_experts: int
    num_experts_per_tok: int

@dataclass
class ModelArgs():
    dim: int
    n_layers: int
    head_dim: int
    hidden_dim: int
    n_heads: int
    n_kv_heads: int
    norm_eps: float
    vocab_size: int
    moe: MoeArgs
    max_batch_size: int = 0
    max_seq_len: int = 0

moe_args = MoeArgs(num_experts=4, num_experts_per_tok=2)

args = ModelArgs(
    dim=512,
    n_layers=1,
    head_dim=128,
    hidden_dim=2048,
    n_heads=4,
    n_kv_heads=2,
    norm_eps=1e-5,
    vocab_size=32_000,
    max_batch_size=3,
    max_seq_len=20,
    moe=moe_args
)

In [None]:
# no sliding window in MoE, the rest remains
class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.args = args

        self.n_heads: int = args.n_heads
        self.n_kv_heads: int = args.n_kv_heads

        self.repeats = self.n_heads // self.n_kv_heads
        self.scale = self.args.head_dim**-0.5

        self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False)
        self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
        self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
        self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
        self._cache_k = None
        self._cache_v = None

    def get_caches(self, x: torch.Tensor):
        dtype, device = x.dtype, x.device
        if self._cache_k is None:
            self._cache_k = torch.empty(
                (
                    self.args.max_batch_size,
                    self.args.max_seq_len,
                    self.n_kv_heads,
                    self.args.head_dim,
                ),
                dtype=dtype,
                device=device,
            )
        if self._cache_v is None:
            self._cache_v = torch.empty(
                (
                    self.args.max_batch_size,
                    self.args.max_seq_len,
                    self.n_kv_heads,
                    self.args.head_dim,
                ),
                dtype=dtype,
                device=device,
            )
        return self._cache_k, self._cache_v

    def forward(self, x, freqs_cis, positions, mask) :
        bsz, seqlen, _ = x.shape
        cache_k, cache_v = self.get_caches(x)

        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
        xq = xq.view(bsz, seqlen, self.n_heads, self.args.head_dim)
        xk = xk.view(bsz, seqlen, self.n_kv_heads, self.args.head_dim)
        xv = xv.view(bsz, seqlen, self.n_kv_heads, self.args.head_dim)
        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

        # The cache is a rotating buffer
        scatter_pos = (positions % self.args.max_seq_len)[None, :, None, None]
        scatter_pos = scatter_pos.repeat(bsz, 1, self.n_kv_heads, self.args.head_dim)
        cache_k[:bsz].scatter_(dim=1, index=scatter_pos, src=xk)
        cache_v[:bsz].scatter_(dim=1, index=scatter_pos, src=xv)

        if positions.shape[0] > 1:
            key, value = repeat_kv(xk, xv, self.repeats)
        else:
            assert mask is None
            cur_pos = int(positions[-1].item() + 1)
            key, value = repeat_kv(
                cache_k[:bsz, :cur_pos, ...],
                cache_v[:bsz, :cur_pos, ...],
                self.repeats,
            )

        query = xq.transpose(1, 2)
        key = key.transpose(1, 2)
        value = value.transpose(1, 2)
        scores = torch.matmul(query, key.transpose(2, 3)) * self.scale

        if mask is not None:
            scores += mask[None, None, ...]

        scores = scores.float()
        scores = nn.functional.softmax(scores, dim=-1).type_as(query)
        output = torch.matmul(scores, value)
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
        return self.wo(output)

In [None]:
# MoE are just list of feedforward layers
class MoeLayer(nn.Module):
    def __init__(self, experts, gate, moe_args):
        super().__init__()
        assert len(experts) > 0

        # list of feedforward layers
        self.experts = nn.ModuleList(experts)
        # list of linear layers
        self.gate = gate
        self.args = moe_args

    def forward(self, inputs: torch.Tensor):

        # (m, seq_len, dim) --> (m * seq_len, dim)
        inputs_squashed = inputs.view(-1, inputs.shape[-1])

        # (m * seq_len, num_experts)
        gate_logits = self.gate(inputs_squashed)

        # (m * seq_len, num_experts_per_tok),
        weights, selected_experts = torch.topk(
            gate_logits, self.args.num_experts_per_tok)
        weights = nn.functional.softmax(
            weights, dim=1, dtype=torch.float).type_as(inputs)

        # (m * seq_len, dim)
        results = torch.zeros_like(inputs_squashed)
        for i, expert in enumerate(self.experts):
            # index of batch and expert
            batch_idx, nth_expert = torch.where(selected_experts == i)

            # weightage * output of expert layers (selected_m, num_expert)
            results[batch_idx] += ( weights[batch_idx, nth_expert, None] *
                expert(inputs_squashed[batch_idx]) )

        # (m * seq_len, dim) --> (m, seq_len, dim)
        return results.view_as(inputs)

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.attention = Attention(args)
        self.feed_forward = MoeLayer(
            experts=[FeedForward(args=args) for _ in range(args.moe.num_experts)],
            gate=nn.Linear(args.dim, args.moe.num_experts, bias=False),
            moe_args=args.moe,
        )
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.args = args

    def forward(self, x, freqs_cis, positions, mask):
        # (m, seq_len, dim)
        r = self.attention.forward(self.attention_norm(x), freqs_cis, positions, mask)
        h = x + r
        r = self.feed_forward.forward(self.ffn_norm(h))
        out = h + r
        return out

In [None]:
class Transformer(nn.Module):
    def __init__(self, args, pipeline_rank=0, num_pipeline_ranks=1):
        super().__init__()
        self.args = args
        self.vocab_size = args.vocab_size
        self.pipeline_rank = pipeline_rank
        self.num_pipeline_ranks = num_pipeline_ranks
        self._precomputed_freqs_cis = None

        # Modules specific to some ranks:
        self.tok_embeddings = None
        self.norm = None
        self.output = None
        if pipeline_rank == 0:
            self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
        if pipeline_rank == num_pipeline_ranks - 1:
            self.norm = RMSNorm(args.dim, eps=args.norm_eps)
            self.output = nn.Linear(args.dim, args.vocab_size, bias=False)

        # Initialize all layers but slice off those not of this rank.
        layers = [TransformerBlock(args=args) for _ in range(args.n_layers)]
        num_layers_per_rank = math.ceil(args.n_layers / self.num_pipeline_ranks)
        offset = self.pipeline_rank * num_layers_per_rank
        end = min(args.n_layers, offset + num_layers_per_rank)
        self.layers = nn.ModuleDict({str(i): layers[i] for i in range(offset, end)})
        self.n_local_layers = len(self.layers)

    @property
    def dtype(self) -> torch.dtype:
        return next(self.parameters()).dtype

    @property
    def device(self) -> torch.device:
        return next(self.parameters()).device

    @property
    def freqs_cis(self) -> torch.Tensor:
        # cache freqs_cis but need to take care that it is on the right device
        if self._precomputed_freqs_cis is None:
            self._precomputed_freqs_cis = precompute_freqs_cis(
                head_dim=self.args.head_dim, end=128_000, theta=1000000.0)
        if self._precomputed_freqs_cis.device != self.device:
            self._precomputed_freqs_cis = self._precomputed_freqs_cis.to(device=self.device)
        return self._precomputed_freqs_cis

    def forward(self, input_ids, positions):

        # position (0, 1, 2 ... seq_len)
        # (seq_len, head_dim // 2)
        freqs_cis = self.freqs_cis[positions]

        (bsz, seqlen) = input_ids.shape
        num_toks = bsz * seqlen

        if self.pipeline_rank == 0:
            h = self.tok_embeddings(input_ids)
        else:
            h = torch.empty(bsz, seqlen, self.args.dim, device=self.device, dtype=self.dtype)
            torch.distributed.recv(h, src=self.pipeline_rank - 1)

        mask = None
        if input_ids.shape[1] > 1:
            tensor = torch.full(
                (seqlen, seqlen),
                dtype=h.dtype,
                fill_value=1,
                device=h.device,
            )
            mask = torch.log(torch.tril(tensor, diagonal=0)).to(h.dtype)

        # (m, seq_len, dim)
        for layer in self.layers.values():
            h = layer(h, freqs_cis, positions, mask)

        if self.pipeline_rank < self.num_pipeline_ranks - 1:
            torch.distributed.send(h, dst=self.pipeline_rank + 1)
            outs = torch.empty(*h.shape[:-1], self.vocab_size, device=h.device, dtype=h.dtype)
        else:
            outs = self.output(self.norm(h))
        if self.num_pipeline_ranks > 1:
            torch.distributed.broadcast(outs, src=self.num_pipeline_ranks - 1)
        return outs.float()

        def load_state_dict(self, state_dict, *args, **kwargs):
            state_to_load = {}
            skipped = set([])
            for k, v in state_dict.items():
                if k.startswith("tok_embeddings"):
                    if self.pipeline_rank == 0:
                        state_to_load[k] = v
                    else:
                        logging.debug(
                            "Skipping parameter %s at pipeline rank %d",
                            k,
                            self.pipeline_rank,
                        )
                        skipped.add(k)
                elif k.startswith("norm") or k.startswith("output"):
                    if self.pipeline_rank == self.num_pipeline_ranks - 1:
                        state_to_load[k] = v
                    else:
                        logging.debug(
                            "Skipping parameter %s at pipeline rank %d",
                            k,
                            self.pipeline_rank,
                        )
                        skipped.add(k)
                elif k.startswith("layers"):
                    layer_id = k.split(".")[1]
                    if layer_id in self.layers:
                        state_to_load[k] = v
                    else:
                        logging.debug(
                            "Skipping parameter %s at pipeline rank %d",
                            k,
                            self.pipeline_rank,
                        )
                        skipped.add(k)
                else:
                    raise ValueError(f"Unexpected key {k}")
            assert set(state_dict.keys()) == skipped.union(set(state_to_load.keys()))
            super().load_state_dict(state_to_load, *args, **kwargs)

    @staticmethod
    def from_folder(
            folder, max_batch_size, max_seq_len, num_pipeline_ranks=1,
            device="cuda", dtype=torch.float16
        ):
        with open(folder / "params.json", "r") as f:
            model_args = ModelArgs.from_dict(json.load(f))
        model_args.max_batch_size = max_batch_size
        model_args.max_seq_len = max_seq_len
        if num_pipeline_ranks > 1:
            pipeline_rank = torch.distributed.get_rank()
        else:
            pipeline_rank = 0
        with torch.device("meta"):
            model = Transformer(
                model_args,
                pipeline_rank=pipeline_rank,
                num_pipeline_ranks=num_pipeline_ranks
            )
        loaded = torch.load(str(folder / "consolidated.00.pth"), mmap=True)
        model.load_state_dict(loaded, assign=True)
        return model.to(device=device, dtype=dtype)

In [None]:
model = Transformer(args).to("cuda", dtype=torch.float32)
logits = model.forward(input_tokens[:, :min_prompt_len], positions)
logprobs = nn.functional.log_softmax(logits, dim=-1)
# print(logits.size(), logprobs.size())

torch.Size([12]) torch.Size([12])
torch.Size([12]) torch.Size([12])
torch.Size([11]) torch.Size([11])
torch.Size([7]) torch.Size([7])


# 3) Phi-2

In [None]:
import math
from typing import Optional

from transformers import PretrainedConfig

class PhiConfig(PretrainedConfig):
    """Phi configuration."""

    model_type = "phi-msft"
    attribute_map = {
        "max_position_embeddings": "n_positions",
        "hidden_size": "n_embd",
        "num_attention_heads": "n_head",
        "num_hidden_layers": "n_layer",
    }

    def __init__(
        self,
        vocab_size: int = 50304,
        n_positions: int = 2048,
        n_embd: int = 1024,
        n_layer: int = 20,
        n_inner: Optional[int] = None,
        n_head: int = 16,
        n_head_kv: Optional[int] = None,
        rotary_dim: Optional[int] = 32,
        activation_function: Optional[str] = "gelu_new",
        flash_attn: bool = False,
        flash_rotary: bool = False,
        fused_dense: bool = False,
        attn_pdrop: float = 0.0,
        embd_pdrop: float = 0.0,
        resid_pdrop: float = 0.0,
        layer_norm_epsilon: float = 1e-5,
        initializer_range: float = 0.02,
        tie_word_embeddings: bool = False,
        pad_vocab_size_multiple: int = 64,
        **kwargs):

        self.vocab_size = int(
            math.ceil(vocab_size / pad_vocab_size_multiple) *
            pad_vocab_size_multiple
        )
        self.n_positions = n_positions
        self.n_embd = n_embd
        self.n_layer = n_layer
        self.n_inner = n_inner
        self.n_head = n_head
        self.n_head_kv = n_head_kv
        self.rotary_dim = min(rotary_dim, n_embd // n_head)
        self.activation_function = activation_function
        self.flash_attn = flash_attn
        self.flash_rotary = flash_rotary
        self.fused_dense = fused_dense
        self.attn_pdrop = attn_pdrop
        self.embd_pdrop = embd_pdrop
        self.resid_pdrop = resid_pdrop
        self.layer_norm_epsilon = layer_norm_epsilon
        self.initializer_range = initializer_range

        super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)

In [None]:
arg = {
    "_name_or_path": "microsoft/phi-2",
    "activation_function": "gelu_new",
    "architectures": ["PhiForCausalLM"],
    "attn_pdrop": 0.0,
    "auto_map": {
        "AutoConfig": "configuration_phi.PhiConfig",
        "AutoModelForCausalLM": "modeling_phi.PhiForCausalLM"
    },
    "embd_pdrop": 0.0,
    "flash_attn": false,
    "flash_rotary": false,
    "fused_dense": false,
    "img_processor": null,
    "initializer_range": 0.02,
    "layer_norm_epsilon": 1e-05,
    "model_type": "phi-msft",
    "n_embd": 2560,
    "n_head": 32,
    "n_head_kv": null,
    "n_inner": null,
    "n_layer": 32,
    "n_positions": 2048,
    "resid_pdrop": 0.1,
    "rotary_dim": 32,
    "tie_word_embeddings": false,
    "torch_dtype": "float16",
    "transformers_version": "4.35.2",
    "vocab_size": 51200
}