In [1]:
# =========================
# Colab: FEVER + TruthfulQA -> 10k merged JSON
# =========================
!pip install -q pandas numpy openpyxl

from google.colab import drive
drive.mount('/content/drive')

import json, random
from pathlib import Path
import pandas as pd
import numpy as np

# ---------- CONFIG: change these to where your files live ----------
INPUT_FEVER = Path("/content/drive/MyDrive/Colab Notebooks/train.jsonl")      # FEVER jsonl
INPUT_TQA   = Path("/content/drive/MyDrive/Colab Notebooks/TruthfulQA.xlsx")  # TruthfulQA xlsx or csv

# Output directory (in Drive)
OUT_DIR = Path("/content/drive/MyDrive/Colab Notebooks/merged_data")
OUT_DIR.mkdir(parents=True, exist_ok=True)

OUT_FEVER_5K = OUT_DIR / "fever_5000.json"
OUT_TQA_5K   = OUT_DIR / "truthfulqa_5000_paraphrased.json"
OUT_MERGED   = OUT_DIR / "merged_10k.json"

random.seed(42)

def clean_text(x):
    if pd.isna(x):
        return ""
    s = str(x).strip()
    return " ".join(s.split())

# ---------- FEVER loader (jsonl) ----------
def load_fever_jsonl(path: Path) -> pd.DataFrame:
    rows = []
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                obj = json.loads(line)
            except Exception:
                continue
            question = obj.get("question") or obj.get("claim") or obj.get("statement") or ""
            answer   = obj.get("answer") or question  # for FEVER, answer ~ claim/statement
            label    = obj.get("label")
            evidence = obj.get("evidence_text") or obj.get("evidence") or obj.get("page") or ""
            rows.append({
                "question": clean_text(question),
                "answer": clean_text(answer),
                "evidence": clean_text(str(evidence).replace("_", " ")),
                "raw_label": label
            })
    df = pd.DataFrame(rows)

    def map_label(x):
        if x is None: return np.nan
        s = str(x).strip().upper()
        if s == "SUPPORTS": return 1
        if s == "REFUTES":  return 0
        if s in {"NEI", "NOT ENOUGH INFO", "NOT_ENOUGH_INFO"}: return np.nan
        if s in {"0", "1"}: return int(s)
        return np.nan

    df["labels"] = df["raw_label"].map(map_label)
    df = df.dropna(subset=["labels"]).copy()   # drop NEI/unknown
    df["labels"] = df["labels"].astype(int)
    df["question"] = df["question"].fillna("")
    df["answer"]   = df["answer"].fillna(df["question"])
    df["evidence"] = df["evidence"].fillna("")
    return df[["question","answer","evidence","labels"]]

# ---------- TruthfulQA loader (xlsx or csv) ----------
def load_truthfulqa(path: Path) -> pd.DataFrame:
    if path.suffix.lower() in {".xlsx", ".xls"}:
        df = pd.read_excel(path, sheet_name=0)
    else:
        df = pd.read_csv(path)

    cols = {c.lower(): c for c in df.columns}
    # required
    q_col = cols.get("question") or list(df.columns)[0]
    # try common answer col names
    a_col = (cols.get("answer") or cols.get("model_answer") or cols.get("response")
             or cols.get("generated_answer") or cols.get("best_answer"))
    if a_col is None:
        # fallback to second column
        a_col = list(df.columns)[1]

    # optional label col
    l_col = cols.get("label") or cols.get("is_true") or cols.get("target") or cols.get("truth")

    df["question"] = df[q_col].map(clean_text)
    df["answer"]   = df[a_col].map(clean_text)
    df["evidence"] = ""

    if l_col and l_col in df.columns:
        raw = df[l_col].astype(str).str.strip().str.lower()
        mapping = {
            "1": 1, "true": 1, "t": 1, "yes": 1, "y": 1, "factual": 1, "correct": 1,
            "0": 0, "false": 0, "f": 0, "no": 0,  "n": 0, "hallucinated": 0, "incorrect": 0
        }
        df["labels"] = raw.map(mapping)
        df = df.dropna(subset=["labels"]).copy()
        df["labels"] = df["labels"].astype(int)
    else:
        # If no labels, assume factual=1 (you can change this)
        df["labels"] = 1

    df = df.dropna(subset=["question","answer"])
    return df[["question","answer","evidence","labels"]]

