# EfficientRAG: Data Synthesis & Training Pipeline with Local Llama-3

This notebook demonstrates the end-to-end pipeline for **EfficientRAG**. It covers setting up the environment, adapting the official codebase to run on Google Colab using a local **vLLM server** (Meta-Llama-3-8B-Instruct), creating a tiny subset of HotpotQA for demonstration purposes, and running the full synthetic data generation pipeline.

**Key steps covered:**
1.  **Environment Setup:** Installing dependencies (`vllm`, `faiss`, `transformers`).
2.  **Codebase Adaptation:** Patching the official repository to support local Llama-3 inference and fix file path issues.
3.  **Corpus Indexing:** Embedding passages using Contriever.
4.  **Data Synthesis Pipeline:**
    * Query Decomposition
    * Token Labeling & Extraction
    * Next-Hop Query Construction & Filtering
    * Negative Sampling
5.  **Final Compilation:** assembling the training data for the Filter and Labeler models.

In [None]:
! git clone https://github.com/nil-zhuang/efficientrag-official.git

Cloning into 'efficientrag-official'...
remote: Enumerating objects: 122, done.[K
remote: Counting objects: 100% (122/122), done.[K
remote: Compressing objects: 100% (111/111), done.[K
remote: Total 122 (delta 13), reused 108 (delta 7), pack-reused 0 (from 0)[K
Receiving objects: 100% (122/122), 1.81 MiB | 7.98 MiB/s, done.
Resolving deltas: 100% (13/13), done.


## 1. Environment Setup & Dependencies

We first clone the official repository and install the necessary Python libraries. We specifically install `vllm` to run a high-performance local inference server for Llama-3, avoiding reliance on paid OpenAI APIs for this demo.

In [None]:
%cd efficientrag-official
! pip install -r requirements.txt

/content/efficientrag-official
Collecting faiss-cpu (from -r requirements.txt (line 1))
  Downloading faiss_cpu-1.12.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (5.1 kB)
Collecting msal (from -r requirements.txt (line 2))
  Downloading msal-1.34.0-py3-none-any.whl.metadata (11 kB)
Collecting accelerate==0.29.1 (from -r requirements.txt (line 14))
  Downloading accelerate-0.29.1-py3-none-any.whl.metadata (18 kB)
Collecting deepspeed>=0.14.1 (from -r requirements.txt (line 16))
  Downloading deepspeed-0.18.2.tar.gz (1.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m70.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting vllm (from -r requirements.txt (line 17))
  Downloading vllm-0.11.0-cp38-abi3-manylinux1_x86_64.whl.metadata (17 kB)
Collecting black (from -r requirements.txt (line 19))
  Downloading black-25.11.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.ma

In [None]:
! pip install -U "accelerate>=1.1.0" "peft>=0.16.0" "transformers>=4.46.0"

Collecting accelerate>=1.1.0
  Downloading accelerate-1.11.0-py3-none-any.whl.metadata (19 kB)
Collecting peft>=0.16.0
  Downloading peft-0.18.0-py3-none-any.whl.metadata (14 kB)
Downloading accelerate-1.11.0-py3-none-any.whl (375 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m375.8/375.8 kB[0m [31m26.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading peft-0.18.0-py3-none-any.whl (556 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m556.4/556.4 kB[0m [31m45.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: accelerate, peft
  Attempting uninstall: accelerate
    Found existing installation: accelerate 0.29.1
    Uninstalling accelerate-0.29.1:
      Successfully uninstalled accelerate-0.29.1
  Attempting uninstall: peft
    Found existing installation: peft 0.17.1
    Uninstalling peft-0.17.1:
      Successfully uninstalled peft-0.17.1
Successfully installed accelerate-1.11.0 peft-0.18.0


## Wait for the libraries to be installed. Then, you **MUST** restart the session.

In [None]:
%cd efficientrag-official

/content/efficientrag-official


In [None]:
import os, json, random
from pathlib import Path

# Project root (adjust if your %cd differs)
ROOT = Path.cwd()
DATA = ROOT / "data"
DATASET_DIR = DATA / "dataset" / "hotpotQA"
MODEL_CACHE = ROOT / "model_cache"

for p in [DATASET_DIR, MODEL_CACHE]:
    p.mkdir(parents=True, exist_ok=True)

# Make HF/Transformers use our local cache folder
os.environ["HF_HOME"] = str(MODEL_CACHE)
os.environ["TRANSFORMERS_CACHE"] = str(MODEL_CACHE)
os.environ["HF_DATASETS_CACHE"] = str(MODEL_CACHE)

## 2. Hugging Face Authentication

Access to **Meta-Llama-3-8B-Instruct** is gated. You must provide a Hugging Face token with access permissions to this model.

In [None]:
import os
token = ... # put your own Hugging Face access token

if token:
    # Non-interactive login (recommended: set the env var in Colab “Secrets” or in a cell)
    from huggingface_hub import login
    login(token=token, add_to_git_credential=False)
else:
    # Fallback: interactive login widget
    from huggingface_hub import notebook_login
    notebook_login()


## 3. Creating a "Tiny" Dataset

Processing the entire HotpotQA dataset takes hours or days. For this demonstration, we create a deterministic **"Tiny" subset** (35 train samples, 10 dev, 5 test). This allows us to run the entire pipeline in minutes to verify functionality.

In [None]:
# Tiny subset: 50 samples total -> 35 train, 10 dev, 5 test (as .json arrays, NOT jsonl)
import json
from pathlib import Path
from datasets import load_dataset

# Paths
DATASET_DIR = Path("data/dataset/hotpotQA")
DATASET_DIR.mkdir(parents=True, exist_ok=True)
train_path = DATASET_DIR / "train.json"
dev_path   = DATASET_DIR / "dev.json"
test_path  = DATASET_DIR / "test.json"

# Load HF dataset
ds = load_dataset("hotpotqa/hotpot_qa", "distractor")

# Deterministic tiny subsets
SEED = 42
train_small = ds["train"].shuffle(seed=SEED).select(range(350))
val_small   = ds["validation"].shuffle(seed=SEED)
dev_small   = val_small.select(range(100))
test_small  = val_small.select(range(100, 150))

# Save EXACT records (preserve all original fields), as JSON arrays
with train_path.open("w", encoding="utf-8") as f:
    json.dump([ex for ex in train_small], f, ensure_ascii=False)

with dev_path.open("w", encoding="utf-8") as f:
    json.dump([ex for ex in dev_small], f, ensure_ascii=False)

with test_path.open("w", encoding="utf-8") as f:
    json.dump([ex for ex in test_small], f, ensure_ascii=False)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

distractor/train-00000-of-00002.parquet:   0%|          | 0.00/166M [00:00<?, ?B/s]

distractor/train-00001-of-00002.parquet:   0%|          | 0.00/166M [00:00<?, ?B/s]

distractor/validation-00000-of-00001.par(…):   0%|          | 0.00/27.5M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/90447 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/7405 [00:00<?, ? examples/s]

## 4. Model Caching & Downloads

We pre-download the **Contriever** (for retrieval) and **DeBERTa** (for the Labeler/Filter backbone) models to a specific cache directory. This ensures smooth execution when running the scripts later.

In [None]:
import torch
from transformers import AutoTokenizer, AutoModel

# Contriever retriever (MS MARCO variant)
ctr_name = "facebook/contriever-msmarco"
ctr_tok = AutoTokenizer.from_pretrained(ctr_name, cache_dir=str(MODEL_CACHE))
ctr = AutoModel.from_pretrained(ctr_name, cache_dir=str(MODEL_CACHE))

# DeBERTa-v3-large encoder (used by EfficientRAG Labeler/Filter)
deb_name = "microsoft/deberta-v3-large"
deb = AutoModel.from_pretrained(deb_name, torch_dtype="auto", cache_dir=str(MODEL_CACHE))

print("Models cached under:", MODEL_CACHE)



tokenizer_config.json:   0%|          | 0.00/321 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/619 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/580 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/874M [00:00<?, ?B/s]

Models cached under: /content/efficientrag-official/model_cache


In [None]:
from huggingface_hub import snapshot_download

# 2) Make the folders expected by the repo
import os, pathlib
pathlib.Path("model_cache/contriever-msmarco").mkdir(parents=True, exist_ok=True)
pathlib.Path("model_cache/deberta-v3-large").mkdir(parents=True, exist_ok=True)

# 3) Download model snapshots there (no symlinks to avoid path surprises)
snapshot_download(
    repo_id="facebook/contriever-msmarco",
    local_dir="model_cache/contriever-msmarco",
    local_dir_use_symlinks=False
)
snapshot_download(
    repo_id="microsoft/deberta-v3-large",
    local_dir="model_cache/deberta-v3-large",
    local_dir_use_symlinks=False
)

# 4) (Optional but reduces warnings) Prefer HF_HOME over TRANSFORMERS_CACHE
import os
os.environ["HF_HOME"] = os.path.abspath("model_cache")
os.environ["TRANSFORMERS_CACHE"] = os.environ["HF_HOME"]  # still set for older code

For more details, check out https://huggingface.co/docs/huggingface_hub/main/en/guides/download#download-files-to-local-folder.


Fetching 8 files:   0%|          | 0/8 [00:00<?, ?it/s]

tokenizer_config.json:   0%|          | 0.00/321 [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

.gitattributes: 0.00B [00:00, ?B/s]

README.md: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/619 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/874M [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

Fetching 9 files:   0%|          | 0/9 [00:00<?, ?it/s]

.gitattributes: 0.00B [00:00, ?B/s]

generator_config.json:   0%|          | 0.00/560 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/580 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

tf_model.h5:   0%|          | 0.00/1.74G [00:00<?, ?B/s]

spm.model:   0%|          | 0.00/2.46M [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/874M [00:00<?, ?B/s]

pytorch_model.generator.bin:   0%|          | 0.00/571M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

## 5. Codebase Adaptation (Patching)

The official EfficientRAG repository was designed for specific directory structures and OpenAI API calls. The following cells apply **hot-fixes** to the source code to:
1.  Fix `FileNotFoundErrors` by ensuring output directories exist automatically.
2.  Update the `config.py` to recognize our local Llama model.
3.  Modify prompt templates to work better with Llama-3.
4.  Fix deprecated arguments in the Transformers `TrainingArguments`.

In [None]:
# Apply a tiny patch in-place
import re, pathlib

p = pathlib.Path("src/retrievers/multihop_data_extrator.py")
code = p.read_text()
code = code.replace(
    'with open(output_dir, "w+") as f:',
    'os.makedirs(os.path.dirname(output_dir), exist_ok=True)\n    with open(output_dir, "w+") as f:'
)
p.write_text(code)
print("Patched multihop_data_extrator.py to auto-create the output directory.")

Patched multihop_data_extrator.py to auto-create the output directory.


In [None]:
import pathlib, re
p = pathlib.Path("src/data_synthesize/token_labeling.py")
s = p.read_text(encoding="utf-8")
if '"hotpotQA":' not in s:
    pat = re.compile(r'(TOKEN_LABEL_PROMPT_TEMPLATE_MAPPING\s*=\s*\{)(.*?)(\})', re.S)
    s = pat.sub(lambda m: f'{m.group(1)}{m.group(2).rstrip()}{"," if m.group(2).strip() and not m.group(2).rstrip().endswith(",") else ""}\n    "hotpotQA": TOKEN_LABEL_SYNTHESIZE_FEW_SHOT_PROMPT_WIKIMQA\n{m.group(3)}', s, count=1)
    p.write_text(s, encoding="utf-8")


s = p.read_text(encoding="utf-8")
if "os.makedirs(" not in s:
    if "import os" not in s:
        s = "import os\n" + s
    pat = re.compile(r'(?m)^(?P<i>\s*)with\s+open\(\s*os\.path\.join\([^)]*\)\s*,\s*["\']w\+?["\'][^)]*\)\s*as\s*f\s*:\s*')
    s = pat.sub(lambda m: f'{m.group("i")}os.makedirs(os.path.join(SYNTHESIZED_TOKEN_LABELING_DATA_PATH, opt.dataset), exist_ok=True)\n{m.group(0)}', s, count=1)
    p.write_text(s, encoding="utf-8")

In [None]:
import pathlib, re
p = pathlib.Path("src/efficient_rag/filter_training.py")
s = p.read_text(encoding="utf-8")
s = re.sub(r'\bevaluation_strategy\s*=\s*["\']steps["\']', 'eval_strategy="steps"', s)
p.write_text(s, encoding="utf-8")

4174

In [None]:
import pathlib, re
p = pathlib.Path("src/conf/config.py")
s = p.read_text(encoding="utf-8")

s = re.sub(
    r'MODEL_DICT\s*=\s*\{[^}]+\}',
    '''MODEL_DICT = {
    "gpt35": "gpt-35-turbo-1106",
    "gpt4": "gpt-4-0125-preview",
    "llama": "llama-3-70b-gptq-int4",
    "llama-8B": "meta-llama/Meta-Llama-3-8B-Instruct",
    "deepseek": "deepseek-chat",
}''',
    s,
    flags=re.S,
)
p.write_text(s, encoding="utf-8")

1988

In [None]:
import pathlib, re
p = pathlib.Path("src/data_synthesize/next_hop_query_construction.py")
s = p.read_text(encoding="utf-8")

if "os.makedirs(" not in s:
    if "import os" not in s:
        s = "import os\n" + s
    pat = re.compile(r'(?m)^(?P<i>\s*)with\s+open\(output_path[^:]+:')
    s = pat.sub(lambda m: f'{m.group("i")}os.makedirs(output_path.replace(f"{{opt.split}}.jsonl",""), exist_ok=True)\n{m.group(0)}', s, count=1)
    p.write_text(s, encoding="utf-8")

In [None]:
import pathlib, re
p = pathlib.Path("src/data_synthesize/negative_sampling_labeled.py")
s = p.read_text(encoding="utf-8")

if "os.makedirs(" not in s:
    if "import os" not in s:
        s = "import os\n" + s
    pat = re.compile(r'(?m)^(?P<i>\s*)with\s+open\(\s*os\.path\.join\([^)]*\)\s*,\s*["\']w["\'][^)]*\)\s*as\s*f\s*:\s*')
    s = pat.sub(lambda m: f'{m.group("i")}os.makedirs(os.path.join(SYNTHESIZED_NEGATIVE_SAMPLING_LABELED_DATA_PATH, opt.dataset), exist_ok=True)\n{m.group(0)}', s, count=1)
    p.write_text(s, encoding="utf-8")

In [None]:
import pathlib, re
p = pathlib.Path("src/data_synthesize/negative_sampling.py")
s = p.read_text(encoding="utf-8")

if "os.makedirs(" not in s:
    if "import os" not in s:
        s = "import os\n" + s
    pat = re.compile(r'(?m)^(?P<i>\s*)with\s+open\(output_data_path[^:]+:')
    s = pat.sub(lambda m: f'{m.group("i")}os.makedirs(output_data_path.replace(f"{{opts.split}}.jsonl",""), exist_ok=True)\n{m.group(0)}', s, count=1)
    p.write_text(s, encoding="utf-8")


In [None]:
import pathlib, re
p = pathlib.Path("src/data_synthesize/next_hop_query_filtering.py")
s = p.read_text(encoding="utf-8")
s = re.sub(
    r'for k, v in infos\.items\(\):\s*\n\s*v\s*=\s*v\s*/\s*num_samples\s*\*\s*100\s*\n\s*print\(f"{k}: {v:.2f}"\)',
    'for k, v in infos.items():\n        v = v / num_samples * 100 if num_samples != 0 else 0\n        print(f"{k}: {v:.2f}")',
    s,
    flags=re.S,
)
p.write_text(s, encoding="utf-8")

7875

In [None]:
# Apply a tiny patch in-place
import re, pathlib

p = pathlib.Path("src/retrievers/multihop_data_extrator.py")
code = p.read_text()
code = code.replace(
    'with open(output_dir, "w+") as f:',
    'os.makedirs(os.path.dirname(output_dir), exist_ok=True)\n    with open(output_dir, "w+") as f:'
)
p.write_text(code)
print("Patched multihop_data_extrator.py to auto-create the output directory.")

Patched multihop_data_extrator.py to auto-create the output directory.


## 6. Corpus Preparation

We standardize the HotpotQA corpus into a unified `.jsonl` format required by the EfficientRAG retrievers.

In [None]:
# This script reads HotpotQA and writes the unified corpus jsonl under data/corpus/hotpotQA/corpus.jsonl
! python /content/efficientrag-official/src/retrievers/multihop_data_extrator.py --dataset hotpotQA

  0% 0/3931 [00:00<?, ?it/s]100% 3931/3931 [00:00<00:00, 617174.21it/s]


## 7. Embedding the Corpus

We use **Contriever** to encode the text passages into dense vectors. These embeddings will be used later for retrieval and negative sampling.

In [None]:
# Uses the default Contriever; outputs FAISS/emb files under output_dir
!python src/retrievers/passage_embedder.py \
  --passages data/corpus/hotpotQA/corpus.jsonl \
  --output_dir data/corpus/hotpotQA/contriever \
  --model_type contriever

Loading passages from data/corpus/hotpotQA/corpus.jsonl
Processing chunk 1/1
[2KEmbedding[35m   0%[0m [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0/4 [0m [ [33m0:00:00[0m < [36m-:--:--[0m , [31m? it/s[0m ]2025-10-22 23:19:49.799077: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1761175189.821451    5869 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1761175189.828201    5869 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1761175189.845560    5869 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1761175189.845600    5869 computati

## 8. Starting Local Llama-3 Server (vLLM)

Instead of making network calls to OpenAI, we spin up a local API server using **vLLM**. This serves `Meta-Llama-3-8B-Instruct` on port 8000, mimicking the OpenAI API format.

*Note: This requires a GPU with roughly 14GB+ VRAM (Colab A100 or T4 is usually sufficient).*

In [None]:
import subprocess, time, requests, json, os, signal

# Kill any prior instance
_ = subprocess.run("pkill -f vllm.entrypoints.openai.api_server || true", shell=True)

cmd = [
  "python","-m","vllm.entrypoints.openai.api_server",
  "--model","meta-llama/Meta-Llama-3-8B-Instruct",
  #"--quantization","gptq",
  "--dtype","auto","--host","0.0.0.0","--port","8000",
  #"--served-model-name","meta-llama/Meta-Llama-3-8B-Instruct",
  #"--gpu-memory-utilization","0.92",
  #"--max-model-len","8192",
  #"--max-num-seqs","6",
  "--api-key","token-colab-local-123",
  #"--enforce-eager"
]

proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)

# Wait for readiness
BASE="http://127.0.0.1:8000/v1"
HEAD={"Authorization":"Bearer token-colab-local-123","Content-Type":"application/json"}
for _ in range(400):
    try:
        r=requests.get(f"{BASE}/models",headers=HEAD,timeout=2)
        if r.status_code==200: break
    except Exception: pass
    time.sleep(1)

print("Server OK:", r.status_code==200)

# Keep 'proc' to stop later:
# proc.terminate(); proc.wait()

Server OK: True


## 9. Injecting Local Inference Logic

Here we overwrite the repo's model utility files (`src/utils/model.py`, `src/language_models/llama.py`, etc.) to direct all LLM calls to our local `localhost:8000` endpoint instead of GPT-4. We also implement robust JSON parsing and retry logic.

In [None]:
import pathlib

new_code = """
import json
import random
import re
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Callable, Literal, Optional

from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
from tqdm.rich import tqdm_rich
from language_models import LanguageModel

def _to_text(result):
    if result is None:
        return ""
    # Already a string
    if isinstance(result, str):
        return result
    # OpenAI/vLLM modern clients return pydantic objects with .choices[0].message.content
    try:
        if hasattr(result, "choices"):
            ch0 = result.choices[0]
            # vLLM/OpenAI: choices[i].message.content
            msg = getattr(ch0, "message", None)
            if msg is not None:
                content = getattr(msg, "content", None)
                if isinstance(content, str):
                    return content
        # Dict-like fallbacks
        if isinstance(result, dict):
            choices = result.get("choices") or []
            if choices:
                message = choices[0].get("message") or {}
                content = message.get("content")
                if isinstance(content, str):
                    return content
    except Exception:
        pass
    # Last resort: stringify (prevents TypeError in regex)
    return str(result)


class EmptyContentError(RuntimeError):
    pass

def _safe_json_parse(result: Optional[str]) -> Optional[dict]:
    if not result:
        return None
    # try fenced block first
    m = re.search(r"```json\\s*(\\{.*?\\})\\s*```", result, re.DOTALL | re.IGNORECASE)
    if m:
        try:
            return json.loads(m.group(1).strip())
        except Exception:
            pass
    # otherwise first top-level-looking {...}
    m2 = re.search(r"\\{.*\\}", result, re.DOTALL)
    if m2:
        try:
            return json.loads(m2.group(0).strip())
        except Exception:
            return None
    return None

def _get_parser(type_: str) -> Callable:
    if type_ == "json":
        return _safe_json_parse
    elif type_ in ("text", "raw"):
        # raw/text are the same here (return the string)
        return lambda s: s
    raise ValueError(f"Unsupported type: {type_}")

# Try up to 3 times; back off if the model returns empty
@retry(
    stop=stop_after_attempt(3),
    wait=wait_exponential(multiplier=1, min=1, max=6),
    retry=retry_if_exception_type(EmptyContentError),
    reraise=False,
    retry_error_callback=lambda _: None,
)
def ask_model(
    model: LanguageModel,
    prompt: str,
    system_msg: str = None,
    type: Literal["json", "text", "raw"] = "json",
    check_if_valid: Callable = None,
    sleep: bool = True,
    mode: Literal["chat", "completion"] = "chat",
) -> Optional[dict]:
    if sleep:
        time.sleep(random.uniform(0.8, 1.8))

    # 1) primary attempt with JSON mode when requested
    if mode == "chat":
        result = model.chat(prompt, system_msg, json_mode=(type == "json"))
    else:
        result = model.complete(prompt)

    result = _to_text(result)

    # None / empty payload from server? -> trigger retry (common with JSON mode or overload)
    if not result:
        raise EmptyContentError("empty content from server")

    parser = _get_parser(type)
    parsed = parser(result)

    # Fallback path also normalized:
    if type == "json" and parsed is None:
        if mode == "chat":
            fallback = model.chat(prompt, system_msg, json_mode=False)
        else:
            fallback = model.complete(prompt)
        fallback = _to_text(fallback)
        if not fallback:
            raise EmptyContentError("empty content after fallback")
        parsed = _safe_json_parse(fallback)
        if parsed is None:
            raise EmptyContentError("unparseable JSON after fallback")



    if check_if_valid is not None and parsed is not None and not check_if_valid(parsed):
        # validator says no -> return None (don’t raise to avoid hiding raw)
        return None

    return parsed

def ask_model_in_parallel(
    model: LanguageModel,
    prompts: list[str],
    system_msg: str = None,
    type: Literal["json", "text", "raw"] = "json",
    check_if_valid_list: list[Callable] = None,
    max_workers: int = 4,
    desc: str = "Processing...",
    verbose=True,
    mode: Literal["chat", "completion"] = "chat",
):
    if max_workers == -1:
        max_workers = len(prompts)
    assert max_workers >= 1

    if check_if_valid_list is None:
        check_if_valid_list = [None] * len(prompts)
    assert len(prompts) == len(check_if_valid_list)

    results = []
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = {
            executor.submit(
                ask_model,
                model=model,
                prompt=prompt,
                system_msg=system_msg,
                type=type,
                check_if_valid=chk,
                sleep=True,
                mode=mode,
            ): i
            for i, (prompt, chk) in enumerate(zip(prompts, check_if_valid_list))
        }
        for fut in tqdm_rich(as_completed(futures), total=len(futures), desc=desc, disable=not verbose):
            idx = futures[fut]
            try:
                results.append((idx, fut.result()))
            except Exception:
                results.append((idx, None))
        results.sort(key=lambda x: x[0])
        return [r for _, r in results]

def get_type_parser(type: str) -> Callable:
    def json_parser(result: str):
        # pattern = r"```json(.*?)```"
        pattern = r"{.*?}"
        matches = re.findall(pattern, result, re.DOTALL)
        if matches:
            result = matches[0].strip()
        return json.loads(result)

    def text_parser(result: str):
        return result

    if type == "json":
        return json_parser
    elif type == "text":
        return text_parser
    else:
        raise ValueError(f"Unsupported type: {type}")
"""

# Overwrite the file with the desired implementation
target = pathlib.Path("src/utils/model.py")
target.write_text(new_code, encoding="utf-8")
print(f"Wrote first-code implementation to {target}")

Wrote first-code implementation to src/utils/model.py


In [None]:
import pathlib

new_code = """
from openai import OpenAI

from .base import LanguageModel

LLAMA_ENDPOINT = "http://127.0.0.1:8000/v1"
LLAMA_API_KEY  = "token-colab-local-123"


class LlamaServer(LanguageModel):
    def __init__(self, model, **_):
        self.model = model
        self.client = OpenAI(base_url=LLAMA_ENDPOINT, api_key=LLAMA_API_KEY, timeout=10.0, max_retries=3)

    def chat(self, message: str, system_msg: str = None, json_mode: bool = False):
        if system_msg is None:
            system_msg = "You are a helpful assistant."
        messages = [
            {"role": "system", "content": system_msg},
            {"role": "user", "content": message},
        ]
        kwargs = dict(temperature=0.1, top_p=1.0, max_tokens=512)

        # Prefer json_object; json_schema support can vary by vLLM version
        # If schema is critical for you, keep it—but add a safe fallback below.
        if json_mode:
            kwargs["response_format"] = {"type": "json_object"}

        try:
            resp = self.client.chat.completions.create(model=self.model, messages=messages, **kwargs)
        except Exception:
            # Retry once without response_format (some vLLM builds/models return empty/err with it)
            if "response_format" in kwargs:
                kwargs.pop("response_format", None)
            resp = self.client.chat.completions.create(model=self.model, messages=messages, **kwargs)

        # Return STRING content, not the whole object
        return getattr(resp.choices[0].message, "content", "") or ""


    def complete(self, prompts: str):
        response = self.client.completions.create(
            model=self.model, prompt=prompts, echo=False, max_tokens=100
        )
        response = response.choices[0].text
        return response


if __name__ == "__main__":
    llama = LlamaServer("Meta-Llama-3-8B-Instruct")
    response = llama.complete(
        "The reason of human landing on moon is that, some one found it strange behind the moon."
    )
    print(response)
"""

# Overwrite the file with the desired implementation
target = pathlib.Path("src/language_models/llama.py")
target.write_text(new_code, encoding="utf-8")
print(f"Wrote second-code implementation to {target}")

Wrote second-code implementation to src/language_models/llama.py


In [None]:
import pathlib

new_code = """
from .aoai import AOAI
from .base import LanguageModel
from .deepseek import DeepSeek
from .llama import LlamaServer

MODEL_DICT = {
    "gpt35": "gpt-35-turbo-1106",
    "gpt4": "gpt-4-0125-preview",
    "llama": "llama-3-70b-gptq-int4",
    "llama-8B": "meta-llama/Meta-Llama-3-8B-Instruct",
    "deepseek": "deepseek-chat",
}


def get_model(model_name: str, **kwargs) -> LanguageModel:
    if model_name in MODEL_DICT:
        model_name = MODEL_DICT[model_name]
    lower = model_name.lower()

    # 1) Prefer llama branch to avoid "gptq" false-positive
    if lower.startswith(("llama", "meta-llama", "meta")):
        return LlamaServer(model=model_name, **kwargs)
    # 2) Only route to AOAI for real GPT model names
    elif "gpt" in model_name.lower():
        return AOAI(model=model_name, **kwargs)
    if "deepseek" in lower:
        return DeepSeek(model=model_name, **kwargs)
    raise NotImplementedError(f"Model {model_name} not implemented")
"""

# Overwrite the file with the desired implementation
target = pathlib.Path("src/language_models/__init__.py")
target.write_text(new_code, encoding="utf-8")
print(f"Wrote third-code implementation to {target}")

Wrote third-code implementation to src/language_models/__init__.py


In [None]:
import pathlib, re

utils_path = pathlib.Path("src/utils/utils.py")
code = utils_path.read_text(encoding="utf-8")

# 1) Ensure imports (idempotent)
if "from pathlib import Path" not in code:
    code = re.sub(
        r'(^(\s*import[^\n]*\n|\s*from[^\n]*\n)+)',
        r'\1from pathlib import Path\n',
        code,
        count=1,
        flags=re.M
    ) if re.search(r'^(\s*import|\s*from)', code, flags=re.M) else "from pathlib import Path\n" + code

if re.search(r'^\s*import json\s*$', code, flags=re.M) is None:
    code = "import json\n" + code

# 2) Patch write_jsonl open(...) to Path version (mkdir + p.open)
pattern_write = re.compile(
    r'(?m)^(?P<indent>\s*)with\s+open\(\s*file_path\s*,\s*["\']w\+?["\'].*?\)\s*as\s*f\s*:\s*$'
)
replacement_write = (
    r'\g<indent>p = Path(file_path)\n'
    r'\g<indent>p.parent.mkdir(parents=True, exist_ok=True)\n'
    r'\g<indent>with p.open("w", encoding="utf-8") as f:'
)
code, n_write = pattern_write.subn(replacement_write, code)

# 3) Replace the entire load_jsonl definition with a safe version
#    Match 'def load_jsonl(...):' through the next 'def ' at same indent or EOF
pattern_load = re.compile(
    r'(?ms)^(?P<indent>\s*)def\s+load_jsonl\s*\([^)]*\)\s*:\s*'
    r'(?:.*?)(?=^\s*def\s+|\Z)'
)
replacement_load = (
    r'\g<indent>def load_jsonl(file_path, *, missing_ok=True):\n'
    r'\g<indent>    """\n'
    r'\g<indent>    Read records from a JSONL file.\n'
    r'\g<indent>    - If missing and missing_ok=True, returns [] and ensures parent dir exists.\n'
    r'\g<indent>    - If missing_ok=False, raises FileNotFoundError.\n'
    r'\g<indent>    """\n'
    r'\g<indent>    p = Path(file_path)\n'
    r'\g<indent>    if not p.exists():\n'
    r'\g<indent>        p.parent.mkdir(parents=True, exist_ok=True)\n'
    r'\g<indent>        if missing_ok:\n'
    r'\g<indent>            return []\n'
    r'\g<indent>        raise FileNotFoundError(f"JSONL file not found: {p}. Set missing_ok=True to return [].")\n'
    r'\g<indent>    with p.open("r", encoding="utf-8") as f:\n'
    r'\g<indent>        return [json.loads(line) for line in f if line.strip()]\n'
)
code, n_load = pattern_load.subn(replacement_load, code)

# 4) If no load_jsonl existed, append our safe version
if n_load == 0 and "def load_jsonl" not in code:
    code += "\n\n" + replacement_load.replace(r'\g<indent>', '')

utils_path.write_text(code, encoding="utf-8")
print(f"Patched write sites: {n_write}, patched load_jsonl: {max(n_load, 1) if 'def load_jsonl' in code else 0}")


Patched write sites: 1, patched load_jsonl: 1


In [None]:
import pathlib

new_code = """
import argparse
import json
import os
import sys
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Iterator
import subprocess
import time
import signal
import requests

from tqdm.rich import tqdm_rich

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from prompts import (
    TOKEN_LABEL_REDUNDANT_EVALUATION_PROMPT,
    TOKEN_LABEL_REDUNDANT_SYSTEM_MSG,
    TOKEN_LABEL_SYNTHESIZE_FEW_SHOT_PROMPT_MUSIQUE,
    TOKEN_LABEL_SYNTHESIZE_FEW_SHOT_PROMPT_WIKIMQA,
    TOKEN_LABELING_SYSTEM_MSG,
)

from conf import (
    MODEL_DICT,
    SYNTHESIZED_DECOMPOSED_DATA_PATH,
    SYNTHESIZED_TOKEN_LABELING_DATA_PATH,
)
from language_models import LanguageModel, get_model
from utils import ask_model, ask_model_in_parallel, load_jsonl
from utils.model import get_type_parser

TOKEN_LABEL_PROMPT_TEMPLATE_MAPPING = {
    "musique": TOKEN_LABEL_SYNTHESIZE_FEW_SHOT_PROMPT_MUSIQUE,
    "musique-simple": TOKEN_LABEL_SYNTHESIZE_FEW_SHOT_PROMPT_MUSIQUE,
    "2WikiMQA": TOKEN_LABEL_SYNTHESIZE_FEW_SHOT_PROMPT_WIKIMQA,
    "hotpotQA": TOKEN_LABEL_SYNTHESIZE_FEW_SHOT_PROMPT_WIKIMQA
}

# ---------- Local Llama server management (for llama-8B) ----------

LLAMA_LOCAL_BASE_URL = "http://127.0.0.1:8000/v1"
LLAMA_LOCAL_API_KEY = "token-colab-local-1234"


def start_llama_server() -> subprocess.Popen:
    # Kill any prior instance just in case
    subprocess.run(
        "pkill -f vllm.entrypoints.openai.api_server || true",
        shell=True,
        check=False,
    )

    cmd = [
        "python",
        "-m",
        "vllm.entrypoints.openai.api_server",
        "--model",
        "meta-llama/Meta-Llama-3-8B-Instruct",
        "--dtype",
        "auto",
        "--host",
        "0.0.0.0",
        "--port",
        "8000",
        "--api-key",
        LLAMA_LOCAL_API_KEY,
        "--max-model-len",
        "8192",
        "--max-num-seqs",
        "6",        # conservative concurrency
        "--gpu-memory-utilization",
        "0.9",
    ]

    proc = subprocess.Popen(
        cmd,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
        bufsize=1,
    )

    # Wait for readiness
    headers = {
        "Authorization": f"Bearer {LLAMA_LOCAL_API_KEY}",
        "Content-Type": "application/json",
    }
    for _ in range(300):
        try:
            r = requests.get(f"{LLAMA_LOCAL_BASE_URL}/models", headers=headers, timeout=2)
            if r.status_code == 200:
                print("[token_labeling] vLLM server is ready.")
                return proc
        except Exception:
            pass
        time.sleep(1)

    print("[token_labeling] vLLM server did not become ready in time; killing it.")
    proc.kill()
    raise RuntimeError("Failed to start vLLM server")


def stop_llama_server(proc: subprocess.Popen):
    if proc is None:
        return
    print("[token_labeling] Stopping vLLM server...")
    try:
        proc.terminate()
        try:
            proc.wait(timeout=15)
        except subprocess.TimeoutExpired:
            proc.kill()
    except Exception:
        pass

# -----------------------------------------------------------------


class TokenLabeler:
    def __init__(self, model: str, dataset: str, split: str) -> None:
        self.model: LanguageModel
        self.model = get_model(model)

        labeled_data_path = os.path.join(
            SYNTHESIZED_DECOMPOSED_DATA_PATH, dataset, f"{split}.jsonl"
        )
        self.labeled_data = load_jsonl(labeled_data_path)
        self.check_if_valid = lambda x: all(
            [k in x.keys() for k in ["extracted_words"]]
        )
        self.token_labeling_prompt = TOKEN_LABEL_PROMPT_TEMPLATE_MAPPING[dataset]

    def parse(self, starting: int = 0, workers=10):
        labeled_data = self.labeled_data[starting:]
        # keep only samples whose overall state is None
        labeled_data = [d for d in labeled_data if d.get("state", None) is None]

        for sample in tqdm_rich(labeled_data, desc="Processing..."):
            yield self.parse_sample(sample)


    def parse_sample(self, sample: dict) -> dict:
        prompt_list = self.parse_prompt(sample)
        results = []
        for prompt in prompt_list:
            result = ask_model(
                self.model,
                prompt,
                TOKEN_LABELING_SYSTEM_MSG,
                type="json",
                check_if_valid=self.check_if_valid,
            )
            if result is None:
                result = {"extracted_words": "", "status": "error"}
            results.append(result)
        for subq_id, result in zip(
            sorted(sample["decomposed_questions"].keys()), results
        ):
            chunk = sample["decomposed_questions"][subq_id]
            chunk["extracted_words"] = result["extracted_words"]
        return sample

    def parse_prompt(self, data: dict) -> list[dict]:
        prompt_list = []
        for subq_id in sorted(data["decomposed_questions"].keys()):
            subq = data["decomposed_questions"][subq_id]
            format_kwargs = {
                "question": subq["sub_question"],
                "paragraph": subq["positive_paragraph"],
                "answer": subq["answer"],
            }
            prompt = self.token_labeling_prompt.format(**format_kwargs)
            prompt_list.append(prompt)
        return prompt_list

    def parse_failed(self, token_labeled_data: list[dict]) -> list[dict]:
        results = []
        failed_question_ids = set()
        for sample in token_labeled_data:
            for sub_question_id in sorted(sample["decomposed_questions"].keys()):
                if (
                    sample["decomposed_questions"][sub_question_id].get("state", None)
                    == "error"
                ):
                    failed_question_ids.add(sample["id"])
                    break
        progress = tqdm_rich(
            desc="Processing failed...", total=len(failed_question_ids)
        )
        for sample in token_labeled_data:
            if sample["id"] not in failed_question_ids:
                results.append(sample)
                continue
            for sub_question_id in sorted(sample["decomposed_questions"].keys()):
                if (
                    sample["decomposed_questions"][sub_question_id].get("state", None)
                    != "error"
                ):
                    continue
                prompt_list = self.parse_prompt(sample)
                prompt = prompt_list[int(sub_question_id) - 1]
                result = ask_model(
                    self.model, prompt, type="json", check_if_valid=self.check_if_valid
                )
                if result is None:
                    continue
                del sample["decomposed_questions"][sub_question_id]["state"]
                sample["decomposed_questions"][sub_question_id]["extracted_words"] = (
                    result["extracted_words"]
                )
            progress.update(1)
            results.append(sample)
        return results


class TokenReLabeler:
    def __init__(self, model: str, dataset: str, split: str) -> None:
        self.model: LanguageModel
        self.model = get_model(model)
        self.model_powerful = get_model("Llama3-8B-Instruct")

        labeled_data_path = os.path.join(
            SYNTHESIZED_TOKEN_LABELING_DATA_PATH, dataset, f"{split}.jsonl"
        )
        self.labeled_data = load_jsonl(labeled_data_path)
        self.check_if_valid = lambda x: all(
            [k in x.keys() for k in ["extracted_words"]]
        )
        self.check_redundant_valid = lambda x: type(x) == dict and all(
            [k in x.keys() for k in ["redundant", "missing"]]
        )
        self.type_parser = get_type_parser(type="json")

    def label_redundant(self, labeled_data: list[dict], workers: int) -> list[dict]:
        redundant_questions = []
        with ThreadPoolExecutor(max_workers=workers) as executor:
            tasks = {
                executor.submit(self.check_sample_redundant, sample): idx
                for idx, sample in enumerate(labeled_data)
            }
            for future in tqdm_rich(
                as_completed(tasks), total=len(tasks), desc="Redundant"
            ):
                task_id = tasks[future]
                try:
                    result = future.result()
                    redundant_questions.append((task_id, result))
                finally:
                    ...
            redundant_questions = [
                r[1] for r in sorted(redundant_questions, key=lambda x: x[0])
            ]
        for sample, redundant in zip(labeled_data, redundant_questions):
            assert redundant["id"] == sample["id"]
            for subq_id in redundant["redundant"]:
                sample["decomposed_questions"][subq_id]["redundant"] = True
        return labeled_data

    def parse(self, workers: int = 10, redundant_labeled: bool = False) -> list[dict]:
        labeled_data = [
            d
            for d in self.labeled_data
            if all(
                chunk.get("state", None) is None
                for chunk in d["decomposed_questions"].values()
            )
        ]

        # 1. use GPT3.5 to identify if extracted words is redundant or missing
        if not redundant_labeled:
            labeled_data = self.label_redundant(labeled_data, workers)

        # 2. use Llama3 to re-label the extracted words
        results = []
        data_mapping = {d["id"]: d for d in labeled_data}

        max_iter = 5
        current_iter = 0
        while current_iter < max_iter:
            current_iter += 1

            prompts = []
            for sample in labeled_data:
                sample_prompts = self.build_relabel_prompt(sample)
                prompts.extend(sample_prompts)

            print(
                f"Current iteration: {current_iter}, "
                f"max iteration: {max_iter}, "
                f"handling {len(prompts)} prompts."
            )
            if len(prompts) <= 0:
                break

            batched_prompts = [p["prompt"] for p in prompts]
            results = self.model_powerful.chat(
                batched_prompts, TOKEN_LABELING_SYSTEM_MSG, json_mode=True
            )

            for prompt, result in zip(prompts, results):
                try:
                    json_result = self.type_parser(result)
                    data = data_mapping[prompt["id"]]
                    chunk = data["decomposed_questions"][prompt["subq_id"]]
                    chunk["extracted_words_old"] = chunk["extracted_words"]
                    chunk["extracted_words"] = json_result["extracted_words"]
                    chunk["redundant"] = False
                except json.JSONDecodeError:
                    print(f"Error on {prompt['id']} sub-question {prompt['subq_id']}")
                    json_result = None

        return labeled_data

    def build_relabel_prompt(self, sample: dict):
        prompts = []
        for subq_id, chunk in sample["decomposed_questions"].items():
            if not chunk.get("redundant", False):
                break
            sub_question = chunk["sub_question"]
            paragraph = chunk["positive_paragraph"]
            answer = chunk["answer"]
            prompt = TOKEN_LABEL_SYNTHESIZE_FEW_SHOT_PROMPT_MUSIQUE.format(
                question=sub_question, paragraph=paragraph, answer=answer
            )
            info = {
                "id": sample["id"],
                "subq_id": subq_id,
                "prompt": prompt,
            }
            prompts.append(info)
        return prompts

    def check_sample_redundant(self, sample: dict):
        redundant = {"id": sample["id"], "redundant": []}
        for subq_id, subq in sample["decomposed_questions"].items():
            question = subq["sub_question"]
            answer = subq["answer"]
            extracted_words = subq["extracted_words"]
            evaluation_prompt = TOKEN_LABEL_REDUNDANT_EVALUATION_PROMPT.format(
                question=question, answer=answer, extracted_words=extracted_words
            )
            evaluation = ask_model(
                self.model,
                evaluation_prompt,
                TOKEN_LABEL_REDUNDANT_SYSTEM_MSG,
                type="json",
                check_if_valid=self.check_redundant_valid,
            )
            if evaluation["redundant"] or evaluation["missing"]:
                redundant["redundant"].append(subq_id)
        return redundant


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dataset",
        type=str,
        default="musique",
    )
    parser.add_argument("--split", type=str, default="valid")
    parser.add_argument("--model", default="gpt4")
    parser.add_argument(
        "--workers", type=int, default=10, help="Number of parallel processors"
    )
    parser.add_argument(
        "--sync", action="store_true", help="Syncing with label fixed data"
    )
    parser.add_argument("--failed", action="store_true", help="Parse failed data")
    parser.add_argument("--failed_path", type=str, help="Path to failed data")
    parser.add_argument(
        "--relabel", action="store_true", help="Re-label extracted words"
    )
    args = parser.parse_args()
    return args


def main(opt: argparse.Namespace):
    model_name = MODEL_DICT[opt.model]

    # Decide if we should manage a local vLLM server (llama-8B case)
    use_local_llama = opt.model == "llama-8B"

    server_proc = None
    try:
        if use_local_llama:
            server_proc = start_llama_server()

        labeler = TokenLabeler(model_name, opt.dataset, opt.split)

        os.makedirs(
            os.path.join(SYNTHESIZED_TOKEN_LABELING_DATA_PATH, opt.dataset),
            exist_ok=True,
        )

        out_path = os.path.join(
            SYNTHESIZED_TOKEN_LABELING_DATA_PATH,
            opt.dataset,
            f"{opt.split}.jsonl",
        )

        processed = 0
        # Open in "w+" so each run of this script overwrites old file for the same split
        with open(out_path, "w+", encoding="utf-8") as f:
            # We ignore opt.workers here because parse() is now sequential.
            for labeled in labeler.parse(workers=1):
                info = json.dumps(labeled, ensure_ascii=False)
                f.write(info + "\n")
                f.flush()  # ensure it's actually written to disk
                processed += 1

                # Every 200 processed samples, restart the local LLM server
                if use_local_llama and processed % 150 == 0:
                    print(
                        f"[token_labeling] Processed {processed} samples; "
                        "restarting vLLM server to clear caches..."
                    )
                    stop_llama_server(server_proc)
                    server_proc = start_llama_server()

        print(f"[token_labeling] Finished. Total processed: {processed}")

    finally:
        # Always try to stop the server at the end
        if use_local_llama and server_proc is not None:
            stop_llama_server(server_proc)



if __name__ == "__main__":
    options = parse_args()
    main(options)
"""

# Overwrite the file with the desired implementation
target = pathlib.Path("src/data_synthesize/token_labeling.py")
target.write_text(new_code, encoding="utf-8")
print(f"Wrote 5th-code implementation to {target}")

In [None]:
# Apply a small patch to implement build_prompt_template_hotpot
import pathlib

path = pathlib.Path("src/data_synthesize/next_hop_query_construction.py")
code = path.read_text()

old = '''    def build_prompt_template_hotpot(self, sample: dict, dependency: list[str]) -> str:
        raise NotImplementedError()
'''

new = '''    def build_prompt_template_hotpot(self, sample: dict, dependency: list[str]) -> str:
        text = """
        You are helping with multi-hop question answering over Wikipedia.

        You are given:
        - The original multi-hop question:
        <Question>: {question}

        - Some information that is already known:
        {info_list}

        - Answers to earlier sub-questions:
        {subq_answers}

        Your task: write a single best next-hop search query that should be used
        to retrieve more information needed to ultimately answer the original question.

        Respond ONLY in JSON, with a single key "filtered_query" whose value is a string.
        For example:
        ```json
        {{"filtered_query": "Where was <person> born?"}}
        ````

        Now respond with the JSON:
        """.strip()
        return text
'''

if old not in code:
  raise SystemExit("Old stub for build_prompt_template_hotpot not found — patch not applied.")

path.write_text(code.replace(old, new))
print("Patched build_prompt_template_hotpot in next_hop_query_construction.py")

In [None]:
import pathlib, re
p = pathlib.Path("src/data_synthesize/next_hop_query_construction.py")
s = p.read_text(encoding="utf-8")

if "os.makedirs(" not in s:
    if "import os" not in s:
        s = "import os\n" + s
    pat = re.compile(r'(?m)^(?P<i>\s*)with\s+open\(output_path[^:]+:')
    s = pat.sub(lambda m: f'{m.group("i")}os.makedirs(output_path.replace(f"{{opt.split}}.jsonl",""), exist_ok=True)\n{m.group(0)}', s, count=1)
    p.write_text(s, encoding="utf-8")

In [None]:
from importlib import reload
import src.language_models.llama as llama_mod
import src.data_synthesize.prompts.hotpotQA as hotpotQA_mod
import src.utils.model as model_mod
import src.language_models.__init__ as init_mod

reload(llama_mod)
reload(hotpotQA_mod)
reload(model_mod)
reload(init_mod)

<module 'src.language_models.__init__' from '/content/efficientrag-official/src/language_models/__init__.py'>

### A Quick Sanity Check

In [None]:
import requests
from src.language_models.llama import LLAMA_ENDPOINT, LLAMA_API_KEY

def ensure_model_visible(name="meta-llama/Meta-Llama-3-8B-Instruct",
                         base=LLAMA_ENDPOINT, api_key=LLAMA_API_KEY):
    # base is like "http://127.0.0.1:8000/v1"
    r = requests.get(f"{base}/models",
                     headers={"Authorization": f"Bearer {api_key}"},
                     timeout=10)
    r.raise_for_status()
    names = [m["id"] for m in r.json().get("data", [])]
    assert name in names, f"Model {name} not found; got: {names}"
    return names

print(ensure_model_visible())

['meta-llama/Meta-Llama-3-8B-Instruct']


In [None]:
import requests
BASE = "http://127.0.0.1:8000/v1"
HEAD = {"Authorization": "Bearer token-colab-local-123"}
print(requests.get(f"{BASE}/models", headers=HEAD, timeout=10).json())


{'object': 'list', 'data': [{'id': 'meta-llama/Meta-Llama-3-8B-Instruct', 'object': 'model', 'created': 1761175435, 'owned_by': 'vllm', 'root': 'meta-llama/Meta-Llama-3-8B-Instruct', 'parent': None, 'max_model_len': 8192, 'permission': [{'id': 'modelperm-17c8da54e9b54a898c07e33a76ae8b2d', 'object': 'model_permission', 'created': 1761175435, 'allow_create_engine': False, 'allow_sampling': True, 'allow_logprobs': True, 'allow_search_indices': False, 'allow_view': True, 'allow_fine_tuning': False, 'organization': '*', 'group': None, 'is_blocking': False}]}]}


In [None]:
llm = llama_mod.LlamaServer(model="meta-llama/Meta-Llama-3-8B-Instruct")  # must match /v1/models
print(llm.chat("Hi! If I have 5 apples and someone takes 2, how many are left?"))

Let's count them together!

You started with 5 apples, and someone took 2. To find out how many are left, we'll subtract 2 from 5.

5 - 2 = 3

So, there are 3 apples left!


## 10. Synthetic Data Generation Pipeline

We now execute the EfficientRAG data synthesis pipeline step-by-step. This process uses the LLM (Llama-3) to decompose questions, label essential tokens, and construct search queries.

In [None]:
import subprocess, time, requests, json, os, signal

# Kill any prior instance
_ = subprocess.run("pkill -f vllm.entrypoints.openai.api_server || true", shell=True)

cmd = [
  "python","-m","vllm.entrypoints.openai.api_server",
  "--model","meta-llama/Meta-Llama-3-8B-Instruct",
  #"--quantization","gptq",
  "--dtype","auto","--host","0.0.0.0","--port","8000",
  #"--served-model-name","meta-llama/Meta-Llama-3-8B-Instruct",
  #"--gpu-memory-utilization","0.92",
  #"--max-model-len","8192",
  #"--max-num-seqs","6",
  "--api-key","token-colab-local-123",
  #"--enforce-eager"
]

proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)

# Wait for readiness
BASE="http://127.0.0.1:8000/v1"
HEAD={"Authorization":"Bearer token-colab-local-123","Content-Type":"application/json"}
for _ in range(400):
    try:
        r=requests.get(f"{BASE}/models",headers=HEAD,timeout=2)
        if r.status_code==200: break
    except Exception: pass
    time.sleep(1)

print("Server OK:", r.status_code==200)

# Keep 'proc' to stop later:
# proc.terminate(); proc.wait()

### Step 1: Query Decomposition
The model breaks down complex multi-hop questions into simpler sub-questions.

In [None]:
! python src/data_synthesize/query_decompose.py \
  --dataset hotpotQA \
  --split train \
  --model llama-8B \
  --ending 50 # adjust it based on your resources
  #--debug --debug_dir debug/query_decompose \

  for task in tqdm_rich(
[2KFailed to synthesize sample 5a8d9501554299068b959d4d
[2KFailed to synthesize sample 5a875ce15542993e715abf16
[2KFailed to synthesize sample 5adcc89f5542994d58a2f6cf
[2KFailed to synthesize sample 5a906beb55429916514e74b9
[2KFailed to synthesize sample 5ae3cfe05542990afbd1e1e3
[2KFailed to synthesize sample 5a711ec15542994082a3e5aa
[2KFailed to synthesize sample 5a72d20e5542991f9a20c5a7
[2KFailed to synthesize sample 5a7cf78a55429907fabef06d
[2KFailed to synthesize sample 5ab9257b554299753720f749
[2KFailed to synthesize sample 5ac028d95542992a796decb2
[2KFailed to synthesize sample 5ae80c82554299540e5a5707
[2KFailed to synthesize sample 5a7b13b25542992d025e6746
[2KFailed to synthesize sample 5ae7aae35542993210983ee6
[2KFailed to synthesize sample 5a7db4eb5542997cc2c47474
[2KFailed to synthesize sample 5a847de2554299123d8c226f
[2KFailed to synthesize sample 5ae01a0f55429942ec259c20
[2KFailed to synthesize sample 5a72fcdc5542991f9a20c5eb
[2KFa

### Step 2: Token Labeling
The model identifies which specific tokens (words) in the retrieved paragraphs contain the answers to the sub-questions.

In [None]:
! python src/data_synthesize/token_labeling.py \
    --dataset hotpotQA \
    --split train \
    --model llama-8B

  for task in tqdm_rich(
[2KProcessing...[35m 100%[0m [90m━━━━━━━━━━━━━━━━━━━━━━━[0m [32m28/28 [0m [ [33m0:00:22[0m < [36m0:00:00[0m , [31m1 it/s[0m ]
[?25h

### Step 3: Token Extraction
We extract the labeled tokens to create a clean dataset of (Question, Paragraph) -> (Essential Tokens).

In [None]:
! python src/data_synthesize/token_extraction.py \
    --data_path data/synthesized_token_labeling/hotpotQA/train.jsonl \
    --save_path data/token_extracted/hotpotQA/train.jsonl \
    --verbose

  for sample in tqdm_rich(data):
[?25l[2K[35m   0%[0m [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0/28 [0m [ [33m0:00:00[0m < [36m-:--:--[0m , [31m? it/s[0m ][2KRussell 0 0 []
[35m   0%[0m [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0/28 [0m [ [33m0:00:00[0m < [36m-:--:--[0m , [31m? it/s[0m ][2KHobbs 1 1 ['Russell', 'Hobbs', ':']
[35m   0%[0m [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0/28 [0m [ [33m0:00:00[0m < [36m-:--:--[0m , [31m? it/s[0m ][2KHobbs 4 4 ['Russell', 'Hobbs', 'be']
[35m   0%[0m [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0/28 [0m [ [33m0:00:00[0m < [36m-:--:--[0m , [31m? it/s[0m ][2KPeter 0 0 []
[35m   0%[0m [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0/28 [0m [ [33m0:00:00[0m < [36m-:--:--[0m , [31m? it/s[0m ][2KHobbs 1 1 ['Peter', 'Hobbs', '(']
[35m   0%[0m [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0/28 [0m [ [33m0:00:00[0m < [36m-:--:--[0m , [31m? it

### Step 4: Next-Hop Query Construction
Based on the information found so far, the model generates the search query needed for the *next* hop of reasoning.

In [None]:
! python src/data_synthesize/next_hop_query_construction.py \
    --dataset hotpotQA \
    --split train \
    --model llama-8B

  for future in tqdm_rich(
[?25l[2KProcessing...[35m   0%[0m [90m━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0/28 [0m [ [33m0:00:00[0m < [36m-:--:--[0m , [31m? it/s[0m ][2KFailed at sample 27: 
Processing...[35m   0%[0m [90m━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0/28 [0m [ [33m0:00:00[0m < [36m-:--:--[0m , [31m? it/s[0m ][2KFailed at sample 26: 
Processing...[35m   0%[0m [90m━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0/28 [0m [ [33m0:00:00[0m < [36m-:--:--[0m , [31m? it/s[0m ][2KFailed at sample 10: 
Processing...[35m   0%[0m [90m━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0/28 [0m [ [33m0:00:00[0m < [36m-:--:--[0m , [31m? it/s[0m ][2KFailed at sample 1: 
Processing...[35m   0%[0m [90m━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0/28 [0m [ [33m0:00:00[0m < [36m-:--:--[0m , [31m? it/s[0m ][2KFailed at sample 19: 
Processing...[35m   0%[0m [90m━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0/28 [0m [ [33m0:00:00[0m < [36m-:--:--[0m , [31m? it/s[0m ][2KFailed at sample 9:

### Step 5: Query Filtering
We filter out generated queries that are redundant or unhelpful, keeping only high-quality search queries.

In [None]:
! python src/data_synthesize/next_hop_query_filtering.py \
    --data_path data/synthesized_next_query/hotpotQA/train.jsonl \
    --save_path data/next_query_extracted/hotpotQA/train.jsonl \
    --verbose

  for sample in tqdm_rich(data):
[?25l[2K[35m   0%[0m [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0/2 [0m [ [33m0:00:00[0m < [36m-:--:--[0m , [31m? it/s[0m ][2K[35m 100%[0m [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2/2 [0m [ [33m0:00:00[0m < [36m0:00:00[0m , [31m? it/s[0m ]
[?25hcomp_rate: 0.00
variation_rate: 0.00
hitting_rate: 0.00
matching_rate: 0.00
alignment_gap: 0.00
find_rate: 0.00


### Step 6: Negative Sampling
To train a robust model, we need "hard negatives"—passages that look relevant but don't actually contain the answer. We use the Contriever index built earlier to find these.

In [None]:
! python src/data_synthesize/negative_sampling.py \
    --dataset hotpotQA \
    --split train \
    --retriever contriever

Building index from data/corpus/hotpotQA/contriever
Load embeddings: 100% 1/1 [00:00<00:00, 238.56it/s]
Total data indexed 3931
Saving index to data/corpus/hotpotQA/contriever
Serializing index to data/corpus/hotpotQA/contriever/index.faiss, meta data to data/corpus/hotpotQA/contriever/index_meta.faiss
Loading passages from data/corpus/hotpotQA/corpus.jsonl
Loaded 3931 passages.
[2KNegative Sampling...[35m   0%[0m [90m━━━━━━━━━━━━━━━━━━[0m [32m0/2 [0m [ [33m0:00:00[0m < [36m-:--:--[0m , [31m? it/s[0m ]2025-10-22 23:37:20.361383: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1761176240.382634   11349 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1761176240.389156   11349 cuda_blas.cc:1407] Unable to register cuBLAS factory: Att

In [None]:
! python src/data_synthesize/negative_sampling_labeled.py \
    --dataset hotpotQA \
    --split train \
    --model llama-8B

  for future in tqdm_rich(as_completed(tasks), total=len(tasks), desc="Processing..."):
[?25l[2KProcessing...[35m   0%[0m [90m━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0/2 [0m [ [33m0:00:00[0m < [36m-:--:--[0m , [31m? it/s[0m ][2KProcessing...[35m 100%[0m [90m━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2/2 [0m [ [33m0:00:00[0m < [36m0:00:00[0m , [31m? it/s[0m ]
[?25h

In [None]:
! python src/data_synthesize/negative_token_extraction.py \
    --dataset hotpotQA \
    --split train \
    --verbose

  for sample in tqdm_rich(data):
[?25l[2K[35m   0%[0m [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0/2 [0m [ [33m0:00:00[0m < [36m-:--:--[0m , [31m? it/s[0m ][2K[35m 100%[0m [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2/2 [0m [ [33m0:00:00[0m < [36m0:00:00[0m , [31m? it/s[0m ]
[?25hcomp_rate: 0.00
variation_rate: 100.00
hitting_rate: 0.00
matching_rate: 0.00
alignment_gap: 0.00
find_rate: 0.00


### Step 7: Final Compilation
Finally, we aggregate the decomposed questions, labeled tokens, filtered queries, and negative samples into the final training datasets for the **Labeler** and **Filter** models.

In [None]:
! python src/data_synthesize/training_data_synthesize.py \
    --dataset hotpotQA \
    --split train

  for sample in tqdm_rich(samples, desc="Building labeler data"):
[?25l[2KBuilding labeler data[35m   0%[0m [90m━━━━━━━━━━━━━━━━━[0m [32m0/2 [0m [ [33m0:00:00[0m < [36m-:--:--[0m , [31m? it/s[0m ][2KBuilding labeler data[35m 100%[0m [90m━━━━━━━━━━━━━━━━━[0m [32m2/2 [0m [ [33m0:00:00[0m < [36m0:00:00[0m , [31m? it/s[0m ]
  for sample in tqdm_rich(samples, desc="Building filter data"):
[?25l[2KBuilding filter data[35m   0%[0m [90m━━━━━━━━━━━━━━━━━━[0m [32m0/2 [0m [ [33m0:00:00[0m < [36m-:--:--[0m , [31m? it/s[0m ][2KBuilding filter data[35m 100%[0m [90m━━━━━━━━━━━━━━━━━━[0m [32m2/2 [0m [ [33m0:00:00[0m < [36m0:00:00[0m , [31m? it/s[0m ]
[?25h