## Merge openai and qwen rollout

In [None]:
## Merge openai and qwen

file1 = "./openai_rollout_processed.jsonl"   # first file
file2 = "./qwen_rollout.jsonl"     # second file
out_file = "./rollout_full.jsonl"  # output

count = 0
with open(out_file, "w", encoding="utf-8") as fout:
    for src in (file1, file2):
        with open(src, "r", encoding="utf-8") as fin:
            for line in fin:
                line = line.rstrip("\n")
                if line.strip():  # skip blank lines
                    fout.write(line + "\n")
                    count += 1

print(f"Done. Wrote {count} total records to {out_file}")


## Post process for reward later

In [None]:
#!/usr/bin/env python3
"""
Parse step-by-step reasoning generations into a structured JSONL format.

Input
-----
Each line in the input JSONL is expected to be either:
  - A dict with keys:
      "id", "premise", "question", "answer",
      "generation" (list[str]) or "generations" (list[str])
  - A list of such dicts (less common)

Each `generation` is a text blob that (ideally) follows a pattern like:

    Step 1:
    Premises:
    - ...
    Assumptions:
    - ...
    Conclusion:
    - ...

    Step 2:
    ...

    Final:
    <answer>LABEL</answer>

Output
------
For each input record, we write one line to the output JSONL with the schema:

{
  "id": ...,
  "premise": ...,
  "question": ...,
  "answer": ...,
  "extracted_answer": [... or None],
  "generations": [
    {
      "final_tag": str | null,
      "step_order_ok": bool,
      "step_numbers": [int],
      "missing_steps": [int],
      "duplicate_steps": [int],
      "steps": [
        {
          "n": int,
          "premises": [str],
          "assumptions": [str],
          "conclusion": str
        },
        ...
      ]
    },
    ...
  ]
}
"""

import json
import re
from typing import List, Dict, Any, Tuple, Optional

# ---------------------------------------------------------------------------
# Step / section parsing
# ---------------------------------------------------------------------------

STEP_RE = re.compile(r"(?:^|\n)Step\s+(\d+)\s*:?\s*\n", re.IGNORECASE)
FINAL_RE = re.compile(r"(?:^|\n)Final\s*:?", re.IGNORECASE)

SECTION_PATTERNS = {
    "premises": re.compile(r"(?:^|\n)\s*Premises?\s*:\s*", re.IGNORECASE),
    "assumptions": re.compile(r"(?:^|\n)\s*Assumptions?\s*:\s*", re.IGNORECASE),
    "conclusion": re.compile(r"(?:^|\n)\s*Conclusion\s*:\s*", re.IGNORECASE),
}

ANSWER_TAG_RE = re.compile(
    r"<\s*answer\s*>\s*(.*?)\s*<\s*/\s*answer\s*>",
    re.IGNORECASE | re.DOTALL,
)


def _clean_bullets(block: str) -> List[str]:
    """
    Turn a text block into a list of cleaned lines:
    - strip leading/trailing whitespace
    - remove leading bullet characters like "- " or "• "
    - drop empty lines
    """
    items: List[str] = []
    for line in block.splitlines():
        line = line.strip()
        if not line:
            continue
        # Remove leading bullet characters like "- " or "• "
        line = re.sub(r"^[\-\u2022]\s*", "", line)
        items.append(line)
    return items


def _extract_sections(step_text: str) -> Dict[str, Any]:
    """
    Extract sections for a single step.

    Returns a dict with keys:
      - "premises": List[str]
      - "assumptions": List[str]
      - "conclusion": str

    Missing sections yield [] or "" respectively.
    """
    found = []
    for key, pat in SECTION_PATTERNS.items():
        m = pat.search(step_text)
        if m:
            found.append((key, m.start(), m.end()))

    # If no known headers, return empty
    if not found:
        return {"premises": [], "assumptions": [], "conclusion": ""}

    # Sort by position to slice content between headers
    found.sort(key=lambda x: x[1])

    # Build a dict of raw blocks between headers
    raw_blocks: Dict[str, str] = {}
    for i, (key, _start, end) in enumerate(found):
        next_start = len(step_text) if i == len(found) - 1 else found[i + 1][1]
        raw_blocks[key] = step_text[end:next_start].strip()

    # Clean and normalize
    premises_list = _clean_bullets(raw_blocks.get("premises", "")) if "premises" in raw_blocks else []
    assumptions_list = _clean_bullets(raw_blocks.get("assumptions", "")) if "assumptions" in raw_blocks else []
    conclusion_lines = _clean_bullets(raw_blocks.get("conclusion", "")) if "conclusion" in raw_blocks else []
    conclusion = " ".join(conclusion_lines).strip()

    return {
        "premises": premises_list,
        "assumptions": assumptions_list,
        "conclusion": conclusion,
    }


