In [None]:
# !pip install -U -q transformers==4.51.3 accelerate==1.6.0 datasets==3.5.0 bitsandbytes==0.45.5 triton==3.2.0 unsloth==2025.3.19 torch==2.6.0 peft==0.15.2 trl==0.15.2 wandb==0.19.10

In [None]:
import os
os.environ["WANDB_API_KEY"] =

In [None]:
reasoning_start = "<reasoning>"
reasoning_end   = "</reasoning>"
solution_start  = "<solution>"
solution_end    = "</solution>"

SYSTEM_GRPO_PROMPT = \
f"""You are given a problem.
Think about the problem and provide your working out.
Place it between {reasoning_start} and {reasoning_end}.
Then, provide your solution between {solution_start} and {solution_end}"""
SYSTEM_PROMPT = "You are a helpful AI assistant"

In [None]:
from unsloth import FastLanguageModel
import os
import re
import torch
from tqdm import tqdm
from datasets import load_dataset
from peft import PeftModel
import wandb
from transformers import AutoTokenizer
import numpy as np
from transformers import BitsAndBytesConfig
from collections import defaultdict

USE_LLM_JUDGE = True
JUDGE_MODEL_ID = "Qwen/Qwen2.5-7B-instruct" 

WANDB_PROJECT = "gsm8k-evaluation"
WANDB_ENTITY = "animavestra888-independent"
WANDB_NAME = "gsm8k-grpo-model-llm-judge"
USE_WANDB = True

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

MODELS_TO_TEST = [
    # {
    #     "name": "Base Model",
    #     "type": "base",
    #     "model_id": "Qwen/Qwen2.5-0.5B-Instruct"
    # },
    # {
    #     "name": "BNF Model",
    #     "type": "peft",
    #     "model_id": "Qwen/Qwen2.5-0.5B-Instruct",
    #     "peft_source": "wandb",
    #     "wandb_artifact": "animavestra888-independent/Coursework/model-crmnp6jy:v24"
    # },
    # {
    #     "name": "DPO Model",
    #     "type": "peft",
    #     "model_id": "Qwen/Qwen2.5-0.5B-Instruct",
    #     "peft_source": "wandb",
    #     "wandb_artifact": "animavestra888-independent/Coursework/model-izl7baxm:v14"
    # },
    {
        "name": "GRPO Model",
        "type": "peft",
        "model_id": "Qwen/Qwen2.5-0.5B-Instruct",
        "peft_source": "wandb",
        "wandb_artifact": "animavestra888-independent/Coursework/model-6enu8o1w:v7"
    },
]

DATASET_NAME = "gsm8k"
DATASET_SPLIT = "main"
DATASET_SUBSET = "test"

MAX_NEW_TOKENS = 1024
NUM_BEAMS = 1
DO_SAMPLE = False
TEMPERATURE = 0.0
LOG_INTERVAL = 50

if USE_WANDB:
    wandb.init(project=WANDB_PROJECT, entity=WANDB_ENTITY, name=WANDB_NAME)
    wandb.define_metric("step")
    for model_config in MODELS_TO_TEST:
        model_name = model_config["name"]
        wandb.define_metric(f"{model_name}/accuracy", step_metric="step")
        wandb.define_metric(f"{model_name}/correct", step_metric="step")
    wandb_table = wandb.Table(columns=["Model", "Question", "Generated", "Extracted", "Judge Verdict", "Ground Truth", "Correct"])
    wandb_intermediate_table = wandb.Table(columns=["Step", "Model", "Accuracy", "Examples Processed"])
    wandb.config.update({
        "dataset": DATASET_NAME,
        "subset": DATASET_SUBSET,
        "max_new_tokens": MAX_NEW_TOKENS,
        "log_interval": LOG_INTERVAL,
        "llm_judge": USE_LLM_JUDGE,
        "judge_model": JUDGE_MODEL_ID
    })


