# Use GPT-4o to rollout the data

In [None]:
# build_openai_batch_requests_dedup.py
import os, json, re, hashlib
from pathlib import Path

# ---------- Config ----------
INPUT_JSONL  = "./data.jsonl"   # <-- your input
OUT_DIR      = "./openai_batches"
OPENAI_MODELS = ["gpt-4o"]
N_OPENAI_SAMPLES_PER_Q = 3
MAX_TOKENS   = 4096
TOP_P        = 0.95
TEMPS        = [1, 1, 1]
BASE_SEED    = 20250815

# Dedup strategy
DEDUP_BY_KEY = True   # dedupe questions by (id, premise, question)
# If DEDUP_BY_KEY=False, we still guarantee unique custom_id via suffixing on collision.

# ---------- Robust text helpers ----------
def flatten_tokens(x):
    if x is None:
        return []
    if isinstance(x, (list, tuple)):
        out = []
        for y in x:
            out.extend(flatten_tokens(y))
        return out
    if isinstance(x, dict):
        for key in ("text", "content", "value"):
            if key in x and isinstance(x[key], (str, list, dict)):
                return flatten_tokens(x[key])
        return [json.dumps(x, ensure_ascii=False)]
    if isinstance(x, (str, int, float, bool)):
        return [str(x)]
    return [str(x)]

def to_text(field) -> str:
    s = " ".join(flatten_tokens(field))
    s = s.replace("\r\n", "\n")
    s = re.sub(r"\s+([.,!?;:])", r"\1", s)
    s = re.sub(r"[ \t]{2,}", " ", s)
    return s.strip()

# ---------- Same prompt as Qwen ----------
def answer_instruction_for_id(example_id: str) -> str:
    prefix = (str(example_id).split("_")[0]).lower()
    mapping = {
        "folio":         "Answer with exactly one label: True, False, or Uncertain.",
        "esnli":         "Answer with exactly one label: entailment, contradiction, or neutral.",
        "multilogieval": "Answer with exactly one label: yes or no.",
        "proofwriter":   "Answer with exactly one label: True, False, or Unknown.",
        "prontoqa":      "Answer with exactly one label: True, False, or Unknown.",
        "logiqa":        "Answer with exactly one label: entailment or not-entailment.",
        "proverqa":      "Answer with exactly one label: True, False, or Unknown.",
        "qasc":          "Answer with the OPTION LETTER ONLY: A, B, C, D, E, F, G, or H.",
    }
    return mapping.get(prefix, "")

