# DPR Corpus with BM25 Hard Negatives

This notebook creates a custom corpus for MRAG implementation by:
1. Building a covered slice from TempRAGEval gold evidence sentences
2. Augmenting it with BM25-mined hard negative passages
3. Producing a balanced corpus with ~95% negatives for training/evaluation


## 1. Setup and Installation

Install required packages for corpus creation.


In [None]:
!pip -q install ir_datasets transformers datasets faiss-cpu pandas pyarrow tqdm


  Preparing metadata (setup.py) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m866.1/866.1 kB[0m [31m19.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m31.4/31.4 MB[0m [31m84.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m149.0/149.0 kB[0m [31m14.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m45.1/45.1 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m74.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for warc3-wet-clueweb09 (setup.py) ... [?25l[?25hdone
  Building wheel for cbor (setup.py) ... [?25l[?25hdone


In [None]:
!pip -q install pyahocorasick

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/114.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m114.9/114.9 kB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
[?25h

## 2. Configuration and Imports

Set up configuration parameters and import necessary libraries.


In [None]:
# Configuration and imports
import os, json, math, re, itertools, random
from pathlib import Path
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from tqdm import tqdm

import torch
from transformers import AutoTokenizer, AutoModel
import faiss
from datasets import load_dataset

# Configuration parameters
SEED                = 42
DEVICE              = "cuda" if torch.cuda.is_available() else "cpu"

# Corpus size settings
N_PASSAGES_TOTAL    = 200_000      # Total passages to process
SHARD_ROWS          = 20_000       # Rows per shard file
BATCH_ENCODE        = 256
MAX_LEN             = 256
USE_COSINE          = False        # Set True to L2-normalize vectors

# IVF index parameters
IVF_NLIST           = 32768        # Number of clusters for IVF
IVF_TRAIN_EMB       = 50_000       # Training vectors for IVF
IVF_NPROBE          = min(64, max(1, IVF_NLIST // 512))

# Output directory
OUT_DIR       = "dpr_ivf_atlas_covered_slice"
INDEX_PATH    = os.path.join(OUT_DIR, "ivf.index")
MANIFEST_PATH = os.path.join(OUT_DIR, "manifest.json")
os.makedirs(OUT_DIR, exist_ok=True)

# DPR encoder models
Q_MODEL = "facebook/dpr-question_encoder-single-nq-base"
P_MODEL = "facebook/dpr-ctx_encoder-single-nq-base"

# ATLAS corpus paths
CORPUS_DIR = Path("./atlas_data/corpora/wiki/enwiki-dec2021")
JSONL_FILES = [
    CORPUS_DIR / "text-list-100-sec.jsonl",
    CORPUS_DIR / "infobox.jsonl",
]
ATLAS_URLS = {
    "text-list-100-sec.jsonl": "https://dl.fbaipublicfiles.com/atlas/corpora/wiki/enwiki-dec2021/text-list-100-sec.jsonl",
    "infobox.jsonl":           "https://dl.fbaipublicfiles.com/atlas/corpora/wiki/enwiki-dec2021/infobox.jsonl",
}
AUTO_DOWNLOAD_ATLAS = True  # Automatically download if files are missing

# Evaluation settings
TOPK_20 = 20
TOPK_100 = 100
REQUIRE_COVERAGE = False  # Filter questions to those covered by index
COVERAGE_PROBE_K = 1000   # DPR probe depth for coverage check

np.random.seed(SEED)
torch.manual_seed(SEED)
print("Device:", DEVICE)

Device: cuda


## 3. Utility Functions

Helper functions for corpus processing and shard management.


In [None]:
# =========================== #
# Utilities
# =========================== #
def ensure_atlas_jsonl():
    CORPUS_DIR.mkdir(parents=True, exist_ok=True)
    missing = [p.name for p in JSONL_FILES if not p.exists()]
    if not missing:
        return
    if not AUTO_DOWNLOAD_ATLAS:
        raise FileNotFoundError(
            "Missing ATLAS JSONLs:\n  - " + "\n  - ".join(str(p) for p in JSONL_FILES) +
            "\nDownload them before running."
        )
    print("Downloading ATLAS enwiki-Dec-2021 JSONLs...")
    for name in missing:
        url = ATLAS_URLS[name]
        dest = CORPUS_DIR / name
        exit_code = os.system(f'wget -q "{url}" -O "{dest}"')
        if exit_code != 0 or not dest.exists():
            raise RuntimeError(f"Failed to download {name} from {url}")
    print("Done.")

def _norm(s: str) -> str:
    s = re.sub(r"[^a-z0-9 ]+", " ", (s or "").lower())
    return re.sub(r"\s+", " ", s).strip()

def iter_atlas_passages(paths, limit=None):
    """Yield {'internal_id', 'title', 'text'} from ATLAS JSONLs."""
    i = 0
    for p in paths:
        with open(p, "r", encoding="utf-8") as f:
            for line in f:
                if limit is not None and i >= limit:
                    return
                obj = json.loads(line)
                title = (obj.get("title") or "").strip()
                text  = (obj.get("text")  or "").strip()
                if not text:
                    continue
                yield {"internal_id": i, "title": title, "text": text}
                i += 1

def write_shard(rows, shard_idx):
    df = pd.DataFrame(rows)
    shard_path = Path(OUT_DIR) / f"passages_shard_{shard_idx:03d}.parquet"
    table = pa.Table.from_pandas(df)
    pq.write_table(table, shard_path)
    return shard_path, int(df["internal_id"].min()), int(df["internal_id"].max())

def build_shards_and_manifest():
    rows, manifest, shard_idx = [], {"shards": []}, 0
    for rec in tqdm(iter_atlas_passages(JSONL_FILES, limit=N_PASSAGES_TOTAL),
                    total=N_PASSAGES_TOTAL, desc="Sharding passages"):
        rows.append(rec)
        if len(rows) >= SHARD_ROWS:
            p, lo, hi = write_shard(rows, shard_idx)
            manifest["shards"].append({"path": str(p), "lo": lo, "hi": hi})
            rows, shard_idx = [], shard_idx + 1
    if rows:
        p, lo, hi = write_shard(rows, shard_idx)
        manifest["shards"].append({"path": str(p), "lo": lo, "hi": hi})

    with open(MANIFEST_PATH, "w") as f:
        json.dump(manifest, f, indent=2)
    return manifest

def load_manifest():
    with open(MANIFEST_PATH, "r") as f:
        return json.load(f)

### 3.1 IVF Parameter Suggestions

Helper function to suggest appropriate IVF parameters based on corpus size.


In [None]:
import math

def corpus_size_from_manifest(manifest):
    # count rows using [lo, hi] inclusive ranges
    return sum(int(s["hi"]) - int(s["lo"]) + 1 for s in manifest["shards"])

def suggest_ivf_params(corpus_n, train_request):
    """
    Returns (nlist, ntrain, use_flat) following FAISS guidelines:
      - nlist ≈ 4 * sqrt(N) (bounded by training vectors)
      - need >= nlist training vecs; ideally ~30–256 per centroid
      - use Flat for small N
    """
    # For small corpora, Flat is simpler & strong
    if corpus_n < 50_000:
        return (None, None, True)  # use_flat

    # pick a target nlist around 4*sqrt(N)
    nlist_target = max(64, int(4 * math.sqrt(corpus_n)))  # guideline :contentReference[oaicite:2]{index=2}

    # how many vectors will we actually use to train?
    ntrain = min(train_request, corpus_n)

    # don't exceed the hard constraint: ntrain >= nlist
    nlist_max_by_hard = max(1, ntrain)  # strict FAISS requirement :contentReference[oaicite:3]{index=3}

    # also respect the rule-of-thumb ~>= 39 train vecs / centroid for IVF k-means stability
    nlist_max_by_rule = max(1, ntrain // 39)  # ~39 per centroid (FAISS clusterer heuristic) :contentReference[oaicite:4]{index=4}

    nlist = min(nlist_target, nlist_max_by_hard, nlist_max_by_rule)

    # If that collapses too far, either reduce to Flat or accept a small nlist
    if nlist < 64:
        return (None, None, True)  # Flat for very small N

    # ensure at least ~39*nlist training vecs when possible
    ntrain = min(corpus_n, max(ntrain, 39 * nlist))
    return (nlist, ntrain, False)


## 4. Download ATLAS Corpus

Download the ATLAS enwiki-Dec-2021 corpus files if they don't exist.


In [6]:
# === PREP: Download ATLAS Dec-2021 JSONLs if missing ===
from pathlib import Path
import os, sys, urllib.request, shutil, subprocess

BASE = Path("atlas_data/corpora/wiki/enwiki-dec2021")
BASE.mkdir(parents=True, exist_ok=True)

URLS = {
    "text-list-100-sec.jsonl": "https://dl.fbaipublicfiles.com/atlas/corpora/wiki/enwiki-dec2021/text-list-100-sec.jsonl",
    "infobox.jsonl":           "https://dl.fbaipublicfiles.com/atlas/corpora/wiki/enwiki-dec2021/infobox.jsonl",
}

def _wget(url, dest):
    try:
        # use wget if available (faster in Colab)
        subprocess.check_call(["wget", "-q", url, "-O", str(dest)])
        return True
    except Exception:
        return False

def _py_download(url, dest):
    with urllib.request.urlopen(url) as r, open(dest, "wb") as f:
        shutil.copyfileobj(r, f)

for name, url in URLS.items():
    dest = BASE / name
    if not dest.exists():
        print(f"Downloading {name} ...")
        ok = _wget(url, dest)
        if not ok:
            _py_download(url, dest)
        assert dest.exists(), f"Failed to download {name}"

print("ATLAS files present:")
for p in BASE.iterdir():
    print(" -", p, f"({p.stat().st_size/1e6:.1f} MB)")


Downloading text-list-100-sec.jsonl ...
Downloading infobox.jsonl ...
ATLAS files present:
 - atlas_data/corpora/wiki/enwiki-dec2021/text-list-100-sec.jsonl (20900.7 MB)
 - atlas_data/corpora/wiki/enwiki-dec2021/infobox.jsonl (2300.4 MB)


In [7]:
# === Use existing covered slice (SKIP building) ===
from pathlib import Path

# Put your downloaded files here, or change SLICE_DIR to wherever you uploaded them.
SLICE_DIR = Path("./atlas_covered_slice")
SLICE_DIR.mkdir(parents=True, exist_ok=True)

# If your two files are elsewhere, move/copy them or set these paths directly.
SLICE_JSONL_FILES = [
    str(SLICE_DIR / "text-list-100-sec.jsonl"),
    str(SLICE_DIR / "infobox.jsonl"),
]


In [8]:
# Sanity: make sure both exist and are non-empty
for p in SLICE_JSONL_FILES:
    pth = Path(p)
    assert pth.exists() and pth.stat().st_size > 0, f"Missing or empty: {p}"

# Point the rest of the pipeline at this slice
JSONL_FILES = SLICE_JSONL_FILES
print("Using slice files:", JSONL_FILES)

Using slice files: ['atlas_covered_slice/text-list-100-sec.jsonl', 'atlas_covered_slice/infobox.jsonl']


## 5. Build Covered Slice from TempRAGEval

Create a slice of the corpus that contains all passages from pages that match gold evidence sentences from TempRAGEval.


In [None]:
# === Build a covered slice from TempRAGEval gold sentences ===
# This scans the ATLAS JSONLs once, finds any passage that contains a gold-evidence
# sentence (using Aho–Corasick), collects their page titles, and then writes a
# slice that contains *all passages from the matched pages*.

from datasets import load_dataset
import ahocorasick, re, json
from pathlib import Path

# Inputs: reuse your existing CORPUS_DIR and JSONL_FILES from earlier cells
ATLAS_FILES = [str(p) for p in JSONL_FILES]  # original full files
SLICE_DIR = Path("./atlas_covered_slice")
SLICE_DIR.mkdir(parents=True, exist_ok=True)
SLICE_FILES = [SLICE_DIR / "text-list-100-sec.jsonl", SLICE_DIR / "infobox.jsonl"]

# Heuristics
MIN_CHARS_PATTERN = 20   # ignore super short “sentences” that cause false positives
INCLUDE_FULL_PAGES = True  # include all passages from matched pages (recommended)

def _norm(s: str) -> str:
    s = re.sub(r"[^a-z0-9 ]+", " ", (s or "").lower())
    return re.sub(r"\s+", " ", s).strip()

def _golds_from_temprageval():
    ds = load_dataset("siyue/TempRAGEval")["test"]
    # Keep everything here (don’t drop empty time_relation during slice creation)
    golds = []
    for ex in ds:
        for k in ("gold_evidence_1", "gold_evidence_2"):
            if k in ex and ex[k]:
                golds.append(ex[k])
    # normalize + length filter + dedupe
    golds = list({_norm(s) for s in golds if s})
    golds = [g for g in golds if len(g) >= MIN_CHARS_PATTERN]
    return golds

def build_automaton(patterns):
    A = ahocorasick.Automaton()
    for i, p in enumerate(patterns):
        if p:
            A.add_word(p, (i, p))
    A.make_automaton()
    return A

def recover_titles_and_write_slice():
    # 1) Collect normalized gold sentences
    patterns = _golds_from_temprageval()
    print(f"Gold evidence patterns (>= {MIN_CHARS_PATTERN} chars):", len(patterns))
    if not patterns:
        raise RuntimeError("No gold evidence sentences found. Check dataset access / columns.")

    # 2) Build Aho–Corasick automaton
    A = build_automaton(patterns)

    # 3) Pass 1: stream ATLAS files, record titles that contain a gold sentence
    matched_titles = set()
    total_lines = 0
    matched_lines = 0

    for inpath in ATLAS_FILES:
        with open(inpath, "r", encoding="utf-8") as fin:
            for line in fin:
                total_lines += 1
                obj = json.loads(line)
                text = _norm(obj.get("text") or "")
                if not text:
                    continue
                # Cheap scan: stop at first match
                for _ in A.iter(text):
                    matched_titles.add(obj.get("title") or "")
                    matched_lines += 1
                    break

    print(f"Matched pages (unique titles): {len(matched_titles)}")
    print(f"Lines with direct sentence match: {matched_lines} / scanned lines: {total_lines}")

    # 4) Pass 2: write the slice (all passages from matched titles)
    kept = 0
    for src, dst in zip(ATLAS_FILES, SLICE_FILES):
        with open(src, "r", encoding="utf-8") as fin, open(dst, "w", encoding="utf-8") as fout:
            for line in fin:
                obj = json.loads(line)
                if (obj.get("title") or "") in matched_titles:
                    fout.write(line)
                    kept += 1

    print(f"Wrote covered slice to {SLICE_DIR} with {kept} passages total.")
    return SLICE_FILES

# Run the slice builder once (idempotent: overwrite files if they exist)
SLICE_JSONL_FILES = recover_titles_and_write_slice()

# IMPORTANT: point the rest of the pipeline to the slice:
JSONL_FILES = SLICE_JSONL_FILES  # overrides the earlier JSONL_FILES
print("Using slice files:", JSONL_FILES)


Gold evidence patterns (>= 20 chars): 355
Matched pages (unique titles): 340
Lines with direct sentence match: 367 / scanned lines: 37507469
Wrote covered slice to atlas_covered_slice with 10997 passages total.
Using slice files: [PosixPath('atlas_covered_slice/text-list-100-sec.jsonl'), PosixPath('atlas_covered_slice/infobox.jsonl')]


## 6. Install Additional Dependencies

Install packages needed for BM25 negative mining.


In [10]:
!pip install -q "cachetools<6.0" "google-auth<3.0" --upgrade
# (e.g., cachetools==5.5.2 is fine)

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/223.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m223.1/223.1 kB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
google-colab 1.0.0 requires google-auth==2.38.0, but you have google-auth 2.43.0 which is incompatible.
google-auth-oauthlib 1.2.3 requires google-auth<2.42.0,>=2.15.0, but you have google-auth 2.43.0 which is incompatible.[0m[31m
[0m

In [9]:
!pip -q install pyserini

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m178.8/178.8 MB[0m [31m13.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m367.5/367.5 kB[0m [31m30.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.4/17.4 MB[0m [31m121.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m57.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m184.3/184.3 kB[0m [31m16.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m96.4/96.4 kB[0m [31m9.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m69.2/69.2 kB[0m [31m7.1 MB/s[0m eta [36

### 6.1 Install Java for Pyserini

Pyserini requires Java to run. Install JDK and configure JAVA_HOME.


In [11]:
%%bash
# --- Install a JDK and wire JAVA_HOME ---
apt-get update -y >/dev/null
# Prefer 21 (per Pyserini docs); fall back to 17 if 21 isn't available on the image
apt-get install -y openjdk-21-jdk-headless || apt-get install -y openjdk-17-jdk-headless
python - <<'PY'
import os, shutil
candidates = [
    "/usr/lib/jvm/java-21-openjdk-amd64",
    "/usr/lib/jvm/java-17-openjdk-amd64",
    "/usr/lib/jvm/java-11-openjdk-amd64",
]
java_home = next((p for p in candidates if os.path.exists(p)), None)
if not java_home:
    raise SystemExit("No JDK folder found after install.")
os.environ["JAVA_HOME"] = java_home
os.environ["PATH"] = f"{java_home}/bin:" + os.environ["PATH"]
print("JAVA_HOME =", os.environ["JAVA_HOME"])

# Persist for this process and set update-alternatives so /usr/bin/java(javac) points here
os.system(f"update-alternatives --set java {java_home}/bin/java >/dev/null 2>&1 || true")
os.system(f"update-alternatives --set javac {java_home}/bin/javac >/dev/null 2>&1 || true")

# Sanity print
os.system("java -version")
os.system("javac -version")
PY


Reading package lists...
Building dependency tree...
Reading state information...
The following additional packages will be installed:
  ca-certificates-java java-common libpcsclite1 openjdk-21-jre-headless
Suggested packages:
  default-jre pcscd openjdk-21-demo openjdk-21-source libnss-mdns
  fonts-dejavu-extra fonts-ipafont-gothic fonts-ipafont-mincho
  fonts-wqy-microhei | fonts-wqy-zenhei fonts-indic
The following NEW packages will be installed:
  ca-certificates-java java-common libpcsclite1 openjdk-21-jdk-headless
  openjdk-21-jre-headless
0 upgraded, 5 newly installed, 0 to remove and 46 not upgraded.
Need to get 130 MB of archives.
After this operation, 299 MB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/main amd64 java-common all 0.72build2 [6,782 B]
Get:2 http://archive.ubuntu.com/ubuntu jammy-updates/main amd64 libpcsclite1 amd64 1.9.5-3ubuntu1 [19.8 kB]
Get:3 http://archive.ubuntu.com/ubuntu jammy-updates/universe amd64 openjdk-21-jre-

W: Skipping acquire of configured file 'main/source/Sources' as repository 'https://r2u.stat.illinois.edu/ubuntu jammy InRelease' does not seem to provide it (sources.list entry misspelt?)
openjdk version "21.0.8" 2025-07-15
OpenJDK Runtime Environment (build 21.0.8+9-Ubuntu-0ubuntu122.04.1)
OpenJDK 64-Bit Server VM (build 21.0.8+9-Ubuntu-0ubuntu122.04.1, mixed mode, sharing)


## 7. BM25 Hard Negative Mining

Mine hard negative passages using BM25 search. For each query, retrieve top passages and filter out those containing gold evidence or answers.


In [None]:
# === Augment covered slice with BM25 HARD negatives (~95% negatives) ===
# Requires: pip install pyserini
# Inputs:
#   SLICE_DIR: ./atlas_covered_slice/*.jsonl  (your positives)
#   ATLAS_FILES: the two ATLAS JSONLs (full dump)
# Output:
#   ./atlas_slice_with_neg/*.jsonl — covered pages + BM25-mined negative pages

import os, json, re, math, shutil, subprocess
from pathlib import Path
from collections import OrderedDict, defaultdict
from datasets import load_dataset
from tqdm import tqdm

from pyserini.search.lucene import LuceneSearcher

# ---- CONFIG ----
SLICE_DIR   = Path("./atlas_covered_slice")
ATLAS_FILES = [
    "atlas_data/corpora/wiki/enwiki-dec2021/text-list-100-sec.jsonl",
    "atlas_data/corpora/wiki/enwiki-dec2021/infobox.jsonl",
]

OUT_NEG_DIR = Path("./atlas_slice_with_neg"); OUT_NEG_DIR.mkdir(parents=True, exist_ok=True)

# BM25 index paths
BM25_CORPUS_DIR = Path("./bm25_collection")  # holds docs.jsonl
BM25_INDEX_DIR  = Path("./bm25_index")

# BM25 params + mining knobs
BM25_K1, BM25_B = 0.9, 0.4
BM25_TOPK_PER_QUERY = 200        # how many hits to examine per query
TEMP_ONLY = False                 # set True if you want only time_relation!=None queries
TARGET_NEG_FRAC = 0.95            # aim for 95% negatives (by PASSAGE count)
LOWERCASE_COMPARE = True          # normalize case for substring checks

# ---------------------------------------------------------------

def norm(s):
    s = (s or "").strip()
    return re.sub(r"\s+", " ", s).lower() if LOWERCASE_COMPARE else s

# 1) Collect positive titles + count POSITIVE passages (to set the 95% target by passage count)
pos_titles = set()
pos_passages_total = 0
for name in ["text-list-100-sec.jsonl", "infobox.jsonl"]:
    p = SLICE_DIR / name
    if p.exists():
        with p.open("r", encoding="utf-8") as fin:
            for line in fin:
                obj = json.loads(line)
                t = obj.get("title") or ""
                if t: pos_titles.add(t)
                pos_passages_total += 1

print(f"[Positives] unique titles: {len(pos_titles)} | passages: {pos_passages_total}")

# 2) Prepare BM25 collection (a single JSONL with fields id/contents/title), then index with Pyserini
#    We keep 'contents' as: "<title>\n<text>" so titles influence BM25.  (Pyserini JsonCollection)
#    Ref: Pyserini indexing/search docs.  (Lucene BM25)  [citations]
BM25_CORPUS_DIR.mkdir(exist_ok=True, parents=True)
docs_jsonl = BM25_CORPUS_DIR / "docs.jsonl"

if not docs_jsonl.exists():
    print("[BM25] Building docs.jsonl from ATLAS (this runs once)...")
    with docs_jsonl.open("w", encoding="utf-8") as fout:
        for src in ATLAS_FILES:
            with open(src, "r", encoding="utf-8") as fin:
                for i, line in enumerate(fin):
                    obj = json.loads(line)
                    title = obj.get("title") or ""
                    text  = obj.get("text") or obj.get("contents") or ""
                    doc = {
                        "id": f"{Path(src).name}::{i}",
                        "title": title,
                        "contents": f"{title}\n{text}"
                    }
                    fout.write(json.dumps(doc, ensure_ascii=False) + "\n")
    print(f"[BM25] Wrote {docs_jsonl}")

if not BM25_INDEX_DIR.exists() or not any(BM25_INDEX_DIR.iterdir()):
    print("[BM25] Indexing with Pyserini (JsonCollection -> Lucene index)...")
    # Equivalent to: python -m pyserini.index.lucene -collection JsonCollection -generator DefaultLuceneDocumentGenerator
    #                -input bm25_collection -index bm25_index -storePositions -storeDocvectors -storeRaw
    subprocess.run([
        "python", "-m", "pyserini.index.lucene",
        "-collection", "JsonCollection",
        "-generator", "DefaultLuceneDocumentGenerator",
        "-threads", str(os.cpu_count() or 2),
        "-input", str(BM25_CORPUS_DIR),
        "-index", str(BM25_INDEX_DIR),
        "-storePositions", "-storeDocvectors", "-storeRaw"
    ], check=True)
    print("[BM25] Index built at", BM25_INDEX_DIR)

# 3) Mine BM25 hard negatives (per query), exclude positives and any doc containing gold evidence / answer strings
searcher = LuceneSearcher(str(BM25_INDEX_DIR))
searcher.set_bm25(k1=BM25_K1, b=BM25_B)

ds = load_dataset("siyue/TempRAGEval")["test"]
if TEMP_ONLY and "time_relation" in ds.column_names:
    ds = ds.filter(lambda ex: bool(ex.get("time_relation","")))

# strings we consider "positive" signals inside a passage
def gold_strings(ex):
    out = []
    for k in ["gold_evidence_1", "gold_evidence_2", "answer"]:
        v = ex.get(k)
        if isinstance(v, str) and v.strip():
            out.append(norm(v))
    # also original_answer if present
    v = ex.get("original_answer")
    if isinstance(v, str) and v.strip():
        out.append(norm(v))
    return [s for s in out if s]

neg_titles = OrderedDict()   # preserve order added
seen_docids = set()

target_neg_passages = math.ceil(pos_passages_total * TARGET_NEG_FRAC / (1 - TARGET_NEG_FRAC))  # ~ 19x
print(f"[Target] negative passages (approx): {target_neg_passages}")

pbar = tqdm(range(len(ds)), desc="BM25 mining")
for i in pbar:
    ex = ds[i]
    q  = (ex.get("question") or "").strip()
    if not q:
        continue

    hits = searcher.search(q, BM25_TOPK_PER_QUERY)
    gstrs = gold_strings(ex)

    for h in hits:
        if h.docid in seen_docids:
            continue
        seen_docids.add(h.docid)
        raw = searcher.doc(h.docid).raw()  # the original JSON line we indexed
        try:
            jobj = json.loads(raw)
        except Exception:
            # if raw isn't JSON (unlikely with our pipeline), fall back to contents
            jobj = {"title": "", "contents": searcher.doc(h.docid).contents()}

        title = (jobj.get("title") or "").strip()
        contents = (jobj.get("contents") or "").strip()

        # Skip if page is one of the positive titles
        if title in pos_titles:
            continue

        # Skip if gold evidence/answer appears in contents (to avoid false negatives)
        X = norm(contents)
        if any(gs and gs in X for gs in gstrs):
            continue

        # Keep the title as a hard negative candidate
        if title:
            neg_titles.setdefault(title, 1)

    # Early stop if we already have many titles; final passage count is checked at write time
    if len(neg_titles) >= 5 * target_neg_passages:  # a generous buffer; we'll trim on write
        break

print(f"[Mining] candidate negative titles: {len(neg_titles)}")

# 4) Write augmented slice: all covered positives + passages from BM25-mined negative titles
#    We grow negatives until we exceed the target_neg_passages.
def write_augmented(neg_titles_ordered):
    kept_pos, kept_neg = 0, 0
    # open outputs
    out_map = {
        "text-list-100-sec.jsonl": (OUT_NEG_DIR / "text-list-100-sec.jsonl").open("w", encoding="utf-8"),
        "infobox.jsonl": (OUT_NEG_DIR / "infobox.jsonl").open("w", encoding="utf-8"),
    }
    try:
        # write all positives first (exact copy from covered slice)
        for name in ["text-list-100-sec.jsonl", "infobox.jsonl"]:
            covered_src = SLICE_DIR / name
            if covered_src.exists():
                with covered_src.open("r", encoding="utf-8") as fin:
                    for line in fin:
                        out_map[name].write(line)
                        kept_pos += 1

        # now append BM25 negatives
        need = target_neg_passages
        titles_set = set(neg_titles_ordered.keys())
        for src, dst_name in zip(ATLAS_FILES, ["text-list-100-sec.jsonl", "infobox.jsonl"]):
            with open(src, "r", encoding="utf-8") as fin:
                for line in fin:
                    obj = json.loads(line)
                    t = (obj.get("title") or "").strip()
                    if t in titles_set:
                        out_map[dst_name].write(line)
                        kept_neg += 1
                        if kept_neg >= need:
                            # we reached target ~95% by passage-count (may overshoot slightly)
                            # still keep streaming remaining file to avoid partial JSON issues? Not needed for line-delimited.
                            # break both loops cleanly:
                            raise StopIteration
    except StopIteration:
        pass
    finally:
        for f in out_map.values():
            f.close()
    return kept_pos, kept_neg

kept_pos, kept_neg = write_augmented(neg_titles)
total = kept_pos + kept_neg
neg_frac = kept_neg / total if total else 0.0

print(f"[Write] positives kept: {kept_pos}")
print(f"[Write] negatives kept: {kept_neg}")
print(f"[Write] total passages: {total}")
print(f"[Write] negative fraction: {neg_frac:.3f}  (target={TARGET_NEG_FRAC})")

# IMPORTANT: point your indexer at this augmented slice for the rest of the run
JSONL_FILES = [str(OUT_NEG_DIR / "text-list-100-sec.jsonl"),
               str(OUT_NEG_DIR / "infobox.jsonl")]
print("Using augmented slice files:", JSONL_FILES)


[Positives] unique titles: 340 | passages: 10997
[BM25] Building docs.jsonl from ATLAS (this runs once)...
[BM25] Wrote bm25_collection/docs.jsonl
[BM25] Indexing with Pyserini (JsonCollection -> Lucene index)...
[BM25] Index built at bm25_index


README.md:   0%|          | 0.00/2.23k [00:00<?, ?B/s]

test.csv:   0%|          | 0.00/470k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/1244 [00:00<?, ? examples/s]

[Target] negative passages (approx): 208943


BM25 mining: 100%|██████████| 1244/1244 [03:37<00:00,  5.73it/s]


[Mining] candidate negative titles: 87541
[Write] positives kept: 10997
[Write] negatives kept: 208943
[Write] total passages: 219940
[Write] negative fraction: 0.950  (target=0.95)
Using augmented slice files: ['atlas_slice_with_neg/text-list-100-sec.jsonl', 'atlas_slice_with_neg/infobox.jsonl']


## 8. Prepare Output Directory

Set up output directory for the final corpus index.


In [None]:
# === Normalize JSONL_FILES and reset outputs ===
from pathlib import Path
import os, shutil

# Ensure JSONL_FILES are Path objects (some helpers call .exists())
JSONL_FILES = [Path(p) for p in JSONL_FILES]
print("JSONL_FILES ->", JSONL_FILES)

# Fresh output dir to avoid reusing an old index/manifest
OUT_DIR = "dpr_flat_slice_neg"      # <— new folder on purpose
INDEX_PATH = os.path.join(OUT_DIR, "ivf.index")
MANIFEST_PATH = os.path.join(OUT_DIR, "manifest.json")
shutil.rmtree(OUT_DIR, ignore_errors=True)
os.makedirs(OUT_DIR, exist_ok=True)

print("OUT_DIR reset ->", OUT_DIR)

JSONL_FILES -> [PosixPath('atlas_slice_with_neg/text-list-100-sec.jsonl'), PosixPath('atlas_slice_with_neg/infobox.jsonl')]
OUT_DIR reset -> dpr_flat_slice_neg


## 9. Verify Corpus Composition

Count positive vs negative passages in the augmented slice to verify the target ratio.


In [None]:
# === Count positives vs negatives inside the augmented slice (exact) ===
# Logic: positives are pages whose titles come from the covered slice; the rest are negatives.

import json, re
from pathlib import Path
from collections import Counter

def norm_title(t):
    if t is None: return ""
    t = re.sub(r"\s+", " ", str(t)).strip()
    return t

root = Path(".").resolve()

# Find the covered slice and augmented slice JSONLs
covered_files = sorted(root.glob("**/atlas_covered_slice/*.jsonl"))
aug_files     = sorted(root.glob("**/atlas_slice_with_neg/*.jsonl"))

print("Covered files:", [str(p) for p in covered_files])
print("Augmented files:", [str(p) for p in aug_files])

# 1) Build the set of positive titles from the covered slice
pos_titles = set()
covered_rows = 0
for fp in covered_files:
    with fp.open("r", encoding="utf-8") as f:
        for line in f:
            obj = json.loads(line)
            covered_rows += 1
            t = norm_title(obj.get("title"))
            if t: pos_titles.add(t)

print(f"Unique positive titles (covered): {len(pos_titles)} | covered rows: {covered_rows}")

# 2) Classify every augmented-row as pos/neg by title membership
per_file = []
total_rows = pos_rows = neg_rows = 0
for fp in aug_files:
    t_total = t_pos = t_neg = 0
    with fp.open("r", encoding="utf-8") as f:
        for line in f:
            obj = json.loads(line)
            t_total += 1
            t = norm_title(obj.get("title"))
            if t and t in pos_titles:
                t_pos += 1
            else:
                t_neg += 1
    per_file.append({
        "file": str(fp),
        "rows": t_total,
        "pos_rows": t_pos,
        "neg_rows": t_neg,
        "pos_pct": round(100.0 * t_pos / t_total, 2) if t_total else 0.0,
        "neg_pct": round(100.0 * t_neg / t_total, 2) if t_total else 0.0,
    })
    total_rows += t_total
    pos_rows   += t_pos
    neg_rows   += t_neg

print("\n--- Augmented slice composition (per file) ---")
for r in per_file:
    print(r)

print("\n--- Overall augmented slice composition ---")
print("Total rows:", total_rows)
print("Pos rows  :", pos_rows)
print("Neg rows  :", neg_rows)
print("Pos %     :", round(100.0 * pos_rows / total_rows, 2) if total_rows else 0.0)
print("Neg %     :", round(100.0 * neg_rows / total_rows, 2) if total_rows else 0.0)

# 3) (Optional) If you want the *indexed* passage count again:
#    ntotal is authoritative and already printed when you built the index:
#    print("FAISS index size (ntotal):", index.ntotal)


Covered files: ['/content/atlas_covered_slice/infobox.jsonl', '/content/atlas_covered_slice/text-list-100-sec.jsonl']
Augmented files: ['/content/atlas_slice_with_neg/infobox.jsonl', '/content/atlas_slice_with_neg/text-list-100-sec.jsonl']
Unique positive titles (covered): 340 | covered rows: 10997

--- Augmented slice composition (per file) ---
{'file': '/content/atlas_slice_with_neg/infobox.jsonl', 'rows': 273, 'pos_rows': 273, 'neg_rows': 0, 'pos_pct': 100.0, 'neg_pct': 0.0}
{'file': '/content/atlas_slice_with_neg/text-list-100-sec.jsonl', 'rows': 219667, 'pos_rows': 10724, 'neg_rows': 208943, 'pos_pct': 4.88, 'neg_pct': 95.12}

--- Overall augmented slice composition ---
Total rows: 219940
Pos rows  : 10997
Neg rows  : 208943
Pos %     : 5.0
Neg %     : 95.0
