In [None]:
import gc
import sys
from pathlib import Path

import pandas as pd
import torch
from datasets import load_dataset
from tqdm import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
)

sys.path.append(str(Path.cwd().resolve().parent))

In [None]:
from src.config import (
    COT_SYSTEM_PROMPT,
    COT_USER_PROMPT,
    GSM8K_PATH,
    LABEL_ONLY_SYSTEM_PROMPT,
    LABEL_ONLY_USER_PROMPT,
    MODELS_DIR,
)
from src.dataset_generator.helpers.answers import (
    ParsingError,
    parse_gold_answer_number,
    parse_teacher_final_answer,
)

In [3]:
print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.version.cuda)

True
1
12.9


In [4]:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

In [None]:
def cleanup_model(model, tokenizer):
    if model is not None:
        del model
    if tokenizer is not None:
        del tokenizer

    gc.collect()

    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

In [None]:
def build_prompt_cot(question: str) -> str:
    sys_txt = COT_SYSTEM_PROMPT.strip()
    usr_txt = COT_USER_PROMPT.strip().format(question=question.strip())
    return f"{sys_txt}\n\n{usr_txt}\n"


def build_prompt_label_only(question: str) -> str:
    sys_txt = LABEL_ONLY_SYSTEM_PROMPT.strip()
    usr_txt = LABEL_ONLY_USER_PROMPT.strip().format(question=question.strip())
    return f"{sys_txt}\n\n{usr_txt}\n"

In [None]:
def create_quantization_config(bf16: bool) -> BitsAndBytesConfig:
    return BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16 if bf16 else torch.float16,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
    )


def load_model(model_id: str, load_path: str | None, quant_config: BitsAndBytesConfig):
    model = AutoModelForCausalLM.from_pretrained(
        load_path if load_path else model_id,
        quantization_config=quant_config,
        device_map="auto",
    )
    model.config.use_cache = True

    return model


def setup_tokenizer(model_id: str, load_path: str | None):
    tokenizer = AutoTokenizer.from_pretrained(
        load_path if load_path else model_id,
        use_fast=True,
        padding_side="left",
        truncation_side="left",
    )
    return tokenizer


def load_model_and_tokenizer(model_id: str, bf16: bool, load_path: str | None = None):
    quant_config = create_quantization_config(bf16)
    model = load_model(model_id, load_path, quant_config)
    tokenizer = setup_tokenizer(model_id, load_path)
    model.generation_config.pad_token_id = tokenizer.pad_token_id
    return model, tokenizer

In [None]:
def prepare_batch_inputs(questions: list[str], mode: str, tokenizer):
    build_prompt = build_prompt_cot if mode == "cot" else build_prompt_label_only
    prompts = [build_prompt(q) for q in questions]

    return tokenizer(
        prompts,
        return_tensors="pt",
        padding=True,
        truncation=True,
    ), prompts


def extract_responses(generated_tokens, prompts, tokenizer):
    full_texts = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
    return [full[len(prompt) :].strip() for prompt, full in zip(prompts, full_texts)]


def process_batch(model, tokenizer, questions: list[str], mode: str) -> list[str]:
    encoded_inputs, prompts = prepare_batch_inputs(questions, mode, tokenizer)
    encoded_inputs = encoded_inputs.to(model.device)

    generated_tokens = model.generate(
        **encoded_inputs, do_sample=False, use_cache=True, max_new_tokens=1024
    )
    responses = extract_responses(generated_tokens, prompts, tokenizer)
    generated_tokens = generated_tokens.cpu()
    encoded_inputs = encoded_inputs.to("cpu")
    del encoded_inputs, generated_tokens
    return responses


def extract_questions_and_answers(dataset) -> tuple[list[str], list[float | None]]:
    questions = dataset["question"]
    gold_nums = [parse_gold_answer_number(answer) for answer in dataset["answer"]]
    return questions, gold_nums


def load_benchmark_dataset(
    path: str, split: str, limit: int | None = None
) -> tuple[list[str], list[float | None]]:
    dataset = load_dataset(path, name="main", split=split)
    if limit is not None:
        dataset = dataset.select(range(min(limit, len(dataset))))

    return extract_questions_and_answers(dataset)


def count_correct_predictions(
    predictions: list[str], gold_answers: list[float | None]
) -> int:
    correct_count = 0

    for pred_text, gold_num in zip(predictions, gold_answers):
        try:
            pred_num = parse_teacher_final_answer(pred_text)
        except ParsingError:
            pred_num = None

        if pred_num is not None and gold_num is not None and pred_num == gold_num:
            correct_count += 1

    return correct_count


def evaluate_predictions(
    predictions: list[str], gold_answers: list[float | None]
) -> dict:
    correct_count = count_correct_predictions(predictions, gold_answers)
    total_count = len(predictions)

    accuracy = correct_count / total_count if total_count > 0 else 0.0

    return {"accuracy": accuracy, "n": total_count}