def build_messages(premise, question, example_id):
    system = (
        "You are a meticulous logician. Read the premise and question carefully. "
        "Reason step-by-step explicitly in a Natural Language Inference (NLI) style. "
        "Each reasoning step must follow this format:\n\n"
        "Step X:\n"
        "Premises:\n"
        "- ...\n"
        "- ...\n"
        "Assumptions:\n"
        "- ...\n"
        "Conclusion:\n"
        "- ...\n\n"
        "Number each step (Step 1, Step 2, etc.) in order. "
        "Do not skip steps; ensure each conclusion is directly supported by its premises and stated assumptions. Please make sure the assumption must make sense under the context. \n\n"
        "Ensure that all assumptions are stated clearly and rigorously. For example, use specific entities instead of vague references such as 'it' or 'this'. Do not restate premises as assumptions. Premises must remain in the premises field, while assumptions should capture what is not directly stated in the premises but can reasonably be inferred from the premises using commonsense. Make sure all necessary information are provided in each step's premise and assumption to make the conclusion. \n"
        "After the final step, output exactly ONE final tag:\n"
        "<answer>{LABEL}</answer>\n\n"
        "Rules:\n"
        "1) If the question includes multiple-choice options labeled like 'A) ...', 'B) ...', etc., "
        "   answer with the OPTION LETTER ONLY (e.g., A).\n"
        "2) Otherwise, answer with a single classification label that fits the task "
        "   (e.g., True/False/Unknown, yes/no, entailment/contradiction/neutral, not-entailment, etc.).\n"
        "3) The label must match the datasetâ€™s expected surface form (case-sensitive, no extra words).\n"
        "4) After your steps, produce exactly one <answer>...</answer> tag and nothing else."
    )
    ans_sent = answer_instruction_for_id(example_id)
    q2 = f"{question}\n\n{ans_sent}" if ans_sent else question
    user = (
        f"Premise:\n{premise}\n\n"
        f"Question:\n{q2}\n\n"
        "Format your reply EXACTLY as follows:\n"
        "Step 1:\n"
        "Premises:\n"
        "- [list the premises used in this step]\n"
        "Assumptions:\n"
        "- [state any implicit assumptions made in this step]\n"
        "Conclusion:\n"
        "- [state the conclusion entailed from those premises and assumptions]\n\n"
        "Step 2:\n"
        "Premises:\n"
        "- ...\n"
        "Assumptions:\n"
        "- ...\n"
        "Conclusion:\n"
        "- ...\n\n"
        "...\n\n"
        "Final:\n"
        "<answer>YOUR_SINGLE_LABEL_OR_OPTION_LETTER</answer>\n\n"
        "Example (for just one step):\n"
        "Step 1:\n"
        "Premises:\n"
        "- All birds have wings.\n"
        "- Penguins are birds.\n"
        "Assumptions:\n"
        "- None \n"
        "Conclusion:\n"
        "- Therefore, penguins have wings.\n\n"
        "Example (assumptions illustrated):\n"
        "Step 1:\n"
        "Premises:\n"
        "- In a school, every student is either studying in the library or playing in the playground.\n"
        "- Emily was not seen in the library on Friday.\n"
        "Assumptions:\n"
        "- Emily is a student.\n"
        "- \"Not seen in the library\" implies \"not studying in the library.\"\n"
        "Conclusion:\n"
        "- Therefore, Emily was playing in the playground on Friday.\n\n"
        "Final:\n"
        "<answer>yes</answer> \n"
        "Make sure all necessary information are provided in each step's premise and assumption to make the conclusion."
    )
    return [{"role": "system", "content": system},
            {"role": "user",   "content": user}]

# ---------- Helpers for uniqueness ----------
def key_for_rec(rec) -> str:
    rid = rec.get("id", "")
    premise = to_text(rec.get("premise", ""))
    question = to_text(rec.get("question", ""))
    j = json.dumps({"id": rid, "premise": premise, "question": question}, ensure_ascii=False, sort_keys=True)
    return hashlib.sha1(j.encode("utf-8")).hexdigest()

def make_unique_custom_id(base: str, seen: set) -> str:
    """Return a unique custom_id by appending __dup<N> if needed."""
    if base not in seen:
        seen.add(base)
        return base
    i = 1
    while True:
        cand = f"{base}__dup{i}"
        if cand not in seen:
            seen.add(cand)
            return cand
        i += 1
    # Alternative: always append a global index
    # return f"{base}__n{len(seen)}"

