In [None]:
!pip install -q torch transformers peft datasets tqdm

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from datasets import load_dataset
import json
from tqdm import tqdm

num_samples = None  
output_file = "hard_label_test.json"
save_batch_size = 50  

# Load the PEFT model with device_map="auto" for multi-GPU support
base_model = AutoModelForCausalLM.from_pretrained(
    "unsloth/Qwen3-4B-Base", 
    torch_dtype=torch.float16,
    device_map="auto"  
)

peft_model_id = "yobro4619/hard_labels_final"
model = PeftModel.from_pretrained(base_model, peft_model_id)
print(model) 

tokenizer = AutoTokenizer.from_pretrained("unsloth/Qwen3-4B-Base")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

if not hasattr(tokenizer, 'chat_template') or tokenizer.chat_template is None:
    print("No chat template found, setting a default one for Qwen...")
    tokenizer.chat_template = "{% for message in messages %}{% if message['role'] == 'system' %}<|im_start|>system\n{{ message['content'] }}<|im_end|>\n{% elif message['role'] == 'user' %}<|im_start|>user\n{{ message['content'] }}<|im_end|>\n<|im_start|>assistant\n{% endif %}{% endfor %}"

dataset_name = "yobro4619/math-reasoning-dataset"
test_dataset = load_dataset(dataset_name, split="test")
print(f"Loaded test dataset with {len(test_dataset)} samples")

print(f"Dataset columns: {test_dataset.column_names}")

if num_samples is not None:
    test_dataset = test_dataset.select(range(min(num_samples, len(test_dataset))))
    print(f"Processing {len(test_dataset)} samples as requested")

system_prompt= \
"""You are given a math problem.

1. Carefully analyze the problem.
2. Show all your working out and reasoning steps very concisely.
3. Place all reasoning **only once** between the tags <start_working_out> and <end_working_out>.
4. Then, provide the final answer **only once** between the tags <SOLUTION> and </SOLUTION>.
5. Do not repeat any part of your response after these tags have been used.

Stick strictly to this format.

Example Question-1: An Informal Gathering occurs when a group of people get together in a casual, relaxed manner. Which situation below is the best example of an Informal Gathering?   
A. After finding out about his salary raise, Jay and a few colleagues go out for a quick dinner after work.  B. Meena sends out 10 invitations for a bachelorette party she is giving for her elder sister.
Answer:
<start_working_out>The question asks me to identify the best example of an "Informal Gathering," defined as a group getting together in a casual, relaxed manner. I need to evaluate which of the two situations provided best fits this definition.
The first situation describes a few colleagues going out for a quick dinner after work to celebrate one person's salary raise. This seems spontaneous ("go out for a quick dinner") and lacks formal planning. Getting together after work with colleagues can certainly be casual and relaxed, especially when prompted by an immediate event like good news. It aligns well with the core idea of informality.
The second situation describes someone sending out invitations for a bachelorette party. The act of sending invitations implies planning and a degree of formality. While the party itself might aim for a relaxed atmosphere, the organization process (invitations, specific event type) makes the gathering itself less "informal" in its conception compared to a spontaneous decision to grab dinner. Bachelorette parties often have planned elements and aren't typically characterized by the same level of casual spontaneity as the first scenario.
Comparing the two, the first situation embodies the 'casual' and 'relaxed manner' aspect more strongly due to its apparent spontaneity and lack of formal structure like invitations. The second situation, while a social gathering, involves planning and formal invitations, making it less representative of a purely informal gathering according to the definition provided. So I think that the first scenario is the better example of an informal gathering<end_working_out>
<SOLUTION>A</SOLUTION>
"""

def generate_response(prompt, model, tokenizer, system_prompt):
    try:
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": prompt} 
        ]
        input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    except Exception as e:
        print(f"Chat template failed: {e}")
        print("Falling back to simple concatenation...")
        input_text = f"{system_prompt}\n\nUser: {prompt}\nAssistant:"
    
    inputs = tokenizer(input_text, return_tensors="pt")

    if hasattr(model, 'hf_device_map') and model.hf_device_map:
        first_device = next(iter(model.hf_device_map.values()))
        if isinstance(first_device, str) and first_device != "cpu":
            inputs = {k: v.to(first_device) for k, v in inputs.items()}
    
    # Generate
    with torch.no_grad():
        output_ids = model.generate(
            **inputs,
            max_new_tokens=1024, 
            do_sample=True,
            temperature=0.7,
            pad_token_id=tokenizer.eos_token_id
        )
    
    generated_ids = output_ids[0][inputs.input_ids.shape[1]:]
    response = tokenizer.decode(generated_ids, skip_special_tokens=True)
    
    return response

results = []

for i, example in enumerate(tqdm(test_dataset)):
    prompt = example["prompt"]  
    solution = example["solution"]  
    
    response = generate_response(prompt, model, tokenizer, system_prompt)
    
    result = {
        "id": i,
        "prompt": prompt,
        "solution": solution,
        "response": response
    }
    
    results.append(result)
    
    if (i + 1) % save_batch_size == 0:
        with open(f"interim_results_{i+1}.json", "w") as f:
            json.dump(results, f, indent=2)
        print(f"Saved interim results for {i+1} samples")

