In [1]:
from benchmark import SmolEvalWrapper, run_benchmarks
from model import ModelConfig, LlamaModel
from transformers import AutoTokenizer

import torch

In [2]:
device = torch.device("cuda:0")

In [3]:
hf_checkpoint = "HuggingFaceTB/SmolLM-360M"
tokenizer = AutoTokenizer.from_pretrained(hf_checkpoint)
tokenizer.pad_token = tokenizer.eos_token

In [4]:
checkpoint_path="math_3epoch/model.checkpoint.2025-02-01--07-20-39.pt"
state_dict = torch.load(checkpoint_path, weights_only=True)

In [5]:
prefix_len = len("module._orig_mod.")
smol_state_dict = {k[prefix_len:]:v for k,v in state_dict.items()}

In [6]:
model_config = ModelConfig(
    vocab_size=tokenizer.vocab_size,
    d_model=960,
    d_head=64,
    d_mlp_proj=2560,
    n_layers=32,
    n_kv_heads=5,
    n_attn_heads=15,
    rms_norm_eps=1e-5,
    initializer_range=0.008,
    rope_theta=100000.0,
    padding_idx=tokenizer.pad_token_id
)
model = LlamaModel(model_config)

In [7]:
model.load_state_dict(smol_state_dict)

<All keys matched successfully>

In [8]:
model.to(device)

LlamaModel(
  (embed_tokens): Embedding(49152, 960)
  (layers): ModuleList(
    (0-31): 32 x DecoderLayer(
      (self_attn): GroupedQueryAttention(
        (q_proj): Linear(in_features=960, out_features=960, bias=False)
        (k_proj): Linear(in_features=960, out_features=320, bias=False)
        (v_proj): Linear(in_features=960, out_features=320, bias=False)
        (o_proj): Linear(in_features=960, out_features=960, bias=False)
      )
      (mlp): GatedMlp(
        (up_proj): Linear(in_features=960, out_features=2560, bias=False)
        (gate_proj): Linear(in_features=960, out_features=2560, bias=False)
        (down_proj): Linear(in_features=2560, out_features=960, bias=False)
        (silu): SiLU()
      )
      (input_layernorm): RMSNorm((960,), eps=1e-05, elementwise_affine=True)
      (post_attention_layernorm): RMSNorm((960,), eps=1e-05, elementwise_affine=True)
    )
  )
  (norm): RMSNorm((960,), eps=1e-05, elementwise_affine=True)
  (lm_head): Linear(in_features=960, out

In [9]:
model.eval()

LlamaModel(
  (embed_tokens): Embedding(49152, 960)
  (layers): ModuleList(
    (0-31): 32 x DecoderLayer(
      (self_attn): GroupedQueryAttention(
        (q_proj): Linear(in_features=960, out_features=960, bias=False)
        (k_proj): Linear(in_features=960, out_features=320, bias=False)
        (v_proj): Linear(in_features=960, out_features=320, bias=False)
        (o_proj): Linear(in_features=960, out_features=960, bias=False)
      )
      (mlp): GatedMlp(
        (up_proj): Linear(in_features=960, out_features=2560, bias=False)
        (gate_proj): Linear(in_features=960, out_features=2560, bias=False)
        (down_proj): Linear(in_features=2560, out_features=960, bias=False)
        (silu): SiLU()
      )
      (input_layernorm): RMSNorm((960,), eps=1e-05, elementwise_affine=True)
      (post_attention_layernorm): RMSNorm((960,), eps=1e-05, elementwise_affine=True)
    )
  )
  (norm): RMSNorm((960,), eps=1e-05, elementwise_affine=True)
  (lm_head): Linear(in_features=960, out

In [10]:
eval_wrapper = SmolEvalWrapper(model, tokenizer, device, batch_size=8)

In [11]:
task = "gsm8k"

In [12]:
results = run_benchmarks(eval_wrapper, [task], limit=100)
metric_keys = results['results'][task].keys() - ['alias']
metric_values = {metric: results['results'][task][metric] for metric in metric_keys}
metric_values

2025-02-10:01:27:48,622 INFO     [task.py:420] Building contexts for gsm8k on rank 0...
100% 100/100 [00:00<00:00, 449.27it/s]
2025-02-10:01:27:48,848 INFO     [evaluator.py:513] Running generate_until requests
100% 100/100 [03:12<00:00,  1.93s/it]


{'exact_match,strict-match': 0.01,
 'exact_match_stderr,flexible-extract': 0.01714466079977651,
 'exact_match,flexible-extract': 0.03,
 'exact_match_stderr,strict-match': 0.009999999999999998}

In [13]:
results

{'results': {'gsm8k': {'alias': 'gsm8k',
   'exact_match,strict-match': 0.01,
   'exact_match_stderr,strict-match': 0.009999999999999998,
   'exact_match,flexible-extract': 0.03,
   'exact_match_stderr,flexible-extract': 0.01714466079977651}},
 'group_subtasks': {'gsm8k': []},
 'configs': {'gsm8k': {'task': 'gsm8k',
   'tag': ['math_word_problems'],
   'dataset_path': 'gsm8k',
   'dataset_name': 'main',
   'training_split': 'train',
   'test_split': 'test',
   'fewshot_split': 'train',
   'doc_to_text': 'Question: {{question}}\nAnswer:',
   'doc_to_target': '{{answer}}',
   'unsafe_code': False,
   'description': '',
   'target_delimiter': ' ',
   'fewshot_delimiter': '\n\n',
   'num_fewshot': 5,
   'metric_list': [{'metric': 'exact_match',
     'aggregation': 'mean',
     'higher_is_better': True,
     'ignore_case': True,
     'ignore_punctuation': False,
     'regexes_to_ignore': [',', '\\$', '(?s).*#### ', '\\.$']}],
   'output_type': 'generate_until',
   'generation_kwargs': {'unt