# Loading Llama 3 from pretrained weights

In [6]:
from pprint import pprint
import time

import torch

from gollem.models.llama3.config import get_llama3_model_config
from gollem.models.llama3.model import Llama3

cfg = get_llama3_model_config("llama-3.2-1B")
device = "cuda"

cfg.inference_mode = True
cfg.max_sample_batch_size = 1
pprint(cfg)

model = Llama3.from_pretrained(cfg)
model.to(device)

Llama3Config(model_name='llama-3.2-1B',
             n_ctx=1024,
             n_layer=16,
             n_head=32,
             n_kv_head=8,
             d_model=2048,
             intermediate_size=8192,
             vocab_size=128256,
             learning_rate=0.0003,
             warmup_iters=0,
             learning_rate_decay_frac=0.001,
             rope_theta=10000,
             rmsnorm_eps=1e-06,
             weight_decay=0.1,
             grad_clip=1.0,
             betas=(0.9, 0.95),
             fused_adamw=True,
             zero_optimizer=True,
             flash=True,
             activation_checkpointing=False,
             compile=True,
             from_pretrained=False,
             max_sample_batch_size=1,
             inference_mode=True)
Loading weights from pretrained llama3 llama-3.2-1B
Downloading weights for llama-3.2-1B
Existing checkpoint found for llama-3.2-1B at /home/user/gollem/checkpoints/llama-3.2-1B skipping download


  sd_hf = torch.load(checkpoint_dir / "original" / "consolidated.00.pth")  # type: ignore


Creating model and loading weights


Llama3(
  (tok_embeddings): Embedding(128256, 2048)
  (layers): ModuleList(
    (0-15): 16 x TransformerBlock(
      (attention_norm): RMSNorm()
      (attention): InferenceAttention(
        (wq): Linear(in_features=2048, out_features=2048, bias=False)
        (wk): Linear(in_features=2048, out_features=512, bias=False)
        (wv): Linear(in_features=2048, out_features=512, bias=False)
        (wo): Linear(in_features=2048, out_features=2048, bias=False)
      )
      (ffn_norm): RMSNorm()
      (feed_forward): MLP(
        (w1): Linear(in_features=2048, out_features=8192, bias=False)
        (w2): Linear(in_features=8192, out_features=2048, bias=False)
        (w3): Linear(in_features=2048, out_features=8192, bias=False)
        (silu): SiLU()
      )
    )
  )
  (norm): RMSNorm()
  (output): Linear(in_features=2048, out_features=128256, bias=False)
)

In [7]:
tokenizer = cfg.get_tokenizer()
prompts = [
    "On tuesday I will be in a ",
    "On wednesday ",
    "President Trump has ",
    "President Biden has ",
]
prompt_tokens = [
    torch.tensor(tokenizer.encode(prompt), device=device) for prompt in prompts
]


In [8]:
times = []
for p in prompt_tokens:
    start = time.time()
    output_tokens = model.generate([p], 1000, stop_tokens=None, echo=False)
    end = time.time()
    times.append(end - start)

num_output_tokens = sum(len(o.squeeze(0)) for o in output_tokens)
print("Mean time per prompt:", sum(times) / len(times))
print("Mean output tokens per prompt:", num_output_tokens / len(times))
print("Mean time per output token:", sum(times) / num_output_tokens)


Mean time per prompt: 11.781816720962524
Mean output tokens per prompt: 250.0
Mean time per output token: 0.0471272668838501


# KV caching vs no KV caching

With KV-caching
```
max_new_tokens = 50
Mean time per prompt: 0.7313917279243469
Mean output tokens per prompt: 12.5
Mean time per output token: 0.05851133823394775

max_new_tokens = 1000
Mean time per prompt: 11.781816720962524
Mean output tokens per prompt: 250.0
Mean time per output token: 0.0471272668838501
```

Without KV-caching
``` 
max_new_tokens = 50
Mean time per prompt: 0.7028171420097351
Mean output tokens per prompt: 13.5
Mean time per output token: 0.05206052903775816

max_new_tokens = 1000
Mean time per prompt: 94.20367068052292
Mean output tokens per prompt: 251.0
Mean time per output token: 0.3753134290060674
```

In [5]:
for out_seq in output_tokens:
    response = tokenizer.decode(out_seq.tolist())
    print(response)

On tuesday 23nd October a group of mining engineers worked closely with metal manufacturers to find a solution. 5% of mining is capable of deploying this function because