# ---------- Build batch request files per model ----------
def main():
    Path(OUT_DIR).mkdir(parents=True, exist_ok=True)
    total_requests = 0

    # Optional: preload to dedupe by key
    dedup_keys = set()
    # For logging
    duplicate_questions_skipped = 0

    # Read the whole input once if deduping, else stream per model
    with open(INPUT_JSONL, "r", encoding="utf-8") as fin:
        records = []
        for line in fin:
            line = line.strip()
            if not line:
                continue
            rec = json.loads(line)
            if DEDUP_BY_KEY:
                k = key_for_rec(rec)
                if k in dedup_keys:
                    duplicate_questions_skipped += 1
                    continue
                dedup_keys.add(k)
            records.append(rec)

    for model in OPENAI_MODELS:
        out_path = f"openai_rollout_requests_{model}.jsonl"
        written = 0
        seen_custom_ids = set()  # ensure no duplicates in THIS file

        with open(out_path, "w", encoding="utf-8") as fout:
            for rec in records:
                rid = rec.get("id", "")
                premise  = to_text(rec.get("premise", ""))
                question = to_text(rec.get("question", ""))

                messages = build_messages(premise, question, rid)

                for k in range(N_OPENAI_SAMPLES_PER_Q):
                    temp = TEMPS[k % len(TEMPS)]
                    seed = BASE_SEED + k

                    base_custom = f"{rid}__rep{k}__{model}"
                    custom_id = make_unique_custom_id(base_custom, seen_custom_ids)

                    req = {
                        "custom_id": custom_id,
                        "method": "POST",
                        "url": "/v1/chat/completions",
                        "body": {
                            "model": model,
                            "messages": messages,
                            "max_completion_tokens": MAX_TOKENS,
                            "temperature": 0,
                            # "temperature": float(temp)
                        },
                    }
                    fout.write(json.dumps(req, ensure_ascii=False) + "\n")
                    written += 1

        print(f"[{model}] Wrote {written} requests to {out_path}")

        total_requests += written

    if DEDUP_BY_KEY:
        print(f"Skipped {duplicate_questions_skipped} duplicate questions (by (id,premise,question)).")
    print(f"Total requests across models: {total_requests}")


main()


## Helper function for batch processing

In [None]:
import openai
import json

openai.api_key = "YOUR_API_KEY"
client = openai.Client(api_key=openai.api_key)

def submit_batch(batch_path, desc=""):
    path = batch_path
    batch_input_file = client.files.create(
    file=open(path, "rb"),
    purpose="batch"
    )

    batch_input_file_id = batch_input_file.id

    print("Batch id: ", batch_input_file_id)

    batch_request = client.batches.create(
        input_file_id=batch_input_file_id,
        endpoint="/v1/chat/completions",
        completion_window="24h",
        metadata={
        "description": desc
        }
    )

    print("Batch request: ", batch_request)

def cancel_batch(batch_id):
    client.batches.cancel(batch_id)
    print("Batch cancelled: ", batch_id)

def check_batch(batch_id):
    batch = client.batches.retrieve(batch_id)
    print("Batch: ", batch)

def retrieve_batch(file_id, save_path):
    file_response = client.files.content(file_id)
    response_text = file_response.text

    responses = [json.loads(line) for line in response_text.strip().split('\n')]
    output_path = save_path
    with open(output_path, 'w') as outfile:
        json.dump(responses, outfile, indent=2)

    print(f"Responses saved to {output_path}")

## Check and retrieve batch processing

In [None]:
# submit_batch(out_path, "Rollout requests for GPT-4o")

# Check the batch status

# check_batch("batch_id") # replace the batch id

# Cancel the batch if needed
# cancel_batch(batch_id) # replace the batch id


# Retrieve batch responses

output_file_path = "./openai_rollout_responses.json"
retrieve_batch("output_file_id", output_file_path) # replace output file id


## Process batch request output

In [None]:
#!/usr/bin/env python3
"""
Notebook-friendly parser for a *single* OpenAI Batch JSON/JSONL file.

Features:
- STRICT custom_id parsing via regex: ^(?P<base>.+?)__rep(?P<rep>\d+)__
  -> prevents accidental grouping like `multilogieval_136` with `multilogieval_1364`.
- EXACTLY THREE samples per base id: keeps at most rep0, rep1, rep2 (in order).
- DEDUP per rep: if duplicates exist for a rep, the first occurrence is kept.
- Robust JSON/JSONL reading (handles arrays, JSONL, lines with trailing commas).

Edit BATCH_PATH / GROUND_PATH / OUT_PATH below, then call `run()` (e.g. in a notebook).
"""

from __future__ import annotations

