# Dataset Creation for Router Training (For Part One and Two)

For the same of demonstration for the current task at hand, here we build a balanced dataset for training semantic routers. The dataset consists of:
- **Factual prompts**: General knowledge questions from Dolly-15k dataset
- **Coding prompts**: Programming-related questions from CodeAlpaca-20k dataset


The dataset then finally has 2000 queries, 1000 for each class.

In [1]:
import json
import random
import re
from pathlib import Path

from datasets import load_dataset
import csv

In [8]:
# -----------------------
# Config
# -----------------------
SEED = 999999
TOTAL = 2000
FACTUAL_TARGET = TOTAL // 2
CODING_TARGET = TOTAL - FACTUAL_TARGET

FACTUAL_SOURCE = "databricks/databricks-dolly-15k"
CODING_SOURCE = "sahil2801/CodeAlpaca-20k"

# Skip first N samples to get fresh data for testing
SKIP_FIRST = False  # Set to False to disable skipping
SKIP_COUNT = 5000  # Number of samples to skip from the beginning
random.seed(SEED)


In [9]:
# Heuristics for filtering
CODING_KEYWORDS = {
    "python","java","javascript","js","typescript","ts","c++","c#","golang","go","rust","ruby","php","sql",
    "bash","shell","regex","html","css","json","yaml","xml","api","sdk","library","framework","tensorflow",
    "pytorch","numpy","pandas","sklearn","django","flask","fastapi","node","react","vue","svelte","angular",
    "kotlin","swift","scala","haskell","matlab","octave","r ","julia","notebook","jupyter","colab",
    "function","class","method","variable","loop","algorithm","complexity","big o","o(n)","runtime",
    "compile","build","debug","error","stack trace","exception","unit test","test case","package",
    "import","module","script","snippet","code","coding","program","programming"
}
# words that often indicate general factual/trivia/expository questions
FACTUAL_HINTS = {
    "who","what","when","where","why","how","which","name","define","explain","describe","compare","contrast",
    "list","identify","summarize","outline"
}

### Utility functions
They are almost unnecesasry but for the sake of demonstration of methods and not data, why not take clean data and focus on model comparison. We'll see that even filter the most obvious and generic examples our baseline doesn't reach a 100% at start, this will still need tuning.

In [10]:
def is_coding_text(text: str) -> bool:
    t = text.lower()
    if "```" in t:  # code fences
        return True
    # keyword hit
    return any(kw in t for kw in CODING_KEYWORDS)


def is_factual_text(text: str) -> bool:
    t = text.strip().lower()
    # generally a question or expository ask, without obvious coding markers
    if is_coding_text(t):
        return False
    # question-y or factual vibe
    if "?" in t or any(t.startswith(h) for h in FACTUAL_HINTS):
        return True
    # statements that ask to explain/define/etc.
    return any(h in t.split()[:3] for h in ["explain", "define", "describe", "summarize"])


def clean_text(s: str) -> str:
    # Light cleanup
    s = re.sub(r'\s+', ' ', s).strip()
    return s


In [11]:
def sample_dolly_factual(n_target: int):
    """From Dolly 15k, pick Q&A‑ish factual prompts, exclude coding."""
    ds = load_dataset(FACTUAL_SOURCE, split="train")
    records = []
    skipped_count = 0
    
    for i, ex in enumerate(ds):
        # Skip first N samples if SKIP_FIRST is enabled
        if SKIP_FIRST and skipped_count < SKIP_COUNT:
            instr = ex.get("instruction") or ""
            ctx = ex.get("context") or ""
            text_for_filter = " ".join([instr, ctx]).strip()
            if text_for_filter:  # Only count valid samples towards skip count
                category = (ex.get("category") or "").lower()
                likely_factual_cat = category in {"open_qa", "closed_qa", "information_extraction", "classification"}
                if likely_factual_cat and is_factual_text(instr or ctx):
                    if not is_coding_text(instr) and not is_coding_text(ctx):
                        skipped_count += 1
            continue
        instr = ex.get("instruction") or ""
        ctx = ex.get("context") or ""
        resp = ex.get("response") or ""

        text_for_filter = " ".join([instr, ctx]).strip()
        if not text_for_filter:
            continue

        # Prefer Dolly's open_qa & classification/extraction that look factual
        category = (ex.get("category") or "").lower()
        likely_factual_cat = category in {"open_qa", "closed_qa", "information_extraction", "classification"}

        if likely_factual_cat and is_factual_text(instr or ctx):
            if not is_coding_text(instr) and not is_coding_text(ctx):
                prompt = clean_text(instr if instr else ctx)
                if len(prompt) < 15:  # avoid overly short
                    continue
                records.append({
                    "id": f"dolly-{i}",
                    "category": "factual",
                    "prompt": prompt,
                    "response": clean_text(resp) if resp else "",
                    "source": FACTUAL_SOURCE,
                    "source_id": i,
                })

    # If we’re short, relax the category constraint but keep non-coding & factual vibe
    if len(records) < n_target:
        for i, ex in enumerate(ds):
            if len(records) >= n_target * 2:  # cap oversampling before sampling down
                break
            instr = ex.get("instruction") or ""
            ctx = ex.get("context") or ""
            resp = ex.get("response") or ""
            text_for_filter = " ".join([instr, ctx]).strip()
            if not text_for_filter:
                continue
            if is_factual_text(instr or ctx) and not is_coding_text(instr) and not is_coding_text(ctx):
                prompt = clean_text(instr if instr else ctx)
                if len(prompt) < 15:
                    continue
                records.append({
                    "id": f"dolly-{i}",
                    "category": "factual",
                    "prompt": prompt,
                    "response": clean_text(resp) if resp else "",
                    "source": FACTUAL_SOURCE,
                    "source_id": i,
                })

    if SKIP_FIRST:
        print(f"Dolly factual: Skipped first {skipped_count} valid samples")
    
    random.shuffle(records)
    return records[:n_target]


