In [None]:
# Jupyter 版：把 *_pw_aug.jsonl / *_sym_aug.jsonl 中非空的 aug_head/aug_mid/aug_tail
# 分别路由到 toxic_{pw|sym}_{head|mid|tail}_aug.jsonl（仅当对应非空时才写入/创建）

import json
from pathlib import Path
from typing import Dict, Iterable, Tuple, Optional

# ========= 配置（按需修改） =========
PW_IN_PATH = "/Users/zhzhou/Desktop/SafeContextualReward/results/document_synthesis/v4_rpj_llama_s4/toxicity_corpus/tox_top_autoalloc_v1/toxic_topk_200k_pw_aug.jsonl"
SYM_IN_PATH = "/Users/zhzhou/Desktop/SafeContextualReward/results/document_synthesis/v4_rpj_llama_s4/toxicity_corpus/tox_top_autoalloc_v1/toxic_topk_200k_sym_aug.jsonl"
OUT_DIR    = "/Users/zhzhou/Desktop/SafeContextualReward/results/document_synthesis/v4_rpj_llama_s4/cpt_data/"

# ========= 工具 =========
def _iter_jsonl(path: Path) -> Iterable[Dict]:
    with path.open("r", encoding="utf-8") as f:
        for i, line in enumerate(f, 1):
            s = line.strip()
            if not s:
                continue
            try:
                yield json.loads(s)
            except Exception as e:
                print(f"[WARN] skip malformed json at {path.name}:{i}: {e}")

def _make_record(obj: Dict, aug_text: str, kind: str, where: str) -> Dict:
    # 只输出精简字段；保留 token（如果源里有 pw_head/sym_head 等）
    rec = {
        "text": aug_text,
        "aug_type": where,
    }
    tok_key = f"{'pw' if kind == 'pw' else 'sym'}_{where}"
    if tok_key in obj:
        rec["token"] = obj[tok_key]
    return rec

class LazyWriters:
    """按需创建 writer，避免写出空文件。"""
    def __init__(self, out_dir: Path, kind: str):
        self.out_dir = out_dir
        self.kind = kind
        self._files = {}   # where -> file handle
        self.counts = {"head":0, "mid":0, "tail":0}
    def _path(self, where: str) -> Path:
        return self.out_dir / f"toxic_{self.kind}_{where}_aug.jsonl"
    def write(self, where: str, rec: Dict):
        if where not in self._files:
            self.out_dir.mkdir(parents=True, exist_ok=True)
            self._files[where] = self._path(where).open("w", encoding="utf-8")
        self._files[where].write(json.dumps(rec, ensure_ascii=False) + "\n")
        self.counts[where] += 1
    def close(self):
        for fp in self._files.values():
            try: fp.close()
            except: pass

def _split_one_file(in_path: Path, out_dir: Path, kind: str) -> Tuple[int,int,int,int]:
    assert kind in ("pw", "sym")
    writers = LazyWriters(out_dir, kind)
    total = 0
    for obj in _iter_jsonl(in_path):
        total += 1
        # 哪个非空就写哪个（可能多个都非空 => 多路输出）
        h = obj.get("aug_head")
        if isinstance(h, str) and h.strip():
            writers.write("head", _make_record(obj, h, kind, "head"))
        m = obj.get("aug_mid")
        if isinstance(m, str) and m.strip():
            writers.write("mid", _make_record(obj, m, kind, "mid"))
        t = obj.get("aug_tail")
        if isinstance(t, str) and t.strip():
            writers.write("tail", _make_record(obj, t, kind, "tail"))
    writers.close()
    print(f"[OK] {in_path.name} -> "
          f"head({writers.counts['head']}), mid({writers.counts['mid']}), tail({writers.counts['tail']}); "
          f"read {total} lines total.")
    return total, writers.counts["head"], writers.counts["mid"], writers.counts["tail"]

# ========= 执行 =========
out_dir = Path(OUT_DIR)
total = head = mid = tail = 0

pw_path = Path(PW_IN_PATH)
sym_path = Path(SYM_IN_PATH)

if pw_path.exists():
    t, h, m, ta = _split_one_file(pw_path, out_dir, kind="pw")
    total += t; head += h; mid += m; tail += ta
else:
    print(f"[SKIP] pw file not found: {pw_path}")

if sym_path.exists():
    t, h, m, ta = _split_one_file(sym_path, out_dir, kind="sym")
    total += t; head += h; mid += m; tail += ta
else:
    print(f"[SKIP] sym file not found: {sym_path}")

print(f"[SUM] total_rows={total}; head_lines={head}; mid_lines={mid}; tail_lines={tail}; outdir={out_dir}")


[OK] toxic_topk_200k_pw_aug.jsonl -> head(66331), mid(66757), tail(66912); read 200000 lines total.
[OK] toxic_topk_200k_sym_aug.jsonl -> head(66331), mid(66757), tail(66912); read 200000 lines total.
[SUM] total_rows=400000; head_lines=132662; mid_lines=133514; tail_lines=133824; outdir=/Users/zhzhou/Desktop/SafeContextualReward/results/document_synthesis/v4_rpj_llama_s4/cpt_data


In [17]:
# Jupyter 版本：把两份 CSV（含列 words,is_all_doc,is_unsafe,matched,text,aug_head,aug_mid,aug_tail）
# 中非空的 aug_head/aug_mid/aug_tail 分别写到：
#   rpj_pw_head_aug.jsonl / rpj_pw_mid_aug.jsonl / rpj_pw_tail_aug.jsonl
# 仅当对应非空时才创建/写入文件

import json
from pathlib import Path
from typing import Iterable, Dict, Any, List, Tuple
import pandas as pd
import math
import ast