# ---------- Simple paraphrasing (template-based) ----------
Q_TEMPLATES = [
    "Is it true that {q}",
    "Consider the statement: {q}",
    "Please evaluate the following claim: {q}",
    "Assess the truth of this claim: {q}",
    "According to common knowledge, {q}",
    "True or false: {q}",
    "Verify this statement: {q}",
    "Does the following hold? {q}",
]
A_TEMPLATES = [
    "{a}",
    "In short, {a}",
    "Answer: {a}",
    "The correct answer is: {a}",
    "Briefly: {a}",
]

def paraphrase(q, a, rng: random.Random):
    q_new = rng.choice(Q_TEMPLATES).format(q=q)
    a_new = rng.choice(A_TEMPLATES).format(a=a)
    return clean_text(q_new), clean_text(a_new)

# ---------- Build FEVER 5k ----------
fever_df_full = load_fever_jsonl(INPUT_FEVER)
if len(fever_df_full) < 5000:
    raise ValueError(f"FEVER usable rows: {len(fever_df_full)}; need at least 5000.")
fever_5k = fever_df_full.sample(n=5000, random_state=42).reset_index(drop=True)

# ---------- Build TruthfulQA 5k (paraphrased) ----------
tqa_df = load_truthfulqa(INPUT_TQA)
rng = random.Random(42)
tqa_rows = []

if len(tqa_df) >= 5000:
    base = tqa_df.sample(n=5000, random_state=42).reset_index(drop=True)
else:
    # repeat rows with different paraphrases to reach 5k
    repeats = int(np.ceil(5000 / max(1, len(tqa_df))))
    base = pd.concat([tqa_df] * repeats, ignore_index=True).sample(n=5000, random_state=42).reset_index(drop=True)

for _, row in base.iterrows():
    q_p, a_p = paraphrase(row["question"], row["answer"], rng)
    tqa_rows.append({"question": q_p, "answer": a_p, "evidence": "", "labels": int(row["labels"])})

tqa_5k_paraphrased = pd.DataFrame(tqa_rows)

# ---------- Save individual outputs ----------
fever_5k.to_json(OUT_FEVER_5K, orient="records", force_ascii=False, indent=2)
tqa_5k_paraphrased.to_json(OUT_TQA_5K, orient="records", force_ascii=False, indent=2)

# ---------- Merge and save ----------
merged = pd.concat([fever_5k, tqa_5k_paraphrased], ignore_index=True)
merged.to_json(OUT_MERGED, orient="records", force_ascii=False, indent=2)

# ---------- Summary ----------
print("Saved files:")
print(f"- FEVER 5k: {OUT_FEVER_5K}")
print(f"- TruthfulQA 5k paraphrased: {OUT_TQA_5K}")
print(f"- MERGED 10k: {OUT_MERGED}\n")

print("Label counts:")
print("  FEVER 5k:\n", fever_5k["labels"].value_counts())
print("  TQA 5k paraphrased:\n", tqa_5k_paraphrased["labels"].value_counts())
print("  Merged 10k:\n", merged["labels"].value_counts())

# OPTIONAL: also write a smaller gzip for easy downloading (keeps Drive light)
import gzip, io
gz_path = OUT_DIR / "merged_10k.json.gz"
with gzip.open(gz_path, "wt", encoding="utf-8") as gz:
    gz.write(merged.to_json(orient="records", force_ascii=False))
print(f"\nAlso wrote gzipped copy: {gz_path}")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Saved files:
- FEVER 5k: /content/drive/MyDrive/Colab Notebooks/merged_data/fever_5000.json
- TruthfulQA 5k paraphrased: /content/drive/MyDrive/Colab Notebooks/merged_data/truthfulqa_5000_paraphrased.json
- MERGED 10k: /content/drive/MyDrive/Colab Notebooks/merged_data/merged_10k.json

Label counts:
  FEVER 5k:
 labels
1    3605
0    1395
Name: count, dtype: int64
  TQA 5k paraphrased:
 labels
1    5000
Name: count, dtype: int64
  Merged 10k:
 labels
1    8605
0    1395
Name: count, dtype: int64

