In [1]:
import torch
from torch import nn, optim

import torch.nn.functional as F
from einops import rearrange
from typing import Optional
from ember import (
    Transformer,
    MultiHeadLatentAttn,
    KVCache,
    GroupedQueryAttn,
    Tokenizer,
    RoPE,
    apply_rotary_pos_emb,
    TopKSampler
)
from transformers import AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
B, S, D = 8, 1024, 512
mla_kwargs = {
    "latent_dim": 256,
    "pos_dim": 128,
    "n_heads": 16,
}
gqa_kwargs = {
    "n_query_heads": 8,
    "n_query_groups": 4,
}
# 
# tk = Tokenizer()
tk = Tokenizer(model = "openai-community/gpt2")
model = Transformer(
    vocab_size=tk.vocab_size,
    model_dim=D,
    hidden_dim=1024,
    attn_module=GroupedQueryAttn,
    attn_kwargs=gqa_kwargs,
    n_attn_blocks=3,
)

# mini_deepseek = Transformer(
#     vocab_size=tk.vocab_size,
#     model_dim=D,
#     hidden_dim=1024,
#     attn_module=MultiHeadLatentAttn,
#     attn_kwargs=mla_kwargs,
#     n_attn_blocks=3,
# )

<bound method Module.parameters of Transformer(
  (embed): Embedding(50257, 512)
  (attn_blocks): ModuleList(
    (0-2): 3 x AttentionBlock(
      (mlp): SwiGLU(
        (W): Linear(in_features=512, out_features=682, bias=False)
        (V): Linear(in_features=512, out_features=682, bias=False)
        (W2): Linear(in_features=682, out_features=512, bias=False)
      )
      (norm): RMSNorm()
      (attn): GroupedQueryAttn(
        (fused_qkv): Linear(in_features=512, out_features=1024, bias=True)
        (out_proj): Linear(in_features=512, out_features=512, bias=True)
        (rope): RoPE()
      )
    )
  )
  (norm): RMSNorm()
)>
Parameter count: 3.12e+07


In [3]:
txt = ["Hey there", "How are you?"]
x = tk(txt)
x.shape

torch.Size([2, 4])

In [13]:
txt, tk.decode(tk(txt))

(['Hey there', 'How are you?'],
 ['Hey there<|endoftext|><|endoftext|>', 'How are you?'])

In [5]:
out = model.generate(
    x,
    max_new_tokens=50,
    sampler=TopKSampler(50),
    tokenizer=tk,
)
out

['Hey there',

In [9]:
@torch.inference_mode
def generate(
    model,
    indices: torch.Tensor,
    max_new_tokens: int,
    sampler: nn.Module,
    tokenizer: callable,
) -> torch.Tensor:
    cache_config: dict = model.cache_config
    B, S = indices.shape

    cache = KVCache(
        n_layers=model.n_attn_blocks,
        max_batch_size=B,
        max_seq_len=S + max_new_tokens,
        n_heads=cache_config["n_heads"],
        head_dim=cache_config["head_dim"],
    )
    finished = torch.zeros((B,), dtype=torch.bool, device=indices.device)

    logits = model.forward(indices, cache)  # prefill cache
    cache.initialize_prefill(S)

    max_tokens = S + max_new_tokens
    while cache.current_len <= max_tokens:
        next_tokens = sampler(logits[:, -1, :])
        indices = torch.cat([indices, next_tokens], dim=-1)

        is_eos = next_tokens.squeeze(-1) == tokenizer.eos_token_id
        finished = finished | is_eos
        if finished.all():
            break

        logits = model.forward(next_tokens, cache=cache)
        cache.step()

    # remove tokens past <eos>
    output_strings = []
    for i in range(B):
        print(indices[i])
        seq = indices[i].tolist()
        try:
            eos_idx = seq.index(tokenizer.eos_token_id)
            print(eos_idx)
            seq = seq[:eos_idx]
        except ValueError:
            pass
        output_strings.append(tokenizer.decode(seq))

    return output_strings


generate(
    model,
    x,
    max_new_tokens=50,
    sampler=TopKSampler(50),
    tokenizer=tk,
)

tensor([10814,   612, 50256, 50256, 16039,  6728,  9690,  9690, 26747, 30016,
         9937, 37942,  4152, 30978, 24696, 22692, 19114, 33719, 21889, 24696,
        19380, 33719,  8563, 24776, 29714, 29714,  9328, 48215,  8563, 26747,
        15594, 30978, 24696, 32917,   701, 40292,   701, 33719, 42853, 39646,
         9328, 19380, 19380, 38493,  9328,  9328, 30962, 37624, 35728, 46341,
        19530, 19380, 46341, 33719, 20612])
2
tensor([ 2437,   389,   345,    30, 19450,  3501, 37707, 19878, 19878, 32247,
        48486, 19019, 19387, 19878, 32371, 46215, 16880,  3501,  3501, 38045,
        38045, 24350, 35477, 42587, 19019, 31198, 16880, 32371,  8905, 35477,
        23243, 46116, 45949, 42956, 16547, 38044, 43426,  5480, 11232, 32371,
        19878, 39261, 12450, 37707, 38977, 45949, 43426, 19878, 38044, 31198,
        33396,  8905, 49809, 29583, 48486])


['Hey there',
 'How are you? olive giving NIGHT PER PER VID classy injust INC PERbertoorea Tap giving giving Bis Bis hallsDemon Tune injust prominence TapbertoEPDemon sorted awarding Vinyl Wickedolphins INFORMATION Prepare Palest Jonathanberto PER Hancock shoulders NIGHT remod Vinyl Prepare PER INFORMATION prominence ReductionEPossession Flip classy']

In [20]:
tk = Tokenizer()
tk.decode([tk.pad_token_id])

'<|end_of_text|>'