In [7]:
import torch
device="cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

In [12]:
import torch
from torch.utils.data import Dataset, DataLoader

pad_id = 5
eos_id = 6

def generate(model, x: torch.Tensor, max_new_tokens: int): # top_k, top_p, temperature
  tokens = x.detach().cpu().numpy().tolist()

  for _ in range(max_new_tokens):
    x = x.unsqueeze(0).to(device)
    out = model.forward(x)
    out = out.squeeze(0)
    probs = torch.softmax(out[-1], dim=-1)
    _, max_index = torch.max(probs, dim=-1)
    tokens.append(max_index.item())
    if max_index == eos_id or len(tokens) > 256: # <eos> and max context length
      break

    x = torch.tensor(tokens)

  return tokens

In [None]:
from llama_config import LlamaConfig
from llama_model import LlamaForCausalLM
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download

llama_config = LlamaConfig(
    vocab_size=32768,
    emb_dim=256,
    context_length=256,
    n_heads=128,
    n_layers=20,
    n_kv_groups=64,
    hidden_dim=2048,
)

llama_model = LlamaForCausalLM(llama_config)
llama_model = llama_model.to(device)

model_path = hf_hub_download(
    repo_id="AhmetSemih/llama-50m-pretrained-books-tr_tokenizer",
    filename="llama-50m-pretrained-books-tr_tokenizer.safetensors",
)

state_dict = load_file(model_path)
llama_model.load_state_dict(state_dict)

llama_model.eval()

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32768, 256)
    (layers): ModuleList(
      (0-19): 20 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=256, out_features=256, bias=False)
          (k_proj): Linear(in_features=256, out_features=128, bias=False)
          (v_proj): Linear(in_features=256, out_features=128, bias=False)
          (o_proj): Linear(in_features=256, out_features=256, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=256, out_features=2048, bias=False)
          (up_proj): Linear(in_features=256, out_features=2048, bias=False)
          (down_proj): Linear(in_features=2048, out_features=256, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): Linear(in_features=256, out_features=32768, bias=False)
)

In [None]:
sample_text = "dedi ki" # let say next word is "neredeydin" and next token is "nere" and id is 3018
encoded_ids=[3452, 2, 20059]

In [22]:
id_tensor = torch.tensor(encoded_ids, device=device)


-----------------------------------------------------------------------------
"c = vocab size, s = logits for input tokens, p = index of the correct token"

![alt text](https://framerusercontent.com/images/KWVlMyakoCF4DcOSjtN1WsAXs.webp?width=1300&height=348)

In [23]:
out = llama_model(id_tensor.unsqueeze(0))
out

tensor([[[  7.3754,   1.6016,   5.1854,  ...,  -7.5244,  -6.4256, -12.8459],
         [  0.6572,   1.1768,   1.8223,  ...,  -7.1019,  -4.7086, -14.1919],
         [  4.9584,   0.2911,   6.3374,  ...,  -8.1821,  -7.5339, -15.9622]]],
       device='mps:0', grad_fn=<LinearBackward0>)

In [None]:
output_logits=out[0][2] # logits for after "ki" token
output_logits

tensor([  4.9584,   0.2911,   6.3374,  ...,  -8.1821,  -7.5339, -15.9622],
       device='mps:0', grad_fn=<SelectBackward0>)

In [30]:
out[0,2,:]

tensor([  4.9584,   0.2911,   6.3374,  ...,  -8.1821,  -7.5339, -15.9622],
       device='mps:0', grad_fn=<SliceBackward0>)

In [29]:
out[:, -1, :]

tensor([[  4.9584,   0.2911,   6.3374,  ...,  -8.1821,  -7.5339, -15.9622]],
       device='mps:0', grad_fn=<SliceBackward0>)

In [25]:
#index of correct token ("nere") is 3018 
#vocab size is 32768

import torch.nn.functional as F

correct_token_id = 3018
F.cross_entropy(output_logits.unsqueeze(0), torch.tensor([correct_token_id], device=device))


tensor(13.3377, device='mps:0', grad_fn=<NllLossBackward0>)

In [None]:
#find models prediction for next token
probs = torch.softmax(output_logits, dim=-1)
_, max_index = torch.max(probs, dim=-1)
max_index.item()  # 32240 = ":" token 

32240