def sample_codealpaca_coding(n_target: int):
    """From CodeAlpaca-20k, sample coding prompts (most are coding by construction)."""
    ds = load_dataset(CODING_SOURCE, split="train")
    candidates = []
    skipped_count = 0
    
    for i, ex in enumerate(ds):
        # Skip first N samples if SKIP_FIRST is enabled
        if SKIP_FIRST and skipped_count < SKIP_COUNT:
            instr = ex.get("instruction") or ""
            if instr and is_coding_text(instr):  # Only count valid coding samples towards skip count
                skipped_count += 1
            continue
            
        instr = ex.get("instruction") or ""
        output = ex.get("output") or ex.get("response") or ""
        if not instr:
            continue
        if is_coding_text(instr):
            candidates.append({
                "id": f"codealpaca-{i}",
                "category": "coding",
                "prompt": clean_text(instr),
                "response": clean_text(output) if output else "",
                "source": CODING_SOURCE,
                "source_id": i,
            })

    # If not enough matched by keywords, backfill with any examples (CodeAlpaca is coding-oriented anyway)
    if len(candidates) < n_target:
        skipped_count_relaxed = 0
        for i, ex in enumerate(ds):
            if len(candidates) >= n_target * 2:
                break
                
            # Skip first N samples if SKIP_FIRST is enabled (for relaxed pass)
            if SKIP_FIRST and skipped_count_relaxed < SKIP_COUNT:
                instr = ex.get("instruction") or ""
                if instr:  # Count any valid instruction towards skip count
                    skipped_count_relaxed += 1
                continue
            instr = ex.get("instruction") or ""
            output = ex.get("output") or ex.get("response") or ""
            if not instr:
                continue
            candidates.append({
                "id": f"codealpaca-{i}",
                "category": "coding",
                "prompt": clean_text(instr),
                "response": clean_text(output) if output else "",
                "source": CODING_SOURCE,
                "source_id": i,
            })

    if SKIP_FIRST:
        print(f"CodeAlpaca coding: Skipped first {skipped_count} valid samples")
    
    random.shuffle(candidates)
    return candidates[:n_target]


In [12]:
def create_dataset():
    print("Loading & sampling…")
    if SKIP_FIRST:
        print(f"SKIP_FIRST enabled: Skipping first {SKIP_COUNT} valid samples from each source")
    else:
        print("SKIP_FIRST disabled: Using samples from the beginning")
    
    factual = sample_dolly_factual(FACTUAL_TARGET)
    print(f"Factual collected: {len(factual)}")

    coding = sample_codealpaca_coding(CODING_TARGET)
    print(f"Coding collected: {len(coding)}")

    combined = factual + coding
    random.shuffle(combined)

    # Update output filenames to indicate if skip was used
    if SKIP_FIRST:
        out_jsonl = Path(f"prompts_{TOTAL}_skip{SKIP_COUNT}.jsonl")
        out_csv = Path(f"prompts_{TOTAL}_skip{SKIP_COUNT}.csv")
    else:
        out_jsonl = Path(f"prompts_{TOTAL}.jsonl")
        out_csv = Path(f"prompts_{TOTAL}.csv")

    print(f"Writing {out_jsonl} …")
    with out_jsonl.open("w", encoding="utf-8") as f:
        for row in combined:
            f.write(json.dumps(row, ensure_ascii=False) + "\n")

    print(f"Writing {out_csv} …")
    with out_csv.open("w", encoding="utf-8", newline="") as f:
        w = csv.DictWriter(f, fieldnames=["id","category","prompt","response","source","source_id"])
        w.writeheader()
        for row in combined:
            w.writerow(row)

    # Simple stats
    def avg_len(key):
        xs = [len(r[key].split()) for r in combined if r[key]]
        return round(sum(xs)/len(xs), 1) if xs else 0.0

    print("Done!")
    print(f"Total: {len(combined)} (factual={len(factual)}, coding={len(coding)})")
    if SKIP_FIRST:
        print(f"Dataset created with skip_first={SKIP_COUNT} - contains fresh data for testing")
    print(f"Avg prompt length (words): {avg_len('prompt')}")
    print(f"Avg response length (words): {avg_len('response')}")
    print(f"Output files: {out_jsonl.name}, {out_csv.name}")


