In [1]:
import torch
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, AutoConfig
import os
from datasets import load_from_disk
from vllm.inputs import TokensPrompt
from vllm import LLM, SamplingParams
from tqdm import tqdm
from small_llm_reasoning.evaluation.gsm8k import get_score, eight_shot_messages
# from small_llm_reasoning.generation.vllm_generation import llama_forward


cache_dir = '/scratch3/workspace/wenlongzhao_umass_edu-reason/dev_kedar/transformers_cache/'
os.environ['TRANSFORMERS_CACHE'] = cache_dir


  from .autonotebook import tqdm as notebook_tqdm
2025-04-06 18:07:44,956	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [2]:
# Loading data
data_path= "../datasets/gsm8k"
split='val'
data = load_from_disk(f'{data_path}/{split}/')
data

Dataset({
    features: ['question', 'answer'],
    num_rows: 1000
})

In [3]:
# Loading model
hf_token = os.getenv("hf_token")

# model_name= "meta-llama/Llama-3.2-1B-Instruct"
model_name= "meta-llama/Llama-3.2-3B-Instruct"
# model_name= "meta-llama/Llama-3.1-8B-Instruct"
# model_name= "meta-llama/Llama-3.2-3B"
# model_name= "meta-llama/Llama-3.3-70B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token, cache_dir=cache_dir)
tokenizer.pad_token_id = tokenizer.eos_token_id



In [4]:
def get_prompt(ex, few_shot):
    
    prompt = [
        {
            'role': 'user',
            'content': f'Given the following problem, reason and give a final answer to the problem.\nProblem: {ex}\nYour response should end with "The final answer is [answer]" where [answer] is the response to the problem.\n'
        }
    ]
    if few_shot:
        prompt = eight_shot_messages + prompt
    return prompt

def tokenize_function(example,few_shot):
    prompt= get_prompt(example['question'], few_shot)
    prompt= tokenizer.apply_chat_template(prompt,  tokenize= False, add_generation_prompt=True)
    return {'input_ids': {'prompt_token_ids':tokenizer(prompt, add_special_tokens=False)['input_ids']}}


In [5]:
few_shot=True
prompt=get_prompt(data['question'][0],few_shot)
print(tokenizer.apply_chat_template(prompt,  tokenize= False, add_generation_prompt=True))
# print(prompt)

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 06 Apr 2025

<|eot_id|><|start_header_id|>user<|end_header_id|>

Given the following problem, reason and give a final answer to the problem.
Problem: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?
Your response should end with "The final answer is [answer]" where [answer] is the response to the problem.<|eot_id|><|start_header_id|>assistant<|end_header_id|>

There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. The final answer is 6<|eot_id|><|start_header_id|>user<|end_header_id|>

Given the following problem, reason and give a final answer to the problem.
Problem: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?
Your response 

In [None]:
# tokenized_dataset = data.map(tokenize_function, batched=False)
tokenized_dataset = data.map(lambda x: tokenize_function(x,few_shot=few_shot), batched=False)


In [None]:
len(tokenized_dataset['input_ids'])

In [15]:
tokenized_dataset.save_to_disk(f"{data_path}/tokenized/LLaMA3B-Instruct/{split}/eight-shot/")

Saving the dataset (1/1 shards): 100%|██████████| 1000/1000 [00:00<00:00, 54998.61 examples/s]


{'prompt_token_ids': [128000,
  128006,
  9125,
  128007,
  271,
  38766,
  1303,
  33025,
  2696,
  25,
  6790,
  220,
  2366,
  18,
  198,
  15724,
  2696,
  25,
  220,
  2705,
  5186,
  220,
  2366,
  20,
  271,
  128009,
  128006,
  882,
  128007,
  271,
  22818,
  279,
  2768,
  3575,
  11,
  2944,
  323,
  3041,
  264,
  1620,
  4320,
  311,
  279,
  3575,
  627,
  32298,
  25,
  51521,
  37548,
  706,
  264,
  400,
  1041,
  15,
  4121,
  430,
  568,
  6944,
  311,
  2349,
  1139,
  9333,
  19123,
  779,
  568,
  1436,
  3041,
  1124,
  311,
  813,
  11568,
  1634,
  323,
  44964,
  28844,
  13,
  1283,
  6944,
  311,
  2349,
  220,
  18,
  14,
  605,
  315,
  279,
  3300,
  1139,
  400,
  1135,
  19123,
  1418,
  279,
  2800,
  1139,
  400,
  1041,
  19123,
  13,
  2650,
  1690,
  9863,
  315,
  19123,
  690,
  51521,
  37548,
  617,
  304,
  682,
  5380,
  7927,
  2077,
  1288,
  842,
  449,
  330,
  791,
  1620,
  4320,
  374,
  510,
  9399,
  19727,
  1405,
  510,
  9399,
  