import json
import re
import sys
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Tuple

# ====== EDIT THESE ======
BATCH_PATH = output_file_path      
GROUND_PATH = "./data.jsonl"    
OUT_PATH = "./openai_rollout_processed.jsonl"      
# ========================

ANSWER_RE = re.compile(r"<answer>\s*(.*?)\s*</answer>", re.IGNORECASE | re.DOTALL)
# Strictly capture base id and rep index
CUSTOM_ID_RE = re.compile(r"^(?P<base>.+?)__rep(?P<rep>\d+)__")
ALLOWED_REPS = (0, 1, 2)


def _json_loads_lenient(line: str) -> object:
    """Try to load a JSON object, trimming trailing commas if present."""
    try:
        return json.loads(line)
    except json.JSONDecodeError:
        trimmed = line.rstrip().rstrip(",")
        if trimmed != line:
            try:
                return json.loads(trimmed)
            except json.JSONDecodeError:
                pass
        raise


def read_json_or_jsonl(path: Path) -> Iterable[dict]:
    """
    Read a single file that may be:
      - JSONL (one JSON object per line; sometimes lines end with commas)
      - JSON array of objects
      - Single JSON object (wrapped into an iterable of one)
    Yields dict items.
    """
    text = path.read_text(encoding="utf-8").strip()
    if not text:
        return []

    # Try full JSON first
    try:
        obj = json.loads(text)
        if isinstance(obj, list):
            for item in obj:
                if isinstance(item, dict):
                    yield item
                else:
                    yield {"value": item}
            return
        if isinstance(obj, dict):
            yield obj
            return
        # Fall through to JSONL if other types
    except json.JSONDecodeError:
        pass

    # Treat as JSONL
    for raw in text.splitlines():
        line = raw.strip()
        if not line:
            continue
        try:
            item = _json_loads_lenient(line)
            if isinstance(item, dict):
                yield item
            else:
                yield {"value": item}
        except json.JSONDecodeError:
            # Ignore malformed fragments
            continue


def load_ground_truth(ground_truth_path: Path) -> Dict[str, dict]:
    """
    Load ground-truth JSONL: returns mapping base_id -> record.
    Each record is expected to include: id, premise, question, answer.
    """
    gt: Dict[str, dict] = {}
    with ground_truth_path.open("r", encoding="utf-8") as f:
        for i, line in enumerate(f, 1):
            line = line.strip()
            if not line:
                continue
            try:
                rec = json.loads(line)
            except json.JSONDecodeError as e:
                print(f"[warn] ground-truth line {i} JSON error: {e}", file=sys.stderr)
                continue

            rec_id = rec.get("id")
            if not rec_id:
                print(f"[warn] ground-truth line {i} missing 'id'", file=sys.stderr)
                continue
            gt[rec_id] = rec
    return gt


def parse_custom_id(custom_id: str) -> Optional[Tuple[str, int]]:
    """
    Extract (base_id, rep_idx) using a strict regex. Returns None if not matched.

    Example:
      'multilogieval_136__rep0__gpt-4o-2024-01-01' -> ('multilogieval_136', 0)
    """
    match = CUSTOM_ID_RE.match(custom_id)
    if not match:
        return None
    base = match.group("base")
    rep_s = match.group("rep")
    try:
        rep = int(rep_s)
    except ValueError:
        return None
    return base, rep


def extract_answer(text: str) -> str:
    """Return content inside <answer>...</answer>, or '' if not found."""
    if not isinstance(text, str):
        return ""
    match = ANSWER_RE.search(text)
    if match:
        return match.group(1).strip()
    return ""