def _find_step_blocks(text: str) -> List[Tuple[int, str]]:
    """
    Return a list of (step_number, step_text_block) in order.

    Splits on 'Step N:' and stops each block at the next 'Step M:', 'Final:',
    or the end of the text.
    """
    steps = list(STEP_RE.finditer(text))
    if not steps:
        return []

    # Find 'Final:' marker (if any) to cap the last step
    final_match = FINAL_RE.search(text)
    final_start = final_match.start() if final_match else len(text)

    blocks: List[Tuple[int, str]] = []
    for i, m in enumerate(steps):
        step_num = int(m.group(1))
        start = m.end()
        if i + 1 < len(steps):
            end = steps[i + 1].start()
        else:
            end = final_start  # stop at "Final:" if present; else end of text
        block = text[start:end].strip()
        if block:
            blocks.append((step_num, block))
    return blocks


# ---------------------------------------------------------------------------
# Transformation helpers for the target output format
# ---------------------------------------------------------------------------


def parse_single_generation(gen_text: str) -> Dict[str, Any]:
    """
    Parse one generation text into a structured object:

    {
      "final_tag": str | None,
      "step_order_ok": bool,
      "step_numbers": [int],
      "missing_steps": [int],
      "duplicate_steps": [int],
      "steps": [
        {
          "n": int,
          "premises": [...],
          "assumptions": [...],
          "conclusion": str
        },
        ...
      ]
    }
    """
    # 1) Steps
    steps_info: List[Dict[str, Any]] = []
    for step_num, step_block in _find_step_blocks(gen_text):
        sections = _extract_sections(step_block)
        steps_info.append(
            {
                "n": step_num,
                "premises": sections["premises"],
                "assumptions": sections["assumptions"],
                "conclusion": sections["conclusion"],
            }
        )

    # 2) Step numbers & order checks
    step_numbers = [s["n"] for s in steps_info]
    step_order_ok = step_numbers == sorted(step_numbers)

    # Missing steps: assume steps should start at 1 if any exist
    missing_steps: List[int] = []
    if step_numbers:
        expected = set(range(1, max(step_numbers) + 1))
        present = set(step_numbers)
        missing_steps = sorted(list(expected - present))

    # Duplicates
    counts: Dict[int, int] = {}
    for n in step_numbers:
        counts[n] = counts.get(n, 0) + 1
    duplicate_steps = sorted([n for n, c in counts.items() if c > 1])

    # 3) Final tag (from <answer>...</answer>)
    match = ANSWER_TAG_RE.search(gen_text)
    final_tag: Optional[str] = match.group(1).strip() if match else None

    return {
        "final_tag": final_tag,
        "step_order_ok": bool(step_order_ok),
        "step_numbers": step_numbers,
        "missing_steps": missing_steps,
        "duplicate_steps": duplicate_steps,
        "steps": sorted(steps_info, key=lambda d: d["n"]),
    }


def transform_record(data: Dict[str, Any]) -> Dict[str, Any]:
    """
    Transform an input JSON record into the target output schema.
    """
    out: Dict[str, Any] = {
        "id": data.get("id"),
        "premise": data.get("premise"),
        "question": data.get("question"),
        "answer": data.get("answer"),
        "extracted_answer": data.get("extracted_answer"),
        "generations": [],
    }

    # The input may use key "generation" (list of strings) or "generations"
    generations = data.get("generation") or data.get("generations") or []
    for gen_text in generations:
        out["generations"].append(parse_single_generation(gen_text))

    return out


def main() -> None:
    # Edit these paths as needed
    input_path = "./rollout_full.jsonl"
    output_path = "./rollout_processed.jsonl"

    with open(input_path, "r", encoding="utf-8") as f_in, open(
        output_path,
        "w",
        encoding="utf-8",
    ) as f_out:
        for line in f_in:
            line = line.strip()
            if not line:
                continue

            raw = json.loads(line)

            if isinstance(raw, dict) and "generation" in raw:
                result: Any = transform_record(raw)
            elif isinstance(raw, list):
                result = [transform_record(rec) for rec in raw]
            else:
                # Fallback: treat as a single dict-shaped record
                result = transform_record(raw)

            f_out.write(json.dumps(result, ensure_ascii=False) + "\n")

    print(f"Wrote output to {output_path}")


if __name__ == "__main__":
    main()
