In [5]:
#from unsloth import FastLanguageModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import re
from datasets import load_dataset, Dataset
import textstat
import numpy as np
import json
import os
from tqdm import tqdm
from glob import glob

In [2]:
params = {
    "load_in_4bit": True,
    "device_map": "auto",  # Automatically handle model placement
    "torch_dtype": torch.float16,  # Use half precision
    "use_cache": True,
}

In [None]:
pretrained_model = AutoModelForCausalLM.from_pretrained("unsloth/meta-llama-3.1-8b-instruct-unsloth-bnb-4bit", **params).to("cuda").eval()
pretrained_tokenizer = AutoTokenizer.from_pretrained("unsloth/meta-llama-3.1-8b-instruct-unsloth-bnb-4bit")

model_lora = AutoModelForCausalLM.from_pretrained("/home/ben/code/wandb/gsm8k/test_gsm8k/merged", **params).to("cuda").eval()
tokenizer_lora = AutoTokenizer.from_pretrained("/home/ben/code/wandb/gsm8k/test_gsm8k/merged")

In [None]:
# Load and prep dataset
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

# uncomment middle messages for 1-shot prompting
def get_gsm8k_questions(split = "train") -> Dataset:
    data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
    data = data.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer'])
    }) # type: ignore
    return data # type: ignore

dataset = get_gsm8k_questions()
len(dataset)

In [None]:
generation_config = {
    "max_new_tokens": 256,
    "do_sample": False,
}

BATCH_SIZE = 8
output_pretrained, output_lora = [], []
num_ex = 50
for i in tqdm(range(0, num_ex, BATCH_SIZE)):
    batch_end = min(i + BATCH_SIZE, num_ex)
    batch_prompts = []
    
    # Prepare batch of prompts
    for j in range(i, batch_end):
        prompt = dataset[j]['question']
        answer = dataset[j]['answer']
        batch_prompts.append([
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": prompt},
        ])
    text = [pretrained_tokenizer.apply_chat_template(prompt, tokenize = False, add_generation_prompt = True) for prompt in batch_prompts]

    inputs = pretrained_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to("cuda")
    input_length = inputs.input_ids.shape[-1]
    output = pretrained_model.generate(**inputs, **generation_config )
    output = output[:, input_length: ]
    batch_outputs = pretrained_tokenizer.batch_decode(output, skip_special_tokens=True)
    output_pretrained.extend(batch_outputs)

    output = model_lora.generate(**inputs, **generation_config)
    output = output[:, input_length: ]
    batch_outputs = tokenizer_lora.batch_decode(output, skip_special_tokens=True)
    output_lora.extend(batch_outputs)

In [59]:
def get_answer(text: str) -> str:
    return text.split("<answer>")[1].split("</answer>")[0].strip()

def get_answer_rate(responses):
    correct, wrong = 0, 0
    for i, o in enumerate(responses):
        try: 
            ans = int(get_answer(o))
            if ans == int(dataset[i]['answer']):
                correct += 1
            else:
                wrong += 1
        except:
            pass
    return correct, wrong

In [None]:
get_answer_rate(output_pretrained), get_answer_rate(output_lora)

In [None]:
def get_average_flesch_kincaid(responses) -> list[float]:
    scores = [textstat.flesch_kincaid_grade(r) for r in responses]
    return sum(scores) / len(scores)

get_average_flesch_kincaid(output_pretrained), get_average_flesch_kincaid(output_lora)

In [None]:
def get_average_length(responses) -> list[float]:
    return sum([len(r) for r in responses]) / len(responses)

get_average_length(output_pretrained), get_average_length(output_lora)


In [None]:
def soft_format_reward_func(responses) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    matches = [re.match(pattern, r, flags=re.DOTALL) for r in responses]
    return sum([1 if match else 0 for match in matches])

soft_format_reward_func(output_lora), soft_format_reward_func(output_pretrained)

In [None]:
def has_reasoning(responses) -> list[float]:
    return sum([1 if "<reasoning>" in r else 0 for r in responses])

has_reasoning(output_lora), has_reasoning(output_pretrained)

In [None]:
def has_answer(responses) -> list[float]:
    return sum([1 if "<answer>" in r else 0 for r in responses])

has_answer(output_lora), has_answer(output_pretrained)