In [None]:
from datasets import load_dataset
from huggingface_hub import login
login(token=hf_token)
dataset = load_dataset("lighteval/MATH", split="test", trust_remote_code = True)

In [None]:
def formatting_prompts_func(example):
    messages = [
    {"role": "user", "content": f"Problem: {example['problem']}\n\nSolve this math problem step by step."},
    {"role": "assistant", "content": example['solution']},
    ]
    texts = tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = False)
    return {"task": texts}

dataset = dataset.map(formatting_prompts_func, batched=False)

In [None]:
def tokenize_function(example):
    outputs = tokenizer(example["task"], truncation=True, padding="max_length", max_length=5096)
    return outputs
    
tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=["problem", "solution", "level", "type"], load_from_cache_file=False)
tokenized_datasets.set_format("torch")

In [None]:
from unsloth.chat_templates import get_chat_template
from unsloth import FastLanguageModel

tokenizer = get_chat_template(
    tokenizer,
    chat_template = "llama-3.1",
)

output_file = "llama1b_r32_gsm8k_responses.jsonl"
#output_file = "llama1b_r64_gsm8k_responses.jsonl"
#output_file = "llama1b_r128_gsm8k_responses.jsonl"

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = 'bharati2324/Llama-1B-Math-LoRA-r32-merged',
    #model_name = 'bharati2324/Llama-1B-Math-LoRA-r64-merged',
    #model_name = 'bharati2324/Llama-1B-Math-LoRA-r128-merged',
    max_seq_length = 2048,
    dtype = torch.float16,
    load_in_4bit = False,
)
FastLanguageModel.for_inference(model)

In [None]:
def generate_responses(problem_text, num_responses=5):
    responses = []
    input_prompt = [
        {"role": "user", "content": f"Q: {problem_text}\n\nSolve this math problem step by step.\nA:"
     ]
    input_ids = tokenizer.apply_chat_template(input_prompt, return_tensors="pt").to(model.device)

    for _ in range(num_responses):
        output = model.generate(
            input_ids=input_ids,
            max_length=256,
            num_return_sequences=1,
            temperature=0.7,
            top_p=0.9,
            do_sample=True
        )
        response = tokenizer.decode(output[0], skip_special_tokens=True).split("A:")[-1].strip()
        responses.append(response)
    return responses

In [None]:
with open(output_file, "w") as f:
    for idx, example in enumerate(dataset):
        problem = example["question"]
        responses = generate_responses(problem, num_responses)
        
        output_entry = {
            "problem_id": idx,
            "problem": problem,
            "responses": responses
        }
        f.write(json.dumps(output_entry) + "\n")
        print(f"Processed problem {idx+1}/{len(dataset)}")

print(f"Responses saved to {output_file}")