In [1]:
# ===== MUST RUN FIRST CELL (before importing pyterrier / jnius) =====
import os, glob

# A. Prefer system JDK (if it exists)
sys_java_home = "/usr/lib/jvm/java-17-openjdk-amd64"
candidates = [
    os.path.join(sys_java_home, "lib", "server", "libjvm.so"),
    "/gz-data/jdk/current/lib/server/libjvm.so",  # If installed to /gz-data/jdk/current as per previous steps
]

jvm_path = next((p for p in candidates if os.path.isfile(p)), None)

if jvm_path is None:
    # Fallback: Search in common directories
    for base in ("/usr/lib/jvm", "/gz-data/jdk"):
        for p in glob.glob(base + "/**/lib/server/libjvm.so", recursive=True):
            jvm_path = p
            break

if not jvm_path:
    raise RuntimeError("Could not find libjvm.so. Please install JDK 17 (system or /gz-data) and try again.")

java_home = os.path.abspath(os.path.join(jvm_path, "..", ".."))  # Remove /lib/server
os.environ["JAVA_HOME"] = java_home
os.environ["JVM_PATH"]  = jvm_path
os.environ["LD_LIBRARY_PATH"] = os.path.dirname(jvm_path) + ":" + os.environ.get("LD_LIBRARY_PATH", "")
os.environ["PATH"] = os.path.join(java_home, "bin") + ":" + os.environ.get("PATH", "")

# Optional: Specify Java temporary directory and memory
os.environ.setdefault("JAVA_TOOL_OPTIONS", "-Djava.io.tmpdir=/gz-data/tmp")
os.environ.setdefault("_JAVA_OPTIONS", "-Xms512m -Xmx8g")
os.makedirs("/gz-data/tmp", exist_ok=True)

print("JAVA_HOME =", os.environ["JAVA_HOME"])
print("JVM_PATH  =", os.environ["JVM_PATH"])

# ===== Now import PyTerrier and initialize JVM =====
import pyterrier as pt
if not pt.started():
    # Use standard init for stability (includes Java initialization); add mem/jvm_opts if needed
    pt.init()  # Or pt.init(tail=False)
print("PyTerrier started =", pt.started())

JAVA_HOME = /gz-data/jdk/current/lib
JVM_PATH  = /gz-data/jdk/current/lib/server/libjvm.so


  if not pt.started():


PyTerrier started = True


Picked up JAVA_TOOL_OPTIONS: -Djava.io.tmpdir=/gz-data/tmp
Picked up _JAVA_OPTIONS: -Xms512m -Xmx8g
Java started and loaded: pyterrier.java, pyterrier.terrier.java [version=5.11 (build: craig.macdonald 2025-01-13 21:29), helper_version=0.0.8]
java is now started automatically with default settings. To force initialisation early, run:
pt.java.init() # optional, forces java initialisation
  pt.init()  # 或 pt.init(tail=False)
  print("PyTerrier started =", pt.started())


In [2]:
# ========== Environment & Pipeline Setup (BM25-only + Dual-GPU Sharding) ==========
import os, re, torch, random, warnings, pandas as pd
from datasets import load_dataset
import pyterrier as pt
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from tqdm import tqdm
from rouge_score import rouge_scorer
from nltk.translate.bleu_score import sentence_bleu
from bert_score import score as bert_score

warnings.filterwarnings("ignore")

# ===== Manual Settings: Single GPU in this Notebook + Sharding Parameters =====
GPU_VISIBLE = "0"   # Notebook A: "0", Notebook B: "1"
SHARD       = 1     # A=0, B=1
NUM_SHARDS  = 2     # Set to 2 for two GPUs in parallel

# ====== Cache & Path Configuration ======
os.environ["CUDA_VISIBLE_DEVICES"] = GPU_VISIBLE
os.environ["HF_HOME"] = "/gz-data/hf_cache"
os.environ["TRANSFORMERS_CACHE"] = "/gz-data/hf_cache"
os.environ["HF_DATASETS_CACHE"] = "/gz-data/hf_cache"
os.makedirs("/gz-data/hf_cache", exist_ok=True)

ROOT = "/gz-data/nlquad_colbert"
os.makedirs(ROOT, exist_ok=True)
BM25_TOPK = 5
MAXLEN, GEN_MAXLEN = 250, 384
MIN_SPLIT_LEN, DESIRED_SEG_LEN = 1000, 250
GEN_MODEL = "/gz-data/models/deepseek-llm-7b-chat"  # Keep consistent with original (local path)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

random.seed(42); torch.manual_seed(42)
if not pt.started():
    pt.init()

print(f"Using GPU_VISIBLE={GPU_VISIBLE} | SHARD {SHARD}/{NUM_SHARDS}")

# ===== Load Data & Preprocess =====
print(">>> Loading NLQuAD...")
dataset = load_dataset("LLukas22/NLQuAD", split="test")
records = []
for art in dataset:
    for para in art["paragraphs"]:
        ctx = para["context"]
        if len(ctx.split()) >= MIN_SPLIT_LEN:
            cid = para["qas"][0]["id"].split("_")[0]
            for qa in para["qas"]:
                if qa["answers"]:
                    records.append({
                        "context_id": cid,
                        "context": ctx,
                        "question": qa["question"],
                        "answer": qa["answers"][0]["text"],
                        "qa_id": qa["id"]
                    })
df = pd.DataFrame(records)
df = df.sort_values(["context_id", "qa_id"]).reset_index(drop=True)

# ===== Paragraph Splitting =====
def semantic_split(text, max_words=DESIRED_SEG_LEN):
    sents = re.split(r"(?<=[.!?])\s+", text.strip())
    buf, out = [], []
    for s in sents:
        if s.strip():
            buf.append(s)
            if len(" ".join(buf).split()) >= max_words:
                out.append(" ".join(buf))
                buf = []
    if buf: out.append(" ".join(buf))
    return out

para_records = []
for cid, grp in df.groupby("context_id"):
    context = grp["context"].iloc[0]
    for i, seg in enumerate(semantic_split(context)):
        para_records.append({"docno": f"{cid}_{i}", "text": seg, "cid": cid})
para_df = pd.DataFrame(para_records)
para_df["docid"] = para_df.index.astype(str)
docno_to_docid = dict(zip(para_df["docno"], para_df["docid"]))
para_text_map = dict(zip(para_df["docid"], para_df["text"]))

def clean_query(q):
    return re.sub(r"[^A-Za-z0-9 ]", "", q.strip())

# ===== Read Eligible QIDs (S5 Unified Set) =====
ELIGIBLE_CSV = f"{ROOT}/eligible_qids_top5.csv"
eligible = None
if os.path.exists(ELIGIBLE_CSV):
    try:
        eligible = set(pd.read_csv(ELIGIBLE_CSV)["qa_id"].astype(str))
        print(f">>> Loaded eligible S5 set: {len(eligible)} qids")
    except Exception as e:
        print(f"⚠️ Failed to load {ELIGIBLE_CSV}: {e}. Falling back to >=5 check.")

# ===== BM25 Index =====
print(">>> Building BM25 Index...")
index_ref = f"{ROOT}/pt_index"
if not os.path.exists(index_ref):
    index_ref = pt.IterDictIndexer(f"{ROOT}/pt_index", meta={"docno": 44, "text": 60000}, overwrite=True).index(para_df.to_dict("records"))
index = pt.IndexFactory.of(index_ref)
bm25 = pt.BatchRetrieve(index, wmodel="BM25")

# ===== Load LLM (Consistent with Original: 8-bit + device_map="auto" + compile) =====
print(">>> Loading LLM...")
tokenizer = AutoTokenizer.from_pretrained(GEN_MODEL, trust_remote_code=True)
quant_cfg = BitsAndBytesConfig(load_in_8bit=True)
model = AutoModelForCausalLM.from_pretrained(
    GEN_MODEL,
    quantization_config=quant_cfg,
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True
)
torch.backends.cuda.matmul.allow_tf32 = True
# Keep original settings (adjust if conflicts arise; unchanged here to meet "keep other parts identical" requirement)
model = torch.compile(model, mode="reduce-overhead")
print(">>> Model ready!")

# ===== Prompt Template =====
def build_prompt(question, context):
    return f"""You are an AI assistant. Based on the context, answer the question in the following format:

Context: {context}

Question: {question}

Final Answer:"""