# ========================= 配置 =========================
INPUT_CSVS = [
    "/Users/zhzhou/Desktop/SafeContextualReward/results/document_synthesis/v4_rpj_llama_s4/screened/seed_v1/augmented_docs.csv",
    "/Users/zhzhou/Desktop/SafeContextualReward/results/document_synthesis/v4_rpj_llama_s4/screened/seed_v1/unsafe_candidates_augmented.csv",
]
OUTDIR = "/Users/zhzhou/Desktop/SafeContextualReward/results/document_synthesis/v4_rpj_llama_s4/cpt_data/"

REQUIRED_COLS = ["words","is_all_doc","is_unsafe","matched","text","aug_head","aug_mid","aug_tail"]

# ========================= 工具函数 =========================
def parse_words(val) -> List[str]:
    """尽量把 CSV 里的 words 解析成字符串列表。"""
    if val is None or (isinstance(val, float) and math.isnan(val)):
        return []
    if isinstance(val, list):
        return [str(x) for x in val]
    if isinstance(val, (set, tuple)):
        return [str(x) for x in list(val)]
    s = str(val).strip()
    if not s:
        return []
    if (s.startswith("[") and s.endswith("]")) or \
       (s.startswith("(") and s.endswith(")")) or \
       (s.startswith("{") and s.endswith("}")):
        try:
            obj = ast.literal_eval(s)
            if isinstance(obj, (list, tuple, set)):
                return [str(x) for x in obj]
        except Exception:
            pass
    return [s]

def words_to_string(val) -> str:
    """把 words（可能是 list/tuple/set/str/None）转成无方括号的字符串。"""
    if val is None or (isinstance(val, float) and math.isnan(val)):
        return ""
    if isinstance(val, (list, tuple, set)):
        return ", ".join(str(x) for x in val)
    return str(val)

def iter_rows(csv_path: Path) -> Iterable[Dict[str, Any]]:
    df = pd.read_csv(csv_path)
    missing = [c for c in REQUIRED_COLS if c not in df.columns]
    if missing:
        raise ValueError(f"{csv_path} 缺少列: {missing}")
    for _, row in df.iterrows():
        yield {
            "words": parse_words(row.get("insert_pw")),
            "text": row.get("text", ""),
            "aug_head": row.get("aug_head", None),
            "aug_mid":  row.get("aug_mid",  None),
            "aug_tail": row.get("aug_tail", None),
        }

class LazyWriters:
    """按需创建 writer，避免空文件。"""
    def __init__(self, out_dir: Path):
        self.out_dir = out_dir
        self._files = {}  # where -> file handle
        self.counts = {"head": 0, "mid": 0, "tail": 0}
    def _path(self, where: str) -> Path:
        return self.out_dir / f"rpj_pw_{where}_aug.jsonl"
    def write(self, where: str, rec: Dict[str, Any]):
        if where not in self._files:
            self.out_dir.mkdir(parents=True, exist_ok=True)
            self._files[where] = self._path(where).open("w", encoding="utf-8")
        self._files[where].write(json.dumps(rec, ensure_ascii=False) + "\n")
        self.counts[where] += 1
    def close(self):
        for fp in self._files.values():
            try: fp.close()
            except: pass

def route_csv_to_jsonls(
    records: Iterable[Dict[str, Any]],
    out_dir: Path
) -> Tuple[int, int, int, int]:
    """
    将记录按非空 aug_* 路由到对应 JSONL：
      rpj_pw_head_aug.jsonl / rpj_pw_mid_aug.jsonl / rpj_pw_tail_aug.jsonl
    返回 (total_read, n_head, n_mid, n_tail)
    """
    writers = LazyWriters(out_dir)
    total = 0
    for r in records:
        total += 1
        words_str = words_to_string(r.get("words", []))
        # head
        ah = r.get("aug_head")
        if isinstance(ah, str) and ah.strip():
            writers.write("head", {"text": ah, "aug_type": "head", "words": words_str})
        # mid
        am = r.get("aug_mid")
        if isinstance(am, str) and am.strip():
            writers.write("mid", {"text": am, "aug_type": "mid", "words": words_str})
        # tail
        at = r.get("aug_tail")
        if isinstance(at, str) and at.strip():
            writers.write("tail", {"text": at, "aug_type": "tail", "words": words_str})
    writers.close()
    print(f"[OK] wrote -> head({writers.counts['head']}), mid({writers.counts['mid']}), tail({writers.counts['tail']}); "
          f"from {total} input rows.")
    return total, writers.counts["head"], writers.counts["mid"], writers.counts["tail"]

# ========================= 执行 =========================
out_dir = Path(OUTDIR)
total = head = mid = tail = 0
for csv_path_str in INPUT_CSVS:
    p = Path(csv_path_str)
    if not p.exists():
        print(f"[SKIP] CSV not found: {p}")
        continue
    rows = iter_rows(p)  # 迭代器，节省内存
    t, h, m, ta = route_csv_to_jsonls(rows, out_dir)
    total += t; head += h; mid += m; tail += ta

print(f"[SUM] total_rows={total}; head_lines={head}; mid_lines={mid}; tail_lines={tail}; outdir={out_dir}")


[OK] wrote -> head(868), mid(867), tail(887); from 2622 input rows.
[OK] wrote -> head(22247), mid(22109), tail(22338); from 66694 input rows.
[SUM] total_rows=69316; head_lines=23115; mid_lines=22976; tail_lines=23225; outdir=/Users/zhzhou/Desktop/SafeContextualReward/results/document_synthesis/v4_rpj_llama_s4/cpt_data
