In [12]:
#from unsloth import FastLanguageModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import re
from datasets import load_dataset, Dataset
import textstat
import numpy as np
import json
import os
from tqdm import tqdm
from glob import glob
import pandas as pd

In [3]:
params = {
    "load_in_4bit": True,
    "device_map": "auto",  # Automatically handle model placement
    "torch_dtype": torch.float16,  # Use half precision
    "use_cache": True,
}

In [8]:
model_paths = [
    'unsloth/meta-llama-3.1-8b-instruct-unsloth-bnb-4bit',
    "/home/ben/code/wandb/gsm8k/test_gsm8k/merged",
    "/home/ben/code/wandb/gsm8k/gsm8k_flesch/merged",
    '/home/ben/code/wandb/gsm8k/outputs/checkpoint-1500',
    '/home/ben/code/wandb/gsm8k/outputs/checkpoint-2500',
]


In [None]:
# Load and prep dataset
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

# uncomment middle messages for 1-shot prompting
def get_gsm8k_questions(split = "train") -> Dataset:
    data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
    data = data.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer'])
    }) # type: ignore
    return data # type: ignore

dataset = get_gsm8k_questions(split = "test")
len(dataset)

In [1]:
def load_model(model_path: str):
    model = AutoModelForCausalLM.from_pretrained(model_path, **params).to("cuda").eval()
    tokenizer = AutoTokenizer.from_pretrained(model_path, **params)
    return model, tokenizer

generation_config = {
    "max_new_tokens": 256,
    "do_sample": False,
}
def run_generation(model_, tokenizer_, batch_size=8, num_ex=50):
    outputs = []
    for i in tqdm(range(0, num_ex, batch_size)):
        batch_end = min(i + batch_size, num_ex)
        batch_prompts = []
        
        # Prepare batch of prompts
        for j in range(i, batch_end):
            prompt = dataset[j]['question']
            answer = dataset[j]['answer']
            batch_prompts.append([
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": prompt},
            ])
        text = [tokenizer_.apply_chat_template(prompt, tokenize = False, add_generation_prompt = True) for prompt in batch_prompts]

        inputs = tokenizer_(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to("cuda")
        input_length = inputs.input_ids.shape[-1]
        output = model_.generate(**inputs, **generation_config )
        output = output[:, input_length: ]
        batch_outputs = tokenizer_.batch_decode(output, skip_special_tokens=True)
        outputs.extend(batch_outputs)
    return outputs

def run(model_path: str):
    model, tokenizer = load_model(model_path)
    outputs = run_generation(model, tokenizer)
    model.to("cpu")
    del model
    del tokenizer
    torch.cuda.empty_cache()
    return outputs

In [None]:
all_outputs = []
for model_path in model_paths:
    outputs = run(model_path)
    all_outputs.append(outputs)

In [10]:
def get_answer(text: str) -> str:
    return text.split("<answer>")[1].split("</answer>")[0].strip()

def get_answer_rate(responses):
    correct, wrong = 0, 0
    for i, o in enumerate(responses):
        try: 
            ans = int(get_answer(o))
            if ans == int(dataset[i]['answer']):
                correct += 1
            else:
                wrong += 1
        except:
            pass
    return correct, wrong

def get_average_flesch_kincaid(responses) -> list[float]:
    scores = [textstat.flesch_kincaid_grade(r) for r in responses]
    return sum(scores) / len(scores)

def get_average_length(responses) -> list[float]:
    return sum([len(r) for r in responses]) / len(responses)

def soft_format_reward_func(responses) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    matches = [re.match(pattern, r, flags=re.DOTALL) for r in responses]
    return sum([1 if match else 0 for match in matches])

def has_reasoning(responses) -> list[float]:
    return sum([1 if "<reasoning>" in r else 0 for r in responses])

def has_answer(responses) -> list[float]:
    return sum([1 if "<answer>" in r else 0 for r in responses])


In [None]:
metrics = []
for model_path, outputs in zip(model_paths, all_outputs):
    metrics.append({
        'model_path': model_path,
        "answer_rate": get_answer_rate(outputs),
        "average_flesch_kincaid": get_average_flesch_kincaid(outputs),
        "average_length": get_average_length(outputs),
        "soft_format_reward": soft_format_reward_func(outputs),
        "has_reasoning": has_reasoning(outputs),
        "has_answer": has_answer(outputs),
    })
pd.DataFrame(metrics)