# ===== Dynamic Batch Inference =====
def batch_generate_dynamic(prompts, initial_bs=8, max_bs=8):
    results = []
    i, bs = 0, initial_bs
    last_safe_bs = initial_bs
    while i < len(prompts):
        batch = prompts[i:i+bs]
        try:
            inputs = tokenizer(batch, return_tensors="pt", padding=True,
                               truncation=True, max_length=2048).to(DEVICE)
            outputs = model.generate(
                **inputs,
                max_new_tokens=GEN_MAXLEN,
                min_new_tokens=256,
                penalty_alpha=1.2,             # Keep unchanged
                pad_token_id=tokenizer.eos_token_id
            )
            decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
            results.extend([d.strip() for d in decoded])
            i += bs
            if bs < max_bs:
                last_safe_bs = bs
                bs = min(bs * 2, max_bs)
        except RuntimeError as e:
            if "CUDA out of memory" in str(e):
                torch.cuda.empty_cache()
                print(f"⚠️ OOM at bs={bs}, rolling back to {last_safe_bs}")
                bs = max(last_safe_bs // 2, 1)
                if bs < 1:
                    print("❌ Even bs=1 failed, aborting.")
                    break
            else:
                raise
    return results

# ===== Metric Calculation =====
def compute_metrics(gens, refs):
    scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
    _, _, bert_f1 = bert_score(gens, refs, lang="en", model_type="roberta-large", verbose=False)
    res = []
    for g, r, b in zip(gens, refs, bert_f1):
        scr = scorer.score(r, g)
        bleu = sentence_bleu([r.split()], g.split())
        res.append({
            "rouge1": round(scr["rouge1"].fmeasure, 4),
            "rouge2": round(scr["rouge2"].fmeasure, 4),
            "rougeL": round(scr["rougeL"].fmeasure, 4),
            "bleu": round(bleu, 4),
            "bertscore": round(b.item(), 4)
        })
    return res

# ====== Answer Cleaning ======
def clean_answer(text):
    text = text.strip()
    if "Final Answer:" in text:
        return text.split("Final Answer:", 1)[1].strip()
    paragraphs = re.split(r"\n\s*\n", text)
    cleaned_paras = []
    for para in paragraphs:
        first_line = para.strip().splitlines()[0] if para.strip() else ""
        if first_line.startswith(("Question:", "Context:", "RULES:", "Answer the question")):
            continue
        cleaned_paras.append(para.strip())
    return "\n\n".join(p for p in cleaned_paras if p)

# ===== Top-K Configuration =====
CTX_TOPK_LIST = [1, 2, 3, 4, 5]
topk_list = [(k, f"top{k}") for k in CTX_TOPK_LIST]
MAX_REQUIRED_K = max(CTX_TOPK_LIST)  # 5

# ====== Sharding (Key Addition): Assign context_id to this shard ======
# Only modify these lines, keep others unchanged
NUM_SUBPARTS = 4         # Fixed: split into half
SUBPART = 1              # B=0; A assisting=1

all_groups = list(df.groupby("context_id"))
shard_groups = [g for i, g in enumerate(all_groups) if i % NUM_SHARDS == SHARD]
sub_groups = [g for j, g in enumerate(shard_groups) if j % NUM_SUBPARTS == SUBPART]

print(f">>> Shard {SHARD}/{NUM_SHARDS} groups = {len(shard_groups)} | "
      f"Subpart {SUBPART}/{NUM_SUBPARTS} = {len(sub_groups)}")

# ===== Main Loop (BM25-only, Remove ColBERT Reranking, Keep Others Unchanged) =====
results = []

for cid, grp in tqdm(sub_groups, total=len(sub_groups)):
    batch_prompts, meta = [], []
    for _, row in grp.iterrows():
        q, gt, qid = row["question"], row["answer"], str(row["qa_id"])

        if eligible is not None and qid not in eligible:
            continue

        # --- BM25 with cid filtering ---
        bm25_in = pd.DataFrame({"qid": ["0"], "query": [clean_query(q)]})
        out = bm25.transform(bm25_in)
        out = out[out["docno"].str.startswith(cid)].head(BM25_TOPK)
        ids = [int(docno_to_docid[d]) for d in out["docno"] if d in docno_to_docid]

        # Uniform sample requirement (ensure top5 available)
        if eligible is None and len(ids) < MAX_REQUIRED_K:
            continue

        paras = [para_text_map[str(i)] for i in ids]

        # ===== Remove ColBERT reranking, use BM25 order directly =====
        pairs = list(zip(ids, paras, [None]*len(paras)))

        if len(pairs) < MAX_REQUIRED_K:
            continue

        # --- Slice BM25 list, record topk "rank1/2/.." text ---
        for topk, tag in topk_list:
            _, top_ps, _ = zip(*pairs[:topk])
            topk_ranked_context = "\n".join([f"rank{i+1}: {top_ps[i]}" for i in range(len(top_ps))])

            numbered = [f"Paragraph {i+1}: {p}" for i, p in enumerate(top_ps)]

            # Run all topk sequentially; add reversed for topk>=2; add shuffled for topk>=3
            strategies = [("sequential", numbered)]
            if topk >= 2 and len(numbered) > 1:
                strategies += [("reversed", list(reversed(numbered)))]
            if topk >= 3 and len(numbered) > 1:
                strategies += [("shuffled", random.sample(numbered, len(numbered)))]

            for strat_name, context in strategies:
                batch_prompts.append(build_prompt(q, "\n".join(context)))
                meta.append((f"{strat_name}_{tag}", gt, row["qa_id"], q, topk, topk_ranked_context, cid))

    if not batch_prompts:
        continue

    answers = batch_generate_dynamic(batch_prompts, initial_bs=8, max_bs=8)
    for ans, (strat, gt, qid, q, topk_val, topk_ranked_ctx, cid_val) in zip(answers, meta):
        if not ans:
            continue
        ans_clean = clean_answer(ans)
        m = compute_metrics([ans_clean], [gt])[0]
        results.append({
            "cid": cid_val,
            "qid": qid,
            "question": q,
            "topk": topk_val,
            "topk_ranked_context": topk_ranked_ctx,  # Keep column name unchanged
            "strategy": strat,
            "answer_clean": ans.strip(),
            "answer_for_eval": ans_clean,
            **m
        })

# ===== Save Results (With Sharding Suffix; Preserve Auto Line Breaks) =====
df_res = pd.DataFrame(results)

# Fixed column order
cols_order = [
    "cid", "qid", "question",
    "topk", "topk_ranked_context", "strategy",
    "answer_clean", "answer_for_eval",
    "rouge1", "rouge2", "rougeL", "bleu", "bertscore"
]
df_res = df_res[cols_order]

# Convert \n in topk_ranked_context to actual newlines (auto-wrap in Excel/tables)
df_res["topk_ranked_context"] = df_res["topk_ranked_context"].apply(lambda x: x.replace("\n", "\r\n"))

suf = f"_shard{SHARD}of{NUM_SHARDS}_part{SUBPART}of{NUM_SUBPARTS}"
df_res.to_csv(f"{ROOT}/final_results_stage1_bm25only{suf}.csv", index=False)
avg_m = df_res.groupby("strategy")[["rouge1", "rouge2", "rougeL", "bleu", "bertscore"]].mean().reset_index()
avg_m.to_csv(f"{ROOT}/average_metrics_stage1_bm25only{suf}.csv", index=False)

print("✅ Done. Saved:")
print("   -", f"{ROOT}/final_results_stage1_bm25only{suf}.csv")
print("   -", f"{ROOT}/average_metrics_stage1_bm25only{suf}.csv")

Using GPU_VISIBLE=0 | SHARD 1/2
>>> Loading NLQuAD...
>>> Building BM25 Index...
>>> Loading LLM...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

>>> Model ready!
>>> Shard 1/2 groups = 117 | Subpart 1/4 = 29


  0%|          | 0/29 [00:00<?, ?it/s]Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense

✅ Done. Saved:
   - /gz-data/nlquad_colbert/final_results_stage1_bm25only_shard1of2_part1of4.csv
   - /gz-data/nlquad_colbert/average_metrics_stage1_bm25only_shard1of2_part1of4.csv





In [None]:
# ========== Environment & Pipeline Setup (BM25-only + Dual-GPU Sharding) ==========
import os, re, torch, random, warnings, pandas as pd
from datasets import load_dataset
import pyterrier as pt
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from tqdm import tqdm
from rouge_score import rouge_scorer
from nltk.translate.bleu_score import sentence_bleu
from bert_score import score as bert_score

warnings.filterwarnings("ignore")

# ===== Manual Settings: Single GPU in this Notebook + Sharding Parameters =====
GPU_VISIBLE = "0"   # Notebook A: "0", Notebook B: "1"
SHARD       = 1     # A=0, B=1
NUM_SHARDS  = 2     # Set to 2 for two GPUs in parallel

# ====== Cache & Path Configuration ======
os.environ["CUDA_VISIBLE_DEVICES"] = GPU_VISIBLE
os.environ["HF_HOME"] = "/gz-data/hf_cache"
os.environ["TRANSFORMERS_CACHE"] = "/gz-data/hf_cache"
os.environ["HF_DATASETS_CACHE"] = "/gz-data/hf_cache"
os.makedirs("/gz-data/hf_cache", exist_ok=True)

ROOT = "/gz-data/nlquad_colbert"
os.makedirs(ROOT, exist_ok=True)
BM25_TOPK = 5
MAXLEN, GEN_MAXLEN = 250, 384
MIN_SPLIT_LEN, DESIRED_SEG_LEN = 1000, 250
GEN_MODEL = "/gz-data/models/deepseek-llm-7b-chat"  # Keep consistent with original (local path)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

random.seed(42); torch.manual_seed(42)
if not pt.started():
    pt.init()

print(f"Using GPU_VISIBLE={GPU_VISIBLE} | SHARD {SHARD}/{NUM_SHARDS}")

# ===== Load Data & Preprocess =====
print(">>> Loading NLQuAD...")
dataset = load_dataset("LLukas22/NLQuAD", split="test")
records = []
for art in dataset:
    for para in art["paragraphs"]:
        ctx = para["context"]
        if len(ctx.split()) >= MIN_SPLIT_LEN:
            cid = para["qas"][0]["id"].split("_")[0]
            for qa in para["qas"]:
                if qa["answers"]:
                    records.append({
                        "context_id": cid,
                        "context": ctx,
                        "question": qa["question"],
                        "answer": qa["answers"][0]["text"],
                        "qa_id": qa["id"]
                    })
df = pd.DataFrame(records)
df = df.sort_values(["context_id", "qa_id"]).reset_index(drop=True)

# ===== Paragraph Splitting =====
def semantic_split(text, max_words=DESIRED_SEG_LEN):
    sents = re.split(r"(?<=[.!?])\s+", text.strip())
    buf, out = [], []
    for s in sents:
        if s.strip():
            buf.append(s)
            if len(" ".join(buf).split()) >= max_words:
                out.append(" ".join(buf))
                buf = []
    if buf: out.append(" ".join(buf))
    return out

para_records = []
for cid, grp in df.groupby("context_id"):
    context = grp["context"].iloc[0]
    for i, seg in enumerate(semantic_split(context)):
        para_records.append({"docno": f"{cid}_{i}", "text": seg, "cid": cid})
para_df = pd.DataFrame(para_records)
para_df["docid"] = para_df.index.astype(str)
docno_to_docid = dict(zip(para_df["docno"], para_df["docid"]))
para_text_map = dict(zip(para_df["docid"], para_df["text"]))

def clean_query(q):
    return re.sub(r"[^A-Za-z0-9 ]", "", q.strip())

# ===== Read Eligible QIDs (S5 Unified Set) =====
ELIGIBLE_CSV = f"{ROOT}/eligible_qids_top5.csv"
eligible = None
if os.path.exists(ELIGIBLE_CSV):
    try:
        eligible = set(pd.read_csv(ELIGIBLE_CSV)["qa_id"].astype(str))
        print(f">>> Loaded eligible S5 set: {len(eligible)} qids")
    except Exception as e:
        print(f"⚠️ Failed to load {ELIGIBLE_CSV}: {e}. Falling back to >=5 check.")

# ===== BM25 Index =====
print(">>> Building BM25 Index...")
index_ref = f"{ROOT}/pt_index"
if not os.path.exists(index_ref):
    index_ref = pt.IterDictIndexer(f"{ROOT}/pt_index", meta={"docno": 44, "text": 60000}, overwrite=True).index(para_df.to_dict("records"))
index = pt.IndexFactory.of(index_ref)
bm25 = pt.BatchRetrieve(index, wmodel="BM25")

# ===== Load LLM (Consistent with Original: 8-bit + device_map="auto" + compile) =====
print(">>> Loading LLM...")
tokenizer = AutoTokenizer.from_pretrained(GEN_MODEL, trust_remote_code=True)
quant_cfg = BitsAndBytesConfig(load_in_8bit=True)
model = AutoModelForCausalLM.from_pretrained(
    GEN_MODEL,
    quantization_config=quant_cfg,
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True
)
torch.backends.cuda.matmul.allow_tf32 = True
# Keep original settings (adjust if conflicts arise; unchanged here to meet "keep other parts identical" requirement)
model = torch.compile(model, mode="reduce-overhead")
print(">>> Model ready!")

# ===== Prompt Template =====
def build_prompt(question, context):
    return f"""You are an AI assistant. Based on the context, answer the question in the following format:

Context: {context}

Question: {question}

Final Answer:"""

# ===== Dynamic Batch Inference =====
def batch_generate_dynamic(prompts, initial_bs=8, max_bs=8):
    results = []
    i, bs = 0, initial_bs
    last_safe_bs = initial_bs
    while i < len(prompts):
        batch = prompts[i:i+bs]
        try:
            inputs = tokenizer(batch, return_tensors="pt", padding=True,
                               truncation=True, max_length=2048).to(DEVICE)
            outputs = model.generate(
                **inputs,
                max_new_tokens=GEN_MAXLEN,
                min_new_tokens=256,
                penalty_alpha=1.2,             # Keep unchanged
                pad_token_id=tokenizer.eos_token_id
            )
            decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
            results.extend([d.strip() for d in decoded])
            i += bs
            if bs < max_bs:
                last_safe_bs = bs
                bs = min(bs * 2, max_bs)
        except RuntimeError as e:
            if "CUDA out of memory" in str(e):
                torch.cuda.empty_cache()
                print(f"⚠️ OOM at bs={bs}, rolling back to {last_safe_bs}")
                bs = max(last_safe_bs // 2, 1)
                if bs < 1:
                    print("❌ Even bs=1 failed, aborting.")
                    break
            else:
                raise
    return results

# ===== Metric Calculation =====
def compute_metrics(gens, refs):
    scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
    _, _, bert_f1 = bert_score(gens, refs, lang="en", model_type="roberta-large", verbose=False)
    res = []
    for g, r, b in zip(gens, refs, bert_f1):
        scr = scorer.score(r, g)
        bleu = sentence_bleu([r.split()], g.split())
        res.append({
            "rouge1": round(scr["rouge1"].fmeasure, 4),
            "rouge2": round(scr["rouge2"].fmeasure, 4),
            "rougeL": round(scr["rougeL"].fmeasure, 4),
            "bleu": round(bleu, 4),
            "bertscore": round(b.item(), 4)
        })
    return res

# ====== Answer Cleaning ======
def clean_answer(text):
    text = text.strip()
    if "Final Answer:" in text:
        return text.split("Final Answer:", 1)[1].strip()
    paragraphs = re.split(r"\n\s*\n", text)
    cleaned_paras = []
    for para in paragraphs:
        first_line = para.strip().splitlines()[0] if para.strip() else ""
        if first_line.startswith(("Question:", "Context:", "RULES:", "Answer the question")):
            continue
        cleaned_paras.append(para.strip())
    return "\n\n".join(p for p in cleaned_paras if p)

# ===== Top-K Configuration =====
CTX_TOPK_LIST = [1, 2, 3, 4, 5]
topk_list = [(k, f"top{k}") for k in CTX_TOPK_LIST]
MAX_REQUIRED_K = max(CTX_TOPK_LIST)  # 5

# ====== Sharding (Key Addition): Assign context_id to this shard ======
# Only modify these lines, keep others unchanged
NUM_SUBPARTS = 4         # Fixed: split into half
SUBPART = 0              # B=0; A assisting=1

all_groups = list(df.groupby("context_id"))
shard_groups = [g for i, g in enumerate(all_groups) if i % NUM_SHARDS == SHARD]
sub_groups = [g for j, g in enumerate(shard_groups) if j % NUM_SUBPARTS == SUBPART]

print(f">>> Shard {SHARD}/{NUM_SHARDS} groups = {len(shard_groups)} | "
      f"Subpart {SUBPART}/{NUM_SUBPARTS} = {len(sub_groups)}")

# ===== Main Loop (BM25-only, Remove ColBERT Reranking, Keep Others Unchanged) =====
results = []

for cid, grp in tqdm(sub_groups, total=len(sub_groups)):
    batch_prompts, meta = [], []
    for _, row in grp.iterrows():
        q, gt, qid = row["question"], row["answer"], str(row["qa_id"])

        if eligible is not None and qid not in eligible:
            continue

        # --- BM25 with cid filtering ---
        bm25_in = pd.DataFrame({"qid": ["0"], "query": [clean_query(q)]})
        out = bm25.transform(bm25_in)
        out = out[out["docno"].str.startswith(cid)].head(BM25_TOPK)
        ids = [int(docno_to_docid[d]) for d in out["docno"] if d in docno_to_docid]

        # Uniform sample requirement (ensure top5 available)
        if eligible is None and len(ids) < MAX_REQUIRED_K:
            continue

        paras = [para_text_map[str(i)] for i in ids]

        # ===== Remove ColBERT reranking, use BM25 order directly =====
        pairs = list(zip(ids, paras, [None]*len(paras)))

        if len(pairs) < MAX_REQUIRED_K:
            continue

        # --- Slice BM25 list, record topk "rank1/2/.." text ---
        for topk, tag in topk_list:
            _, top_ps, _ = zip(*pairs[:topk])
            topk_ranked_context = "\n".join([f"rank{i+1}: {top_ps[i]}" for i in range(len(top_ps))])

            numbered = [f"Paragraph {i+1}: {p}" for i, p in enumerate(top_ps)]

            # Run all topk sequentially; add reversed for topk>=2; add shuffled for topk>=3
            strategies = [("sequential", numbered)]
            if topk >= 2 and len(numbered) > 1:
                strategies += [("reversed", list(reversed(numbered)))]
            if topk >= 3 and len(numbered) > 1:
                strategies += [("shuffled", random.sample(numbered, len(numbered)))]

            for strat_name, context in strategies:
                batch_prompts.append(build_prompt(q, "\n".join(context)))
                meta.append((f"{strat_name}_{tag}", gt, row["qa_id"], q, topk, topk_ranked_context, cid))

    if not batch_prompts:
        continue

    answers = batch_generate_dynamic(batch_prompts, initial_bs=8, max_bs=8)
    for ans, (strat, gt, qid, q, topk_val, topk_ranked_ctx, cid_val) in zip(answers, meta):
        if not ans:
            continue
        ans_clean = clean_answer(ans)
        m = compute_metrics([ans_clean], [gt])[0]
        results.append({
            "cid": cid_val,
            "qid": qid,
            "question": q,
            "topk": topk_val,
            "topk_ranked_context": topk_ranked_ctx,  # Keep column name unchanged
            "strategy": strat,
            "answer_clean": ans.strip(),
            "answer_for_eval": ans_clean,
            **m
        })

# ===== Save Results (With Sharding Suffix; Preserve Auto Line Breaks) =====
df_res = pd.DataFrame(results)

# Fixed column order
cols_order = [
    "cid", "qid", "question",
    "topk", "topk_ranked_context", "strategy",
    "answer_clean", "answer_for_eval",
    "rouge1", "rouge2", "rougeL", "bleu", "bertscore"
]
df_res = df_res[cols_order]

# Convert \n in topk_ranked_context to actual newlines (auto-wrap in Excel/tables)
df_res["topk_ranked_context"] = df_res["topk_ranked_context"].apply(lambda x: x.replace("\n", "\r\n"))

suf = f"_shard{SHARD}of{NUM_SHARDS}_part{SUBPART}of{NUM_SUBPARTS}"
df_res.to_csv(f"{ROOT}/final_results_stage1_bm25only{suf}.csv", index=False)
avg_m = df_res.groupby("strategy")[["rouge1", "rouge2", "rougeL", "bleu", "bertscore"]].mean().reset_index()
avg_m.to_csv(f"{ROOT}/average_metrics_stage1_bm25only{suf}.csv", index=False)

print("✅ Done. Saved:")
print("   -", f"{ROOT}/final_results_stage1_bm25only{suf}.csv")
print("   -", f"{ROOT}/average_metrics_stage1_bm25only{suf}.csv")

In [None]:
# ========== Environment & Pipeline Setup (BM25-only + Dual-GPU Sharding) ==========
import os, re, torch, random, warnings, pandas as pd
from datasets import load_dataset
import pyterrier as pt
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from tqdm import tqdm
from rouge_score import rouge_scorer
from nltk.translate.bleu_score import sentence_bleu
from bert_score import score as bert_score

warnings.filterwarnings("ignore")

# ===== Manual Settings: Single GPU in this Notebook + Sharding Parameters =====
GPU_VISIBLE = "1"   # Notebook A: "0", Notebook B: "1"
SHARD       = 1     # A=0, B=1
NUM_SHARDS  = 2     # Set to 2 for two GPUs in parallel

# ====== Cache & Path Configuration ======
os.environ["CUDA_VISIBLE_DEVICES"] = GPU_VISIBLE
os.environ["HF_HOME"] = "/gz-data/hf_cache"
os.environ["TRANSFORMERS_CACHE"] = "/gz-data/hf_cache"
os.environ["HF_DATASETS_CACHE"] = "/gz-data/hf_cache"
os.makedirs("/gz-data/hf_cache", exist_ok=True)

ROOT = "/gz-data/nlquad_colbert"
os.makedirs(ROOT, exist_ok=True)
BM25_TOPK = 5
MAXLEN, GEN_MAXLEN = 250, 384
MIN_SPLIT_LEN, DESIRED_SEG_LEN = 1000, 250
GEN_MODEL = "/gz-data/models/deepseek-llm-7b-chat"  # Keep consistent with original (local path)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

random.seed(42); torch.manual_seed(42)
if not pt.started():
    pt.init()

print(f"Using GPU_VISIBLE={GPU_VISIBLE} | SHARD {SHARD}/{NUM_SHARDS}")

# ===== Load Data & Preprocess =====
print(">>> Loading NLQuAD...")
dataset = load_dataset("LLukas22/NLQuAD", split="test")
records = []
for art in dataset:
    for para in art["paragraphs"]:
        ctx = para["context"]
        if len(ctx.split()) >= MIN_SPLIT_LEN:
            cid = para["qas"][0]["id"].split("_")[0]
            for qa in para["qas"]:
                if qa["answers"]:
                    records.append({
                        "context_id": cid,
                        "context": ctx,
                        "question": qa["question"],
                        "answer": qa["answers"][0]["text"],
                        "qa_id": qa["id"]
                    })
df = pd.DataFrame(records)
df = df.sort_values(["context_id", "qa_id"]).reset_index(drop=True)

# ===== Paragraph Splitting =====
def semantic_split(text, max_words=DESIRED_SEG_LEN):
    sents = re.split(r"(?<=[.!?])\s+", text.strip())
    buf, out = [], []
    for s in sents:
        if s.strip():
            buf.append(s)
            if len(" ".join(buf).split()) >= max_words:
                out.append(" ".join(buf))
                buf = []
    if buf: out.append(" ".join(buf))
    return out

para_records = []
for cid, grp in df.groupby("context_id"):
    context = grp["context"].iloc[0]
    for i, seg in enumerate(semantic_split(context)):
        para_records.append({"docno": f"{cid}_{i}", "text": seg, "cid": cid})
para_df = pd.DataFrame(para_records)
para_df["docid"] = para_df.index.astype(str)
docno_to_docid = dict(zip(para_df["docno"], para_df["docid"]))
para_text_map = dict(zip(para_df["docid"], para_df["text"]))

def clean_query(q):
    return re.sub(r"[^A-Za-z0-9 ]", "", q.strip())

# ===== Read Eligible QIDs (S5 Unified Set) =====
ELIGIBLE_CSV = f"{ROOT}/eligible_qids_top5.csv"
eligible = None
if os.path.exists(ELIGIBLE_CSV):
    try:
        eligible = set(pd.read_csv(ELIGIBLE_CSV)["qa_id"].astype(str))
        print(f">>> Loaded eligible S5 set: {len(eligible)} qids")
    except Exception as e:
        print(f"⚠️ Failed to load {ELIGIBLE_CSV}: {e}. Falling back to >=5 check.")

# ===== BM25 Index =====
print(">>> Building BM25 Index...")
index_ref = f"{ROOT}/pt_index"
if not os.path.exists(index_ref):
    index_ref = pt.IterDictIndexer(f"{ROOT}/pt_index", meta={"docno": 44, "text": 60000}, overwrite=True).index(para_df.to_dict("records"))
index = pt.IndexFactory.of(index_ref)
bm25 = pt.BatchRetrieve(index, wmodel="BM25")

# ===== Load LLM (Consistent with Original: 8-bit + device_map="auto" + compile) =====
print(">>> Loading LLM...")
tokenizer = AutoTokenizer.from_pretrained(GEN_MODEL, trust_remote_code=True)
quant_cfg = BitsAndBytesConfig(load_in_8bit=True)
model = AutoModelForCausalLM.from_pretrained(
    GEN_MODEL,
    quantization_config=quant_cfg,
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True
)
torch.backends.cuda.matmul.allow_tf32 = True
# Keep original settings (adjust if conflicts arise; unchanged here to meet "keep other parts identical" requirement)
model = torch.compile(model, mode="reduce-overhead")
print(">>> Model ready!")

# ===== Prompt Template =====
def build_prompt(question, context):
    return f"""You are an AI assistant. Based on the context, answer the question in the following format:

Context: {context}

Question: {question}

Final Answer:"""

# ===== Dynamic Batch Inference =====
def batch_generate_dynamic(prompts, initial_bs=8, max_bs=8):
    results = []
    i, bs = 0, initial_bs
    last_safe_bs = initial_bs
    while i < len(prompts):
        batch = prompts[i:i+bs]
        try:
            inputs = tokenizer(batch, return_tensors="pt", padding=True,
                               truncation=True, max_length=2048).to(DEVICE)
            outputs = model.generate(
                **inputs,
                max_new_tokens=GEN_MAXLEN,
                min_new_tokens=256,
                penalty_alpha=1.2,             # Keep unchanged
                pad_token_id=tokenizer.eos_token_id
            )
            decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
            results.extend([d.strip() for d in decoded])
            i += bs
            if bs < max_bs:
                last_safe_bs = bs
                bs = min(bs * 2, max_bs)
        except RuntimeError as e:
            if "CUDA out of memory" in str(e):
                torch.cuda.empty_cache()
                print(f"⚠️ OOM at bs={bs}, rolling back to {last_safe_bs}")
                bs = max(last_safe_bs // 2, 1)
                if bs < 1:
                    print("❌ Even bs=1 failed, aborting.")
                    break
            else:
                raise
    return results

# ===== Metric Calculation =====
def compute_metrics(gens, refs):
    scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
    _, _, bert_f1 = bert_score(gens, refs, lang="en", model_type="roberta-large", verbose=False)
    res = []
    for g, r, b in zip(gens, refs, bert_f1):
        scr = scorer.score(r, g)
        bleu = sentence_bleu([r.split()], g.split())
        res.append({
            "rouge1": round(scr["rouge1"].fmeasure, 4),
            "rouge2": round(scr["rouge2"].fmeasure, 4),
            "rougeL": round(scr["rougeL"].fmeasure, 4),
            "bleu": round(bleu, 4),
            "bertscore": round(b.item(), 4)
        })
    return res

# ====== Answer Cleaning ======
def clean_answer(text):
    text = text.strip()
    if "Final Answer:" in text:
        return text.split("Final Answer:", 1)[1].strip()
    paragraphs = re.split(r"\n\s*\n", text)
    cleaned_paras = []
    for para in paragraphs:
        first_line = para.strip().splitlines()[0] if para.strip() else ""
        if first_line.startswith(("Question:", "Context:", "RULES:", "Answer the question")):
            continue
        cleaned_paras.append(para.strip())
    return "\n\n".join(p for p in cleaned_paras if p)

# ===== Top-K Configuration =====
CTX_TOPK_LIST = [1, 2, 3, 4, 5]
topk_list = [(k, f"top{k}") for k in CTX_TOPK_LIST]
MAX_REQUIRED_K = max(CTX_TOPK_LIST)  # 5

# ====== Sharding (Key Addition): Assign context_id to this shard ======
# Only modify these lines, keep others unchanged
NUM_SUBPARTS = 4         # Fixed: split into half
SUBPART = 3              # B=0; A assisting=1

all_groups = list(df.groupby("context_id"))
shard_groups = [g for i, g in enumerate(all_groups) if i % NUM_SHARDS == SHARD]
sub_groups = [g for j, g in enumerate(shard_groups) if j % NUM_SUBPARTS == SUBPART]

print(f">>> Shard {SHARD}/{NUM_SHARDS} groups = {len(shard_groups)} | "
      f"Subpart {SUBPART}/{NUM_SUBPARTS} = {len(sub_groups)}")

# ===== Main Loop (BM25-only, Remove ColBERT Reranking, Keep Others Unchanged) =====
results = []

for cid, grp in tqdm(sub_groups, total=len(sub_groups)):
    batch_prompts, meta = [], []
    for _, row in grp.iterrows():
        q, gt, qid = row["question"], row["answer"], str(row["qa_id"])

        if eligible is not None and qid not in eligible:
            continue

        # --- BM25 with cid filtering ---
        bm25_in = pd.DataFrame({"qid": ["0"], "query": [clean_query(q)]})
        out = bm25.transform(bm25_in)
        out = out[out["docno"].str.startswith(cid)].head(BM25_TOPK)
        ids = [int(docno_to_docid[d]) for d in out["docno"] if d in docno_to_docid]

        # Uniform sample requirement (ensure top5 available)
        if eligible is None and len(ids) < MAX_REQUIRED_K:
            continue

        paras = [para_text_map[str(i)] for i in ids]

        # ===== Remove ColBERT reranking, use BM25 order directly =====
        pairs = list(zip(ids, paras, [None]*len(paras)))

        if len(pairs) < MAX_REQUIRED_K:
            continue

        # --- Slice BM25 list, record topk "rank1/2/.." text ---
        for topk, tag in topk_list:
            _, top_ps, _ = zip(*pairs[:topk])
            topk_ranked_context = "\n".join([f"rank{i+1}: {top_ps[i]}" for i in range(len(top_ps))])

            numbered = [f"Paragraph {i+1}: {p}" for i, p in enumerate(top_ps)]

            # Run all topk sequentially; add reversed for topk>=2; add shuffled for topk>=3
            strategies = [("sequential", numbered)]
            if topk >= 2 and len(numbered) > 1:
                strategies += [("reversed", list(reversed(numbered)))]
            if topk >= 3 and len(numbered) > 1:
                strategies += [("shuffled", random.sample(numbered, len(numbered)))]

            for strat_name, context in strategies:
                batch_prompts.append(build_prompt(q, "\n".join(context)))
                meta.append((f"{strat_name}_{tag}", gt, row["qa_id"], q, topk, topk_ranked_context, cid))

    if not batch_prompts:
        continue

    answers = batch_generate_dynamic(batch_prompts, initial_bs=8, max_bs=8)
    for ans, (strat, gt, qid, q, topk_val, topk_ranked_ctx, cid_val) in zip(answers, meta):
        if not ans:
            continue
        ans_clean = clean_answer(ans)
        m = compute_metrics([ans_clean], [gt])[0]
        results.append({
            "cid": cid_val,
            "qid": qid,
            "question": q,
            "topk": topk_val,
            "topk_ranked_context": topk_ranked_ctx,  # Keep column name unchanged
            "strategy": strat,
            "answer_clean": ans.strip(),
            "answer_for_eval": ans_clean,
            **m
        })

# ===== Save Results (With Sharding Suffix; Preserve Auto Line Breaks) =====
df_res = pd.DataFrame(results)

# Fixed column order
cols_order = [
    "cid", "qid", "question",
    "topk", "topk_ranked_context", "strategy",
    "answer_clean", "answer_for_eval",
    "rouge1", "rouge2", "rougeL", "bleu", "bertscore"
]
df_res = df_res[cols_order]

# Convert \n in topk_ranked_context to actual newlines (auto-wrap in Excel/tables)
df_res["topk_ranked_context"] = df_res["topk_ranked_context"].apply(lambda x: x.replace("\n", "\r\n"))

suf = f"_shard{SHARD}of{NUM_SHARDS}_part{SUBPART}of{NUM_SUBPARTS}"
df_res.to_csv(f"{ROOT}/final_results_stage1_bm25only{suf}.csv", index=False)
avg_m = df_res.groupby("strategy")[["rouge1", "rouge2", "rougeL", "bleu", "bertscore"]].mean().reset_index()
avg_m.to_csv(f"{ROOT}/average_metrics_stage1_bm25only{suf}.csv", index=False)

print("✅ Done. Saved:")
print("   -", f"{ROOT}/final_results_stage1_bm25only{suf}.csv")
print("   -", f"{ROOT}/average_metrics_stage1_bm25only{suf}.csv")

In [None]:
# ========== Environment & Pipeline Setup (BM25-only + Dual-GPU Sharding) ==========
import os, re, torch, random, warnings, pandas as pd
from datasets import load_dataset
import pyterrier as pt
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from tqdm import tqdm
from rouge_score import rouge_scorer
from nltk.translate.bleu_score import sentence_bleu
from bert_score import score as bert_score

warnings.filterwarnings("ignore")

# ===== Manual Settings: Single GPU in this Notebook + Sharding Parameters =====
GPU_VISIBLE = "1"   # Notebook A: "0", Notebook B: "1"
SHARD       = 1     # A=0, B=1
NUM_SHARDS  = 2     # Set to 2 for two GPUs in parallel

# ====== Cache & Path Configuration ======
os.environ["CUDA_VISIBLE_DEVICES"] = GPU_VISIBLE
os.environ["HF_HOME"] = "/gz-data/hf_cache"
os.environ["TRANSFORMERS_CACHE"] = "/gz-data/hf_cache"
os.environ["HF_DATASETS_CACHE"] = "/gz-data/hf_cache"
os.makedirs("/gz-data/hf_cache", exist_ok=True)

ROOT = "/gz-data/nlquad_colbert"
os.makedirs(ROOT, exist_ok=True)
BM25_TOPK = 5
MAXLEN, GEN_MAXLEN = 250, 384
MIN_SPLIT_LEN, DESIRED_SEG_LEN = 1000, 250
GEN_MODEL = "/gz-data/models/deepseek-llm-7b-chat"  # Keep consistent with original (local path)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

random.seed(42); torch.manual_seed(42)
if not pt.started():
    pt.init()

print(f"Using GPU_VISIBLE={GPU_VISIBLE} | SHARD {SHARD}/{NUM_SHARDS}")

# ===== Load Data & Preprocess =====
print(">>> Loading NLQuAD...")
dataset = load_dataset("LLukas22/NLQuAD", split="test")
records = []
for art in dataset:
    for para in art["paragraphs"]:
        ctx = para["context"]
        if len(ctx.split()) >= MIN_SPLIT_LEN:
            cid = para["qas"][0]["id"].split("_")[0]
            for qa in para["qas"]:
                if qa["answers"]:
                    records.append({
                        "context_id": cid,
                        "context": ctx,
                        "question": qa["question"],
                        "answer": qa["answers"][0]["text"],
                        "qa_id": qa["id"]
                    })
df = pd.DataFrame(records)
df = df.sort_values(["context_id", "qa_id"]).reset_index(drop=True)

# ===== Paragraph Splitting =====
def semantic_split(text, max_words=DESIRED_SEG_LEN):
    sents = re.split(r"(?<=[.!?])\s+", text.strip())
    buf, out = [], []
    for s in sents:
        if s.strip():
            buf.append(s)
            if len(" ".join(buf).split()) >= max_words:
                out.append(" ".join(buf))
                buf = []
    if buf: out.append(" ".join(buf))
    return out

para_records = []
for cid, grp in df.groupby("context_id"):
    context = grp["context"].iloc[0]
    for i, seg in enumerate(semantic_split(context)):
        para_records.append({"docno": f"{cid}_{i}", "text": seg, "cid": cid})
para_df = pd.DataFrame(para_records)
para_df["docid"] = para_df.index.astype(str)
docno_to_docid = dict(zip(para_df["docno"], para_df["docid"]))
para_text_map = dict(zip(para_df["docid"], para_df["text"]))

def clean_query(q):
    return re.sub(r"[^A-Za-z0-9 ]", "", q.strip())

# ===== Read Eligible QIDs (S5 Unified Set) =====
ELIGIBLE_CSV = f"{ROOT}/eligible_qids_top5.csv"
eligible = None
if os.path.exists(ELIGIBLE_CSV):
    try:
        eligible = set(pd.read_csv(ELIGIBLE_CSV)["qa_id"].astype(str))
        print(f">>> Loaded eligible S5 set: {len(eligible)} qids")
    except Exception as e:
        print(f"⚠️ Failed to load {ELIGIBLE_CSV}: {e}. Falling back to >=5 check.")

# ===== BM25 Index =====
print(">>> Building BM25 Index...")
index_ref = f"{ROOT}/pt_index"
if not os.path.exists(index_ref):
    index_ref = pt.IterDictIndexer(f"{ROOT}/pt_index", meta={"docno": 44, "text": 60000}, overwrite=True).index(para_df.to_dict("records"))
index = pt.IndexFactory.of(index_ref)
bm25 = pt.BatchRetrieve(index, wmodel="BM25")

# ===== Load LLM (Consistent with Original: 8-bit + device_map="auto" + compile) =====
print(">>> Loading LLM...")
tokenizer = AutoTokenizer.from_pretrained(GEN_MODEL, trust_remote_code=True)
quant_cfg = BitsAndBytesConfig(load_in_8bit=True)
model = AutoModelForCausalLM.from_pretrained(
    GEN_MODEL,
    quantization_config=quant_cfg,
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True
)
torch.backends.cuda.matmul.allow_tf32 = True
# Keep original settings (adjust if conflicts arise; unchanged here to meet "keep other parts identical" requirement)
model = torch.compile(model, mode="reduce-overhead")
print(">>> Model ready!")

# ===== Prompt Template =====
def build_prompt(question, context):
    return f"""You are an AI assistant. Based on the context, answer the question in the following format:

Context: {context}

Question: {question}

Final Answer:"""

# ===== Dynamic Batch Inference =====
def batch_generate_dynamic(prompts, initial_bs=8, max_bs=8):
    results = []
    i, bs = 0, initial_bs
    last_safe_bs = initial_bs
    while i < len(prompts):
        batch = prompts[i:i+bs]
        try:
            inputs = tokenizer(batch, return_tensors="pt", padding=True,
                               truncation=True, max_length=2048).to(DEVICE)
            outputs = model.generate(
                **inputs,
                max_new_tokens=GEN_MAXLEN,
                min_new_tokens=256,
                penalty_alpha=1.2,             # Keep unchanged
                pad_token_id=tokenizer.eos_token_id
            )
            decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
            results.extend([d.strip() for d in decoded])
            i += bs
            if bs < max_bs:
                last_safe_bs = bs
                bs = min(bs * 2, max_bs)
        except RuntimeError as e:
            if "CUDA out of memory" in str(e):
                torch.cuda.empty_cache()
                print(f"⚠️ OOM at bs={bs}, rolling back to {last_safe_bs}")
                bs = max(last_safe_bs // 2, 1)
                if bs < 1:
                    print("❌ Even bs=1 failed, aborting.")
                    break
            else:
                raise
    return results

# ===== Metric Calculation =====
def compute_metrics(gens, refs):
    scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
    _, _, bert_f1 = bert_score(gens, refs, lang="en", model_type="roberta-large", verbose=False)
    res = []
    for g, r, b in zip(gens, refs, bert_f1):
        scr = scorer.score(r, g)
        bleu = sentence_bleu([r.split()], g.split())
        res.append({
            "rouge1": round(scr["rouge1"].fmeasure, 4),
            "rouge2": round(scr["rouge2"].fmeasure, 4),
            "rougeL": round(scr["rougeL"].fmeasure, 4),
            "bleu": round(bleu, 4),
            "bertscore": round(b.item(), 4)
        })
    return res

# ====== Answer Cleaning ======
def clean_answer(text):
    text = text.strip()
    if "Final Answer:" in text:
        return text.split("Final Answer:", 1)[1].strip()
    paragraphs = re.split(r"\n\s*\n", text)
    cleaned_paras = []
    for para in paragraphs:
        first_line = para.strip().splitlines()[0] if para.strip() else ""
        if first_line.startswith(("Question:", "Context:", "RULES:", "Answer the question")):
            continue
        cleaned_paras.append(para.strip())
    return "\n\n".join(p for p in cleaned_paras if p)

# ===== Top-K Configuration =====
CTX_TOPK_LIST = [1, 2, 3, 4, 5]
topk_list = [(k, f"top{k}") for k in CTX_TOPK_LIST]
MAX_REQUIRED_K = max(CTX_TOPK_LIST)  # 5

# ====== Sharding (Key Addition): Assign context_id to this shard ======
# Only modify these lines, keep others unchanged
NUM_SUBPARTS = 4         # Fixed: split into half
SUBPART = 2              # Changed to 2 as per request

all_groups = list(df.groupby("context_id"))
shard_groups = [g for i, g in enumerate(all_groups) if i % NUM_SHARDS == SHARD]
sub_groups = [g for j, g in enumerate(shard_groups) if j % NUM_SUBPARTS == SUBPART]

print(f">>> Shard {SHARD}/{NUM_SHARDS} groups = {len(shard_groups)} | "
      f"Subpart {SUBPART}/{NUM_SUBPARTS} = {len(sub_groups)}")

# ===== Main Loop (BM25-only, Remove ColBERT Reranking, Keep Others Unchanged) =====
results = []

for cid, grp in tqdm(sub_groups, total=len(sub_groups)):
    batch_prompts, meta = [], []
    for _, row in grp.iterrows():
        q, gt, qid = row["question"], row["answer"], str(row["qa_id"])

        if eligible is not None and qid not in eligible:
            continue

        # --- BM25 with cid filtering ---
        bm25_in = pd.DataFrame({"qid": ["0"], "query": [clean_query(q)]})
        out = bm25.transform(bm25_in)
        out = out[out["docno"].str.startswith(cid)].head(BM25_TOPK)
        ids = [int(docno_to_docid[d]) for d in out["docno"] if d in docno_to_docid]

        # Uniform sample requirement (ensure top5 available)
        if eligible is None and len(ids) < MAX_REQUIRED_K:
            continue

        paras = [para_text_map[str(i)] for i in ids]

        # ===== Remove ColBERT reranking, use BM25 order directly =====
        pairs = list(zip(ids, paras, [None]*len(paras)))

        if len(pairs) < MAX_REQUIRED_K:
            continue

        # --- Slice BM25 list, record topk "rank1/2/.." text ---
        for topk, tag in topk_list:
            _, top_ps, _ = zip(*pairs[:topk])
            topk_ranked_context = "\n".join([f"rank{i+1}: {top_ps[i]}" for i in range(len(top_ps))])

            numbered = [f"Paragraph {i+1}: {p}" for i, p in enumerate(top_ps)]

            # Run all topk sequentially; add reversed for topk>=2; add shuffled for topk>=3
            strategies = [("sequential", numbered)]
            if topk >= 2 and len(numbered) > 1:
                strategies += [("reversed", list(reversed(numbered)))]
            if topk >= 3 and len(numbered) > 1:
                strategies += [("shuffled", random.sample(numbered, len(numbered)))]

            for strat_name, context in strategies:
                batch_prompts.append(build_prompt(q, "\n".join(context)))
                meta.append((f"{strat_name}_{tag}", gt, row["qa_id"], q, topk, topk_ranked_context, cid))

    if not batch_prompts:
        continue

    answers = batch_generate_dynamic(batch_prompts, initial_bs=8, max_bs=8)
    for ans, (strat, gt, qid, q, topk_val, topk_ranked_ctx, cid_val) in zip(answers, meta):
        if not ans:
            continue
        ans_clean = clean_answer(ans)
        m = compute_metrics([ans_clean], [gt])[0]
        results.append({
            "cid": cid_val,
            "qid": qid,
            "question": q,
            "topk": topk_val,
            "topk_ranked_context": topk_ranked_ctx,  # Keep column name unchanged
            "strategy": strat,
            "answer_clean": ans.strip(),
            "answer_for_eval": ans_clean,
            **m
        })

# ===== Save Results (With Sharding Suffix; Preserve Auto Line Breaks) =====
df_res = pd.DataFrame(results)

# Fixed column order
cols_order = [
    "cid", "qid", "question",
    "topk", "topk_ranked_context", "strategy",
    "answer_clean", "answer_for_eval",
    "rouge1", "rouge2", "rougeL", "bleu", "bertscore"
]
df_res = df_res[cols_order]

# Convert \n in topk_ranked_context to actual newlines (auto-wrap in Excel/tables)
df_res["topk_ranked_context"] = df_res["topk_ranked_context"].apply(lambda x: x.replace("\n", "\r\n"))

suf = f"_shard{SHARD}of{NUM_SHARDS}_part{SUBPART}of{NUM_SUBPARTS}"
df_res.to_csv(f"{ROOT}/final_results_stage1_bm25only{suf}.csv", index=False)
avg_m = df_res.groupby("strategy")[["rouge1", "rouge2", "rougeL", "bleu", "bertscore"]].mean().reset_index()
avg_m.to_csv(f"{ROOT}/average_metrics_stage1_bm25only{suf}.csv", index=False)

print("✅ Done. Saved:")
print("   -", f"{ROOT}/final_results_stage1_bm25only{suf}.csv")
print("   -", f"{ROOT}/average_metrics_stage1_bm25only{suf}.csv")

In [4]:
import pandas as pd
import os

ROOT = "/gz-data/nlquad_colbert"

# 需要合并的 final_results CSV
final_files = [
    "final_results_stage1_bm25only_shard1of2_part3of4.csv",
    "final_results_stage1_bm25only_shard1of2_part1of4.csv",
    "final_results_stage1_bm25only_shard1of2_part0of2.csv",
    "final_results_stage1_bm25only_shard0of2.csv"
]

# 输出文件名
output_file = "final_results_stage1_bm25only_merged.csv"

# 读取 & 合并
dfs = []
for fname in final_files:
    fpath = os.path.join(ROOT, fname)
    if os.path.exists(fpath):
        df = pd.read_csv(fpath)
        dfs.append(df)
        print(f"Loaded {fname}: {len(df)} rows")
    else:
        raise FileNotFoundError(f"❌ File not found: {fpath}")

merged_df = pd.concat(dfs, ignore_index=True)
print(f"Total merged rows: {len(merged_df)}")

# 检查重复
dup_mask = merged_df.duplicated(subset=["qid", "strategy", "topk"], keep=False)
if dup_mask.any():
    dup_rows = merged_df.loc[dup_mask, ["qid", "strategy", "topk"]]
    raise ValueError(f"❌ Found duplicate rows based on qid+strategy+topk:\n{dup_rows}")

# 排序
merged_df = merged_df.sort_values(by=["qid", "topk", "strategy"]).reset_index(drop=True)

# 保存
out_path = os.path.join(ROOT, output_file)
merged_df.to_csv(out_path, index=False)
print(f"✅ Merged and sorted file saved to: {out_path}")


Loaded final_results_stage1_bm25only_shard1of2_part3of4.csv: 408 rows
Loaded final_results_stage1_bm25only_shard1of2_part1of4.csv: 324 rows
Loaded final_results_stage1_bm25only_shard1of2_part0of2.csv: 768 rows
Loaded final_results_stage1_bm25only_shard0of2.csv: 1296 rows
Total merged rows: 2796
✅ Merged and sorted file saved to: /gz-data/nlquad_colbert/final_results_stage1_bm25only_merged.csv


In [None]:
import os
import pandas as pd

# Fixed file list (in the order provided or desired)
files = [
    "/gz-data/nlquad_colbert/average_metrics_stage1_bm25only_shard1of2_part3of4.csv",
    "/gz-data/nlquad_colbert/average_metrics_stage1_bm25only_shard1of2_part1of4.csv",
    "/gz-data/nlquad_colbert/average_metrics_stage1_bm25only_shard1of2_part0of4.csv",
    "/gz-data/nlquad_colbert/average_metrics_stage1_bm25only_shard0of2.csv",
]

# 1) Check if all files exist
missing = [f for f in files if not os.path.exists(f)]
if missing:
    raise FileNotFoundError("The following files were not found:\n" + "\n".join(missing))

# 2) Read files
dfs = [pd.read_csv(f) for f in files]

# 3) Check column consistency
cols0 = dfs[0].columns.tolist()
for f, df in zip(files, dfs):
    if df.columns.tolist() != cols0:
        raise ValueError(f"Inconsistent columns in file: {f}\nExpected: {cols0}\nFound: {df.columns.tolist()}")

# 4) Merge
merged = pd.concat(dfs, ignore_index=True)

# 5) Sort (by 'strategy' column if it exists)
if "strategy" in merged.columns:
    merged = merged.sort_values(by=["strategy"]).reset_index(drop=True)

# 6) Save, keeping four decimal places
out_path = "/gz-data/nlquad_colbert/average_metrics_stage1_bm25only_merged.csv"
merged.to_csv(out_path, index=False, float_format="%.4f")

print("✅ Merge complete, output:", out_path)

In [None]:
# ========== Stage 2: LLM-based Scoring (Coherence / Informativeness) ==========
import os, re, gc, torch, warnings
import pandas as pd
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM

warnings.filterwarnings("ignore")

# ====== Path Configuration (Kaggle-friendly) ======
# Place the merged Stage 1 output file in Kaggle's /kaggle/input/<dataset>/ or /kaggle/working/
CANDIDATES = ["/kaggle/input/final-results-stage1-bm25only-merged/final_results_stage1_bm25only_merged.csv"]  # If added as a Kaggle Dataset
RESULT_FILE = next((p for p in CANDIDATES if os.path.exists(p)), None)
assert RESULT_FILE is not None, f"Stage 1 merged results file not found. Please check paths:\n{CANDIDATES}"

# Stage 2 output to working directory
STAGE2_FILE = "/kaggle/working/final_results_stage2_bm25only_merged_scored.csv"

# Qwen scoring model path/ID:
# - If mounted as a Kaggle Dataset: e.g., /kaggle/input/qwen25-7b-instruct
# - Or, if internet is enabled, use HF Hub name (e.g., 'Qwen/Qwen2.5-7B-Instruct')
LLM_SCORE_MODEL = "Qwen/Qwen2.5-7B-Instruct"  # Adjust to your actual path/ID

# ====== Load NLQuAD Gold Answers ======
print(">>> Loading NLQuAD to recover gold answers ...")
dataset = load_dataset("LLukas22/NLQuAD", split="test")
gold_map = {}
for article in dataset:
    for para in article["paragraphs"]:
        for qa in para["qas"]:
            if qa["answers"]:
                gold_map[qa["id"]] = qa["answers"][0]["text"]

# ====== Load Stage 1 Merged Results (BM25-only merged) ======
print(f">>> Reading Stage 1 merged results from: {RESULT_FILE}")
df_res = pd.read_csv(RESULT_FILE)
if "qid" not in df_res.columns or "question" not in df_res.columns or "answer_for_eval" not in df_res.columns:
    raise ValueError("Input file missing required columns: ['qid', 'question', 'answer_for_eval']")

df_res["gold_answer"] = df_res["qid"].map(gold_map)

# ====== Load Qwen Scoring Model (Kaggle: Single GPU Preferred) ======
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f">>> Loading scorer ({LLM_SCORE_MODEL}) on {device} ...")
tok_score = AutoTokenizer.from_pretrained(LLM_SCORE_MODEL, trust_remote_code=True)
# Kaggle typically has one GPU, so device_map='auto' is sufficient; no need for balanced/max_memory
score_model = AutoModelForCausalLM.from_pretrained(
    LLM_SCORE_MODEL,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map="auto",
    trust_remote_code=True
)
print(">>> Scorer ready!")

# ====== Build Prompt ======
def build_score_prompt(question, gold, generated):
    return f"""
You are an expert evaluator. Based on the question, gold answer, and generated answer, provide two scores:
1. Coherence (1-5): Logical clarity and fluency of the generated answer.
2. Informativeness (1-5): Amount of correct and relevant information compared to the gold answer.

Question: {question}
Gold Answer: {gold}
Generated Answer: {generated}

IMPORTANT:
- Do NOT provide any explanation or extra text.
- Only output two integers, separated by a comma, in the LAST line.
Format:
x,y

Now output the result:
"""

# ====== Parse LLM Output ======
def parse_llm_output(text):
    lines = text.strip().splitlines()
    for line in reversed(lines):
        if re.match(r"^\s*[1-5]\s*,\s*[1-5]\s*$", line):
            a, b = [int(x.strip()) for x in line.split(",")]
            return {"coherence": a, "informativeness": b}
    nums = re.findall(r"\b[1-5]\b", text)
    if len(nums) >= 2:
        return {"coherence": int(nums[-2]), "informativeness": int(nums[-1])}
    return {"coherence": 0, "informativeness": 0}

# ====== Batch Generate Scores (With Fallback for OOM) ======
def batch_generate_scores(prompts, initial_bs=2, min_bs=1, max_new_tokens=8):
    results = []
    i, bs = 0, initial_bs
    while i < len(prompts):
        batch = prompts[i:i+bs]
        try:
            inputs = tok_score(batch, return_tensors="pt", padding=True, truncation=True).to(score_model.device)
            with torch.inference_mode():
                outputs = score_model.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    do_sample=False,
                    pad_token_id=tok_score.eos_token_id
                )
            decoded = tok_score.batch_decode(outputs, skip_special_tokens=True)
            results.extend(decoded)
            i += bs
            if bs < initial_bs:
                bs = min(bs * 2, initial_bs)
        except RuntimeError as e:
            if "CUDA out of memory" in str(e):
                torch.cuda.empty_cache()
                new_bs = max(bs // 2, min_bs)
                print(f"⚠️ OOM at batch_size={bs}, reducing to {new_bs}")
                bs = new_bs
                if bs < 1:
                    raise RuntimeError("❌ Even batch_size=1 failed. Aborting Stage 2.")
            else:
                raise
    return results

# ====== Stage 2 Main Loop ======
print(">>> Stage 2 (LLM-based scoring) running ...")
BATCH_SIZE = 2 if torch.cuda.is_available() else 2   # Kaggle GPU: 8 is stable; smaller for CPU
prompts, meta, results_s2 = [], [], []

for _, row in tqdm(df_res.iterrows(), total=len(df_res)):
    prompt = build_score_prompt(row["question"], row["gold_answer"], row["answer_for_eval"])
    prompts.append(prompt)
    meta.append(row.to_dict())

    if len(prompts) >= BATCH_SIZE:
        outs = batch_generate_scores(prompts, initial_bs=BATCH_SIZE)
        for out, row_data in zip(outs, meta):
            sc = parse_llm_output(out)
            row_data.update(sc)
            results_s2.append(row_data)
        prompts, meta = [], []

# Flush remaining
if prompts:
    outs = batch_generate_scores(prompts, initial_bs=BATCH_SIZE)
    for out, row_data in zip(outs, meta):
        sc = parse_llm_output(out)
        row_data.update(sc)
        results_s2.append(row_data)

# ====== Save Stage 2 Results ======
out_df = pd.DataFrame(results_s2)
out_df.to_csv(STAGE2_FILE, index=False)
print(f"✅ Stage 2 completed. Results saved at {STAGE2_FILE}")

In [None]:
# -*- coding: utf-8 -*-
# ========== Stage 2: LLM-based Scoring (Coherence / Informativeness) ==========
import os, re, gc, torch, warnings
import pandas as pd
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM

warnings.filterwarnings("ignore")

# ====== Path Configuration (Kaggle-friendly) ======
# Place your Stage 1 merged output file in Kaggle's /kaggle/input/<dataset>/ or /kaggle/working/
CANDIDATES = ["/kaggle/input/final-results-stage1-bm25only-merged/final_results_stage1_bm25only_merged.csv"]  # If added as a Kaggle Dataset
RESULT_FILE = next((p for p in CANDIDATES if os.path.exists(p)), None)
assert RESULT_FILE is not None, f"Stage 1 merged results file not found. Check paths:\n{CANDIDATES}"

# Stage 2 output to working directory
STAGE2_FILE = "/kaggle/working/final_results_stage2_bm25only_merged_scored.csv"

# Qwen scoring model path/ID:
# - If mounted as a Kaggle Dataset: e.g., /kaggle/input/qwen25-7b-instruct
# - Or if internet is enabled, use HF Hub name (e.g., 'Qwen/Qwen2.5-7B-Instruct')
LLM_SCORE_MODEL = "Qwen/Qwen2.5-7B-Instruct"  # Adjust to your actual path/ID

# ====== Load NLQuAD Gold Answers ======
print(">>> Loading NLQuAD to recover gold answers ...")
dataset = load_dataset("LLukas22/NLQuAD", split="test")
gold_map = {}
for article in dataset:
    for para in article["paragraphs"]:
        for qa in para["qas"]:
            if qa["answers"]:
                gold_map[qa["id"]] = qa["answers"][0]["text"]

# ====== Load Stage 1 Merged Results (BM25-only merged) ======
print(f">>> Reading Stage 1 merged results from: {RESULT_FILE}")
df_res = pd.read_csv(RESULT_FILE)
if "qid" not in df_res.columns or "question" not in df_res.columns or "answer_for_eval" not in df_res.columns:
    raise ValueError("Input file lacks required columns: ['qid','question','answer_for_eval']")

df_res["gold_answer"] = df_res["qid"].map(gold_map)

# ====== Load Qwen Scoring Model (Kaggle: Single GPU preferred) ======
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f">>> Loading scorer ({LLM_SCORE_MODEL}) on {device} ...")
tok_score = AutoTokenizer.from_pretrained(LLM_SCORE_MODEL, trust_remote_code=True)
# Kaggle typically has one GPU, so device_map='auto' is sufficient; no need for balanced/max_memory
score_model = AutoModelForCausalLM.from_pretrained(
    LLM_SCORE_MODEL,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map="auto",
    trust_remote_code=True
)
print(">>> Scorer ready!")

# ====== Construct Prompt ======
def build_score_prompt(question, gold, generated):
    return f"""
You are an expert evaluator. Based on the question, gold answer, and generated answer, provide two scores:
1. Coherence (1-5): Logical clarity and fluency of the generated answer.
2. Informativeness (1-5): Amount of correct and relevant information compared to the gold answer.

Question: {question}
Gold Answer: {gold}
Generated Answer: {generated}

IMPORTANT:
- Do NOT provide any explanation or extra text.
- Only output two integers, separated by a comma, in the LAST line.
Format:
x,y

Now output the result:
"""

# ====== Parse LLM Output ======
def parse_llm_output(text):
    lines = text.strip().splitlines()
    for line in reversed(lines):
        if re.match(r"^\s*[1-5]\s*,\s*[1-5]\s*$", line):
            a, b = [int(x.strip()) for x in line.split(",")]
            return {"coherence": a, "informativeness": b}
    nums = re.findall(r"\b[1-5]\b", text)
    if len(nums) >= 2:
        return {"coherence": int(nums[-2]), "informativeness": int(nums[-1])}
    return {"coherence": 0, "informativeness": 0}

# ====== Batch Scoring (with Batch Size Fallback to Prevent OOM) ======
def batch_generate_scores(prompts, initial_bs=2, min_bs=1, max_new_tokens=8):
    results = []
    i, bs = 0, initial_bs
    while i < len(prompts):
        batch = prompts[i:i+bs]
        try:
            inputs = tok_score(batch, return_tensors="pt", padding=True, truncation=True).to(score_model.device)
            with torch.inference_mode():
                outputs = score_model.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    do_sample=False,
                    pad_token_id=tok_score.eos_token_id
                )
            decoded = tok_score.batch_decode(outputs, skip_special_tokens=True)
            results.extend(decoded)
            i += bs
            if bs < initial_bs:
                bs = min(bs * 2, initial_bs)
        except RuntimeError as e:
            if "CUDA out of memory" in str(e):
                torch.cuda.empty_cache()
                new_bs = max(bs // 2, min_bs)
                print(f"⚠️ OOM at batch_size={bs}, reducing to {new_bs}")
                bs = new_bs
                if bs < 1:
                    raise RuntimeError("❌ Even batch_size=1 failed. Aborting Stage 2.")
            else:
                raise
    return results

# ====== Stage 2 Main Loop ======
print(">>> Stage 2 (LLM-based scoring) running ...")
BATCH_SIZE = 2 if torch.cuda.is_available() else 2   # Kaggle GPU: 8 is stable; CPU: smaller batch
prompts, meta, results_s2 = [], [], []

for _, row in tqdm(df_res.iterrows(), total=len(df_res)):
    prompt = build_score_prompt(row["question"], row["gold_answer"], row["answer_for_eval"])
    prompts.append(prompt)
    meta.append(row.to_dict())

    if len(prompts) >= BATCH_SIZE:
        outs = batch_generate_scores(prompts, initial_bs=BATCH_SIZE)
        for out, row_data in zip(outs, meta):
            sc = parse_llm_output(out)
            row_data.update(sc)
            results_s2.append(row_data)
        prompts, meta = [], []

# Flush remaining
if prompts:
    outs = batch_generate_scores(prompts, initial_bs=BATCH_SIZE)
    for out, row_data in zip(outs, meta):
        sc = parse_llm_output(out)
        row_data.update(sc)
        results_s2.append(row_data)

# ====== Save Stage 2 Results ======
out_df = pd.DataFrame(results_s2)
out_df.to_csv(STAGE2_FILE, index=False)
print(f"✅ Stage 2 completed. Results saved at {STAGE2_FILE}")