Also wrote gzipped copy: /content/drive/MyDrive/Colab Notebooks/merged_data/merged_10k.json.gz


In [2]:
# =========================
# Balance to 10k (5k ones, 5k zeros)
# =========================
import pandas as pd
import numpy as np
from pathlib import Path
import random, json

random.seed(42)

BASE_DIR = Path("/content/drive/MyDrive/Colab Notebooks/merged_data")
FEVER_5K = BASE_DIR / "fever_5000.json"
TQA_5K   = BASE_DIR / "truthfulqa_5000_paraphrased.json"

OUT_BALANCED = BASE_DIR / "balanced_10k.json"

def load_json(path: Path) -> pd.DataFrame:
    return pd.read_json(path)

fever = load_json(FEVER_5K)
tqa   = load_json(TQA_5K)

# Sanity
for dfname, df in [("FEVER", fever), ("TQA", tqa)]:
    assert set(df.columns) >= {"question","answer","evidence","labels"}, f"{dfname} bad schema: {df.columns}"

# Split FEVER by label
fever_pos = fever[fever["labels"] == 1].copy()
fever_neg = fever[fever["labels"] == 0].copy()

print("FEVER label counts:\n", fever["labels"].value_counts())
print("TQA label counts:\n", tqa["labels"].value_counts())

# === Target counts ===
TARGET_POS = 5000
TARGET_NEG = 5000

# We'll take all TQA as positives (they're all 1s per your output).
tqa_pos = tqa.copy()
assert len(tqa_pos) >= TARGET_POS, "Need at least 5k positives from TQA."

# ---- Build POSITIVE set (choose exactly 5k) ----
# Option A: use exactly the 5k TQA positives (simple & consistent)
pos_final = tqa_pos.sample(n=TARGET_POS, random_state=42).reset_index(drop=True)

# ---- Build NEGATIVE set (need 5k total) ----
neg_needed = TARGET_NEG

# 1) Use all FEVER negatives we have
neg_parts = []
neg_parts.append(fever_neg)
neg_needed -= len(fever_neg)

print(f"FEVER negatives used: {len(fever_neg)}, still need: {neg_needed}")

# 2) Create synthetic TQA negatives by mismatching answers across questions
#    (pair each question with a different answer; label = 0)
if neg_needed > 0:
    tqa_for_negs = tqa_pos.copy().reset_index(drop=True)
    q = tqa_for_negs["question"].tolist()
    a = tqa_for_negs["answer"].tolist()

    # Create a derangement-like shuffle to avoid pairing question with its own answer
    idx = list(range(len(a)))
    for _ in range(5):  # a few tries to avoid fixed points
        random.shuffle(idx)
    # Fix any accidental matches
    for i in range(len(idx)):
        if idx[i] == i:
            j = (i + 1) % len(idx)
            idx[i], idx[j] = idx[j], idx[i]

    # Build all mismatched pairs
    tqa_neg_all = pd.DataFrame({
        "question": q,
        "answer": [a[j] for j in idx],
        "evidence": ["" for _ in q],
        "labels": [0 for _ in q],
    })

    # Sample exactly the number we need
    if len(tqa_neg_all) < neg_needed:
        # If somehow not enough, repeat with a different shuffle
        reps = int(np.ceil(neg_needed / len(tqa_neg_all)))
        tqa_neg_all = pd.concat([tqa_neg_all]*reps, ignore_index=True)

    tqa_neg_sample = tqa_neg_all.sample(n=neg_needed, random_state=42).reset_index(drop=True)
    neg_parts.append(tqa_neg_sample)

neg_final = pd.concat(neg_parts, ignore_index=True)

# ---- Assert sizes and labels ----
assert len(pos_final) == TARGET_POS, f"pos_final != {TARGET_POS}"
assert len(neg_final) == TARGET_NEG, f"neg_final != {TARGET_NEG}"
assert set(pos_final["labels"].unique()) == {1}
assert set(neg_final["labels"].unique()) == {0}

# ---- Merge, shuffle, save ----
balanced = pd.concat([pos_final, neg_final], ignore_index=True)
balanced = balanced.sample(frac=1.0, random_state=42).reset_index(drop=True)

print("Balanced counts:\n", balanced["labels"].value_counts())