# Save all responses to a JSON file
with open(output_file, "w") as f:
    json.dump(results, f, indent=2, ensure_ascii=False)

print(f"Inference completed. Results saved to {output_file}")
print(f"Total samples processed: {len(results)}")

In [None]:
import re

input_file = "/kaggle/working/hard_label_test.json" 
output_file = "/kaggle/working/hard_label_test_extracted.json"  

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("unsloth/Qwen3-4B-Base")
if tokenizer.eos_token is None:
    tokenizer.eos_token = "</s>"  

print(f"Using EOS token: {tokenizer.eos_token}")

reasoning_end = r"<end_working_out>"
solution_start = r"<SOLUTION>"
solution_end_regex = r"</SOLUTION>[\s]{0,}" + \
    "(?:" + re.escape(tokenizer.eos_token) + ")?"

# Pattern to match the complete solution tags with content
match_format = re.compile(
    rf"{solution_start}(.+?){solution_end_regex}"
    rf"[\s]{{0,}}",
    flags=re.MULTILINE | re.DOTALL
)

# Pattern to extract fraction answers like \frac{3}{2} (check this FIRST before numbers)
match_fractions = re.compile(
    solution_start + r".*?(\\frac\{[^}]+\}\{[^}]+\})",
    flags=re.MULTILINE | re.DOTALL
)

# Pattern to extract numerical answers (but not if part of a fraction)
match_numbers = re.compile(
    solution_start + r"(?!.*\\frac).*?[\s]{0,}([-]?[\d\.\,]{1,})(?!\})",
    flags=re.MULTILINE | re.DOTALL
)

# Pattern for multiple choice answers (A, B, C, D, etc.) - should be standalone letter
match_choice = re.compile(
    solution_start + r".*?[\s]{0,}([A-Z])(?:\s|</SOLUTION>)",
    flags=re.MULTILINE | re.DOTALL
)

def extract_answer(response_text):
    """
    Extract answer from model response STRICTLY from within <SOLUTION></SOLUTION> tags
    Priority order:
    1. Try fraction answer (check first to avoid conflicts with numbers)
    2. Try numerical answer
    3. Try multiple choice answer
    4. Return complete content inside SOLUTION tags
    """
    
    solution_matches = match_format.findall(response_text)
    if not solution_matches:
        return ""  
    
    solution_content = solution_matches[0].strip()
    
    solution_content = re.sub(re.escape(tokenizer.eos_token), "", solution_content).strip()
    
    
    fraction_pattern = re.compile(r'(\\frac\{[^}]+\}\{[^}]+\})')
    fraction_matches = fraction_pattern.findall(solution_content)
    if fraction_matches:
        return fraction_matches[0].strip()
    
    number_pattern = re.compile(r'^\s*([-]?[\d\.\,]+)\s*$|(?:^|\s)([-]?[\d\.\,]+)(?:\s|$)')
    number_matches = number_pattern.findall(solution_content)
    if number_matches:
        for match_tuple in number_matches:
            for match in match_tuple:
                if match:
                    return match.strip()
    
    choice_pattern = re.compile(r'(?:^|\s)([A-Z])(?:\s|$)')
    choice_matches = choice_pattern.findall(solution_content)
    if choice_matches:
        return choice_matches[0].strip()
    
    return solution_content



print(f"\nLoading input file: {input_file}")
try:
    with open(input_file, 'r', encoding='utf-8') as f:
        data = json.load(f)
    print(f"Loaded {len(data)} samples")
except FileNotFoundError:
    print(f"Error: File {input_file} not found!")
    exit(1)
except json.JSONDecodeError as e:
    print(f"Error: Invalid JSON in {input_file}: {e}")
    exit(1)

processed_count = 0
extraction_stats = {
    "numerical": 0,
    "fraction": 0, 
    "choice": 0,
    "complete": 0,
    "empty": 0
}

for i, sample in enumerate(data):
    if "response" not in sample:
        print(f"Warning: Sample {i} missing 'response' field")
        sample["extracted_answer"] = ""
        extraction_stats["empty"] += 1
        continue
    
    extracted_answer = extract_answer(sample["response"])
    sample["extracted_answer"] = extracted_answer
    
    if not extracted_answer:
        extraction_stats["empty"] += 1
    elif re.match(r'^[-]?[\d\.\,]+$', extracted_answer):
        extraction_stats["numerical"] += 1
    elif extracted_answer.startswith('\\frac'):
        extraction_stats["fraction"] += 1
    elif re.match(r'^[A-Z]$', extracted_answer):
        extraction_stats["choice"] += 1
    else:
        extraction_stats["complete"] += 1
    
    processed_count += 1

print(f"Saving results to: {output_file}")
with open(output_file, 'w', encoding='utf-8') as f:
    json.dump(data, f, indent=2, ensure_ascii=False)

print(f"\nProcessing completed!")
print(f"Total samples processed: {processed_count}")
print(f"Extraction statistics:")
print(f"  - Numerical answers: {extraction_stats['numerical']}")
print(f"  - Fraction answers: {extraction_stats['fraction']}")
print(f"  - Multiple choice answers: {extraction_stats['choice']}")
print(f"  - Complete text answers: {extraction_stats['complete']}")
print(f"  - Empty/No answer found: {extraction_stats['empty']}")