In [13]:
create_dataset()

Loading & sampling…
SKIP_FIRST disabled: Using samples from the beginning
Factual collected: 1000
Coding collected: 1000
Writing prompts_2000.jsonl …
Writing prompts_2000.csv …
Done!
Total: 2000 (factual=1000, coding=1000)
Avg prompt length (words): 10.9
Avg response length (words): 38.9
Output files: prompts_2000.jsonl, prompts_2000.csv


# Dataset Creation for Router Training (For Part Three)

To train a BERT-like model for routing we need prompts and labels which answer if the prompt should be routed to a weak model or a strong model.
We start with taking all 2000 of previously generated prompts and calling Llama 3.1 8B (our weak model) to generate response.
To generate we use OpenRouter.ai becuase it is cheap, some API calls might fail but few retries are enough.

To do so we use the following script.   
(I am sorry for not making doing it in the notebook, most of my work was done in scripts because I am comfortable with that, but I understand if the Nexos team prefers only notebooks I can totally get comfortable with that at the job).

Note: this part of the notebook requires an API key for openrouter.ai, you can place it in a .env file and it should work. See `.env.example`.

In [9]:
!python generate_responses.py --help

usage: generate_responses.py [-h] [--input INPUT] [--model MODEL]
                             [--output OUTPUT] [--max-tokens MAX_TOKENS]
                             [--temperature TEMPERATURE]
                             [--max-concurrent MAX_CONCURRENT] [--limit LIMIT]

Generate responses for dataset using OpenRouter API (Async)

options:
  -h, --help            show this help message and exit
  --input INPUT, -i INPUT
                        Input dataset file (default: prompts_2000.jsonl)
  --model MODEL, -m MODEL
                        Model to use for generation (default: meta-
                        llama/llama-3.1-8b-instruct)
  --output OUTPUT, -o OUTPUT
                        Output file prefix (default: auto-generated based on
                        model and timestamp)
  --max-tokens MAX_TOKENS
                        Maximum tokens to generate (default: 1000)
  --temperature TEMPERATURE
                        Temperature for generation (default: 0.7)
  --max-concur

The command produces a dataset with response:

In [None]:
!python generate_responses.py --max-concurrent 5 --input prompts_2000.jsonl

Once we have our responses generated, we need to run them through a stronger model and generate ratings.  
To generate ratings, we present the prompt and the response from our weaker model to a stronger model, and ask it to produce a rating on a scale of 1 to 5.  
For the strong model we use `openai/gpt-oss-120b` as it is substantially better than our weak model Llama 3.1 8B in majority of the benchmarks and also does well on SWE-bench. We make the assumption that it would be a good judge.


The following is the prompt template we use:

```
"""You are an expert evaluator tasked with rating the quality of AI model responses. 

Please rate the following response on a scale from 1 to 5 based on these criteria:
- **Accuracy**: Is the information correct?
- **Relevance**: Does it directly address the user's question/request?
- **Completeness**: Does it provide a thorough answer?
- **Clarity**: Is it well-written and easy to understand?
- **Helpfulness**: Would this response be useful to the user?

Rating Scale:
- **1**: Very Poor - Incorrect, irrelevant, or unhelpful
- **2**: Poor - Mostly incorrect or not very helpful
- **3**: Average - Somewhat helpful but has issues
- **4**: Good - Helpful and mostly accurate
- **5**: Excellent - Highly accurate, relevant, and helpful

**Original User Prompt:**
{original_prompt}

**Model Response to Rate:**
{model_response}

Please provide only a single number (1, 2, 3, 4, or 5) as your rating, followed by a brief explanation in parentheses.

Rating:"""
```

To generate response, the following is the interface of the script responsible:

In [10]:
!python rate_responses.py --help

usage: rate_responses.py [-h] [--input INPUT] [--rating-model RATING_MODEL]
                         [--output OUTPUT] [--max-concurrent MAX_CONCURRENT]
                         [--max-retries MAX_RETRIES] [--limit LIMIT]

Rate responses in enhanced dataset using OpenRouter API (Async)

options:
  -h, --help            show this help message and exit
  --input INPUT, -i INPUT
                        Input enhanced dataset file
  --rating-model RATING_MODEL, -r RATING_MODEL
                        Model to use for rating (default: openai/gpt-oss-120b)
  --output OUTPUT, -o OUTPUT
                        Output file prefix (default: auto-generated based on
                        input and timestamp)
  --max-concurrent MAX_CONCURRENT
                        Maximum concurrent requests (default: 15)
  --max-retries MAX_RETRIES
                        Maximum retries per failed request (default: 3)
  --limit LIMIT         Limit processing to first N items (for testing)


In [None]:
!python rate_responses.py --input rated_enhanced_dataset_meta_llama_llama_3.1_8b_instruct_20250811_140130_20250811_151904.jsonl --max-retries 5 max-concurrent 5