def generate_predictions(
    model, tokenizer, questions: list[str], mode: str, batch_size: int
) -> list[str]:
    outputs = []
    model.eval()

    with torch.inference_mode():
        for i in tqdm(range(0, len(questions), batch_size), desc=f"Evaluating {mode}"):
            batch_questions = questions[i : i + batch_size]
            batch_outputs = process_batch(model, tokenizer, batch_questions, mode)
            outputs.extend(batch_outputs)

    return outputs


def benchmark(model, tokenizer, mode: str, batch_size: int, limit: int | None = None):
    questions, gold_answers = load_benchmark_dataset(GSM8K_PATH, "test", limit)
    predictions = generate_predictions(model, tokenizer, questions, mode, batch_size)
    df = pd.DataFrame(
        {"question": questions, "prediction": predictions, "gold_answer": gold_answers}
    )
    return evaluate_predictions(predictions, gold_answers), df

In [None]:
MODEL_ID = "Qwen/Qwen2.5-3B"
SCTOD_PATH = MODELS_DIR / "qwen2.5_3b_sctod_lora"
LABELONLY_PATH = MODELS_DIR / "qwen2.5_3b_labelonly_lora"

# RUNS = [
#     {"name": "student_sctod", "mode": "cot", "path": SCTOD_PATH},
#     {"name": "student_label_only", "mode": "label-only", "path": LABELONLY_PATH},
#     {"name": "base_cot_prompting", "mode": "cot", "path": None},
#     {"name": "base_label_only", "mode": "label-only", "path": None},
# ]

RUNS = [
    *[
        {
            "name": f"student_label_only_checkpoint_{i}",
            "mode": "label-only",
            "path": str(path),
        }
        for i, path in enumerate(
            sorted(
                LABELONLY_PATH.glob("checkpoint-*"),
                key=lambda path: int(str(path).split("checkpoint-")[-1]),
            ),
            start=1,
        )
    ],
    *[
        {"name": f"student_sctod_checkpoint_{i}", "mode": "cot", "path": str(path)}
        for i, path in enumerate(
            sorted(
                SCTOD_PATH.glob("checkpoint-*"),
                key=lambda path: int(str(path).split("checkpoint-")[-1]),
            ),
            start=1,
        )
    ],
    {"name": "base_cot_prompting", "mode": "cot", "path": None},
    {"name": "base_label_only", "mode": "label-only", "path": None},
]

In [10]:
RUNS

