In [1]:
import sys
from pathlib import Path
import gc

import torch
import pandas as pd
from datasets import load_dataset
from tqdm import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
)

sys.path.append(str(Path.cwd().resolve().parent))

In [2]:
from src.config import (
    MODELS_DIR,
    GSM8K_PATH,
    TEACHER_SYSTEM_PROMPT,
    TEACHER_USER_PROMPT,
)
from src.dataset_generator.helpers.answers import (
    ParsingError,
    parse_gold_answer_number,
    parse_teacher_final_answer,
)

In [3]:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

In [None]:
def cleanup_model(model, tokenizer):
    if model is not None:
        del model
    if tokenizer is not None:
        del tokenizer
    
    gc.collect()
    
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

In [5]:
def build_prompt_cot(question: str) -> str:
    sys_txt = TEACHER_SYSTEM_PROMPT.strip()
    usr_txt = TEACHER_USER_PROMPT.strip().format(question=question.strip())
    return f"{sys_txt}\n\n{usr_txt}\n"


def build_prompt_label_only(question: str) -> str:
    prompt = f"""You are a rigorous and precise math solver. Solve the question and output the result only as:
Final Answer: <number>

Example 1:
Question: A farm has 3 barns with 12 cows each. It sells 7 cows and buys 5 more. How many cows now?
Final Answer: 34

Example 2:
Question: Pens cost $2 and notebooks $5. Alex buys 3 pens and 2 notebooks and pays with $20. How much change?
Final Answer: 4

Example 3:
Question: A tank holds 250 liters. 35% is drained, then 40 liters are added. How many liters now?
Final Answer: 202.5

Question: {question.strip()}"""
        
    return prompt

In [6]:
def create_quantization_config(bf16: bool) -> BitsAndBytesConfig:
    return BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16 if bf16 else torch.float16,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
    )


def load_model(model_id: str, load_path: str | None, quant_config: BitsAndBytesConfig):
    model = AutoModelForCausalLM.from_pretrained(
        load_path if load_path else model_id,
        quantization_config=quant_config,
        device_map="auto",
    )
    model.config.use_cache = True
    
    return model


def setup_tokenizer(model_id: str, load_path: str | None):
    tokenizer = AutoTokenizer.from_pretrained(
        load_path if load_path else model_id, 
        use_fast=True, 
        padding_side="left", 
        truncation_size="left"
    )
    return tokenizer


def load_model_and_tokenizer(model_id: str, bf16: bool, load_path: str | None = None):
    quant_config = create_quantization_config(bf16)
    model = load_model(model_id, load_path, quant_config)
    tokenizer = setup_tokenizer(model_id, load_path)
    model.generation_config.pad_token_id = tokenizer.pad_token_id
    return model, tokenizer

In [7]:
def prepare_batch_inputs(questions: list[str], mode: str, tokenizer):
    build_prompt = build_prompt_cot if mode == "cot" else build_prompt_label_only
    prompts = [build_prompt(q) for q in questions]
    
    return tokenizer(
        prompts,
        return_tensors="pt",
        padding=True,
        truncation=True,
    ), prompts

def extract_responses(generated_tokens, prompts, tokenizer):
    full_texts = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
    return [full[len(prompt):].strip() for prompt, full in zip(prompts, full_texts)]


def process_batch(model, tokenizer, questions: list[str], mode: str) -> list[str]:
    encoded_inputs, prompts = prepare_batch_inputs(questions, mode, tokenizer)
    encoded_inputs = encoded_inputs.to(model.device)
    
    generated_tokens = model.generate(
        **encoded_inputs,
        do_sample=False,
        use_cache=True,
    )
    return extract_responses(generated_tokens, prompts, tokenizer)


def extract_questions_and_answers(dataset) -> tuple[list[str], list[float | None]]:
    questions = dataset["question"]
    gold_nums = [parse_gold_answer_number(answer) for answer in dataset["answer"]]
    return questions, gold_nums


def load_benchmark_dataset(
    path: str, split: str, limit: int | None = None
) -> tuple[list[str], list[float | None]]:
    dataset = load_dataset(path, name="main", split=split)
    if limit is not None:
        dataset = dataset.select(range(min(limit, len(dataset))))
        
    return extract_questions_and_answers(dataset)


def count_correct_predictions(predictions: list[str], gold_answers: list[float | None]) -> int:
    correct_count = 0
    
    for pred_text, gold_num in zip(predictions, gold_answers):
        try:
            pred_num = parse_teacher_final_answer(pred_text)
        except ParsingError:
            pred_num = None
        
        if (pred_num is not None and 
            gold_num is not None and 
            pred_num == gold_num):
            correct_count += 1
            
    return correct_count

