In [3]:
import json
import pandas as pd
import torch


CONFIG = {
    "base": "Qwen/Qwen2.5-1.5B-Instruct", # base model
    "device": torch.device("cuda"),

    # model names and prices
    "model_keys": {
        'gpt-5-nano-2025-08-07_response': (0.05, 0.40),
        'grok-4-fast-reasoning_response': (0.20, 0.50),
        'openai_gpt-oss-120b_response': (0.05, 0.24),
        'openai_gpt-oss-20b_response': (0.03, 0.14),
        'nvidia_NVIDIA-Nemotron-Nano-9B-v2_response': (0.04, 0.16),
        'meta-llama_Llama-3.2-11B-Vision-Instruct_response': (0.049, 0.049),
        'moonshotai_Kimi-K2-Instruct-0905_response': (0.50, 2.00),
        'gemini-2.5-flash-lite_response': (0.10, 0.40),
    },

    "PEAK_LR": 1e-6,
    "BATCH_SIZE": 16,
    "EPOCHS": 8,
    "WEIGHT_DECAY": 0.04,
    "NUM_WARMUP_STEPS": 50,
    "SAVE_PATH": "models/router_qwen2.5"
}

def load_data(file_path, length=None):
    records = []
    with open(file_path, "r") as f:
        for line in f:
            if length and len(records) >= length:
                break
            try:
                records.append(json.loads(line))
            except json.JSONDecodeError as e:
                print("Bad line:", e)
    df = pd.DataFrame(records)

    # merge by model keys
    grouped = df.groupby("index")
    merged_rows = []
    for idx, group in grouped:
        merged_row = group.iloc[0].copy()
        for m in CONFIG["model_keys"]:
            vals = group[m].dropna().tolist()
            if vals:
                merged_row[m] = vals[-1]
        merged_rows.append(merged_row)
    df = pd.DataFrame(merged_rows).reset_index(drop=True)

    # clean & get is_correct, and total_price
    for idx, row in df.iterrows():
        for col in CONFIG["model_keys"].keys():
            if pd.notna(row.get(col)) and isinstance(row[col], dict):
                entry = row[col].copy()

                # parse final answer
                answer = entry.get('answer', '')
                if answer and 'Final Answer:' in answer:
                    final_answer_str = answer.split("Final Answer:")[1].strip()
                    try:
                        final_answer = int(final_answer_str)
                    except ValueError:
                        final_answer = 0
                    entry['final_answer'] = final_answer

                    correct_answer = row.get('correct_answer', '')
                    entry['is_correct'] = (final_answer == correct_answer)

                # compute total price
                input_tokens = entry.get('input_tokens', 0)
                output_tokens = entry.get('output_tokens', 0)
                input_rate, output_rate = CONFIG["model_keys"][col]
                entry['total_price'] = (input_tokens * (input_rate / 1000000) + output_tokens * (output_rate / 1000000)) * 1000 # kind of normalizing

                # default missing numeric fields to 0
                entry['total_latency'] = entry.get('total_latency', 0) or 0
                entry['is_correct'] = entry.get('is_correct', False)

                df.at[idx, col] = entry

    # Drop rows with missing or invalid question
    df = df[df['question'].notna() & (df['question'] != '')].reset_index(drop=True)

    # Ensure no nulls in key model fields (example: gemini)
    key_model = 'gemini-2.5-flash-lite_response'
    df = df[
        df[key_model].apply(lambda x: isinstance(x, dict) and all(k in x and x[k] is not None for k in ['is_correct', 'total_latency', 'total_price']))
    ].reset_index(drop=True)

    # Split into train and validation
    df_train = df.sample(frac=0.9, random_state=42)
    df_val = df.drop(df_train.index).reset_index(drop=True)
    df_train = df_train.reset_index(drop=True)

    return df_train, df_val

df_train, df_val = load_data("data/all_providers_results_parallelized.jsonl")
df_train, df_val = df_train[['question', 'gemini-2.5-flash-lite_response']], df_val[['question', 'gemini-2.5-flash-lite_response']]
num_pos = sum(df_train['gemini-2.5-flash-lite_response'].apply(lambda d: d['is_correct']))
num_neg = len(df_train) - num_pos
print(f"Number of positive samples: {num_pos}")
print(f"Number of negative samples: {num_neg}")

Number of positive samples: 14186
Number of negative samples: 3811
