In [1]:
import torch
from torch import nn, optim

import torch.nn.functional as F
import torch.nn.init as init

from ignite.llm import (
    Transformer,
    GroupedQueryAttn,
    MultiHeadLatentAttn,
    Tokenizer,
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
B, S, D = 8, 1024, 512
mla_kwargs = {
    "latent_dim": 256,
    "pos_dim": 128,
    "n_heads": 16,
}
gqa_kwargs = {
    "n_query_heads": 8,
    "n_query_groups": 4,
}

tk = Tokenizer()

mini_llama = Transformer(
    vocab_size=tk.vocab_size,
    model_dim=D,
    hidden_dim=1024,
    attn_module=GroupedQueryAttn,
    attn_kwargs=gqa_kwargs,
    n_attn_blocks=3,
)

mini_deepseek = Transformer(
    vocab_size=tk.vocab_size,
    model_dim=D,
    hidden_dim=1024,
    attn_module=MultiHeadLatentAttn,
    attn_kwargs=mla_kwargs,
    n_attn_blocks=3,
)

<bound method Module.parameters of Transformer(
  (embed): Embedding(128256, 512)
  (attn_blocks): ModuleList(
    (0-2): 3 x AttentionBlock(
      (mlp): SwiGLU(
        (W): Linear(in_features=512, out_features=682, bias=False)
        (V): Linear(in_features=512, out_features=682, bias=False)
        (W2): Linear(in_features=682, out_features=512, bias=False)
      )
      (norm): RMSNorm()
      (attn): GroupedQueryAttn(
        (fused_qkv): Linear(in_features=512, out_features=1024, bias=True)
        (out_proj): Linear(in_features=512, out_features=512, bias=True)
      )
    )
  )
  (norm): RMSNorm()
)>
Parameter count: 7.12e+07
<bound method Module.parameters of Transformer(
  (embed): Embedding(128256, 512)
  (attn_blocks): ModuleList(
    (0-2): 3 x AttentionBlock(
      (mlp): SwiGLU(
        (W): Linear(in_features=512, out_features=682, bias=False)
        (V): Linear(in_features=512, out_features=682, bias=False)
        (W2): Linear(in_features=682, out_features=512, bia

In [3]:
txt = ["Hey there", "How are you?"]
x = tk(txt)
x.shape

torch.Size([2, 5])

In [4]:
with torch.no_grad():
    out_l = mini_llama(x)
    print(out_l.shape)
    out_d = mini_deepseek(x)
    print(out_d.shape)

torch.Size([2, 5, 128256])
torch.Size([2, 5, 128256])


In [5]:
p = torch.softmax(out_d, dim=-1)
p = p.max(dim=-1).indices.tolist()
p

[[64822, 55519, 101747, 69796, 13772], [64822, 116598, 116598, 87252, 122673]]