balanced.to_json(OUT_BALANCED, orient="records", force_ascii=False, indent=2)
print(f"\nSaved balanced dataset to:\n{OUT_BALANCED}")

# Optional: also save a gzipped version for easy downloading
import gzip
gz_path = BASE_DIR / "balanced_10k.json.gz"
with gzip.open(gz_path, "wt", encoding="utf-8") as gz:
    gz.write(balanced.to_json(orient="records", force_ascii=False))
print(f"Also wrote gzipped copy: {gz_path}")


FEVER label counts:
 labels
1    3605
0    1395
Name: count, dtype: int64
TQA label counts:
 labels
1    5000
Name: count, dtype: int64
FEVER negatives used: 1395, still need: 3605
Balanced counts:
 labels
0    5000
1    5000
Name: count, dtype: int64

Saved balanced dataset to:
/content/drive/MyDrive/Colab Notebooks/merged_data/balanced_10k.json
Also wrote gzipped copy: /content/drive/MyDrive/Colab Notebooks/merged_data/balanced_10k.json.gz


In [3]:
# =========================
# Split balanced_10k.json into 70%/15%/15%
# =========================
import pandas as pd
from sklearn.model_selection import train_test_split
from pathlib import Path
import gzip

BASE_DIR = Path("/content/drive/MyDrive/Colab Notebooks/merged_data")
BALANCED_PATH = BASE_DIR / "balanced_10k.json"

OUT_TRAIN = BASE_DIR / "train_70.json"
OUT_VAL   = BASE_DIR / "val_15.json"
OUT_TEST  = BASE_DIR / "test_15.json"

# Load balanced dataset
df = pd.read_json(BALANCED_PATH)

print("Full dataset counts:\n", df["labels"].value_counts())

# First split: train vs temp (val+test)
train_df, temp_df = train_test_split(
    df,
    test_size=0.30,  # 30% goes to val+test
    stratify=df["labels"],
    random_state=42
)

# Second split: val vs test from the temp set
val_df, test_df = train_test_split(
    temp_df,
    test_size=0.50,  # half of temp = 15% of original
    stratify=temp_df["labels"],
    random_state=42
)

# Sanity checks
print("\nTrain counts:\n", train_df["labels"].value_counts())
print("Val counts:\n", val_df["labels"].value_counts())
print("Test counts:\n", test_df["labels"].value_counts())

# Save JSON
train_df.to_json(OUT_TRAIN, orient="records", force_ascii=False, indent=2)
val_df.to_json(OUT_VAL, orient="records", force_ascii=False, indent=2)
test_df.to_json(OUT_TEST, orient="records", force_ascii=False, indent=2)

# Also save gzipped versions for quicker download
for path in [OUT_TRAIN, OUT_VAL, OUT_TEST]:
    gz_path = path.with_suffix(path.suffix + ".gz")
    with gzip.open(gz_path, "wt", encoding="utf-8") as gz:
        gz.write(path.read_text(encoding="utf-8"))
    print(f"Gzipped copy: {gz_path}")

print("\nSaved splits to Drive:")
print(f"- Train: {OUT_TRAIN} ({len(train_df)} rows)")
print(f"- Val:   {OUT_VAL} ({len(val_df)} rows)")
print(f"- Test:  {OUT_TEST} ({len(test_df)} rows)")


Full dataset counts:
 labels
0    5000
1    5000
Name: count, dtype: int64

Train counts:
 labels
1    3500
0    3500
Name: count, dtype: int64
Val counts:
 labels
1    750
0    750
Name: count, dtype: int64
Test counts:
 labels
0    750
1    750
Name: count, dtype: int64
Gzipped copy: /content/drive/MyDrive/Colab Notebooks/merged_data/train_70.json.gz
Gzipped copy: /content/drive/MyDrive/Colab Notebooks/merged_data/val_15.json.gz
Gzipped copy: /content/drive/MyDrive/Colab Notebooks/merged_data/test_15.json.gz

Saved splits to Drive:
- Train: /content/drive/MyDrive/Colab Notebooks/merged_data/train_70.json (7000 rows)
- Val:   /content/drive/MyDrive/Colab Notebooks/merged_data/val_15.json (1500 rows)
- Test:  /content/drive/MyDrive/Colab Notebooks/merged_data/test_15.json (1500 rows)