dataset = load_dataset(DATASET_NAME, DATASET_SPLIT, split=DATASET_SUBSET)
print(f"Загружено примеров: {len(dataset)}")

def extract_answer(completion):
    boxed_match = re.search(r"####\s*([-]?[\d,]+(?:\.\d+)?)", completion)
    if boxed_match:
        return boxed_match.group(1).replace(',', '')
    
    numbers = re.findall(r"[-]?(?:\d{1,3}(?:,\d{3})*|\d+)(?:\.\d+)?", completion)
    if numbers:
        return numbers[-1].replace(',', '')
    
    return None

def extract_ground_truth(answer_str):
    match = re.search(r"####\s*([-]?[\d,]+(?:\.\d+)?)", answer_str)
    if match:
        return match.group(1).replace(',', '')
    return None

models = {}
tokenizers = {}
judge_model = None
judge_tokenizer = None

for config in MODELS_TO_TEST:
    model_name = config["name"]
    
    try:
        if config["type"] == "base":
            model, tokenizer = FastLanguageModel.from_pretrained(
                config["model_id"],
                attn_implementation="flash_attention_2",
                quantization_config=bnb_config,
            )
        elif config["type"] == "peft":
            base_model, tokenizer = FastLanguageModel.from_pretrained(
                config["model_id"],
                attn_implementation="flash_attention_2",
                quantization_config=bnb_config,
                load_in_4bit=True
            )
            
            if config["peft_source"] == "wandb":
                artifact = wandb.use_artifact(config["wandb_artifact"], type='model')
                peft_path = artifact.download()
                model = PeftModel.from_pretrained(base_model, peft_path)
        
        model.eval()
        models[model_name] = model
        tokenizers[model_name] = tokenizer
        
        if tokenizer.pad_token_id is None:
            tokenizer.pad_token_id = tokenizer.eos_token_id
            
        print(f"    Модель '{model_name}' успешно загружена\n")
        
    except Exception as e:
        print(f"    Ошибка загрузки модели '{model_name}': {str(e)}\n")

if USE_LLM_JUDGE:
    print("Загрузка модели-судьи...")
    try:
        judge_model, judge_tokenizer = FastLanguageModel.from_pretrained(
            JUDGE_MODEL_ID,
            attn_implementation="flash_attention_2",
            quantization_config=bnb_config,
            load_in_4bit=True
        )
        judge_model.eval()
        if judge_tokenizer.pad_token_id is None:
            judge_tokenizer.pad_token_id = judge_tokenizer.eos_token_id
        print(f"    Модель-судья '{JUDGE_MODEL_ID}' успешно загружена\n")
    except Exception as e:
        print(f"    Ошибка загрузки модели-судии: {str(e)}")
        USE_LLM_JUDGE = False

def verify_with_judge(question, generated_response, ground_truth):
    messages = [
        {"role": "system", "content": "You are an expert math evaluator. Determine if the candidate's answer matches the ground truth. Consider only the final numerical answer."},
        {"role": "user", "content": f"Problem: {question}\nCandidate Answer: {generated_response}\nGround Truth: {ground_truth}\n\nIs the candidate's answer correct? Respond ONLY with 'yes' or 'no'."}
    ]
    
    prompt = judge_tokenizer.apply_chat_template(
        messages, 
        tokenize=False, 
        add_generation_prompt=True
    )
    
    inputs = judge_tokenizer(prompt, return_tensors="pt").to(judge_model.device)
    
    with torch.no_grad():
        outputs = judge_model.generate(
            **inputs,
            max_new_tokens=10,
            temperature=0.0,
            do_sample=False,
            pad_token_id=judge_tokenizer.pad_token_id,
            eos_token_id=judge_tokenizer.eos_token_id,
        )
    
    response = judge_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
    return "yes" in response.lower()

