# EfficientRAG: Inference & Evaluation

This is the final notebook in the EfficientRAG series. In this phase, we use the models trained in the previous step (Filter and Labeler) to perform end-to-end inference on the HotpotQA dataset.

**Key Objectives:**
1.  **Setup:** Install dependencies and download the trained model checkpoints.
2.  **Indexing:** Build a dense retrieval index (FAISS) for the corpus using Contriever.
3.  **Inference Pipeline:** Run the full RAG pipeline:
    * **Retrieve:** Fetch relevant documents.
    * **Filter:** Remove irrelevant passages.
    * **Label/Extract:** Isolate key tokens.
    * **Generate:** Produce the final answer using a generator (e.g., Llama-3 or GPT).
4.  **Evaluation:** Measure performance using standard metrics (Exact Match, F1, etc.).

## 1. Environment Setup

We start by cloning the repository and installing the necessary libraries. We also install `faiss-gpu` for fast similarity search and `vllm` for efficient LLM inference.

In [1]:
! git clone https://github.com/nil-zhuang/efficientrag-official.git

Cloning into 'efficientrag-official'...
remote: Enumerating objects: 122, done.[K
remote: Counting objects: 100% (122/122), done.[K
remote: Compressing objects: 100% (111/111), done.[K
remote: Total 122 (delta 13), reused 108 (delta 7), pack-reused 0 (from 0)[K
Receiving objects: 100% (122/122), 1.81 MiB | 9.75 MiB/s, done.
Resolving deltas: 100% (13/13), done.


In [2]:
%cd efficientrag-official
! pip install -r requirements.txt

/content/efficientrag-official
Collecting faiss-cpu (from -r requirements.txt (line 1))
  Downloading faiss_cpu-1.13.0-cp39-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (7.7 kB)
Collecting msal (from -r requirements.txt (line 2))
  Downloading msal-1.34.0-py3-none-any.whl.metadata (11 kB)
Collecting accelerate==0.29.1 (from -r requirements.txt (line 14))
  Downloading accelerate-0.29.1-py3-none-any.whl.metadata (18 kB)
Collecting deepspeed>=0.14.1 (from -r requirements.txt (line 16))
  Downloading deepspeed-0.18.2.tar.gz (1.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m43.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting vllm (from -r requirements.txt (line 17))
  Downloading vllm-0.11.2-cp38-abi3-manylinux1_x86_64.whl.metadata (18 kB)
Collecting black (from -r requirements.txt (line 19))
  Downloading black-25.11.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.many

In [3]:
! pip install -U "accelerate>=1.1.0" "peft>=0.16.0"

Collecting accelerate>=1.1.0
  Downloading accelerate-1.12.0-py3-none-any.whl.metadata (19 kB)
Downloading accelerate-1.12.0-py3-none-any.whl (380 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m380.9/380.9 kB[0m [31m15.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: accelerate
  Attempting uninstall: accelerate
    Found existing installation: accelerate 0.29.1
    Uninstalling accelerate-0.29.1:
      Successfully uninstalled accelerate-0.29.1
Successfully installed accelerate-1.12.0


In [1]:
%cd efficientrag-official

/content/efficientrag-official


In [2]:
import os, json, random
from pathlib import Path

# Project root (adjust if your %cd differs)
ROOT = Path.cwd()
DATA = ROOT / "data"
DATASET_DIR = DATA / "dataset" / "hotpotQA"
MODEL_CACHE = ROOT / "model_cache"

for p in [DATASET_DIR, MODEL_CACHE]:
    p.mkdir(parents=True, exist_ok=True)

# Make HF/Transformers use our local cache folder
os.environ["HF_HOME"] = str(MODEL_CACHE)
os.environ["TRANSFORMERS_CACHE"] = str(MODEL_CACHE)
os.environ["HF_DATASETS_CACHE"] = str(MODEL_CACHE)

## 2. Download Models & Data

We need to retrieve three sets of artifacts:
1.  **Base Models:** Contriever (for retrieval) and the Generator (e.g., Llama-3).
2.  **Trained Checkpoints:** The Filter and Labeler models we trained in the previous notebook.
3.  **Corpus Data:** The Wikipedia corpus used for retrieval.

In [3]:
import json
from pathlib import Path
from datasets import load_dataset

# Paths
DATASET_DIR = Path("data/dataset/hotpotQA")
DATASET_DIR.mkdir(parents=True, exist_ok=True)
train_path = DATASET_DIR / "train.json"
dev_path   = DATASET_DIR / "valid.json"
test_path  = DATASET_DIR / "test.json"

# Load HF dataset
ds = load_dataset("hotpotqa/hotpot_qa", "distractor")

# Deterministic tiny subsets
SEED = 42
train_small = ds["train"].shuffle(seed=SEED)#.select(range(1))
val_small   = ds["validation"].shuffle(seed=SEED)

# Split val_small into dev_small and test_small (approximately half each)
val_size = len(val_small)
dev_size = val_size // 2

dev_small   = val_small.select(range(dev_size))
test_small  = val_small.select(range(dev_size, val_size))

# Save EXACT records (preserve all original fields), as JSON arrays
with train_path.open("w", encoding="utf-8") as f:
    json.dump([ex for ex in train_small], f, ensure_ascii=False)

with dev_path.open("w", encoding="utf-8") as f:
    json.dump([ex for ex in dev_small], f, ensure_ascii=False)

with test_path.open("w", encoding="utf-8") as f:
    json.dump([ex for ex in test_small], f, ensure_ascii=False)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

distractor/train-00000-of-00002.parquet:   0%|          | 0.00/166M [00:00<?, ?B/s]

distractor/train-00001-of-00002.parquet:   0%|          | 0.00/166M [00:00<?, ?B/s]

distractor/validation-00000-of-00001.par(…):   0%|          | 0.00/27.5M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/90447 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/7405 [00:00<?, ? examples/s]

In [None]:
import os
token = ... # put your own Hugging Face access token

if token:
    # Non-interactive login (recommended: set the env var in Colab “Secrets” or in a cell)
    from huggingface_hub import login
    login(token=token, add_to_git_credential=False)
else:
    # Fallback: interactive login widget
    from huggingface_hub import notebook_login
    notebook_login()


In [5]:
import os
import gdown
import zipfile

# Create destination directories
os.makedirs("trained_models/labeler", exist_ok=True)
os.makedirs("trained_models/filter", exist_ok=True)

labeler_zip = "trained_models/labeler.zip"
filter_zip  = "trained_models/filter.zip"

# === Download from Drive ===
print("Downloading checkpoints...")
gdown.download(id="1uZtDV6cBMv7S4sLeXHGXASlN1F8OGt3v", output=labeler_zip, quiet=False)
gdown.download(id="1-jxDEQXbYxLpTghuEXYGvpgcqx9FktqK", output=filter_zip, quiet=False)

# === Extract ===
print("Extracting labeler...")
with zipfile.ZipFile(labeler_zip, 'r') as zip_ref:
    zip_ref.extractall("trained_models/labeler")

print("Extracting filter...")
with zipfile.ZipFile(filter_zip, 'r') as zip_ref:
    zip_ref.extractall("trained_models/filter")

# === Check that they extracted ===
print("Labeler files:", os.listdir("trained_models/labeler"))
print("Filter files:", os.listdir("trained_models/filter"))

Downloading checkpoints...


Downloading...
From (original): https://drive.google.com/uc?id=1uZtDV6cBMv7S4sLeXHGXASlN1F8OGt3v
From (redirected): https://drive.google.com/uc?id=1uZtDV6cBMv7S4sLeXHGXASlN1F8OGt3v&confirm=t&uuid=790c517c-9176-4efd-84a9-06a9d6d30177
To: /content/efficientrag-official/trained_models/labeler.zip
100%|██████████| 1.57G/1.57G [00:16<00:00, 94.2MB/s]
Downloading...
From (original): https://drive.google.com/uc?id=1-jxDEQXbYxLpTghuEXYGvpgcqx9FktqK
From (redirected): https://drive.google.com/uc?id=1-jxDEQXbYxLpTghuEXYGvpgcqx9FktqK&confirm=t&uuid=9d439940-b292-4bd7-8a3e-d8527ed93c62
To: /content/efficientrag-official/trained_models/filter.zip
100%|██████████| 1.54G/1.54G [00:16<00:00, 92.0MB/s]


Extracting labeler...
Extracting filter...
Labeler files: ['added_tokens.json', 'spm.model', 'tokenizer_config.json', 'model.safetensors', 'special_tokens_map.json', 'config.json', 'trainer_state.json', 'training_args.bin']
Filter files: ['added_tokens.json', 'spm.model', 'tokenizer_config.json', 'model.safetensors', 'special_tokens_map.json', 'config.json', 'trainer_state.json', 'training_args.bin']


In [6]:
import torch
from transformers import AutoTokenizer, AutoModel

# Contriever retriever (MS MARCO variant)
ctr_name = "facebook/contriever-msmarco"
ctr_tok = AutoTokenizer.from_pretrained(ctr_name, cache_dir=str(MODEL_CACHE))
ctr = AutoModel.from_pretrained(ctr_name, cache_dir=str(MODEL_CACHE))

# DeBERTa-v3-large encoder (used by EfficientRAG Labeler/Filter)
deb_name = "microsoft/deberta-v3-large"
deb = AutoModel.from_pretrained(deb_name, torch_dtype="auto", cache_dir=str(MODEL_CACHE))

print("Models cached under:", MODEL_CACHE)



tokenizer_config.json:   0%|          | 0.00/321 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/619 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/580 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/874M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/874M [00:00<?, ?B/s]

Models cached under: /content/efficientrag-official/model_cache


In [7]:
from huggingface_hub import snapshot_download

# 2) Make the folders expected by the repo
import os, pathlib
pathlib.Path("model_cache/contriever-msmarco").mkdir(parents=True, exist_ok=True)
pathlib.Path("model_cache/deberta-v3-large").mkdir(parents=True, exist_ok=True)

# 3) Download model snapshots there (no symlinks to avoid path surprises)
snapshot_download(
    repo_id="facebook/contriever-msmarco",
    local_dir="model_cache/contriever-msmarco",
    local_dir_use_symlinks=False
)
snapshot_download(
    repo_id="microsoft/deberta-v3-large",
    local_dir="model_cache/deberta-v3-large",
    local_dir_use_symlinks=False
)

# 4) (Optional but reduces warnings) Prefer HF_HOME over TRANSFORMERS_CACHE
import os
os.environ["HF_HOME"] = os.path.abspath("model_cache")
os.environ["TRANSFORMERS_CACHE"] = os.environ["HF_HOME"]  # still set for older code

For more details, check out https://huggingface.co/docs/huggingface_hub/main/en/guides/download#download-files-to-local-folder.


Fetching 8 files:   0%|          | 0/8 [00:00<?, ?it/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer_config.json:   0%|          | 0.00/321 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

.gitattributes: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/619 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

Fetching 9 files:   0%|          | 0/9 [00:00<?, ?it/s]

README.md: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/580 [00:00<?, ?B/s]

generator_config.json:   0%|          | 0.00/560 [00:00<?, ?B/s]

spm.model:   0%|          | 0.00/2.46M [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/874M [00:00<?, ?B/s]

.gitattributes: 0.00B [00:00, ?B/s]

tokenizer_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

tf_model.h5:   0%|          | 0.00/1.74G [00:00<?, ?B/s]

pytorch_model.generator.bin:   0%|          | 0.00/571M [00:00<?, ?B/s]

## 3. Codebase Adaptation

To ensure the inference script runs smoothly in this specific environment, we apply patches to:
* Fix file paths for saving results.
* Adjust the `config.py` to point to our local model paths.
* Ensure compatibility with the specific version of `transformers` installed.

In [8]:
import pathlib, re

p = pathlib.Path("src/efficient_rag/model/model.py")
code = p.read_text()

# 1) Remove StableDropout from the v2 import list (robust to spacing/newlines)
code2 = re.sub(
    r"(?m)^(from\s+transformers\.models\.deberta_v2\.modeling_deberta_v2\s+import\s*\([\s\S]*?\))",
    lambda m: re.sub(r"(?m)^[ \t]*StableDropout,\s*\n", "", m.group(0)),
    code,
    count=1,
)

# 2) Insert the fallback try/except block right after the v2 import,
#    but only if we haven't already added it.
try_block = """\
try:
    # Older layouts where v2 exported it
    from transformers.models.deberta_v2.modeling_deberta_v2 import StableDropout  # noqa: F401
except Exception:
    try:
        # Common/current location
        from transformers.models.deberta.modeling_deberta import StableDropout  # noqa: F401
    except Exception:
        # Last resort: use vanilla dropout (slightly different behavior but unblocks you)
        from torch.nn import Dropout as StableDropout  # type: ignore
"""

if "from transformers.models.deberta.modeling_deberta import StableDropout" not in code2 and "Dropout as StableDropout" not in code2:
    code2 = re.sub(
        r"(?m)^(from\s+transformers\.models\.deberta_v2\.modeling_deberta_v2\s+import\s*\([\s\S]*?\))\s*\n",
        r"\1\n" + try_block + "\n",
        code2,
        count=1,
    )

p.write_text(code2)

3793

In [9]:
import pathlib, re
p = pathlib.Path("src/conf/config.py")
s = p.read_text(encoding="utf-8")

s = re.sub(
    r'MODEL_DICT\s*=\s*\{[^}]+\}',
    '''MODEL_DICT = {
    "gpt35": "gpt-35-turbo-1106",
    "gpt4": "gpt-4-0125-preview",
    "llama": "llama-3-70b-gptq-int4",
    "llama-8B": "meta-llama/Meta-Llama-3-8B-Instruct",
    "deepseek": "deepseek-chat",
}''',
    s,
    flags=re.S,
)
p.write_text(s, encoding="utf-8")

1988

In [10]:
# Apply a tiny patch in-place
import re, pathlib

p = pathlib.Path("src/retrievers/multihop_data_extrator.py")
code = p.read_text()
code = code.replace(
    'with open(output_dir, "w+") as f:',
    'os.makedirs(os.path.dirname(output_dir), exist_ok=True)\n    with open(output_dir, "w+") as f:'
)
p.write_text(code)
print("Patched multihop_data_extrator.py to auto-create the output directory.")

Patched multihop_data_extrator.py to auto-create the output directory.


In [11]:
# This script reads HotpotQA and writes the unified corpus jsonl under data/corpus/hotpotQA/corpus.jsonl
! python /content/efficientrag-official/src/retrievers/multihop_data_extrator.py --dataset hotpotQA

100% 507493/507493 [00:01<00:00, 438289.10it/s]


In [12]:
# Uses the default Contriever; outputs FAISS/emb files under output_dir
! python src/retrievers/passage_embedder.py \
  --passages data/corpus/hotpotQA/corpus.jsonl \
  --output_dir data/corpus/hotpotQA/contriever \
  --model_type contriever

Loading passages from data/corpus/hotpotQA/corpus.jsonl
Processing chunk 1/1
[2KEmbedding[35m   0%[0m [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0/496 [0m [ [33m0:00:00[0m < [36m-:--:--[0m , [31m? it/s[0m ]2025-11-22 20:13:46.114240: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1763842426.142622    7073 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1763842426.150948    7073 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1763842426.172412    7073 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1763842426.172435    7073 computati

## 4. Start Local Inference Server

We spin up a local **vLLM** server to act as our Generator. This allows the EfficientRAG pipeline to query the LLM (Meta-Llama-3-8B-Instruct) via a local API endpoint, avoiding external API costs and latency.

In [22]:
import subprocess, time, requests, json, os, signal

# Kill any prior instance
_ = subprocess.run("pkill -f vllm.entrypoints.openai.api_server || true", shell=True)

cmd = [
  "python","-m","vllm.entrypoints.openai.api_server",
  "--model","meta-llama/Meta-Llama-3-8B-Instruct",
  #"--quantization","gptq",
  "--dtype","auto","--host","0.0.0.0","--port","8000",
  #"--served-model-name","meta-llama/Meta-Llama-3-8B-Instruct",
  #"--gpu-memory-utilization","0.92",
  "--max-num-seqs","10",
  "--api-key","token-colab-local-1234",
  #"--enforce-eager"
]

proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)

# Wait for readiness
BASE="http://127.0.0.1:8000/v1"
HEAD={"Authorization":"Bearer token-colab-local-1234","Content-Type":"application/json"}
for _ in range(400):
    try:
        r=requests.get(f"{BASE}/models",headers=HEAD,timeout=2)
        if r.status_code==200: break
    except Exception: pass
    time.sleep(1)

print("Server OK:", r.status_code==200)

# Keep 'proc' to stop later:
# proc.terminate(); proc.wait()

Server OK: True


In [14]:
import pathlib

new_code = """
import json
import random
import re
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Callable, Literal, Optional

from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
from tqdm.rich import tqdm_rich
from language_models import LanguageModel

def _to_text(result):
    if result is None:
        return ""
    # Already a string
    if isinstance(result, str):
        return result
    # OpenAI/vLLM modern clients return pydantic objects with .choices[0].message.content
    try:
        if hasattr(result, "choices"):
            ch0 = result.choices[0]
            # vLLM/OpenAI: choices[i].message.content
            msg = getattr(ch0, "message", None)
            if msg is not None:
                content = getattr(msg, "content", None)
                if isinstance(content, str):
                    return content
        # Dict-like fallbacks
        if isinstance(result, dict):
            choices = result.get("choices") or []
            if choices:
                message = choices[0].get("message") or {}
                content = message.get("content")
                if isinstance(content, str):
                    return content
    except Exception:
        pass
    # Last resort: stringify (prevents TypeError in regex)
    return str(result)


class EmptyContentError(RuntimeError):
    pass

def _safe_json_parse(result: Optional[str]) -> Optional[dict]:
    if not result:
        return None
    # try fenced block first
    m = re.search(r"```json\\s*(\\{.*?\\})\\s*```", result, re.DOTALL | re.IGNORECASE)
    if m:
        try:
            return json.loads(m.group(1).strip())
        except Exception:
            pass
    # otherwise first top-level-looking {...}
    m2 = re.search(r"\\{.*\\}", result, re.DOTALL)
    if m2:
        try:
            return json.loads(m2.group(0).strip())
        except Exception:
            return None
    return None

def _get_parser(type_: str) -> Callable:
    if type_ == "json":
        return _safe_json_parse
    elif type_ in ("text", "raw"):
        # raw/text are the same here (return the string)
        return lambda s: s
    raise ValueError(f"Unsupported type: {type_}")

# Try up to 3 times; back off if the model returns empty
@retry(
    stop=stop_after_attempt(3),
    wait=wait_exponential(multiplier=1, min=1, max=6),
    retry=retry_if_exception_type(EmptyContentError),
    reraise=False,
    retry_error_callback=lambda _: None,
)
def ask_model(
    model: LanguageModel,
    prompt: str,
    system_msg: str = None,
    type: Literal["json", "text", "raw"] = "json",
    check_if_valid: Callable = None,
    sleep: bool = True,
    mode: Literal["chat", "completion"] = "chat",
) -> Optional[dict]:
    if sleep:
        time.sleep(random.uniform(0.8, 1.8))

    # 1) primary attempt with JSON mode when requested
    if mode == "chat":
        result = model.chat(prompt, system_msg, json_mode=(type == "json"))
    else:
        result = model.complete(prompt)

    result = _to_text(result)

    # None / empty payload from server? -> trigger retry (common with JSON mode or overload)
    if not result:
        raise EmptyContentError("empty content from server")

    parser = _get_parser(type)
    parsed = parser(result)

    # Fallback path also normalized:
    if type == "json" and parsed is None:
        if mode == "chat":
            fallback = model.chat(prompt, system_msg, json_mode=False)
        else:
            fallback = model.complete(prompt)
        fallback = _to_text(fallback)
        if not fallback:
            raise EmptyContentError("empty content after fallback")
        parsed = _safe_json_parse(fallback)
        if parsed is None:
            raise EmptyContentError("unparseable JSON after fallback")



    if check_if_valid is not None and parsed is not None and not check_if_valid(parsed):
        # validator says no -> return None (don’t raise to avoid hiding raw)
        return None

    return parsed

def ask_model_in_parallel(
    model: LanguageModel,
    prompts: list[str],
    system_msg: str = None,
    type: Literal["json", "text", "raw"] = "json",
    check_if_valid_list: list[Callable] = None,
    max_workers: int = 4,
    desc: str = "Processing...",
    verbose=True,
    mode: Literal["chat", "completion"] = "chat",
):
    if max_workers == -1:
        max_workers = len(prompts)
    assert max_workers >= 1

    if check_if_valid_list is None:
        check_if_valid_list = [None] * len(prompts)
    assert len(prompts) == len(check_if_valid_list)

    results = []
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = {
            executor.submit(
                ask_model,
                model=model,
                prompt=prompt,
                system_msg=system_msg,
                type=type,
                check_if_valid=chk,
                sleep=True,
                mode=mode,
            ): i
            for i, (prompt, chk) in enumerate(zip(prompts, check_if_valid_list))
        }
        for fut in tqdm_rich(as_completed(futures), total=len(futures), desc=desc, disable=not verbose):
            idx = futures[fut]
            try:
                results.append((idx, fut.result()))
            except Exception:
                results.append((idx, None))
        results.sort(key=lambda x: x[0])
        return [r for _, r in results]

def get_type_parser(type: str) -> Callable:
    def json_parser(result: str):
        # pattern = r"```json(.*?)```"
        pattern = r"{.*?}"
        matches = re.findall(pattern, result, re.DOTALL)
        if matches:
            result = matches[0].strip()
        return json.loads(result)

    def text_parser(result: str):
        return result

    if type == "json":
        return json_parser
    elif type == "text":
        return text_parser
    else:
        raise ValueError(f"Unsupported type: {type}")
"""

# Overwrite the file with the desired implementation
target = pathlib.Path("src/utils/model.py")
target.write_text(new_code, encoding="utf-8")
print(f"Wrote first-code implementation to {target}")

Wrote first-code implementation to src/utils/model.py


In [15]:
import pathlib

new_code = """
from openai import OpenAI

from .base import LanguageModel

LLAMA_ENDPOINT = "http://127.0.0.1:8000/v1"
LLAMA_API_KEY  = "token-colab-local-1234"


class LlamaServer(LanguageModel):
    def __init__(self, model, **_):
        self.model = model
        self.client = OpenAI(base_url=LLAMA_ENDPOINT, api_key=LLAMA_API_KEY, timeout=10.0, max_retries=3)

    def chat(self, message: str, system_msg: str = None, json_mode: bool = False):
        if system_msg is None:
            system_msg = "You are a helpful assistant."
        messages = [
            {"role": "system", "content": system_msg},
            {"role": "user", "content": message},
        ]
        kwargs = dict(temperature=0.1, top_p=1.0, max_tokens=512)

        # Prefer json_object; json_schema support can vary by vLLM version
        # If schema is critical for you, keep it—but add a safe fallback below.
        if json_mode:
            kwargs["response_format"] = {"type": "json_object"}

        try:
            resp = self.client.chat.completions.create(model=self.model, messages=messages, **kwargs)
        except Exception:
            # Retry once without response_format (some vLLM builds/models return empty/err with it)
            if "response_format" in kwargs:
                kwargs.pop("response_format", None)
            resp = self.client.chat.completions.create(model=self.model, messages=messages, **kwargs)

        # Return STRING content, not the whole object
        return getattr(resp.choices[0].message, "content", "") or ""


    def complete(self, prompts: str):
        response = self.client.completions.create(
            model=self.model, prompt=prompts, echo=False, max_tokens=100
        )
        response = response.choices[0].text
        return response


if __name__ == "__main__":
    llama = LlamaServer("Meta-Llama-3-8B-Instruct")
    response = llama.complete(
        "The reason of human landing on moon is that, some one found it strange behind the moon."
    )
    print(response)
"""

# Overwrite the file with the desired implementation
target = pathlib.Path("src/language_models/llama.py")
target.write_text(new_code, encoding="utf-8")
print(f"Wrote second-code implementation to {target}")

Wrote second-code implementation to src/language_models/llama.py


In [16]:
import pathlib

new_code = """
from .aoai import AOAI
from .base import LanguageModel
from .deepseek import DeepSeek
from .llama import LlamaServer

MODEL_DICT = {
    "gpt35": "gpt-35-turbo-1106",
    "gpt4": "gpt-4-0125-preview",
    "llama": "llama-3-70b-gptq-int4",
    "llama-8B": "meta-llama/Meta-Llama-3-8B-Instruct",
    "deepseek": "deepseek-chat",
}


def get_model(model_name: str, **kwargs) -> LanguageModel:
    if model_name in MODEL_DICT:
        model_name = MODEL_DICT[model_name]
    lower = model_name.lower()

    # 1) Prefer llama branch to avoid "gptq" false-positive
    if lower.startswith(("llama", "meta-llama", "meta")):
        return LlamaServer(model=model_name, **kwargs)
    # 2) Only route to AOAI for real GPT model names
    elif "gpt" in model_name.lower():
        return AOAI(model=model_name, **kwargs)
    if "deepseek" in lower:
        return DeepSeek(model=model_name, **kwargs)
    raise NotImplementedError(f"Model {model_name} not implemented")
"""

# Overwrite the file with the desired implementation
target = pathlib.Path("src/language_models/__init__.py")
target.write_text(new_code, encoding="utf-8")
print(f"Wrote third-code implementation to {target}")

Wrote third-code implementation to src/language_models/__init__.py


In [17]:
import pathlib, re

utils_path = pathlib.Path("src/utils/utils.py")
code = utils_path.read_text(encoding="utf-8")

# 1) Ensure imports (idempotent)
if "from pathlib import Path" not in code:
    code = re.sub(
        r'(^(\s*import[^\n]*\n|\s*from[^\n]*\n)+)',
        r'\1from pathlib import Path\n',
        code,
        count=1,
        flags=re.M
    ) if re.search(r'^(\s*import|\s*from)', code, flags=re.M) else "from pathlib import Path\n" + code

if re.search(r'^\s*import json\s*$', code, flags=re.M) is None:
    code = "import json\n" + code

# 2) Patch write_jsonl open(...) to Path version (mkdir + p.open)
pattern_write = re.compile(
    r'(?m)^(?P<indent>\s*)with\s+open\(\s*file_path\s*,\s*["\']w\+?["\'].*?\)\s*as\s*f\s*:\s*$'
)
replacement_write = (
    r'\g<indent>p = Path(file_path)\n'
    r'\g<indent>p.parent.mkdir(parents=True, exist_ok=True)\n'
    r'\g<indent>with p.open("w", encoding="utf-8") as f:'
)
code, n_write = pattern_write.subn(replacement_write, code)

# 3) Replace the entire load_jsonl definition with a safe version
#    Match 'def load_jsonl(...):' through the next 'def ' at same indent or EOF
pattern_load = re.compile(
    r'(?ms)^(?P<indent>\s*)def\s+load_jsonl\s*\([^)]*\)\s*:\s*'
    r'(?:.*?)(?=^\s*def\s+|\Z)'
)
replacement_load = (
    r'\g<indent>def load_jsonl(file_path, *, missing_ok=True):\n'
    r'\g<indent>    """\n'
    r'\g<indent>    Read records from a JSONL file.\n'
    r'\g<indent>    - If missing and missing_ok=True, returns [] and ensures parent dir exists.\n'
    r'\g<indent>    - If missing_ok=False, raises FileNotFoundError.\n'
    r'\g<indent>    """\n'
    r'\g<indent>    p = Path(file_path)\n'
    r'\g<indent>    if not p.exists():\n'
    r'\g<indent>        p.parent.mkdir(parents=True, exist_ok=True)\n'
    r'\g<indent>        if missing_ok:\n'
    r'\g<indent>            return []\n'
    r'\g<indent>        raise FileNotFoundError(f"JSONL file not found: {p}. Set missing_ok=True to return [].")\n'
    r'\g<indent>    with p.open("r", encoding="utf-8") as f:\n'
    r'\g<indent>        return [json.loads(line) for line in f if line.strip()]\n'
)
code, n_load = pattern_load.subn(replacement_load, code)

# 4) If no load_jsonl existed, append our safe version
if n_load == 0 and "def load_jsonl" not in code:
    code += "\n\n" + replacement_load.replace(r'\g<indent>', '')

utils_path.write_text(code, encoding="utf-8")
print(f"Patched write sites: {n_write}, patched load_jsonl: {max(n_load, 1) if 'def load_jsonl' in code else 0}")


Patched write sites: 1, patched load_jsonl: 1


In [None]:
import pathlib

new_code = """
import argparse
import json
import os
import sys
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Iterator
import subprocess
import time
import signal
import requests

from tqdm.rich import tqdm_rich

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from prompts import (
    TOKEN_LABEL_REDUNDANT_EVALUATION_PROMPT,
    TOKEN_LABEL_REDUNDANT_SYSTEM_MSG,
    TOKEN_LABEL_SYNTHESIZE_FEW_SHOT_PROMPT_MUSIQUE,
    TOKEN_LABEL_SYNTHESIZE_FEW_SHOT_PROMPT_WIKIMQA,
    TOKEN_LABELING_SYSTEM_MSG,
)

from conf import (
    MODEL_DICT,
    SYNTHESIZED_DECOMPOSED_DATA_PATH,
    SYNTHESIZED_TOKEN_LABELING_DATA_PATH,
)
from language_models import LanguageModel, get_model
from utils import ask_model, ask_model_in_parallel, load_jsonl
from utils.model import get_type_parser

TOKEN_LABEL_PROMPT_TEMPLATE_MAPPING = {
    "musique": TOKEN_LABEL_SYNTHESIZE_FEW_SHOT_PROMPT_MUSIQUE,
    "musique-simple": TOKEN_LABEL_SYNTHESIZE_FEW_SHOT_PROMPT_MUSIQUE,
    "2WikiMQA": TOKEN_LABEL_SYNTHESIZE_FEW_SHOT_PROMPT_WIKIMQA,
    "hotpotQA": TOKEN_LABEL_SYNTHESIZE_FEW_SHOT_PROMPT_WIKIMQA
}

# ---------- Local Llama server management (for llama-8B) ----------

LLAMA_LOCAL_BASE_URL = "http://127.0.0.1:8000/v1"
LLAMA_LOCAL_API_KEY = "token-colab-local-1234"


def start_llama_server() -> subprocess.Popen:
    # Kill any prior instance just in case
    subprocess.run(
        "pkill -f vllm.entrypoints.openai.api_server || true",
        shell=True,
        check=False,
    )

    cmd = [
        "python",
        "-m",
        "vllm.entrypoints.openai.api_server",
        "--model",
        "meta-llama/Meta-Llama-3-8B-Instruct",
        "--dtype",
        "auto",
        "--host",
        "0.0.0.0",
        "--port",
        "8000",
        "--api-key",
        LLAMA_LOCAL_API_KEY,
        "--max-model-len",
        "8192",
        "--max-num-seqs",
        "6",        # conservative concurrency
        "--gpu-memory-utilization",
        "0.9",
    ]

    proc = subprocess.Popen(
        cmd,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
        bufsize=1,
    )

    # Wait for readiness
    headers = {
        "Authorization": f"Bearer {LLAMA_LOCAL_API_KEY}",
        "Content-Type": "application/json",
    }
    for _ in range(300):
        try:
            r = requests.get(f"{LLAMA_LOCAL_BASE_URL}/models", headers=headers, timeout=2)
            if r.status_code == 200:
                print("[token_labeling] vLLM server is ready.")
                return proc
        except Exception:
            pass
        time.sleep(1)

    print("[token_labeling] vLLM server did not become ready in time; killing it.")
    proc.kill()
    raise RuntimeError("Failed to start vLLM server")


def stop_llama_server(proc: subprocess.Popen):
    if proc is None:
        return
    print("[token_labeling] Stopping vLLM server...")
    try:
        proc.terminate()
        try:
            proc.wait(timeout=15)
        except subprocess.TimeoutExpired:
            proc.kill()
    except Exception:
        pass

# -----------------------------------------------------------------


class TokenLabeler:
    def __init__(self, model: str, dataset: str, split: str) -> None:
        self.model: LanguageModel
        self.model = get_model(model)

        labeled_data_path = os.path.join(
            SYNTHESIZED_DECOMPOSED_DATA_PATH, dataset, f"{split}.jsonl"
        )
        self.labeled_data = load_jsonl(labeled_data_path)
        self.check_if_valid = lambda x: all(
            [k in x.keys() for k in ["extracted_words"]]
        )
        self.token_labeling_prompt = TOKEN_LABEL_PROMPT_TEMPLATE_MAPPING[dataset]

    def parse(self, starting: int = 0, workers=10):
        labeled_data = self.labeled_data[starting:]
        # keep only samples whose overall state is None
        labeled_data = [d for d in labeled_data if d.get("state", None) is None]

        for sample in tqdm_rich(labeled_data, desc="Processing..."):
            yield self.parse_sample(sample)


    def parse_sample(self, sample: dict) -> dict:
        prompt_list = self.parse_prompt(sample)
        results = []
        for prompt in prompt_list:
            result = ask_model(
                self.model,
                prompt,
                TOKEN_LABELING_SYSTEM_MSG,
                type="json",
                check_if_valid=self.check_if_valid,
            )
            if result is None:
                result = {"extracted_words": "", "status": "error"}
            results.append(result)
        for subq_id, result in zip(
            sorted(sample["decomposed_questions"].keys()), results
        ):
            chunk = sample["decomposed_questions"][subq_id]
            chunk["extracted_words"] = result["extracted_words"]
        return sample

    def parse_prompt(self, data: dict) -> list[dict]:
        prompt_list = []
        for subq_id in sorted(data["decomposed_questions"].keys()):
            subq = data["decomposed_questions"][subq_id]
            format_kwargs = {
                "question": subq["sub_question"],
                "paragraph": subq["positive_paragraph"],
                "answer": subq["answer"],
            }
            prompt = self.token_labeling_prompt.format(**format_kwargs)
            prompt_list.append(prompt)
        return prompt_list

    def parse_failed(self, token_labeled_data: list[dict]) -> list[dict]:
        results = []
        failed_question_ids = set()
        for sample in token_labeled_data:
            for sub_question_id in sorted(sample["decomposed_questions"].keys()):
                if (
                    sample["decomposed_questions"][sub_question_id].get("state", None)
                    == "error"
                ):
                    failed_question_ids.add(sample["id"])
                    break
        progress = tqdm_rich(
            desc="Processing failed...", total=len(failed_question_ids)
        )
        for sample in token_labeled_data:
            if sample["id"] not in failed_question_ids:
                results.append(sample)
                continue
            for sub_question_id in sorted(sample["decomposed_questions"].keys()):
                if (
                    sample["decomposed_questions"][sub_question_id].get("state", None)
                    != "error"
                ):
                    continue
                prompt_list = self.parse_prompt(sample)
                prompt = prompt_list[int(sub_question_id) - 1]
                result = ask_model(
                    self.model, prompt, type="json", check_if_valid=self.check_if_valid
                )
                if result is None:
                    continue
                del sample["decomposed_questions"][sub_question_id]["state"]
                sample["decomposed_questions"][sub_question_id]["extracted_words"] = (
                    result["extracted_words"]
                )
            progress.update(1)
            results.append(sample)
        return results


class TokenReLabeler:
    def __init__(self, model: str, dataset: str, split: str) -> None:
        self.model: LanguageModel
        self.model = get_model(model)
        self.model_powerful = get_model("Llama3-8B-Instruct")

        labeled_data_path = os.path.join(
            SYNTHESIZED_TOKEN_LABELING_DATA_PATH, dataset, f"{split}.jsonl"
        )
        self.labeled_data = load_jsonl(labeled_data_path)
        self.check_if_valid = lambda x: all(
            [k in x.keys() for k in ["extracted_words"]]
        )
        self.check_redundant_valid = lambda x: type(x) == dict and all(
            [k in x.keys() for k in ["redundant", "missing"]]
        )
        self.type_parser = get_type_parser(type="json")

    def label_redundant(self, labeled_data: list[dict], workers: int) -> list[dict]:
        redundant_questions = []
        with ThreadPoolExecutor(max_workers=workers) as executor:
            tasks = {
                executor.submit(self.check_sample_redundant, sample): idx
                for idx, sample in enumerate(labeled_data)
            }
            for future in tqdm_rich(
                as_completed(tasks), total=len(tasks), desc="Redundant"
            ):
                task_id = tasks[future]
                try:
                    result = future.result()
                    redundant_questions.append((task_id, result))
                finally:
                    ...
            redundant_questions = [
                r[1] for r in sorted(redundant_questions, key=lambda x: x[0])
            ]
        for sample, redundant in zip(labeled_data, redundant_questions):
            assert redundant["id"] == sample["id"]
            for subq_id in redundant["redundant"]:
                sample["decomposed_questions"][subq_id]["redundant"] = True
        return labeled_data

    def parse(self, workers: int = 10, redundant_labeled: bool = False) -> list[dict]:
        labeled_data = [
            d
            for d in self.labeled_data
            if all(
                chunk.get("state", None) is None
                for chunk in d["decomposed_questions"].values()
            )
        ]

        # 1. use GPT3.5 to identify if extracted words is redundant or missing
        if not redundant_labeled:
            labeled_data = self.label_redundant(labeled_data, workers)

        # 2. use Llama3 to re-label the extracted words
        results = []
        data_mapping = {d["id"]: d for d in labeled_data}

        max_iter = 5
        current_iter = 0
        while current_iter < max_iter:
            current_iter += 1

            prompts = []
            for sample in labeled_data:
                sample_prompts = self.build_relabel_prompt(sample)
                prompts.extend(sample_prompts)

            print(
                f"Current iteration: {current_iter}, "
                f"max iteration: {max_iter}, "
                f"handling {len(prompts)} prompts."
            )
            if len(prompts) <= 0:
                break

            batched_prompts = [p["prompt"] for p in prompts]
            results = self.model_powerful.chat(
                batched_prompts, TOKEN_LABELING_SYSTEM_MSG, json_mode=True
            )

            for prompt, result in zip(prompts, results):
                try:
                    json_result = self.type_parser(result)
                    data = data_mapping[prompt["id"]]
                    chunk = data["decomposed_questions"][prompt["subq_id"]]
                    chunk["extracted_words_old"] = chunk["extracted_words"]
                    chunk["extracted_words"] = json_result["extracted_words"]
                    chunk["redundant"] = False
                except json.JSONDecodeError:
                    print(f"Error on {prompt['id']} sub-question {prompt['subq_id']}")
                    json_result = None

        return labeled_data

    def build_relabel_prompt(self, sample: dict):
        prompts = []
        for subq_id, chunk in sample["decomposed_questions"].items():
            if not chunk.get("redundant", False):
                break
            sub_question = chunk["sub_question"]
            paragraph = chunk["positive_paragraph"]
            answer = chunk["answer"]
            prompt = TOKEN_LABEL_SYNTHESIZE_FEW_SHOT_PROMPT_MUSIQUE.format(
                question=sub_question, paragraph=paragraph, answer=answer
            )
            info = {
                "id": sample["id"],
                "subq_id": subq_id,
                "prompt": prompt,
            }
            prompts.append(info)
        return prompts

    def check_sample_redundant(self, sample: dict):
        redundant = {"id": sample["id"], "redundant": []}
        for subq_id, subq in sample["decomposed_questions"].items():
            question = subq["sub_question"]
            answer = subq["answer"]
            extracted_words = subq["extracted_words"]
            evaluation_prompt = TOKEN_LABEL_REDUNDANT_EVALUATION_PROMPT.format(
                question=question, answer=answer, extracted_words=extracted_words
            )
            evaluation = ask_model(
                self.model,
                evaluation_prompt,
                TOKEN_LABEL_REDUNDANT_SYSTEM_MSG,
                type="json",
                check_if_valid=self.check_redundant_valid,
            )
            if evaluation["redundant"] or evaluation["missing"]:
                redundant["redundant"].append(subq_id)
        return redundant


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dataset",
        type=str,
        default="musique",
    )
    parser.add_argument("--split", type=str, default="valid")
    parser.add_argument("--model", default="gpt4")
    parser.add_argument(
        "--workers", type=int, default=10, help="Number of parallel processors"
    )
    parser.add_argument(
        "--sync", action="store_true", help="Syncing with label fixed data"
    )
    parser.add_argument("--failed", action="store_true", help="Parse failed data")
    parser.add_argument("--failed_path", type=str, help="Path to failed data")
    parser.add_argument(
        "--relabel", action="store_true", help="Re-label extracted words"
    )
    args = parser.parse_args()
    return args


def main(opt: argparse.Namespace):
    model_name = MODEL_DICT[opt.model]

    # Decide if we should manage a local vLLM server (llama-8B case)
    use_local_llama = opt.model == "llama-8B"

    server_proc = None
    try:
        if use_local_llama:
            server_proc = start_llama_server()

        labeler = TokenLabeler(model_name, opt.dataset, opt.split)

        os.makedirs(
            os.path.join(SYNTHESIZED_TOKEN_LABELING_DATA_PATH, opt.dataset),
            exist_ok=True,
        )

        out_path = os.path.join(
            SYNTHESIZED_TOKEN_LABELING_DATA_PATH,
            opt.dataset,
            f"{opt.split}.jsonl",
        )

        processed = 0
        # Open in "w+" so each run of this script overwrites old file for the same split
        with open(out_path, "w+", encoding="utf-8") as f:
            # We ignore opt.workers here because parse() is now sequential.
            for labeled in labeler.parse(workers=1):
                info = json.dumps(labeled, ensure_ascii=False)
                f.write(info + "\n")
                f.flush()  # ensure it's actually written to disk
                processed += 1

                # Every 200 processed samples, restart the local LLM server
                if use_local_llama and processed % 150 == 0:
                    print(
                        f"[token_labeling] Processed {processed} samples; "
                        "restarting vLLM server to clear caches..."
                    )
                    stop_llama_server(server_proc)
                    server_proc = start_llama_server()

        print(f"[token_labeling] Finished. Total processed: {processed}")

    finally:
        # Always try to stop the server at the end
        if use_local_llama and server_proc is not None:
            stop_llama_server(server_proc)



if __name__ == "__main__":
    options = parse_args()
    main(options)
"""

# Overwrite the file with the desired implementation
target = pathlib.Path("src/data_synthesize/token_labeling.py")
target.write_text(new_code, encoding="utf-8")
print(f"Wrote 5th-code implementation to {target}")

In [19]:
from importlib import reload
import src.language_models.llama as llama_mod
import src.data_synthesize.prompts.hotpotQA as hotpotQA_mod
import src.utils.model as model_mod
import src.language_models.__init__ as init_mod

reload(llama_mod)
reload(hotpotQA_mod)
reload(model_mod)
reload(init_mod)

import requests
from src.language_models.llama import LLAMA_ENDPOINT, LLAMA_API_KEY

def ensure_model_visible(name="meta-llama/Meta-Llama-3-8B-Instruct",
                         base=LLAMA_ENDPOINT, api_key=LLAMA_API_KEY):
    # base is like "http://127.0.0.1:8000/v1"
    r = requests.get(f"{base}/models",
                     headers={"Authorization": f"Bearer {api_key}"},
                     timeout=10)
    r.raise_for_status()
    names = [m["id"] for m in r.json().get("data", [])]
    assert name in names, f"Model {name} not found; got: {names}"
    return names

print(ensure_model_visible())

['meta-llama/Meta-Llama-3-8B-Instruct']


In [20]:
llm = llama_mod.LlamaServer(model="meta-llama/Meta-Llama-3-8B-Instruct")  # must match /v1/models
print(llm.chat("Hi! If I have 5 apples and someone takes 2, how many are left?"))

Let's count them together!

You started with 5 apples, and someone took 2. To find out how many are left, we'll subtract 2 from 5.

5 - 2 = 3

So, there are 3 apples left!


In [23]:
! python src/data_synthesize/query_decompose.py \
  --dataset hotpotQA \
  --split test \
  --model llama-8B \
  --ending 400

  for task in tqdm_rich(
[2KFailed to synthesize sample 5a8b6ee255429950cd6afcfd
[2KFailed to synthesize sample 5adf03765542995534e8c72b
[2KFailed to synthesize sample 5add21555542992ae4cec496
[2KFailed to synthesize sample 5ac371aa5542995ef918c191
[2KFailed to synthesize sample 5ac3ba0455429939154138f2
[2KFailed to synthesize sample 5a883c175542997e5c09a5ae
[2KFailed to synthesize sample 5adf4da15542995534e8c780
[2KFailed to synthesize sample 5adf123f5542995ec70e8f63
[2KFailed to synthesize sample 5abf12005542997719eab661
[2KFailed to synthesize sample 5a745a2955429929fddd83f8
[2KFailed to synthesize sample 5ae1faf15542997f29b3c1e3
[2KFailed to synthesize sample 5a7ae2f2554299042af8f6aa
[2KFailed to synthesize sample 5a80e5f5554299260e20a1ae
[2KFailed to synthesize sample 5add2b865542992c1e3a2550
[2KFailed to synthesize sample 5ae0fd9655429920d5234264
[2KFailed to synthesize sample 5abdb5b455429965af743dfa
[2KFailed to synthesize sample 5ab89cc25542991b5579efc3
[2KFa

In [None]:
proc.terminate()

In [5]:
import subprocess, time, requests, json, os, signal

# Kill any prior instance
_ = subprocess.run("pkill -f vllm.entrypoints.openai.api_server || true", shell=True)

cmd = [
  "python","-m","vllm.entrypoints.openai.api_server",
  "--model","meta-llama/Meta-Llama-3-8B-Instruct",
  #"--quantization","gptq",
  "--dtype","auto","--host","0.0.0.0","--port","8000",
  #"--served-model-name","meta-llama/Meta-Llama-3-8B-Instruct",
  #"--gpu-memory-utilization","0.92",
  "--max-num-seqs","4",
  "--api-key","token-colab-local-1234",
  #"--enforce-eager"
]

proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)

# Wait for readiness
BASE="http://127.0.0.1:8000/v1"
HEAD={"Authorization":"Bearer token-colab-local-1234","Content-Type":"application/json"}
for _ in range(400):
    try:
        r=requests.get(f"{BASE}/models",headers=HEAD,timeout=2)
        if r.status_code==200: break
    except Exception: pass
    time.sleep(1)

print("Server OK:", r.status_code==200)

Server OK: True


In [24]:
import pathlib, re
p = pathlib.Path("src/data_synthesize/token_labeling.py")
s = p.read_text(encoding="utf-8")
if '"hotpotQA":' not in s:
    pat = re.compile(r'(TOKEN_LABEL_PROMPT_TEMPLATE_MAPPING\s*=\s*\{)(.*?)(\})', re.S)
    s = pat.sub(lambda m: f'{m.group(1)}{m.group(2).rstrip()}{"," if m.group(2).strip() and not m.group(2).rstrip().endswith(",") else ""}\n    "hotpotQA": TOKEN_LABEL_SYNTHESIZE_FEW_SHOT_PROMPT_WIKIMQA\n{m.group(3)}', s, count=1)
    p.write_text(s, encoding="utf-8")


s = p.read_text(encoding="utf-8")
if "os.makedirs(" not in s:
    if "import os" not in s:
        s = "import os\n" + s
    pat = re.compile(r'(?m)^(?P<i>\s*)with\s+open\(\s*os\.path\.join\([^)]*\)\s*,\s*["\']w\+?["\'][^)]*\)\s*as\s*f\s*:\s*')
    s = pat.sub(lambda m: f'{m.group("i")}os.makedirs(os.path.join(SYNTHESIZED_TOKEN_LABELING_DATA_PATH, opt.dataset), exist_ok=True)\n{m.group(0)}', s, count=1)
    p.write_text(s, encoding="utf-8")

In [None]:
! python src/data_synthesize/token_labeling.py \
    --dataset hotpotQA \
    --split test \
    --model llama-8B

In [36]:
! python src/data_synthesize/token_extraction.py \
    --data_path data/synthesized_token_labeling/hotpotQA/test.jsonl \
    --save_path data/token_extracted/hotpotQA/test.jsonl \
    --verbose

  for sample in tqdm_rich(data):
[?25l[2K[35m   0%[0m [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0/165 [0m [ [33m0:00:00[0m < [36m-:--:--[0m , [31m? it/s[0m ][2KBrea 0 0 []
[35m   0%[0m [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0/165 [0m [ [33m0:00:00[0m < [36m-:--:--[0m , [31m? it/s[0m ][2KCalifornia 18 18 ['Brea', 'California', '.']
[35m   0%[0m [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0/165 [0m [ [33m0:00:00[0m < [36m-:--:--[0m , [31m? it/s[0m ][2KTorrid 0 0 []
[35m   0%[0m [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0/165 [0m [ [33m0:00:00[0m < [36m-:--:--[0m , [31m? it/s[0m ][2KBrea 77 75 ['the', 'Brea', 'Mall']
[35m   0%[0m [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0/165 [0m [ [33m0:00:00[0m < [36m-:--:--[0m , [31m? it/s[0m ][2KMall 78 78 ['Brea', 'Mall', 'in']
[35m   0%[0m [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0/165 [0m [ [33m0:00:00[0m < [36m-:--:--[0m , [31

In [37]:
import pathlib, re
p = pathlib.Path("src/data_synthesize/next_hop_query_construction.py")
s = p.read_text(encoding="utf-8")

if "os.makedirs(" not in s:
    if "import os" not in s:
        s = "import os\n" + s
    pat = re.compile(r'(?m)^(?P<i>\s*)with\s+open\(output_path[^:]+:')
    s = pat.sub(lambda m: f'{m.group("i")}os.makedirs(output_path.replace(f"{{opt.split}}.jsonl",""), exist_ok=True)\n{m.group(0)}', s, count=1)
    p.write_text(s, encoding="utf-8")

In [None]:
# Apply a small patch to implement build_prompt_template_hotpot
import pathlib

path = pathlib.Path("src/data_synthesize/next_hop_query_construction.py")
code = path.read_text()

old = '''    def build_prompt_template_hotpot(self, sample: dict, dependency: list[str]) -> str:
        raise NotImplementedError()
'''

new = '''    def build_prompt_template_hotpot(self, sample: dict, dependency: list[str]) -> str:
        text = """
You are helping with multi-hop question answering over Wikipedia.

You are given:
- The original multi-hop question:
<Question>: {question}

- Some information that is already known:
{info_list}

- Answers to earlier sub-questions:
{subq_answers}

Your task: write a single best next-hop search query that should be used
to retrieve more information needed to ultimately answer the original question.

Respond ONLY in JSON, with a single key "filtered_query" whose value is a string.
For example:
```json
{{"filtered_query": "Where was <person> born?"}}
````

Now respond with the JSON:
""".strip()
return text
'''

if old not in code:
  raise SystemExit("Old stub for build_prompt_template_hotpot not found — patch not applied.")

path.write_text(code.replace(old, new))
print("Patched build_prompt_template_hotpot in next_hop_query_construction.py")

In [6]:
! python src/data_synthesize/next_hop_query_construction.py \
    --dataset hotpotQA \
    --split test \
    --model llama-8B

  for future in tqdm_rich(
[2KProcessing...[35m 100%[0m [90m━━━━━━━━━━━━━━━━━━━━━[0m [32m165/165 [0m [ [33m0:00:30[0m < [36m0:00:00[0m , [31m6 it/s[0m ]
[?25h

In [7]:
import pathlib, re
p = pathlib.Path("src/data_synthesize/next_hop_query_filtering.py")
s = p.read_text(encoding="utf-8")
s = re.sub(
    r'for k, v in infos\.items\(\):\s*\n\s*v\s*=\s*v\s*/\s*num_samples\s*\*\s*100\s*\n\s*print\(f"{k}: {v:.2f}"\)',
    'for k, v in infos.items():\n        v = v / num_samples * 100 if num_samples != 0 else 0\n        print(f"{k}: {v:.2f}")',
    s,
    flags=re.S,
)
p.write_text(s, encoding="utf-8")

7875

In [8]:
! python src/data_synthesize/next_hop_query_filtering.py \
    --data_path data/synthesized_next_query/hotpotQA/test.jsonl \
    --save_path data/next_query_extracted/hotpotQA/test.jsonl \
    --verbose

  for sample in tqdm_rich(data):
[?25l[2K[35m   0%[0m [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0/165 [0m [ [33m0:00:00[0m < [36m-:--:--[0m , [31m? it/s[0m ][2Kbe 6 6 ['city', 'be', 'the']
[35m   0%[0m [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0/165 [0m [ [33m0:00:00[0m < [36m-:--:--[0m , [31m? it/s[0m ][2Kthe 7 7 ['be', 'the', 'first']
[35m   0%[0m [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0/165 [0m [ [33m0:00:00[0m < [36m-:--:--[0m , [31m? it/s[0m ][2Kfirst 8 8 ['the', 'first', 'Torrid']
[35m   0%[0m [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0/165 [0m [ [33m0:00:00[0m < [36m-:--:--[0m , [31m? it/s[0m ][2KTorrid 9 9 ['first', 'Torrid', 'location']
[35m   0%[0m [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0/165 [0m [ [33m0:00:00[0m < [36m-:--:--[0m , [31m? it/s[0m ][2Klocation 10 10 ['Torrid', 'location', 'open']
[35m   0%[0m [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0/165 [0

In [9]:
proc.terminate()

In [10]:
import pathlib, os

p = pathlib.Path("src/efficientrag_retrieve.py")
code = p.read_text()

# ensure os is imported in the file
if "import os" not in code:
    code = code.replace("import argparse", "import argparse\nimport os", 1)

# auto-create parent dir before writing
code = code.replace(
    'with open(output_path, "w+", encoding="utf-8") as f:',
    'os.makedirs(os.path.dirname(output_path), exist_ok=True)\n'
    '    with open(output_path, "w+", encoding="utf-8") as f:',
)

# switch valid.jsonl (commented demo) -> test.jsonl
old = """dataset = load_jsonl(
        os.path.join(
            SYNTHESIZED_NEXT_QUERY_EXTRACTED_DATA_PATH, opt.dataset, "valid.jsonl"
            # SYNTHESIZED_NEXT_QUERY_EXTRACTED_DATA_PATH, opt.dataset, "demo.jsonl"
        )
    )
"""
new = """dataset = load_jsonl(
        os.path.join(
            SYNTHESIZED_NEXT_QUERY_EXTRACTED_DATA_PATH, opt.dataset, "test.jsonl"
        )
    )
"""
code = code.replace(old, new)

p.write_text(code)
print("Patched efficientrag_retrieve.py")

Patched efficientrag_retrieve.py


# Efficient RAG pipeline

## Retrieving Step (top 10 docs)

In [48]:
! python src/efficientrag_retrieve.py \
    --dataset hotpotQA \
    --retriever contriever \
    --labels 2 \
    --labeler_ckpt trained_models/labeler \
    --filter_ckpt trained_models/filter \
    --topk 10

2025-11-22 23:27:50.481081: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-11-22 23:27:50.498584: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1763854070.520428   61569 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1763854070.527075   61569 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1763854070.543483   61569 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

In [53]:
import subprocess, time, requests, json, os, signal

# Kill any prior instance
_ = subprocess.run("pkill -f vllm.entrypoints.openai.api_server || true", shell=True)

cmd = [
  "python","-m","vllm.entrypoints.openai.api_server",
  "--model","meta-llama/Meta-Llama-3-8B-Instruct",
  "--dtype","auto","--host","0.0.0.0","--port","8000",
  "--api-key","token-colab-local-1234",
]

proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)

# Wait for readiness
BASE="http://127.0.0.1:8000/v1"
HEAD={"Authorization":"Bearer token-colab-local-1234","Content-Type":"application/json"}
for _ in range(400):
    try:
        r=requests.get(f"{BASE}/models",headers=HEAD,timeout=2)
        if r.status_code==200: break
    except Exception: pass
    time.sleep(1)

print("Server OK:", r.status_code==200)

Server OK: True


## End-to-end Question Answering

In [50]:
! python src/efficientrag_qa.py \
    --fpath results/retrieve/efficient_rag/hotpotQA-.jsonl \
    --model llama-8B \
    --dataset hotpotQA

[2KError processing sample 27: 'NoneType' object is not subscriptable
[2KTraceback (most recent call last):
[2K  File "/content/efficientrag-official/src/efficientrag_qa.py", line 47, in 
parse_samples_in_parallel
    res = future.result()
          ^^^^^^^^^^^^^^^
[2K  File "/usr/lib/python3.12/concurrent/futures/_base.py", line 449, in result
    return self.__get_result()
           ^^^^^^^^^^^^^^^^^^^
[2K  File "/usr/lib/python3.12/concurrent/futures/_base.py", line 401, in 
__get_result
    raise self._exception
[2K  File "/usr/lib/python3.12/concurrent/futures/thread.py", line 59, in run
    result = self.fn(*self.args, **self.kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[2K  File "/content/efficientrag-official/src/efficientrag_qa.py", line 84, in 
parse_sample
    "model_output": result["answer"],
                    ~~~~~~^^^^^^^^^^
[2KTypeError: 'NoneType' object is not subscriptable
[2KError processing sample 39: 'NoneType' object is not subscriptable
[2

### Evaluation of it

In [51]:
! python src/evaluation/retrieve.py \
    --fpath results/retrieve/efficient_rag/hotpotQA-_qa_results.jsonl

Average number of chunks: 7.287581699346405
Recall: 0.7974


In [16]:
from pathlib import Path

path = Path("src/evaluation/correctness.py")
text = path.read_text(encoding="utf-8")

old = '{"answer": "simplified answer"}'
new = '{{"answer": "simplified answer"}}'

if old in text:
    text = text.replace(old, new)
    path.write_text(text, encoding="utf-8")
    print("Patched correctness.py")
else:
    print("Pattern not found; file may already be patched.")

Patched correctness.py


In [54]:
! python src/evaluation/correctness.py \
    --fpath results/retrieve/efficient_rag/hotpotQA-_qa_results.jsonl \
    --model llama-8B \
    --workers 10 \
    --extract_answer

  for future in tqdm_rich(
[2KEvaluating[35m 100%[0m [90m━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m153/153 [0m [ [33m0:01:11[0m < [36m0:00:00[0m , [31m3 it/s[0m ]
[?25hEM: 0.5229
F1: 0.5773
Accuracy: 0.5098


In [None]:
proc.terminate()

In [None]:
import subprocess, time, requests, json, os, signal

# Kill any prior instance
_ = subprocess.run("pkill -f vllm.entrypoints.openai.api_server || true", shell=True)

cmd = [
  "python","-m","vllm.entrypoints.openai.api_server",
  "--model","meta-llama/Meta-Llama-3-8B-Instruct",
  "--dtype","auto","--host","0.0.0.0","--port","8000",
  "--api-key","token-colab-local-1234",
]

proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)

# Wait for readiness
BASE="http://127.0.0.1:8000/v1"
HEAD={"Authorization":"Bearer token-colab-local-1234","Content-Type":"application/json"}
for _ in range(400):
    try:
        r=requests.get(f"{BASE}/models",headers=HEAD,timeout=2)
        if r.status_code==200: break
    except Exception: pass
    time.sleep(1)

print("Server OK:", r.status_code==200)

## Direct

In [22]:
! python src/baseline/direct/direct_prompt.py \
  --dataset hotpotQA \
  --split test \
  --model llama-8B \
  --workers 4 \
  --mode direct

Loaded 165 data points from hotpotQA-test
  for future in tqdm_rich(as_completed(tasks), total=len(tasks)):
[2K[35m 100%[0m [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m165/165 [0m [ [33m0:01:20[0m < [36m0:00:00[0m , [31m2 it/s[0m ]
[?25hProcessed 165 samples in 80.85 seconds
Average time per sample: 0.48999424847689543


In [25]:
! python src/evaluation/correctness.py \
    --fpath results/direct/direct/hotpotQA-test.jsonl \
    --model llama-8B \
    --extract_answer \
    --workers 10

  for future in tqdm_rich(
[2KEvaluating[35m 100%[0m [90m━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m165/165 [0m [ [33m0:01:34[0m < [36m0:00:00[0m , [31m0 it/s[0m ]
[?25hEM: 0.2606
F1: 0.2920
Accuracy: 0.2303


## CoT

In [28]:
! python src/baseline/direct/direct_prompt.py \
  --dataset hotpotQA \
  --split test \
  --model llama-8B \
  --workers 4 \
  --mode cot

Loaded 165 data points from hotpotQA-test
  for future in tqdm_rich(as_completed(tasks), total=len(tasks)):
[2K[35m 100%[0m [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m165/165 [0m [ [33m0:02:35[0m < [36m0:00:00[0m , [31m2 it/s[0m ]
[?25hProcessed 165 samples in 155.75 seconds
Average time per sample: 0.9439401063052091


In [31]:
! python src/evaluation/correctness.py \
    --fpath results/direct/cot/hotpotQA-test.jsonl \
    --model llama-8B \
    --extract_answer \
    --workers 10

  for future in tqdm_rich(
[2KEvaluating[35m 100%[0m [90m━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m165/165 [0m [ [33m0:01:56[0m < [36m0:00:00[0m , [31m? it/s[0m ]
[?25hEM: 0.2788
F1: 0.3284
Accuracy: 0.2667


## Direct-R @10

In [32]:
import pathlib, os

p = pathlib.Path("src/baseline/retrieve/direct.py")
code = p.read_text()

# switch valid.jsonl (commented demo) -> test.jsonl
old = """dataset = load_jsonl(
        os.path.join(
            SYNTHESIZED_NEXT_QUERY_EXTRACTED_DATA_PATH, opt.dataset, "valid.jsonl"
        )
    )
"""
new = """dataset = load_jsonl(
        os.path.join(
            SYNTHESIZED_NEXT_QUERY_EXTRACTED_DATA_PATH, opt.dataset, "test.jsonl"
        )
    )
"""
code = code.replace(old, new)

p.write_text(code)
print("Patched direct.py")

Patched direct.py


In [41]:
! python src/baseline/retrieve/direct.py \
    --dataset hotpotQA \
    --retriever contriever \
    --topk 10 \
    --model llama-8B \
    --workers 1

Loading index from data/corpus/hotpotQA/contriever
Loading index from data/corpus/hotpotQA/contriever/index.faiss, meta data from data/corpus/hotpotQA/contriever/index_meta.faiss
Loading passages from data/corpus/hotpotQA/corpus.jsonl
Loaded 507493 passages.
[2K[35m   0%[0m [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0/165 [0m [ [33m0:00:00[0m < [36m-:--:--[0m , [31m? it/s[0m ]2025-11-22 23:13:38.741953: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-11-22 23:13:38.761489: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1763853218.785638   57655 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attemptin

In [42]:
! python src/evaluation/retrieve.py \
    --fpath results/retrieve/direct/hotpotQA-@10.jsonl

Average number of chunks: 10.0
Recall: 0.6986


In [45]:
! python src/evaluation/correctness.py \
    --fpath results/retrieve/direct/hotpotQA-@10.jsonl \
    --model llama-8B \
    --extract_answer \
    --workers 10

  for future in tqdm_rich(
[2KEvaluating[35m 100%[0m [90m━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m141/141 [0m [ [33m0:01:10[0m < [36m0:00:00[0m , [31m2 it/s[0m ]
[?25hEM: 0.3688
F1: 0.4272
Accuracy: 0.3688


## SelfAsk

In [None]:
import pathlib, textwrap

path = pathlib.Path("src/retrievers/embeddings/dense_embedding.py")

new_code = textwrap.dedent("""
from typing import Literal, Union, List

import torch
from transformers import AutoModel, AutoTokenizer

from .base import BaseEmbedding

Pooling = Union[str, Literal["average", "cls"]]


class DenseEmbedding(BaseEmbedding):
    def __init__(
        self,
        model_name_or_path: str,
        embedding_vector_size: int,
        no_fp16: bool = False,
        pooling_type: Pooling = "average",
    ):
        super().__init__()
        self.model_name_or_path = model_name_or_path
        self.embedding_vector_size = embedding_vector_size

        # lazy-loaded
        self.model = None
        self.tokenizer = None

        self.fp16 = not no_fp16
        self.pooling_type = pooling_type

        # fixed device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def instantiate(self):
        \"\"\"Load encoder + tokenizer on a real device (no meta, no device_map).\"\"\"
        if self.model is not None and self.tokenizer is not None:
            return

        # standard tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.model_name_or_path,
            use_fast=True,
        )

        # choose dtype
        if self.fp16 and self.device.type == "cuda":
            dtype = torch.float16
        else:
            dtype = torch.float32

        # IMPORTANT: do NOT use device_map, do NOT use torch_dtype="auto"
        self.model = AutoModel.from_pretrained(
            self.model_name_or_path,
            torch_dtype=dtype,
        )

        self.model.to(self.device)
        self.model.eval()

    @torch.no_grad()
    def embed_batch(self, queries: List[str]):
        \"\"\"Compute embeddings for a batch of queries.\"\"\"
        # ensure model/tokenizer loaded
        if self.model is None or self.tokenizer is None:
            self.instantiate()

        # tokenize on CPU then move whole batch to the same device as the model
        inputs = self.tokenizer(
            queries,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=512,
        )
        # BatchEncoding has .to(), but to be safe:
        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        outputs = self.model(**inputs)
        embeddings = self.pooling(outputs.last_hidden_state, inputs["attention_mask"])

        # always return numpy on CPU
        return embeddings.cpu().numpy()

    def pooling(self, last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        \"\"\"Average or CLS pooling.\"\"\"
        if self.pooling_type == "average":
            # mask out padding tokens
            mask = attention_mask[..., None].bool()
            last_hidden = last_hidden_states.masked_fill(~mask, 0.0)
            # sum over sequence, divide by number of non-pad tokens
            denom = attention_mask.sum(dim=1)[..., None].clamp(min=1)
            return last_hidden.sum(dim=1) / denom
        elif self.pooling_type == "cls":
            return last_hidden_states[:, 0, :]
        else:
            raise NotImplementedError(f"Unknown pooling type: {self.pooling_type}")
""")

path.write_text(new_code)
print("Rewrote dense_embedding.py")


In [None]:
import pathlib, textwrap

path = pathlib.Path("src/baseline/retrieve/selfask.py")

new_code = textwrap.dedent('''
import argparse
import os
import sys
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Union

from tqdm.rich import tqdm_rich
import json
import threading

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))

from direct import (
    DIRECT_RETRIEVE_ANSWER_PROMPT_HOTPOTQA,
    DIRECT_RETRIEVE_ANSWER_PROMPT_MUSIQUE,
    DIRECT_RETRIEVE_ANSWER_PROMPT_WIKIMQA,
)

from conf import (
    CORPUS_DATA_PATH,
    MODEL_DICT,
    SYNTHESIZED_NEXT_QUERY_EXTRACTED_DATA_PATH,
)
from language_models import LanguageModel, get_model
from retrievers import Retriever
from utils import ask_model, load_jsonl, write_jsonl

SELF_ASK_PROMPT = """
Solve the question with the given knowledge.
Each line should start with either "Intermediate answer:", "Follow up:", "The final answer is:", or "Are follow up questions needed here:".

Question: Who lived longer, Muhammad Ali or Alan Turing?
Are follow up questions needed here: Yes.
Follow up: How old was Muhammad Ali when he died?
Intermediate answer: Muhammad Ali was 74 years old when he died.
Follow up: How old was Alan Turing when he died?
Intermediate answer: Alan Turing was 41 years old when he died.
The final answer is: Muhammad Ali

Question: When was the founder of craigslist born?
Are follow up questions needed here: Yes.
Follow up: Who was the founder of craigslist?
Intermediate answer: Craigslist was founded by Craig Newmark.
Follow up: When was Craig Newmark born?
Intermediate answer: Craig Newmark was born on December 6, 1952.
The final answer is: December 6, 1952

Question: Who was the maternal grandfather of George Washington?
Are follow up questions needed here: Yes.
Follow up: Who was the mother of George Washington?
Intermediate answer: The mother of George Washington was Mary Ball Washington.
Follow up: Who was the father of Mary Ball Washington?
Intermediate answer: The father of Mary Ball Washington was Joseph Ball.
The final answer is: Joseph Ball

Question: Are both the directors of Jaws and Casino Royale from the same country?
Are follow up questions needed here: Yes.
Follow up: Who is the director of Jaws?
Intermediate answer: The director of Jaws is Steven Spielberg.
Follow up: Where is Steven Spielberg from?
Intermediate answer: The United States.
Follow up: Who is the director of Casino Royale?
Intermediate answer: The director of Casino Royale is Martin Campbell.
Follow up: Where is Martin Campbell from?
Intermediate answer: New Zealand.
The final answer is: No

Question: {question}
Are follow up questions needed here:
""".strip()

SELF_ASK_PROMPT_MUSIEUQ = """
Solve the question with the given knowledge.
Each line should start with either "Intermediate answer:", "Follow up:", "So the final answer is:", or "Are follow up questions needed here:".
#
Question: In which year did the publisher of In Cold Blood form?
Are follow up questions needed here: Yes.
Follow up: What business published In Cold Blood?
Intermediate answer: In Cold Blood was published in book form by Random House.
Follow up: Which year witnessed the formation of Random House?
Intermediate answer: Random House was form in 2001.
So the final answer is: 2001
#
Question: Who was in charge of the city where The Killing of a Sacred Deer was filmed?
Are follow up questions needed here: Yes.
Follow up: In which city was The Killing of a Sacred Deer filmed
Intermediate answer: The Killing of a Sacred Deer was filmed in Cincinnati.
Follow up: Who was in charge of Cincinnati?
Intermediate answer: The present Mayor of Cincinnati is John Cranley, so John Cranley is in charge.
So the final answer is: John Cranley
#
Question: Where on the Avalon Peninsula is the city that Signal Hill overlooks?
Are follow up questions needed here: Yes.
Follow up: What city does Signal Hill overlook?
Intermediate answer: Signal Hill is a hill which overlooks the city of St. John's.
Follow up: Where on the Avalon Peninsula is St. John's located?
Intermediate answer: St. John's is located on the eastern tip of the Avalon Peninsula.
So the final answer is: eastern tip
#
Question: {question}
Are follow up questions needed here:
""".strip()

SELF_ASK_PROMPT_WIKIMQA = """
Solve the question with the given knowledge.
Each line should start with either "Intermediate answer:", "Follow up:", "So the final answer is:", or "Are follow up questions needed here:".
Follow the examples below to answer the questions with natural language.
#
Question: Which film came out first, Blind Shaft or The Mask Of Fu Manchu?
Are follow up questions needed here: Yes.
Follow up: When did Blind Shaft come out?
Intermediate answer: Blind Shaft came out in 2003.
Follow up: When did The Mask Of Fu Manchu come out?
Intermediate answer: The Mask Of Fu Manchu came out in 1932.
So the final answer is: The Mask Of Fu Manchu
#
Question: When did John V, Prince Of Anhalt-Zerbst's father die?
Are follow up questions needed here: Yes.
Follow up: Who is the father of John V, Prince Of Anhalt-Zerbst?
Intermediate answer: The father of John V, Prince Of Anhalt-Zerbst is Ernest I, Prince of Anhalt-Dessau.
Follow up: When did Ernest I, Prince of Anhalt-Dessau die?
Intermediate answer: Ernest I, Prince of Anhalt-Dessau died on 12 June 1516.
So the final answer is: 12 June 1516
#
Question: Which film has the director who was born later, El Extrano Viaje or Love In Pawn?
Are follow up questions needed here: Yes.
Follow up: Who is the director of El Extrano Viaje?
Intermediate answer: The director of El Extrano Viaje is Fernando Fernan Gomez.
Follow up: Who is the director of Love in Pawn?
Intermediate answer: The director of Love in Pawn is Charles Saunders.
Follow up: When was Fernando Fernan Gomez born?
Intermediate answer: Fernando Fernan Gomez was born on 28 August 1921.
Follow up: When was Charles Saunders (director) born?
Intermediate answer: Charles Saunders was born on 8 April 1904.
So the final answer is: El Extrano Viaje
#
Question: {question}
Are follow up questions needed here:
""".strip()

SELF_ASK_PROMPT_HOTPOTQA = """
Solve the question with the given knowledge.
Each line should start with either "Intermediate answer:", "Follow up:", "So the final answer is:", or "Are follow up questions needed here:".
#
Question: What is the name of this American musician, singer, actor, comedian, and songwriter, who worked with Modern Records and born in December 5, 1932?
Are follow up questions needed here: Yes.
Follow up: Who worked with Modern Records?
Intermediate answer: Artists worked with Modern Records include Etta James, Little Richard, Joe Houston, Ike and Tina Turner and John Lee Hooker.
Follow up: Is Etta James an American musician, singer, actor, comedian, and songwriter, and was born in December 5, 1932?
Intermediate answer: Etta James was born in January 25, 1938, not December 5, 1932, so the answer is no.
Follow up: Is Little Richard an American musician, singer, actor, comedian, and songwriter, and was born in December 5, 1932?
Intermediate answer: Yes, Little Richard, born in December 5, 1932, is an American musician, singer, actor, comedian and songwriter.
So the final answer is: Little Richard
#
Question: Between Chinua Achebe and Rachel Carson, who had more diverse jobs?
Are follow up questions needed here: Yes.
Follow up: What jobs did Chinua Achebe have?
Intermediate answer: Chinua Achebe was a Nigerian (1) novelist, (2) poet, (3) professor, and (4) critic, so Chinua Achebe had 4 jobs.
Follow up: What jobs did Rachel Carson have?
Intermediate answer: Rachel Carson was an American (1) marine biologist, (2) author, and (3) conservationist, so Rachel Carson had 3 jobs.
Follow up: Did Chinua Achebe have more jobs than Rachel Carson?
Intermediate answer: Chinua Achebe had 4 jobs, while Rachel Carson had 3 jobs. 4 is greater than 3, so yes, Chinua Achebe had more jobs.
So the final answer is: Chinua Achebe
#
Question: Remember Me Ballin' is a CD single by Indo G that features an American rapper born in what year?
Are follow up questions needed here: Yes.
Follow up: Which American rapper is featured by Remember Me Ballin', a CD single by Indo G?
Intermediate answer: Gangsta Boo
Follow up: In which year was Gangsta Boo born?
Intermediate answer: Gangsta Boo was born in August 7, 1979, so the answer is 1979.
So the final answer is: 1979
#
Question: {question}
Are follow up questions needed here:
""".strip()

GET_ANSWER_PROMPT_TEMPLATE_MAPPING = {
    "hotpotQA": DIRECT_RETRIEVE_ANSWER_PROMPT_HOTPOTQA,
    "2WikiMQA": DIRECT_RETRIEVE_ANSWER_PROMPT_WIKIMQA,
    "musique": DIRECT_RETRIEVE_ANSWER_PROMPT_MUSIQUE,
}

SELF_ASK_PROMPT_TEMPLATE_MAPPING = {
    "hotpotQA": SELF_ASK_PROMPT_HOTPOTQA,
    "2WikiMQA": SELF_ASK_PROMPT_WIKIMQA,
    "musique": SELF_ASK_PROMPT_MUSIEUQ,
}


def extract_question(generated):
    if "\n" not in generated:
        last_line = generated
    else:
        last_line = generated.split("\n")[-1]

    if "Follow up:" not in last_line:
        print("Follow up not in last line: \n" + generated)

    if ":" not in last_line:
        after_colon = last_line
    else:
        after_colon = generated.split(":")[-1]

    if after_colon == "":
        return ""
    if " " == after_colon[0]:
        after_colon = after_colon[1:]
    if "?" != after_colon[-1]:
        print("Question not end with ?: " + generated)

    return after_colon


def extract_answer(generated):
    if "\n" not in generated:
        last_line = generated
    else:
        last_line = generated.split("\n")[-1]

    if ":" not in last_line:
        after_colon = last_line
    else:
        after_colon = generated.split(":")[-1]

    if after_colon == "":
        return ""
    if " " == after_colon[0]:
        after_colon = after_colon[1:]
    if "." == after_colon[-1]:
        after_colon = after_colon[:-1]

    return after_colon


def get_last_line(generated):
    if "\n" not in generated:
        last_line = generated
    else:
        last_line = generated.split("\n")[-1]
    return last_line


class SelfAsk:
    def __init__(
        self,
        model: str,
        dataset: list[dict],
        retriever: Retriever,
        max_iter: int = 3,
        topk: int = 10,
        dataset_name: str = None,
    ):
        self.model: LanguageModel
        self.model = get_model(model)
        self.dataset = dataset
        self.retriever = retriever
        self.max_iter = max_iter
        self.topk = topk
        self.prompt_template = SELF_ASK_PROMPT_TEMPLATE_MAPPING[dataset_name]
        self.get_answer_prompt_template = GET_ANSWER_PROMPT_TEMPLATE_MAPPING[
            dataset_name
        ]

        self.intermediate = "\nIntermediate answer:"
        self.follow_up = "Follow up:"
        self.final_ans = "So the final answer is:"
        self.check_following_question = "\nAre follow up questions needed here:"
        self.max_iter = 5

    def inference(self, workers: int = 10, save_path=None) -> list[str]:
        """
        If save_path is provided, each result is appended to save_path as a JSON line
        as soon as it is processed.
        """
        lock = threading.Lock()

        if save_path is not None:
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
            # Start with a fresh file for this run
            open(save_path, "w", encoding="utf-8").close()

        with ThreadPoolExecutor(max_workers=workers) as executor:
            tasks = {
                executor.submit(self.infer_sample, sample): idx
                for idx, sample in enumerate(self.dataset)
            }
            results = []
            for future in tqdm_rich(as_completed(tasks), total=len(tasks)):
                idx = tasks[future]
                try:
                    res = future.result()
                    results.append((idx, res))

                    # write this sample immediately to file if requested
                    if save_path is not None:
                        with lock:
                            with open(save_path, "a", encoding="utf-8") as f:
                                f.write(json.dumps(res, ensure_ascii=False) + "\n")
                except Exception as e:
                    print(f"Error in processing {idx}: {e}")
                    import traceback

                    traceback.print_exc()
            results = [res[1] for res in sorted(results, key=lambda x: x[0])]
        return results

    def get_answer(self, question):
        knowledge_list = self.retriever.search(question, self.topk)[0]
        knowledge = "\n".join([f"{doc['text']}" for doc in knowledge_list])
        prompt = self.get_answer_prompt_template.format(
            knowledge=knowledge, question=question
        )
        model_response = ask_model(
            self.model,
            prompt,
            type="json",
            check_if_valid=lambda x: type(x) is dict and "answer" in x.keys(),
            mode="chat",
        )
        if model_response is None:
            return "unknown", knowledge_list
        return model_response["answer"], knowledge_list

    def call_model(self, current_prompt, stop: Union[str, list[str]]):
        def check_if_valid(s: str):
            if type(stop) is str:
                return stop in s
            elif type(stop) is list:
                return any([x in s for x in stop])

        response = ask_model(
            self.model,
            current_prompt,
            type="text",
            mode="completion",
            check_if_valid=check_if_valid,
        )
        if response is None:
            return ""

        response = response.strip()
        if type(stop) is str:
            return response.split(stop)[0]
        elif type(stop) is list:
            idx_list = [response.find(x) for x in stop]
            idx_list = [x if x != -1 else float("inf") for x in idx_list]
            min_idx = min(idx_list)
            for idx, stop_word in zip(idx_list, stop):
                if idx == min_idx:
                    return response.split(stop_word)[0].strip()
        return ""

    def infer_sample(self, sample: dict) -> dict:
        question = sample["question"]
        results = {
            "id": sample["id"],
            "answer": sample["answer"],
            "oracle": [
                f"{sample['id']}-{'{:02d}'.format(chunk['positive_paragraph_idx'])}"
                for chunk in sample["decomposed_questions"].values()
            ],
            "question": question,
            "knowledges": list(),
            "internal_questions": [],
        }

        cur_prompt = self.prompt_template.format(question=question)
        cur_iter = 0
        ret_text = self.call_model(cur_prompt, [self.intermediate, self.final_ans])
        while self.follow_up in get_last_line(ret_text) and cur_iter < self.max_iter:
            cur_iter += 1
            cur_prompt += ret_text
            question = extract_question(ret_text)
            results["internal_questions"].append(question)
            external_answer, knowledge_list = self.get_answer(question)
            results["knowledges"].extend(knowledge_list)

            if external_answer is not None:
                cur_prompt += f"{self.intermediate} {external_answer}."
                ret_text = self.call_model(
                    cur_prompt, [self.intermediate, self.final_ans]
                )
            else:
                # very rare when return no answer
                cur_prompt += self.intermediate
                answer = self.call_model(
                    cur_prompt, ["\n" + self.follow_up, self.final_ans]
                )
                cur_prompt += answer

        if self.final_ans not in ret_text:
            cur_prompt += f"{self.final_ans}"
            ret_text = self.call_model(cur_prompt, "\n")

        final_prompt = cur_prompt + ret_text
        final_answer = extract_answer(final_prompt)

        results["model_answer"] = final_answer
        results["history"] = cur_prompt[len(self.prompt_template) :]
        return results


def main(opt: argparse.Namespace):
    passage_path = os.path.join(CORPUS_DATA_PATH, opt.dataset, "corpus.jsonl")
    if opt.retriever == "e5-base-v2":
        embedding_path = os.path.join(CORPUS_DATA_PATH, opt.dataset, "e5-base")
    elif opt.retriever == "contriever":
        embedding_path = os.path.join(CORPUS_DATA_PATH, opt.dataset, "contriever")
    retriever = Retriever(
        passage_path=passage_path,
        passage_embedding_path=embedding_path,
        index_path_dir=embedding_path,
        model_type=opt.retriever,
    )
    dataset = load_jsonl(
        os.path.join(
            SYNTHESIZED_NEXT_QUERY_EXTRACTED_DATA_PATH,
            opt.dataset,
            f"{opt.split}.jsonl",
        )
    )
    model = MODEL_DICT[opt.model]
    selfask = SelfAsk(
        model=model,
        dataset=dataset,
        retriever=retriever,
        topk=opt.topk,
        dataset_name=opt.dataset,
    )
    import time

    start = time.time()
    save_path = os.path.join(
        "results/retrieve/selfask", f"{opt.dataset}-{opt.split}.jsonl"
    )
    # This will write each processed sample immediately to save_path
    results = selfask.inference(opt.workers, save_path=save_path)
    end = time.time()
    print(f"Total time: {end - start:.2f}s")
    print(f"Average time per sample: {(end - start) / len(dataset):.2f}s")

    # Optionally, rewrite a clean sorted file at the end
    write_jsonl(results, save_path)


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, default="musique")
    parser.add_argument("--split", type=str, default="valid")
    parser.add_argument("--retriever", type=str, default="contriever")
    parser.add_argument("--model", type=str, default="llama-8B")
    parser.add_argument("--topk", type=int, default=10)
    parser.add_argument("--workers", type=int, default=10)
    return parser.parse_args()


if __name__ == "__main__":
    opt = parse_args()
    main(opt)


''')
path.write_text(new_code)
print("Rewrote self_ask.py")

In [None]:
! python src/baseline/retrieve/selfask.py \
    --dataset hotpotQA \
    --split test \
    --retriever contriever \
    --model llama-8B \
    --topk 10 \
    --workers 1

In [None]:
from utils import load_jsonl, write_jsonl

selfask_path = "results/retrieve/selfask/hotpotQA-test.jsonl"
data = load_jsonl(selfask_path)

converted = []
for sample in data:
    oracle_ids = sample["oracle"]  # list of oracle IDs
    # flatten all retrieved doc ids (may contain duplicates)
    chunk_ids = [doc["id"] for doc in sample["knowledges"]]

    converted.append({
        "question": sample["question"],
        "answer": sample["answer"],
        "oracle_ids": oracle_ids,
        "chunk_ids": chunk_ids,
    })

out_path = "results/retrieve/selfask/hotpotQA-test_for_eval.jsonl"
write_jsonl(converted, out_path)
print("Saved to:", out_path)

Saved to: results/retrieve/selfask/hotpotQA-test_for_eval.jsonl


In [None]:
! python src/evaluation/retrieve.py \
    --fpath results/retrieve/selfask/hotpotQA-test_for_eval.jsonl


In [None]:
from utils import load_jsonl, write_jsonl

selfask_path = "results/retrieve/selfask/hotpotQA-test.jsonl"
data = load_jsonl(selfask_path)

qa_style = []
for sample in data:
    qa_style.append({
        "question": sample["question"],
        "answer": sample["answer"],
        "model_output": sample["model_answer"],
        "oracle_ids": sample["oracle"],
        # optional for analysis:
        "chunk_ids": [doc["id"] for doc in sample["knowledges"]],
    })

out_path = "results/retrieve/selfask/hotpotQA-test_qa_results.jsonl"
write_jsonl(qa_style, out_path)
print("Saved to:", out_path)


Saved to: results/retrieve/selfask/hotpotQA-test_qa_results.jsonl


In [None]:
! python src/evaluation/correctness.py \
    --fpath results/retrieve/selfask/hotpotQA-test_qa_results.jsonl \
    --model llama-8B \
    --extract_answer \
    --workers 10