In [5]:
# ============================================
# Colab: Fetch FEVER + TruthfulQA and prepare dataset
# ============================================
!pip install -q datasets pandas numpy openpyxl wikipedia-api nltk tqdm

import json, random, re
from pathlib import Path
from collections import defaultdict

import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from datasets import load_dataset # Removed load_from_disk

import nltk
nltk.download('punkt', quiet=True)
from nltk.tokenize import sent_tokenize

import wikipediaapi

random.seed(42)

# ---------- CONFIG ----------
OUT_DIR = Path("/content/drive/MyDrive/Colab Notebooks/merged_data_new")
OUT_DIR.mkdir(parents=True, exist_ok=True)

N_FEVER = 5000
N_TQA   = 5000
DROP_NEI = True
USE_WIKI = True           # Set False to skip slow Wikipedia lookups
LANG = "en"               # Wikipedia language

# Final split ratios
TRAIN_PCT, VAL_PCT, TEST_PCT = 0.70, 0.15, 0.15

# ---------- Helpers ----------
def clean_text(x):
    s = str(x) if x is not None else ""
    s = s.strip()
    s = re.sub(r"\s+", " ", s)
    return s

Q_TEMPLATES = [
    "Is it true that {q}",
    "Consider the statement: {q}",
    "Please evaluate the following claim: {q}",
    "Assess the truth of this claim: {q}",
    "According to common knowledge, {q}",
    "True or false: {q}",
    "Verify this statement: {q}",
    "Does the following hold? {q}",
]
A_TEMPLATES = [
    "{a}",
    "In short, {a}",
    "Answer: {a}",
    "The correct answer is: {a}",
    "Briefly: {a}",
]

def paraphrase(q, a, rng):
    return clean_text(rng.choice(Q_TEMPLATES).format(q=q)), clean_text(rng.choice(A_TEMPLATES).format(a=a))

# ---------- Load FEVER ----------
# Load FEVER from Hugging Face Hub directly
try:
    fever = load_dataset("fever", split="train")  # Attempt to load directly
except Exception as e:
    print(f"Error loading FEVER from Hugging Face Hub: {e}")
    # You might need to provide a more specific path or method depending on the dataset's structure on the Hub.
    # If the dataset is not publicly available in this manner, you might need to download it manually
    # or find a different source.
    raise e # Re-raise the exception if loading fails


# Fields typical: 'id','label','claim','evidence','verifiable'

def map_fever_label(lbl):
    s = str(lbl).strip().upper()
    if s == "SUPPORTS": return 1
    if s == "REFUTES":  return 0
    return None  # NEI/unknown

# Build a small index of FEVER rows by label (drop NEI if configured)
fever_rows = []
for ex in fever:
    y = map_fever_label(ex["label"])
    if y is None and DROP_NEI:
        continue
    fever_rows.append({
        "claim": clean_text(ex["claim"]),
        "evidence_field": ex.get("evidence", None),  # nested annotations (page, sentence id)
        "label": y
    })

fever_df = pd.DataFrame(fever_rows)
if DROP_NEI:
    fever_df = fever_df.dropna(subset=["label"]).copy()
fever_df["label"] = fever_df["label"].astype(int)