def evaluate_predictions(predictions: list[str], gold_answers: list[float | None]) -> dict:
    correct_count = count_correct_predictions(predictions, gold_answers)
    total_count = len(predictions)
    
    accuracy = correct_count / total_count if total_count > 0 else 0.0
    
    return {"accuracy": accuracy, "n": total_count}


def generate_predictions(model, tokenizer, questions: list[str], mode: str, batch_size: int) -> list[str]:
    outputs = []
    model.eval()
    
    with torch.inference_mode():
        for i in tqdm(range(0, len(questions), batch_size), desc=f"Evaluating {mode}"):
            batch_questions = questions[i:i + batch_size]
            batch_outputs = process_batch(model, tokenizer, batch_questions, mode)
            outputs.extend(batch_outputs)

    return outputs


def benchmark(
    model, tokenizer, mode: str, batch_size: int, limit: int | None = None
):
    questions, gold_answers = load_benchmark_dataset(GSM8K_PATH, "test", limit)
    predictions = generate_predictions(model, tokenizer, questions, mode, batch_size)
    df = pd.DataFrame({
        "question": questions,
        "prediction": predictions,
        "gold_answer": gold_answers
    })
    return evaluate_predictions(predictions, gold_answers), df

In [8]:
MODEL_ID = "Qwen/Qwen2.5-3B"
SCTOD_PATH = MODELS_DIR / "qwen2.5_3b_sctod_lora"
LABELONLY_PATH = MODELS_DIR / "qwen2.5_3b_labelonly_lora"

# RUNS = [
#     {"name": "student_sctod", "mode": "cot", "path": SCTOD_PATH},
#     {"name": "student_label_only", "mode": "label-only", "path": LABELONLY_PATH},
#     {"name": "base_cot_prompting", "mode": "cot", "path": None},
#     {"name": "base_label_only", "mode": "label-only", "path": None},
# ]

RUNS = [
    {"name": f"student_label_only_checkpoint_{i}", "mode": "label-only", "path": str(path)}
    for i, path in enumerate(
        sorted(
            SCTOD_PATH.glob("checkpoint-*"),
            key=lambda path: int(str(path).split("checkpoint-")[-1]),
        ),
        start=1,
    )
]

In [None]:
limit = 200
batch_size = 16

results = []
model, tokenizer = None, None
all_predictions_dfs = []

for run in RUNS[:3]:
    name = run["name"]
    mode = run["mode"]
    path = run["path"]
    
    if model is not None or tokenizer is not None:
        cleanup_model(model, tokenizer)
        
    model, tokenizer = load_model_and_tokenizer(
        model_id=MODEL_ID,
        bf16=True,
        load_path=path,
    )

    metrics, df = benchmark(
        model,
        tokenizer,
        mode=mode,
        limit=limit,
        batch_size=batch_size,
    )
    print(f"{name} -> accuracy: {metrics['accuracy']:.4f}")
    results.append((name, metrics))
    
    df['model'] = name
    all_predictions_dfs.append(df)

cleanup_model(model, tokenizer)

predictions_df = pd.concat(all_predictions_dfs, ignore_index=True)
predictions_df = predictions_df.set_index('model')

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

GPU memory after loading: 2.18 GB


Evaluating label-only: 100%|██████████| 1/1 [00:43<00:00, 43.80s/it]



student_label_only_checkpoint_1 -> accuracy: 0.0000
Cleaning up previous model...
GPU memory after cleanup: 2.18 GB


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

GPU memory after loading: 3.62 GB


Evaluating label-only: 100%|██████████| 1/1 [00:39<00:00, 39.68s/it]



student_label_only_checkpoint_2 -> accuracy: 0.5000
Cleaning up previous model...
GPU memory after cleanup: 2.19 GB


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

GPU memory after loading: 3.63 GB


Evaluating label-only: 100%|██████████| 1/1 [00:33<00:00, 33.30s/it]



student_label_only_checkpoint_3 -> accuracy: 0.5000
GPU memory after cleanup: 2.19 GB

------- Summary -------
student_label_only_checkpoint_1: 0.0000
student_label_only_checkpoint_2: 0.5000
student_label_only_checkpoint_3: 0.5000

Combined predictions dataframe shape: (6, 3)
Models in dataframe: ['student_label_only_checkpoint_1', 'student_label_only_checkpoint_2', 'student_label_only_checkpoint_3']


In [None]:
print("\n----------------- Summary -----------------")
for name, m in results:
    print(f"{name:>24}: {m['accuracy']:.4f}")