def pull_generation_from_entry(entry: dict) -> Optional[str]:
    """
    Fetch the assistant content from a batch entry.

    Handles both:
      - entry["response"]["body"]["choices"][0]["message"]["content"]
      - entry["body"]["choices"][0]["message"]["content"]
    with a fallback to "text" if present.
    """
    try:
        body = entry.get("response", {}).get("body", {})
        choices = body.get("choices")
        if isinstance(choices, list) and choices:
            msg = choices[0].get("message") or {}
            content = msg.get("content")
            if isinstance(content, str):
                return content

        # Alternative layout: entry["body"]["choices"][0]["message"]["content"]
        choices_alt = entry.get("body", {}).get("choices")
        if isinstance(choices_alt, list) and choices_alt:
            msg = choices_alt[0].get("message") or {}
            content = msg.get("content")
            if isinstance(content, str):
                return content

        # Fallback: choices[0]["text"]
        if isinstance(choices, list) and choices:
            text = choices[0].get("text")
            if isinstance(text, str):
                return text
    except Exception as e:
        print(f"[warn] could not extract generation: {e}", file=sys.stderr)
    return None


def parse_batch(batch_path: Path) -> Dict[str, Dict[int, Tuple[str, str]]]:
    """
    Parse the batch output file.

    Returns:
      base_id -> { rep_idx -> (generation_text, extracted_answer) }

    Only keeps rep indices in ALLOWED_REPS and the *first* occurrence per rep.
    """
    groups: Dict[str, Dict[int, Tuple[str, str]]] = {}

    for entry in read_json_or_jsonl(batch_path):
        custom_id = entry.get("custom_id")
        if not custom_id:
            print("[warn] missing custom_id in entry; skipping.", file=sys.stderr)
            continue

        parsed = parse_custom_id(custom_id)
        if not parsed:
            # Ignore entries that do not match the strict pattern
            continue

        base_id, rep_idx = parsed

        if rep_idx not in ALLOWED_REPS:
            # Ignore reps beyond the sampled set (keep exactly 0, 1, 2)
            continue

        gen = pull_generation_from_entry(entry)
        if gen is None:
            print(f"[warn] no generation content for {custom_id}; skipping.", file=sys.stderr)
            continue

        ans = extract_answer(gen)

        # Create base bucket
        if base_id not in groups:
            groups[base_id] = {}

        # Dedup per rep: keep first occurrence only
        if rep_idx in groups[base_id]:
            continue

        groups[base_id][rep_idx] = (gen, ans)

    return groups


def run() -> str:
    """Parse batch outputs, join with ground truth, and write a processed JSONL."""
    batch_path = Path(BATCH_PATH)
    if not batch_path.exists():
        print(f"[error] batch file not found: {batch_path}", file=sys.stderr)
        return ""

    ground_path = Path(GROUND_PATH)
    if not ground_path.exists():
        print(f"[error] ground-truth file not found: {ground_path}", file=sys.stderr)
        return ""

    out_path = Path(OUT_PATH)

    gt = load_ground_truth(ground_path)
    groups = parse_batch(batch_path)

    total_written = 0
    with out_path.open("w", encoding="utf-8") as out_f:
        for base_id, rep_map in groups.items():
            gt_rec = gt.get(base_id)
            if not gt_rec:
                # Strict exact match with ground-truth id only
                print(
                    f"[warn] base_id '{base_id}' not found in ground truth; skipping.",
                    file=sys.stderr,
                )
                continue

            # Assemble in rep order 0,1,2 (skip missing)
            generations: List[str] = []
            extracted: List[str] = []
            for r in ALLOWED_REPS:
                if r in rep_map:
                    gen, ans = rep_map[r]
                    generations.append(gen)
                    extracted.append(ans)

            out_obj = {
                "id": base_id,
                "premise": gt_rec.get("premise"),
                "question": gt_rec.get("question"),
                "answer": gt_rec.get("answer"),
                "generation": generations,
                "extracted_answer": extracted,
            }
            out_f.write(json.dumps(out_obj, ensure_ascii=False) + "\n")
            total_written += 1

    print(f"[info] wrote {total_written} records to {out_path}")
    return str(out_path)


run()