# Balance SUPPORTS/REFUTES if possible
counts = fever_df["label"].value_counts().to_dict()
n_per_class = min(counts.get(0, 0), counts.get(1, 0), N_FEVER // 2)
if n_per_class * 2 < N_FEVER:
    # if we can't perfectly balance because one class is too small, we fill the rest from the larger class
    n0 = min(counts.get(0, 0), N_FEVER // 2)
    n1 = min(counts.get(1, 0), N_FEVER - n0)
else:
    n0 = n1 = n_per_class

fever_0 = fever_df[fever_df["label"] == 0].sample(n=n0, random_state=42)
fever_1 = fever_df[fever_df["label"] == 1].sample(n=n1, random_state=42)
fever_sample = pd.concat([fever_0, fever_1], ignore_index=True)
if len(fever_sample) < N_FEVER:
    # top up with remaining examples from the larger pool
    remainder = N_FEVER - len(fever_sample)
    pool = fever_df.drop(fever_sample.index)
    fever_sample = pd.concat([fever_sample, pool.sample(n=remainder, random_state=42)], ignore_index=True)

fever_sample = fever_sample.sample(frac=1.0, random_state=42).reset_index(drop=True)

# ---------- Wikipedia evidence extraction (best-effort) ----------
def extract_evidence_sentences(fever_row_list, lang="en"):
    """
    For each FEVER row, try to pull one evidence sentence from Wikipedia based on the first
    evidence annotation: (page_title, sentence_index). Caches pages.
    If unavailable, returns empty evidence.
    """
    wiki = wikipediaapi.Wikipedia(lang)
    page_cache = {}
    out_evidence = []
    for row in tqdm(fever_row_list, desc="Fetching Wikipedia evidence"):
        ev = row["evidence_field"]
        ev_text = ""

        # FEVER 'evidence' is a list of lists of annotations; we pick the first valid (title, sentence id)
        page_title = None
        sent_idx = None
        if isinstance(ev, list) and len(ev) > 0:
            # Find a tuple [annotator_id, page, sentence_id, ...]
            found = False
            for group in ev:
                if isinstance(group, list):
                    for annot in group:
                        if isinstance(annot, list) and len(annot) >= 3:
                            page_title = annot[1]
                            sent_idx = annot[2]
                            if page_title is not None and isinstance(sent_idx, int):
                                found = True
                                break
                if found:
                    break

        if page_title:
            page_title_clean = str(page_title).replace("_", " ")
            if page_title_clean not in page_cache:
                page_cache[page_title_clean] = wiki.page(page_title_clean).text or ""
            page_text = page_cache.get(page_title_clean, "")
            if page_text:
                sents = sent_tokenize(page_text)
                if isinstance(sent_idx, int) and 0 <= sent_idx < len(sents):
                    ev_text = sents[sent_idx]
                else:
                    # fallback: take first sentence mentioning the claim's head noun / title fragment
                    found_sent = ""
                    key = page_title_clean.split(" ")[0]
                    for s in sents[:20]:
                        if key.lower() in s.lower():
                            found_sent = s
                            break
                    ev_text = found_sent or (sents[0] if sents else "")

        out_evidence.append(clean_text(ev_text))

    return out_evidence

if USE_WIKI:
    fever_sample["evidence"] = extract_evidence_sentences(fever_sample.to_dict("records"), lang=LANG)
else:
    fever_sample["evidence"] = ""

# Build FEVER unified schema: treat claim as both question & answer (statement verification)
fever_ready = pd.DataFrame({
    "question": fever_sample["claim"].apply(clean_text),
    "answer":   fever_sample["claim"].apply(clean_text),
    "evidence": fever_sample["evidence"].apply(clean_text),
    "labels":   fever_sample["label"].astype(int)
})

print("FEVER ready shape:", fever_ready.shape, fever_ready["labels"].value_counts().to_dict())

# ---------- Load TruthfulQA (generation split) ----------
tqa = load_dataset("truthful_qa", "generation", split="validation")  # 'validation' has the standard set
# fields include: question, best_answer, correct_answers (list), incorrect_answers (list)

tqa_rows = []
for ex in tqa:
    q = clean_text(ex["question"])
    best = clean_text(ex.get("best_answer", ""))
    correct_list = [clean_text(x) for x in ex.get("correct_answers", []) if clean_text(x)]
    incorrect_list = [clean_text(x) for x in ex.get("incorrect_answers", []) if clean_text(x)]

    if q and best:
        # positive
        tqa_rows.append({"question": q, "answer": best, "labels": 1})
    # make a negative if possible
    if q and incorrect_list:
        neg_ans = random.choice(incorrect_list)
        tqa_rows.append({"question": q, "answer": neg_ans, "labels": 0})

tqa_df = pd.DataFrame(tqa_rows).dropna()
# Paraphrase to diversify (lightweight)
rng = random.Random(42)
q_new, a_new, y = [], [], []
for _, r in tqa_df.iterrows():
    q_p, a_p = paraphrase(r["question"], r["answer"], rng)
    q_new.append(q_p); a_new.append(a_p); y.append(int(r["labels"]))
tqa_df = pd.DataFrame({"question": q_new, "answer": a_new, "labels": y})
tqa_df["evidence"] = ""  # no passages available here

# Sample exactly N_TQA (balanced-ish if possible)
# Try to take half positives, half negatives
pos = tqa_df[tqa_df["labels"] == 1]
neg = tqa_df[tqa_df["labels"] == 0]
n_each = min(len(pos), len(neg), N_TQA // 2)
if 2 * n_each < N_TQA:
    extra = N_TQA - 2 * n_each
    # Fill extra with whichever class has more
    extra_df = pos if len(pos) > len(neg) else neg
    tqa_sample = pd.concat([
        pos.sample(n=n_each, random_state=42),
        neg.sample(n=n_each, random_state=42),
        extra_df.sample(n=extra, random_state=42)
    ], ignore_index=True)
else:
    tqa_sample = pd.concat([
        pos.sample(n=n_each, random_state=42),
        neg.sample(n=n_each, random_state=42)
    ], ignore_index=True)

tqa_sample = tqa_sample.sample(frac=1.0, random_state=42).reset_index(drop=True)
tqa_ready = tqa_sample[["question","answer","evidence","labels"]].copy()
print("TruthfulQA ready shape:", tqa_ready.shape, tqa_ready["labels"].value_counts().to_dict())

# ---------- Merge and save master 10k ----------
merged = pd.concat([fever_ready, tqa_ready], ignore_index=True)
merged = merged.sample(frac=1.0, random_state=42).reset_index(drop=True)

MASTER_PATH = OUT_DIR / "merged_10k_fever_tqa.json"
merged.to_json(MASTER_PATH, orient="records", force_ascii=False, indent=2)
print(f"Saved merged 10k → {MASTER_PATH}")

# ---------- (Optional) Rebalance merged to 50/50 exactly ----------
pos_ct = (merged["labels"] == 1).sum()
neg_ct = (merged["labels"] == 0).sum()
print("Merged label counts:", {"1": int(pos_ct), "0": int(neg_ct)})

# If you want *exact* 50/50 (5k/5k), do:
TARGET_TOTAL = len(merged)
TARGET_PER = TARGET_TOTAL // 2
pos_df = merged[merged["labels"] == 1]
neg_df = merged[merged["labels"] == 0]
if len(pos_df) >= TARGET_PER and len(neg_df) >= TARGET_PER:
    merged_bal = pd.concat([
        pos_df.sample(n=TARGET_PER, random_state=42),
        neg_df.sample(n=TARGET_PER, random_state=42)
    ], ignore_index=True).sample(frac=1.0, random_state=42).reset_index(drop=True)
else:
    merged_bal = merged  # fallback if one class too small

BALANCED_PATH = OUT_DIR / "merged_10k_balanced.json"
merged_bal.to_json(BALANCED_PATH, orient="records", force_ascii=False, indent=2)
print(f"Saved balanced 10k → {BALANCED_PATH}")
print("Balanced label counts:", merged_bal["labels"].value_counts().to_dict())

# ---------- 70/15/15 stratified split ----------
from sklearn.model_selection import train_test_split

df = merged_bal.copy()
train_df, temp_df = train_test_split(
    df, test_size=(1.0 - TRAIN_PCT), stratify=df["labels"], random_state=42
)
val_df, test_df = train_test_split(
    temp_df, test_size=TEST_PCT / (VAL_PCT + TEST_PCT),
    stratify=temp_df["labels"], random_state=42
)

TRAIN_JSON = OUT_DIR / "train_70.json"
VAL_JSON   = OUT_DIR / "val_15.json"
TEST_JSON  = OUT_DIR / "test_15.json"

train_df.to_json(TRAIN_JSON, orient="records", force_ascii=False, indent=2)
val_df.to_json(VAL_JSON, orient="records", force_ascii=False, indent=2)
test_df.to_json(TEST_JSON, orient="records", force_ascii=False, indent=2)

print("\nSaved splits:")
print(f"- Train ({len(train_df)}): {TRAIN_JSON}")
print(f"- Val   ({len(val_df)}): {VAL_JSON}")
print(f"- Test  ({len(test_df)}): {TEST_JSON}")

FileNotFoundError: Unable to find 'hf://datasets/fever/train.parquet'