[{'name': 'student_label_only_checkpoint_1',
  'mode': 'label-only',
  'path': '/workspace/chain-of-thought-distillation/code/artifacts/models/qwen2.5_3b_labelonly_lora/checkpoint-50'},
 {'name': 'student_label_only_checkpoint_2',
  'mode': 'label-only',
  'path': '/workspace/chain-of-thought-distillation/code/artifacts/models/qwen2.5_3b_labelonly_lora/checkpoint-100'},
 {'name': 'student_label_only_checkpoint_3',
  'mode': 'label-only',
  'path': '/workspace/chain-of-thought-distillation/code/artifacts/models/qwen2.5_3b_labelonly_lora/checkpoint-150'},
 {'name': 'student_label_only_checkpoint_4',
  'mode': 'label-only',
  'path': '/workspace/chain-of-thought-distillation/code/artifacts/models/qwen2.5_3b_labelonly_lora/checkpoint-200'},
 {'name': 'student_label_only_checkpoint_5',
  'mode': 'label-only',
  'path': '/workspace/chain-of-thought-distillation/code/artifacts/models/qwen2.5_3b_labelonly_lora/checkpoint-250'},
 {'name': 'student_label_only_checkpoint_6',
  'mode': 'label-only

In [None]:
limit = None
results = []
model, tokenizer = None, None
all_predictions_dfs = []

for run in RUNS:
    name = run["name"]
    mode = run["mode"]
    path = run["path"]

    batch_size = 128 if mode == "label-only" else 16

    if model is not None or tokenizer is not None:
        cleanup_model(model, tokenizer)

    model, tokenizer = load_model_and_tokenizer(
        model_id=MODEL_ID,
        bf16=True,
        load_path=path,
    )

    metrics, df = benchmark(
        model,
        tokenizer,
        mode=mode,
        limit=limit,
        batch_size=batch_size,
    )
    print(f"{name} -> accuracy: {metrics['accuracy']:.4f}")
    results.append((name, metrics))

    df["model"] = name
    all_predictions_dfs.append(df)

cleanup_model(model, tokenizer)

predictions_df = pd.concat(all_predictions_dfs, ignore_index=True)
predictions_df = predictions_df.set_index("model")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Evaluating label-only: 100%|██████████| 11/11 [15:33<00:00, 84.86s/it] 


student_label_only_checkpoint_1 -> accuracy: 0.1069


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Evaluating label-only: 100%|██████████| 11/11 [03:50<00:00, 20.99s/it]


student_label_only_checkpoint_2 -> accuracy: 0.1350


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Evaluating label-only: 100%|██████████| 11/11 [06:25<00:00, 35.04s/it]


student_label_only_checkpoint_3 -> accuracy: 0.1403


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Evaluating label-only: 100%|██████████| 11/11 [11:03<00:00, 60.28s/it]


student_label_only_checkpoint_4 -> accuracy: 0.1501


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Evaluating label-only: 100%|██████████| 11/11 [07:19<00:00, 39.99s/it]


student_label_only_checkpoint_5 -> accuracy: 0.1456


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Evaluating label-only: 100%|██████████| 11/11 [06:17<00:00, 34.36s/it]


student_label_only_checkpoint_6 -> accuracy: 0.1501


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Evaluating label-only: 100%|██████████| 11/11 [08:30<00:00, 46.40s/it]


student_label_only_checkpoint_7 -> accuracy: 0.1448


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Evaluating label-only: 100%|██████████| 11/11 [06:20<00:00, 34.62s/it]


student_label_only_checkpoint_8 -> accuracy: 0.1501


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Evaluating cot: 100%|██████████| 83/83 [12:26<00:00,  8.99s/it]


student_sctod_checkpoint_1 -> accuracy: 0.7263


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Evaluating cot: 100%|██████████| 83/83 [12:21<00:00,  8.94s/it]


student_sctod_checkpoint_2 -> accuracy: 0.7392


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Evaluating cot: 100%|██████████| 83/83 [13:15<00:00,  9.59s/it]


student_sctod_checkpoint_3 -> accuracy: 0.7233


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Evaluating cot: 100%|██████████| 83/83 [15:17<00:00, 11.05s/it]


student_sctod_checkpoint_4 -> accuracy: 0.7604


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Evaluating cot: 100%|██████████| 83/83 [14:59<00:00, 10.84s/it]


student_sctod_checkpoint_5 -> accuracy: 0.7551


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Evaluating cot: 100%|██████████| 83/83 [15:49<00:00, 11.44s/it]


student_sctod_checkpoint_6 -> accuracy: 0.7642


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Evaluating cot: 100%|██████████| 83/83 [15:22<00:00, 11.12s/it]


student_sctod_checkpoint_7 -> accuracy: 0.7559


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Evaluating cot: 100%|██████████| 83/83 [15:02<00:00, 10.88s/it]


student_sctod_checkpoint_8 -> accuracy: 0.7854


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Evaluating cot: 100%|██████████| 83/83 [14:46<00:00, 10.68s/it]


student_sctod_checkpoint_9 -> accuracy: 0.7582


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Evaluating cot: 100%|██████████| 83/83 [14:55<00:00, 10.79s/it]


student_sctod_checkpoint_10 -> accuracy: 0.7657


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Evaluating cot: 100%|██████████| 83/83 [16:31<00:00, 11.95s/it]


student_sctod_checkpoint_11 -> accuracy: 0.7627


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Evaluating cot: 100%|██████████| 83/83 [15:37<00:00, 11.30s/it]


student_sctod_checkpoint_12 -> accuracy: 0.7627


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Evaluating cot: 100%|██████████| 83/83 [15:15<00:00, 11.03s/it]


student_sctod_checkpoint_13 -> accuracy: 0.7582


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Evaluating cot: 100%|██████████| 83/83 [14:36<00:00, 10.56s/it]


student_sctod_checkpoint_14 -> accuracy: 0.7612


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Evaluating cot: 100%|██████████| 83/83 [18:40<00:00, 13.50s/it]


base_cot_prompting -> accuracy: 0.7051


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Evaluating label-only: 100%|██████████| 11/11 [14:26<00:00, 78.77s/it]


base_label_only -> accuracy: 0.1099


In [None]:
print("\n----------------- GSM8K Test Set Accuracy -----------------")
for name, m in results:
    print(f"{name:>30}: {m['accuracy'] * 100:.2f}%")


----------------- GSM8K Test Set Accuracy -----------------
student_label_only_checkpoint_1: 10.69%
student_label_only_checkpoint_2: 13.50%
student_label_only_checkpoint_3: 14.03%
student_label_only_checkpoint_4: 15.01%
student_label_only_checkpoint_5: 14.56%
student_label_only_checkpoint_6: 15.01%
student_label_only_checkpoint_7: 14.48%
student_label_only_checkpoint_8: 15.01%
    student_sctod_checkpoint_1: 72.63%
    student_sctod_checkpoint_2: 73.92%
    student_sctod_checkpoint_3: 72.33%
    student_sctod_checkpoint_4: 76.04%
    student_sctod_checkpoint_5: 75.51%
    student_sctod_checkpoint_6: 76.42%
    student_sctod_checkpoint_7: 75.59%
    student_sctod_checkpoint_8: 78.54%
    student_sctod_checkpoint_9: 75.82%
   student_sctod_checkpoint_10: 76.57%
   student_sctod_checkpoint_11: 76.27%
   student_sctod_checkpoint_12: 76.27%
   student_sctod_checkpoint_13: 75.82%
   student_sctod_checkpoint_14: 76.12%
            base_cot_prompting: 70.51%
               base_label_only: 10

In [None]:
predictions_df = predictions_df.reset_index().set_index(["model", "question"])

In [14]:
predictions_df.reset_index().to_csv("predictions.csv")