def format_prompt(question, tokenizer):
    messages = [
        {"role": "system", "content": SYSTEM_GRPO_PROMPT},
        {"role": "user", "content": question}
    ]
    return tokenizer.apply_chat_template(
        messages, 
        tokenize=False, 
        add_generation_prompt=True
    )

results = {model_name: {"correct": 0, "total": 0, "predictions": [], "judge_used": 0, "judge_correct": 0} for model_name in models}

for i, example in enumerate(tqdm(dataset, desc="Обработка примеров")):
    question = example['question']
    ground_truth = extract_ground_truth(example['answer'])
    
    if not ground_truth:
        continue
    
    try:
        gt_number = float(ground_truth)
    except ValueError:
        continue

    for model_name, model in models.items():
        tokenizer = tokenizers[model_name]
        prompt = format_prompt(question, tokenizer)
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        
        results[model_name]["total"] += 1
        
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=MAX_NEW_TOKENS,
                num_beams=NUM_BEAMS,
                do_sample=DO_SAMPLE,
                temperature=TEMPERATURE,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
            )
        
        completion = tokenizer.decode(
            outputs[0][inputs.input_ids.shape[1]:], 
            skip_special_tokens=True
        )
        
        extracted_ans = extract_answer(completion)
        is_correct = False
        judge_verdict = "N/A"
        used_judge = False
        
        if extracted_ans:
            try:
                ans_number = float(extracted_ans)
                if ans_number == gt_number:
                    is_correct = True
            except ValueError:
                pass
        
        if not is_correct and USE_LLM_JUDGE and judge_model:
            used_judge = True
            results[model_name]["judge_used"] += 1
            if verify_with_judge(question, completion, ground_truth):
                is_correct = True
                results[model_name]["judge_correct"] += 1
                judge_verdict = "Correct"
            else:
                judge_verdict = "Incorrect"
        
        if is_correct:
            results[model_name]["correct"] += 1

        results[model_name]["predictions"].append({
            "question": question,
            "generated": completion,
            "extracted": extracted_ans,
            "judge_verdict": judge_verdict,
            "ground_truth": ground_truth,
            "correct": is_correct,
            "used_judge": used_judge
        })
        
        if USE_WANDB:
            wandb_table.add_data(
                model_name,
                question,
                completion,
                extracted_ans,
                judge_verdict,
                ground_truth,
                "correct" if is_correct else "incorrect"
            )

    current_step = i + 1
    
    if (i+1) % LOG_INTERVAL == 0 or (i+1) == len(dataset):
        print(f"\n--- Отчет после {i+1} примеров ---")
        
        for model_name in models:
            total = results[model_name]["total"]
            correct = results[model_name]["correct"]
            accuracy = correct / total * 100 if total > 0 else 0
            
            print(f"{model_name}:")
            print(f"  Правильных: {correct}/{total}")
            print(f"  Точность: {accuracy:.2f}%")
            
            if USE_LLM_JUDGE:
                judge_used = results[model_name]["judge_used"]
                judge_correct = results[model_name]["judge_correct"]
                print(f"  Проверок судьей: {judge_used} ({judge_correct} правильных)")
            
            if USE_WANDB:
                wandb.log({
                    "step": current_step,
                    f"{model_name}/accuracy": accuracy,
                    f"{model_name}/correct": correct,
                    f"{model_name}/total": total
                })

if USE_WANDB:
    final_accuracy_data = [[model_name, results[model_name]["correct"] / results[model_name]["total"] * 100] 
                           for model_name in models]
    bar_chart = wandb.plot.bar(
        wandb.Table(
            data=final_accuracy_data,
            columns=["Model", "Accuracy"]
        ),
        "Model",
        "Accuracy",
        title="Final Accuracy"
    )

    wandb.log({
        "final_accuracy_chart": bar_chart,
        "predictions_table": wandb_table,
        **{f"final_accuracy/{model}": results[model]["correct"] / results[model]["total"] * 100 
           for model in models}
    })
    wandb.finish()

print("\nТестирование завершено")