In [None]:
from __future__ import annotations

import asyncio
import hashlib
import json
import os
import re
import time
from pathlib import Path
from sys import path

import tqdm
import tqdm.asyncio
from datasets import load_dataset
from dotenv import load_dotenv

load_dotenv()
path.append(str(Path.cwd().parent))

In [3]:
from src.config import (
    GSM8K_PATH,
    RANDOM_STATE,
    TEACHER_COTS_FILE,
    TEACHER_SYSTEM_PROMPT,
    TEACHER_USER_PROMPT,
)
from src.logs import get_logger

In [4]:
logger = get_logger(__name__)

In [5]:
dataset = load_dataset(GSM8K_PATH, "main", split="train")

In [6]:
dataset

Dataset({
    features: ['question', 'answer'],
    num_rows: 7473
})

In [7]:
dataset[:5]

{'question': ['Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?',
  'Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?',
  'Betty is saving money for a new wallet which costs $100. Betty has only half of the money she needs. Her parents decided to give her $15 for that purpose, and her grandparents twice as much as her parents. How much more money does Betty need to buy the wallet?',
  'Julie is reading a 120-page book. Yesterday, she was able to read 12 pages and today, she read twice as many pages as yesterday. If she wants to read half of the remaining pages tomorrow, how many pages should she read?',
  'James writes a 3-page letter to 2 different friends twice a week.  How many pages does he write a year?'],
 'answer': ['Natalia sold 48/2 = <<48/2=24>>24 clips in May.\nNatalia sold 48+24 = <<48+24=72>>72 clips alt

In [None]:
import httpx

client = httpx.AsyncClient(
    base_url="https://openrouter.ai/api/v1",
    headers={"Authorization": f"Bearer {os.environ['OPENROUTER_API_KEY']}"},
    timeout=60.0,
)


async def get_teacher_response(question: str, client: httpx.AsyncClient):
    payload = {
        "model": "deepseek/deepseek-r1-distill-qwen-32b",
        "messages": [
            {"role": "system", "content": TEACHER_SYSTEM_PROMPT},
            {"role": "user", "content": TEACHER_USER_PROMPT.format(question=question)},
        ],
        "provider": {"only": ["deepinfra/fp8"], "allow_fallbacks": False},
        "temperature": 0.8,
        "top_p": 0.95,
        "tool_choice": "none",
    }

    resp = await client.post("/chat/completions", json=payload)
    resp.raise_for_status()
    return resp.json()

In [9]:
dataset[0]

{'question': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?',
 'answer': 'Natalia sold 48/2 = <<48/2=24>>24 clips in May.\nNatalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.\n#### 72'}

In [10]:
response = await get_teacher_response(dataset[0]["question"], client)

In [11]:
response

{'id': 'gen-1757883160-B4awjlanbxEM2YHIo6hl',
 'provider': 'DeepInfra',
 'model': 'deepseek/deepseek-r1-distill-qwen-32b',
 'object': 'chat.completion',
 'created': 1757883160,
 'choices': [{'logprobs': None,
   'finish_reason': 'stop',
   'native_finish_reason': 'stop',
   'index': 0,
   'message': {'role': 'assistant',
    'content': 'Reasoning:\n- Clips sold in April: 48\n- Clips sold in May: 48 ÷ 2 = 24\n- Total clips sold: 48 + 24 = 72\nFinal Answer: 72',
    'refusal': None,
    'reasoning': "First, identify the number of clips sold in April, which is 48.\n\nNext, determine the number of clips sold in May by calculating half of April's sales: 48 divided by 2 equals 24.\n\nFinally, add the clips sold in both months to find the total: 48 plus 24 equals 72.\n"}}],
 'usage': {'prompt_tokens': 492,
  'completion_tokens': 129,
  'total_tokens': 621,
  'prompt_tokens_details': None}}

In [12]:
print(response["choices"][0]["message"]["content"])

Reasoning:
- Clips sold in April: 48
- Clips sold in May: 48 ÷ 2 = 24
- Total clips sold: 48 + 24 = 72
Final Answer: 72


In [13]:
print(response["choices"][0]["message"]["reasoning"])

First, identify the number of clips sold in April, which is 48.

Next, determine the number of clips sold in May by calculating half of April's sales: 48 divided by 2 equals 24.

Finally, add the clips sold in both months to find the total: 48 plus 24 equals 72.



In [14]:
response["usage"]

{'prompt_tokens': 492,
 'completion_tokens': 129,
 'total_tokens': 621,
 'prompt_tokens_details': None}

In [15]:
response["provider"]

'DeepInfra'

In [None]:
NUM_EXAMPLES = 10
NUM_SAMPLES_PER_EXAMPLE = 3
MAX_CONCURRENCY = 100

In [None]:
_num_re = re.compile(r"[-+]?\d*\.?\d+(?:[eE][-+]?\d+)?")


def sha256(s: str) -> str:
    return hashlib.sha256(s.encode()).hexdigest()


def _to_number(s: str):
    m = _num_re.findall(s)
    if not m:
        return None
    val = m[-1]
    if "e" in val.lower() or "." in val:
        return float(val)
    return int(val)


def parse_gsm8k_answer(ans: str):
    try:
        answer_part = ans.split("####")[-1].strip()
        return _to_number(answer_part)
    except (IndexError, ValueError) as e:
        raise ValueError(f"Could not parse GSM8K answer: '{ans}'") from e


def parse_teacher_final_answer(content: str):
    try:
        for line in content.splitlines():
            if line.lower().startswith("final answer:"):
                return _to_number(line.split(":", 1)[-1].strip())
        raise ValueError("No 'Final Answer:' line found")
    except (IndexError, ValueError) as e:
        raise ValueError(f"Could not parse teacher final answer: '{content}'") from e

In [27]:
def create_meta_file():
    meta = {
        "dataset": GSM8K_PATH,
        "split": "train",
        "model": "deepseek/deepseek-r1-distill-qwen-32b",
        "system_prompt": TEACHER_SYSTEM_PROMPT,
        "num_examples": NUM_EXAMPLES,
        "num_samples_per_example": NUM_SAMPLES_PER_EXAMPLE,
        "max_concurrency": MAX_CONCURRENCY,
        "random_seed": RANDOM_STATE,
        "timestamp": time.time(),
    }

    meta_path = TEACHER_COTS_FILE.with_suffix(".meta.json")
    with meta_path.open("w", encoding="utf-8") as f:
        json.dump(meta, f, ensure_ascii=False, indent=2)
    logger.info(f"Meta file created at {meta_path}")

In [None]:
def create_base_record(row_idx: int, q: str, gold_text: str) -> dict:
    return {
        "row_index": int(row_idx),
        "question": q,
        "gold_answer_text": gold_text,
        "gold_answer_number": parse_gsm8k_answer(gold_text),
        "user_prompt": TEACHER_USER_PROMPT.format(question=q),
        "samples": [],
    }


async def generate_sample(row_idx: int, q: str, sample_id: int, sem: asyncio.Semaphore):
    async with sem:
        try:
            start = time.perf_counter()
            resp = await get_teacher_response(q, client)
            latency_ms = int((time.perf_counter() - start) * 1000)
            choice = resp["choices"][0]
            content = choice["message"]["content"]

            sample_record = {
                "sample_id": int(sample_id),
                "teacher_answer_text": content,
                "teacher_answer_number": parse_teacher_final_answer(content),
                "teacher_reasoning_text": choice["message"].get("reasoning"),
                "finish_reason": choice.get("finish_reason"),
                "native_finish_reason": choice.get("native_finish_reason"),
                "provider": resp.get("provider"),
                "model": resp.get("model"),
                "created": resp.get("created"),
                "usage": resp.get("usage"),
                "request_id": resp.get("id"),
                "latency_ms": latency_ms,
            }
            return sample_record
        except Exception as e:
            logger.error(
                f"Error generating sample for row {row_idx}, sample {sample_id}: '{e}'",
                exc_info=True,
            )
            return None


async def generate_samples(idx: int, q: str, gold: str, sem: asyncio.Semaphore):
    record = create_base_record(idx, q, gold)

    tasks = []
    for s in range(NUM_SAMPLES_PER_EXAMPLE):
        tasks.append(generate_sample(idx, q, s, sem))

    samples = await asyncio.gather(*tasks)
    record["samples"] = [sample for sample in samples if sample is not None]
    return record


async def generate_all_samples():
    ds = dataset.select(range(min(NUM_EXAMPLES, len(dataset))))
    sem = asyncio.Semaphore(MAX_CONCURRENCY)

    tasks = []

    for idx, row in enumerate(ds):
        tasks.append(generate_samples(idx, row["question"], row["answer"], sem))

    results = await tqdm.asyncio.tqdm.gather(*tasks)

    return results


def save_samples_to_file(records: list[dict]):
    successful_samples = 0
    failed_samples = 0

    with TEACHER_COTS_FILE.open("w", encoding="utf-8") as f:
        for rec in records:
            f.write(json.dumps(rec, ensure_ascii=False) + "\n")
            successful_samples += len(rec["samples"])
            failed_samples += NUM_SAMPLES_PER_EXAMPLE - len(rec["samples"])

    logger.info(f"Total samples saved: {successful_samples}")
    logger.info(f"Total samples failed: {failed_samples}")
    return successful_samples, failed_samples

In [None]:
create_meta_file()
results = await generate_all_samples()
successful, failed = save_samples_to_file(results)

print(f"Wrote {successful} samples to {TEACHER_COTS_FILE}")
if failed > 0:
    print(f"Failed to generate {failed} samples. Check logs for details.")

100%|██████████| 10/10 [00:24<00:00,  2.44s/it]

Wrote 30 samples to /home/rabadaba/engineering-thesis/code/data/teacher_cots.jsonl





In [None]:
import pandas as pd

df = pd.read_json(TEACHER_COTS_FILE, lines=True)

In [32]:
df

Unnamed: 0,row_index,question,gold_answer_text,gold_answer_number,user_prompt,samples
0,0,Natalia sold clips to 48 of her friends in Apr...,Natalia sold 48/2 = <<48/2=24>>24 clips in May...,72,Follow the required format exactly.\n\nExample...,"[{'sample_id': 0, 'teacher_answer_text': 'Reas..."
1,1,Weng earns $12 an hour for babysitting. Yester...,Weng earns 12/60 = $<<12/60=0.2>>0.2 per minut...,10,Follow the required format exactly.\n\nExample...,"[{'sample_id': 0, 'teacher_answer_text': 'Reas..."
2,2,Betty is saving money for a new wallet which c...,"In the beginning, Betty has only 100 / 2 = $<<...",5,Follow the required format exactly.\n\nExample...,"[{'sample_id': 0, 'teacher_answer_text': 'Reas..."
3,3,"Julie is reading a 120-page book. Yesterday, s...",Maila read 12 x 2 = <<12*2=24>>24 pages today....,42,Follow the required format exactly.\n\nExample...,"[{'sample_id': 0, 'teacher_answer_text': 'Reas..."
4,4,James writes a 3-page letter to 2 different fr...,He writes each friend 3*2=<<3*2=6>>6 pages a w...,624,Follow the required format exactly.\n\nExample...,"[{'sample_id': 0, 'teacher_answer_text': 'Reas..."
5,5,Mark has a garden with flowers. He planted pla...,There are 80/100 * 10 = <<80/100*10=8>>8 more ...,35,Follow the required format exactly.\n\nExample...,"[{'sample_id': 0, 'teacher_answer_text': 'Reas..."
6,6,Albert is wondering how much pizza he can eat ...,He eats 32 from the largest pizzas because 2 x...,48,Follow the required format exactly.\n\nExample...,"[{'sample_id': 0, 'teacher_answer_text': 'Reas..."
7,7,Ken created a care package to send to his brot...,"To the initial 2 pounds of jelly beans, he add...",16,Follow the required format exactly.\n\nExample...,"[{'sample_id': 0, 'teacher_answer_text': 'Reas..."
8,8,Alexis is applying for a new job and bought a ...,Let S be the amount Alexis paid for the shoes....,41,Follow the required format exactly.\n\nExample...,"[{'sample_id': 0, 'teacher_answer_text': 'Reas..."
9,9,Tina makes $18.00 an hour. If she works more ...,She works 8 hours a day for $18 per hour so sh...,990,Follow the required format exactly.\n\nExample...,"[{'sample_id': 0, 'teacher_answer_text': 'Reas..."
