# Construct training data

## Construct preference DPO data

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Build ONE DPO pair per id based on reward scores.

For each id:
  - good_answer = generation with the highest reward for that id (must be > GOOD_THRESHOLD)
  - bad_answer  = generation with the lowest  reward for that id (must be < BAD_THRESHOLD)

Optionally skip zero-reward generations (SKIP_REWARD_ZERO=True).
No similarity filtering is applied.

Input
-----
Expects JSONL files where each line is a dict containing:
  - "id"
  - "premises" / "premise" / "context" (any of these)
  - "question"
  - "generations": list[dict], each with:
        - "steps" / "final_tag" / "answer" (reasoning + final label)
        - "reward" (float)
    OR a top-level "reward_score": list[float] aligned with "generations".

Configurable via `input_files`, thresholds, etc.

Output
------
JSONL at `output_path`, one line per pair:

{
  "id": str,
  "question": str,        # collapsed prompt (system + user + assistant prefix)
  "good_answer": str,     # formatted reasoning + final answer
  "bad_answer": str,      # formatted reasoning + final answer
  "reward": [good, bad]   # [float, float]
}
"""

import json
import re
from pathlib import Path
from collections import defaultdict
from typing import List, Dict, Any, Tuple

# ---------------------- Configure here ---------------------- #
input_files = [
    "./reward_data/Multi-Thread/output/rewarded_data.jsonl",
    "./refine_data/Multi-Thread/output/refined_data.jsonl",
]

SKIP_REWARD_ZERO = True
GOOD_THRESHOLD = 0.5
BAD_THRESHOLD = 0.5

output_path = "./dpo_data.jsonl"
# ----------------------------------------------------------- #


def read_jsonl(path: str) -> List[Dict[str, Any]]:
    """Load a JSONL file into a list of dicts, skipping malformed lines."""
    recs: List[Dict[str, Any]] = []
    p = Path(path)
    if not p.exists():
        print(f"[WARN] Skipping missing file: {path}")
        return recs
    with p.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                recs.append(json.loads(line))
            except json.JSONDecodeError as e:
                print(f"[WARN] Skipping bad JSON line in {path}: {e}")
    print(f"[OK] Loaded {len(recs)} records from {path}")
    return recs


# ---------- tiny text helpers ---------- #
def flatten_tokens(x: Any) -> List[str]:
    """Recursively flatten nested token-like structures into a list of strings."""
    if isinstance(x, list):
        out: List[str] = []
        for y in x:
            out.extend(flatten_tokens(y))
        return out
    if isinstance(x, str):
        return [x]
    return [str(x)]


def to_text(field: Any) -> str:
    """Convert a token-like structure into a clean, single string."""
    s = " ".join(flatten_tokens(field)).strip()
    s = re.sub(r"\s+([.,!?;:])", r"\1", s)
    return re.sub(r"\s{2,}", " ", s)


def canon(s: str) -> str:
    """Canonicalize a string (for dedup): strip, lower, collapse whitespace."""
    return re.sub(r"\s+", " ", (s or "").strip()).lower()


# ---------------------- prompt building ---------------------- #
def answer_instruction_for_id(example_id: str) -> str:
    """
    Dataset-specific answer instructions, based on the leading prefix of the id.
    """
    prefix = (str(example_id).split("_")[0]).lower()
    mapping = {
        "folio": (
            "Is the question A) True; B) False or C) Unknown based on the premises? "
            "Answer with exactly one label: A, B, or C."
        ),
        "esnli": (
            "Is the question entailment/contradiction/neutral based on the premises? "
            "Answer with exactly one label: entailment, contradiction, or neutral."
        ),
        "multilogieval": "Answer with exactly one label: yes or no.",
        "proofwriter": (
            "Is the question True/False/Unknown based on the premises? "
            "Answer with exactly one label: True, False, or Unknown."
        ),
        "prontoqa": (
            "Is the question True/false based on the premises? "
            "Answer with exactly one label: True or False."
        ),
        "logiqa": (
            "Is the question entailment/not-entailment based on the premises? "
            "Answer with exactly one label: entailment or not-entailment."
        ),
        "proverqa": (
            "Is the question A) True; B) False or C) Unknown based on the premises? "
            "Answer with exactly one label: A, B, or C."
        ),
        "qasc": "Answer with exactly one label: A, B, C, D, E, F, G, or H.",
        "ar": "Answer with exactly one label: A, B, C, D, or E.",
        "ar_lsat": "Answer with exactly one label: A, B, C, D, or E.",
    }
    return mapping.get(prefix, "")


def build_messages(premise: str, question: str, example_id: str):
    """
    Build a (system, user) message pair encoding the step-by-step reasoning format.
    """
    system = (
        "You are a careful reasoner. Read the premise(s) and the question, then think "
        "step-by-step using numbered steps. For EACH step, write three lines exactly in "
        "this order and wording:\n"
        "Premise:\n"
        "Assumption:\n"
        "Conclusion:\n\n"
        "Formatting rules:\n"
        "- Title each step as 'Step N:' where N starts at 1 and increases by 1.\n"
        "- The Premise of a step must be either (i) one of the given premises OR (ii) a Conclusion from any previous step.\n"
        "- The Assumption must be a commonsense or contextually reasonable assumption that would make sense to a human.\n"
        "- The Conclusion must be new information that logically follows from the Premise and the Assumption.\n"
        "- After you finish all steps, output a final line: 'Final answer: [xxx]'.\n"
        "- If the question has choices A), B), C) ..., put ONLY the option letter inside the brackets (e.g., [A]).\n"
        "- Otherwise, put ONLY the single required label inside the brackets (e.g., [True], [False], [entailment], [contradiction], [neutral], etc.).\n"
        "- Do NOT use XML tags. Do NOT add extra commentary before or after the steps or the final line.\n\n"
        "-----\n"
        "Below is an example:\n"
        "Premise: Harry read a book. People who read book will be smart.\n"
        "Question: Will Harry be smart? Answer with true/false/unknown.\n"
        "Step 1:\n"
        "Premise: Harry read a book. People who read book will be smart.\n"
        "Assumption: Harry is a person. A person is people.\n"
        "Conclusion: Harry will be smart.\n\n"
        "Final answer: [True]\n"
        "End of example\n"
        "-----"
    )
    answer_sentence = answer_instruction_for_id(example_id)
    q_with_constraint = f"{question}\n\n{answer_sentence}" if answer_sentence else question
    user = f"Premise:\n{premise}\n\nQuestion:\n{q_with_constraint}\n"
    return [
        {"role": "system", "content": system},
        {"role": "user", "content": user},
    ]


def collapse_messages_to_prompt(msgs: List[Dict[str, str]]) -> str:
    """
    Collapse (system, user) messages into a flat text prompt.
    """
    return (
        "system\n" + msgs[0]["content"].strip() + "\n"
        + "user\n" + msgs[1]["content"].strip() + "\n"
        + "assistant\n"
    )


# ---------------------- generation formatting ---------------------- #
def format_generation(g: Dict[str, Any]) -> str:
    """
    Convert a rewarded generation into a text blob combining:
      - step-by-step reasoning blocks
      - final answer in [brackets]
    """
    steps = g.get("steps") or []
    final = g.get("final_tag") or g.get("answer") or ""
    reason_parts = []

    for idx, s in enumerate(steps, 1):
        prem_str = to_text(s.get("premises") or "")
        asm_str = to_text(s.get("assumptions") or "")
        con_str = to_text(s.get("conclusion") or "")
        reason_parts.append(
            "\n".join(
                [
                    f"Step {idx}:",
                    f"Premise: {prem_str}",
                    f"Assumption: {asm_str}",
                    f"Conclusion: {con_str}",
                ]
            ).strip()
        )

    reasoning_str = "\n\n".join(reason_parts).strip()
    reasoning_tagged = f"Reasoning:\n{reasoning_str}" if reasoning_str else ""
    answer_tagged = f"\nFinal answer: [{str(final).strip()}]"
    return (reasoning_tagged + answer_tagged).strip()


# ---------------------- normalize inputs ---------------------- #
def normalize_generations(rec: Dict[str, Any]) -> Tuple[str, str, List[Dict[str, Any]]]:
    """
    Normalize a single record into:
      - premise text
      - question text
      - a list of generations with fields {"formatted", "reward", "raw"}
    """
    premise_raw = rec.get("premises") or rec.get("premise") or rec.get("context")
    question_raw = rec.get("question") or ""

    premise = to_text(premise_raw or "")
    question = to_text(question_raw)
    gens = rec.get("generations") or []

    rewards = rec.get("reward_score")
    if isinstance(rewards, list) and len(rewards) == len(gens):
        aligned = rewards
    else:
        aligned = [g.get("reward") for g in gens]

    out: List[Dict[str, Any]] = []
    for idx, g in enumerate(gens):
        r = aligned[idx] if idx < len(aligned) else None
        try:
            r = float(r) if r is not None else None
        except Exception:
            r = None

        if r is None:
            continue
        if SKIP_REWARD_ZERO and r == 0.0:
            continue

        formatted = format_generation(g)
        out.append({"formatted": formatted, "reward": r, "raw": g})

    return premise, question, out


# ---------------------- BEST-vs-WORST with thresholds ---------------------- #
def build_pairs(records_by_id: Dict[str, Dict[str, Any]]) -> List[Dict[str, Any]]:
    """
    Build ONE DPO pair per id:

      - good_answer = generation with the highest reward for that id
                      (must be > GOOD_THRESHOLD)
      - bad_answer  = generation with the lowest reward for that id
                      (must be < BAD_THRESHOLD)

    Skip ids that:
      - have fewer than 2 usable generations, or
      - do not meet the threshold constraints.
    """
    pairs: List[Dict[str, Any]] = []
    total_ids = 0
    with_pairs = 0

    for _id, bundle in records_by_id.items():
        total_ids += 1
        premise_text = to_text(bundle.get("premise", ""))
        question_text = to_text(bundle.get("raw_question", ""))
        question = collapse_messages_to_prompt(build_messages(premise_text, question_text, _id))

        gens = [
            g
            for g in bundle.get("generations", [])
            if isinstance(g.get("reward"), (int, float))
        ]
        if len(gens) < 2:
            continue

        g_best = max(gens, key=lambda x: x["reward"])
        g_worst = min(gens, key=lambda x: x["reward"])

        # Enforce thresholds
        if g_best["reward"] <= GOOD_THRESHOLD:
            continue
        if g_worst["reward"] >= BAD_THRESHOLD:
            continue

        pairs.append(
            {
                "id": _id,
                "question": question,
                "good_answer": g_best["formatted"],
                "bad_answer": g_worst["formatted"],
                "reward_good": float(g_best["reward"]),
                "reward_bad": float(g_worst["reward"]),
            }
        )
        with_pairs += 1

    print(
        f"[INFO] Processed ids: {total_ids} | "
        f"ids with a qualifying best-vs-worst pair: {with_pairs}"
    )
    return pairs


# ---------------------- Load & Merge ---------------------- #
all_records: List[Dict[str, Any]] = []
for f in input_files:
    all_records.extend(read_jsonl(f))

by_id: Dict[str, Dict[str, Any]] = defaultdict(
    lambda: {
        "premise": "",
        "raw_question": "",
        "generations": [],
    }
)

for rec in all_records:
    _id = rec.get("id")
    if _id is None:
        continue
    premise, question, gens = normalize_generations(rec)

    # Keep the longest premise / question we see for each id (best guess)
    if len(premise) > len(by_id[_id]["premise"]):
        by_id[_id]["premise"] = premise
    if len(question) > len(by_id[_id]["raw_question"]):
        by_id[_id]["raw_question"] = question

    by_id[_id]["generations"].extend(gens)

print(f"[INFO] Unique ids merged: {len(by_id)}")


# ---------------------- Build ---------------------- #
pairs = build_pairs(by_id)


# ---------------------- Deduplicate ---------------------- #
def dedupe_instances(pairs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    seen = set()
    out: List[Dict[str, Any]] = []
    for p in pairs:
        key = (
            p["id"],
            canon(p["question"]),
            canon(p["good_answer"]),
            canon(p["bad_answer"]),
        )
        if key in seen:
            continue
        seen.add(key)
        out.append(p)
    return out


pairs = dedupe_instances(pairs)
print(f"[DEDUP] Total pairs to write: {len(pairs)}")


# ---------------------- Write ---------------------- #
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
    for p in pairs:
        f.write(
            json.dumps(
                {
                    "id": p["id"],
                    "question": p["question"],
                    "good_answer": p["good_answer"],
                    "bad_answer": p["bad_answer"],
                    "reward": [p["reward_good"], p["reward_bad"]],
                },
                ensure_ascii=False,
            )
            + "\n"
        )

print(f"[DONE] Wrote pairs to: {output_path}")
