In [None]:
# ===== 0. SETUP & ENVIRONMENT =====
print("="*80)
print("SECTION 0: Environment Setup")
print("="*80)

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Create experiment directory
import os
EXPERIMENT_DIR = "/content/drive/MyDrive/NLP"
os.makedirs(EXPERIMENT_DIR, exist_ok=True)
print(f"\nExperiment directory: {EXPERIMENT_DIR}")

# Install dependencies
print("\nInstalling dependencies...")
!pip install -q transformers datasets torch scikit-learn pandas rank-bm25 sentencepiece

# Check GPU
import torch
print(f"\nGPU Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
    print("WARNING: Running on CPU - training will be slow!")

SECTION 0: Environment Setup
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).

Experiment directory: /content/drive/MyDrive/NLP

Installing dependencies...

GPU Available: True
GPU Name: Tesla T4


In [None]:

# ===== 1. DOWNLOAD DATA =====
print("\n" + "="*80)
print("SECTION 1: Downloading Datasets")
print("="*80)

import os
os.chdir("/content")

# Clone QuanTemp repo
if not os.path.exists("/content/QuanTemp"):
    print("\nCloning QuanTemp repository...")
    !git clone https://github.com/factiverse/QuanTemp.git
    print("‚úì QuanTemp cloned")
else:
    print("‚úì QuanTemp already exists")

# Clone FactIR repo
if not os.path.exists("/content/factIR"):
    print("\nCloning FactIR repository...")
    !git clone https://github.com/factiverse/factIR.git
    print("‚úì FactIR cloned")
else:
    print("‚úì FactIR already exists")

# Download QuanTemp data if needed
os.chdir("/content/QuanTemp")
if not os.path.exists("data/bm25_scored_evidence/bm25_top_100_claimdecomp.json"):
    print("\nDownloading QuanTemp pre-processed data...")
    !pip install -q gdown
    !gdown --folder 1GYzSK0oU2MiaKbyBO3hE8kO4gdmxDjCv -O /content/QuanTemp/data --remaining-ok

    # Unzip BM25 evidence if zipped
    if os.path.exists("data/bm25_scored_evidence/bm25_top_100_claimdecomp.json.zip"):
        !unzip -q data/bm25_scored_evidence/bm25_top_100_claimdecomp.json.zip -d data/bm25_scored_evidence/
    print("‚úì QuanTemp data downloaded")
else:
    print("‚úì QuanTemp data already exists")

# Verify files
print("\nVerifying required files...")
required_files = [
    "/content/QuanTemp/data/raw_data/train_claims_quantemp.json",
    "/content/QuanTemp/data/raw_data/val_claims_quantemp.json",
    "/content/QuanTemp/data/raw_data/test_claims_quantemp.json",
    "/content/QuanTemp/data/bm25_scored_evidence/bm25_top_100_claimdecomp.json",
    "/content/factIR/evidence.csv"
]

all_exist = True
for f in required_files:
    exists = os.path.exists(f)
    status = "‚úì" if exists else "‚úó MISSING"
    print(f"{status} {f}")
    if not exists:
        all_exist = False

if not all_exist:
    print("\n‚ö†Ô∏è  Some files are missing. Please ensure:")
    print("   1. QuanTemp data is downloaded")
    print("   2. FactIR evidence.csv exists")
    print("   3. You've uploaded the QuanTemp evidence corpus JSON if needed")
    raise FileNotFoundError("Required data files missing")

print("\n‚úì All required files present")


SECTION 1: Downloading Datasets
‚úì QuanTemp already exists
‚úì FactIR already exists
‚úì QuanTemp data already exists

Verifying required files...
‚úì /content/QuanTemp/data/raw_data/train_claims_quantemp.json
‚úì /content/QuanTemp/data/raw_data/val_claims_quantemp.json
‚úì /content/QuanTemp/data/raw_data/test_claims_quantemp.json
‚úì /content/QuanTemp/data/bm25_scored_evidence/bm25_top_100_claimdecomp.json
‚úì /content/factIR/evidence.csv

‚úì All required files present


In [None]:

# ===== 2. LOAD FREE LLM FOR DECOMPOSITION =====
print("\n" + "="*80)
print("SECTION 2: Loading Free LLM for Decomposition")
print("="*80)

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch

# Use FLAN-T5-base (free, local model)
LLM_MODEL_NAME = "google/flan-t5-base"
print(f"\nLoading {LLM_MODEL_NAME}...")

llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
llm_model = AutoModelForSeq2SeqLM.from_pretrained(LLM_MODEL_NAME)
llm_model = llm_model.to(device)
llm_model.eval()

print(f"‚úì LLM loaded on {device}")



SECTION 2: Loading Free LLM for Decomposition

Loading google/flan-t5-base...
‚úì LLM loaded on cuda


In [None]:
 # ===== 3. DECOMPOSITION METHODS =====
print("\n" + "="*80)
print("SECTION 3: Defining Decomposition Methods")
print("="*80)

import re
from typing import List

def clean_and_deduplicate(subqueries: List[str], min_tokens: int = 4) -> List[str]:
    """Clean, deduplicate, and filter subqueries"""
    cleaned = []
    seen = set()

    for sq in subqueries:
        sq = sq.strip()
        if len(sq) == 0:
            continue

        # Remove numbering like "1.", "2.", etc.
        sq = re.sub(r'^\d+[\.\)]\s*', '', sq)
        sq = sq.strip()

        # Check minimum tokens
        if len(sq.split()) < min_tokens:
            continue

        # Deduplicate (case-insensitive)
        sq_lower = sq.lower()
        if sq_lower not in seen:
            seen.add(sq_lower)
            cleaned.append(sq)

    return cleaned

def decompose_with_llm(claim: str, prompt_template: str, max_subqueries: int = 4) -> List[str]:
    """Generate decomposition using LLM"""
    prompt = prompt_template.format(claim=claim)

    inputs = llm_tokenizer(
        prompt,
        return_tensors="pt",
        max_length=256,
        truncation=True
    ).to(device)

    with torch.no_grad():
        outputs = llm_model.generate(
            **inputs,
            max_length=200,
            num_return_sequences=1,
            temperature=0.7,
            do_sample=True
        )

    response = llm_tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Split by newlines or numbered items
    subqueries = re.split(r'\n|(?=\d+[\.\)])', response)
    subqueries = clean_and_deduplicate(subqueries)

    return subqueries[:max_subqueries]

# Define 3 decomposition strategies
# DECOMPOSITION_METHODS = {
#     "baseline": {
#         "name": "Baseline (No Decomposition)",
#         "func": lambda claim: [claim]  # Just return original claim
#     },

#     "llm_numeric": {
#         "name": "LLM: Numeric & Temporal Focus",
#         "func": lambda claim: decompose_with_llm(
#             claim,
#             "Split this fact-checking claim into 2-4 atomic subclaims. Focus on extracting specific numbers, dates, and temporal constraints separately. Keep each subclaim concise.\n\nClaim: {claim}\n\nSubclaims:"
#         )
#     },

#     "llm_constraints": {
#         "name": "LLM: Constraint Extraction",
#         "func": lambda claim: decompose_with_llm(
#             claim,
#             "Extract 2-4 key factual constraints from this claim as separate queries. Separate numeric constraints from categorical/entity constraints.\n\nClaim: {claim}\n\nConstraints:"
#         )
#     },

#     "llm_keywords": {
#         "name": "LLM: Keyword Queries",
#         "func": lambda claim: decompose_with_llm(
#             claim,
#             "Rewrite this claim into 2-4 short keyword-based search queries to find relevant evidence. Focus on core facts that can be verified.\n\nClaim: {claim}\n\nQueries:"
#         )
#     }
# }

print("Decomposition methods defined:")
for key, method in DECOMPOSITION_METHODS.items():
    print(f"  ‚Ä¢ {key}: {method['name']}")

# Test decomposition
print("\nTesting decomposition on sample claim...")

test_claim = "\"When you throw 23 million people off of health insurance -- people with cancer, people with heart disease, people with diabetes -- thousands of people will die. \u2026 This is study after study making this point."
for key, method in DECOMPOSITION_METHODS.items():
    result = method["func"](test_claim)
    print(f"\n{method['name']}:")
    for i, sq in enumerate(result, 1):
        print(f"  {i}. {sq}")



SECTION 3: Defining Decomposition Methods
Decomposition methods defined:
  ‚Ä¢ baseline: Baseline (No Decomposition)
  ‚Ä¢ llm_numeric: LLM v2: Numeric & Temporal (Anchored)
  ‚Ä¢ llm_constraints: LLM v2: Entity + Constraint Queries
  ‚Ä¢ llm_keywords: LLM v2: High-Precision Keyword Queries

Testing decomposition on sample claim...

Baseline (No Decomposition):
  1. "When you throw 23 million people off of health insurance -- people with cancer, people with heart disease, people with diabetes -- thousands of people will die. ‚Ä¶ This is study after study making this point.

LLM v2: Numeric & Temporal (Anchored):

LLM v2: Entity + Constraint Queries:
  1. WHat is the number of people with cancer?

LLM v2: High-Precision Keyword Queries:
  1. BM25: When you throw 23 million people off of health insurance -- people with cancer, people with heart disease, people with diabetes -- thousands of people will die. ...


In [None]:
# ===== Improved Decomposition Strategies (BM25-Anchored) =====
DECOMPOSITION_METHODS = {

    "baseline": {
        "name": "Baseline (No Decomposition)",
        "func": lambda claim: [claim]
    },

    "llm_numeric": {
        "name": "LLM v2: Numeric & Temporal (Anchored)",
        "func": lambda claim: decompose_with_llm(
            claim,
            """
You are generating BM25 search queries. DO NOT paraphrase the claim.

Rules:
- Reuse exact words and phrases from the claim (copy substrings).
- Keep all numbers, dates, percentages, and currencies EXACTLY as written.
- Each query must contain at least one number or temporal expression.
- Output 2‚Äì4 queries, one per line.
- Each query must be short (6‚Äì12 tokens).
- Avoid generic terms like "report", "study", "data".

Claim: {claim}

Queries:
""".strip()
        )
    },

    "llm_constraints": {
        "name": "LLM v2: Entity + Constraint Queries",
        "func": lambda claim: decompose_with_llm(
            claim,
            """
Generate 2‚Äì4 precise BM25 search queries.

Rules:
- Each query must include:
  (1) a named entity copied exactly from the claim AND
  (2) a numeric or temporal constraint copied exactly from the claim.
- Do NOT paraphrase entities or numbers.
- Keep queries concise (6‚Äì12 tokens).
- If no explicit number exists, use exact quantifier phrases ("most", "all", "only").

Claim: {claim}

Queries:
""".strip()
        )
    },

    "llm_keywords": {
        "name": "LLM v2: High-Precision Keyword Queries",
        "func": lambda claim: decompose_with_llm(
            claim,
            """
Rewrite the claim into 2‚Äì4 HIGH-PRECISION keyword queries for BM25.

Rules:
- Copy all named entities exactly as they appear.
- Copy all numbers, dates, and units exactly.
- Add at most ONE extra descriptive token if necessary.
- Each query must be <= 10 tokens.
- Prefer rare and specific words over generic ones.

Claim: {claim}

Queries:
""".strip()
        )
    }
}


In [None]:

# ===== 4. FACTIR EVALUATION =====
print("\n" + "="*80)
print("SECTION 4: FactIR Ranking Evaluation")
print("="*80)

import pandas as pd
import numpy as np
from rank_bm25 import BM25Okapi
from tqdm import tqdm
import pickle

# Load FactIR data
print("\nLoading FactIR evidence.csv...")
factir_df = pd.read_csv("/content/factIR/evidence.csv")
print(f"‚úì Loaded {len(factir_df)} rows with {factir_df['claim'].nunique()} unique claims")

# Simple tokenizer
def tokenize(text):
    text = str(text).lower()
    text = re.sub(r"[^a-z0-9\s]", " ", text)
    return [t for t in text.split() if t]

# Ranking metrics
def recall_at_k(rels_sorted, k):
    return 1.0 if any(rels_sorted[:k]) else 0.0

def mrr(rels_sorted):
    for i, r in enumerate(rels_sorted, start=1):
        if r == 1:
            return 1.0 / i
    return 0.0

# Cache for decompositions
FACTIR_DECOMP_CACHE = os.path.join(EXPERIMENT_DIR, "factir_decompositions.pkl")

if os.path.exists(FACTIR_DECOMP_CACHE):
    print(f"\nLoading cached decompositions from {FACTIR_DECOMP_CACHE}")
    with open(FACTIR_DECOMP_CACHE, "rb") as f:
        decomp_cache = pickle.load(f)
else:
    print("\nGenerating decompositions for FactIR claims (this may take time)...")
    decomp_cache = {}

    unique_claims = factir_df['claim'].unique()

    for method_key in tqdm(DECOMPOSITION_METHODS.keys(), desc="Methods"):
        decomp_cache[method_key] = {}
        method_func = DECOMPOSITION_METHODS[method_key]["func"]

        for claim in tqdm(unique_claims, desc=f"  {method_key}", leave=False):
            decomp_cache[method_key][claim] = method_func(claim)

    # Save cache
    with open(FACTIR_DECOMP_CACHE, "wb") as f:
        pickle.dump(decomp_cache, f)
    print(f"‚úì Saved decomposition cache to {FACTIR_DECOMP_CACHE}")

# Evaluate each method
print("\nEvaluating decomposition methods on FactIR...")

factir_results = []
skipped_claims = 0
for method_key, method_info in DECOMPOSITION_METHODS.items():
    print(f"\nEvaluating: {method_info['name']}")

    method_metrics = []

    for claim, group in tqdm(factir_df.groupby("claim"), desc="  Claims", leave=False):

        snippets = group["snippet"].fillna("").tolist()
        rels = group["relevance"].astype(int).tolist()

        # Get subqueries for this claim
        subqueries = decomp_cache[method_key][claim]

        # Tokenize documents
        tokenized_docs = [tokenize(s) for s in snippets]

        if method_key == "baseline":
            # Baseline: single query
            bm25 = BM25Okapi(tokenized_docs)
            query_tokens = tokenize(claim)
            scores = bm25.get_scores(query_tokens)
        else:
            bm25 = BM25Okapi(tokenized_docs)

            # Baseline score
            base_scores = bm25.get_scores(tokenize(claim))

            # Decomposition score: mean over all subqueries
            if len(subqueries) > 0:
                decomp_scores = np.mean(
                    [bm25.get_scores(tokenize(q)) for q in subqueries],
                    axis=0
                )
            else:
                decomp_scores = np.zeros(len(snippets))

            # Option B: mixture of baseline + decomposition
            alpha = 0.6
            scores = alpha * base_scores + (1 - alpha) * decomp_scores

        # Rank by scores
        order = np.argsort(scores)[::-1]
        rels_sorted = [rels[i] for i in order]

        method_metrics.append({
            "claim": claim,
            "method": method_key,
            "hit@5": recall_at_k(rels_sorted, 5),
            "hit@10": recall_at_k(rels_sorted, 10),
            "mrr": mrr(rels_sorted),
            "num_subqueries": len(subqueries),
            "num_candidates": len(snippets),
            "num_relevant": int(sum(rels))
        })

    # Compute averages
    metrics_df = pd.DataFrame(method_metrics)
    avg_hit5 = metrics_df["hit@5"].mean()
    avg_hit10 = metrics_df["hit@10"].mean()
    avg_mrr = metrics_df["mrr"].mean()

    print(f"  Hit@5:  {avg_hit5:.4f}")
    print(f"  Hit@10: {avg_hit10:.4f}")
    print(f"  MRR:    {avg_mrr:.4f}")

    factir_results.append({
        "method": method_key,
        "method_name": method_info["name"],
        "hit@5": avg_hit5,
        "hit@10": avg_hit10,
        "mrr": avg_mrr,
        "per_claim_results": method_metrics
    })

# Save FactIR results
factir_summary_df = pd.DataFrame([
    {
        "method": r["method"],
        "method_name": r["method_name"],
        "hit@5": r["hit@5"],
        "hit@10": r["hit@10"],
        "mrr": r["mrr"]
    }
    for r in factir_results
])

factir_summary_path = os.path.join(EXPERIMENT_DIR, "factir_ranking_results.csv")
factir_summary_df.to_csv(factir_summary_path, index=False)
print(f"\n‚úì Saved FactIR summary to {factir_summary_path}")

# Select best decomposition method
best_method = factir_summary_df.loc[factir_summary_df["mrr"].idxmax()]
BEST_DECOMP_KEY = best_method["method"]
print(f"\nüèÜ BEST DECOMPOSITION METHOD: {best_method['method_name']}")
print(f"   MRR: {best_method['mrr']:.4f}")

# Save detailed per-claim results
for result in factir_results:
    method_key = result["method"]
    per_claim_df = pd.DataFrame(result["per_claim_results"])
    per_claim_path = os.path.join(EXPERIMENT_DIR, f"factir_{method_key}_perclaim.csv")
    per_claim_df.to_csv(per_claim_path, index=False)
    print(f"   Saved {method_key} per-claim results to {per_claim_path}")



SECTION 4: FactIR Ranking Evaluation

Loading FactIR evidence.csv...
‚úì Loaded 1412 rows with 100 unique claims

Loading cached decompositions from /content/drive/MyDrive/NLP/factir_decompositions.pkl

Evaluating decomposition methods on FactIR...

Evaluating: Baseline (No Decomposition)




  Hit@5:  0.8100
  Hit@10: 0.8300
  MRR:    0.6739

Evaluating: LLM v2: Numeric & Temporal (Anchored)




  Hit@5:  0.8000
  Hit@10: 0.8400
  MRR:    0.6772

Evaluating: LLM v2: Entity + Constraint Queries




  Hit@5:  0.8200
  Hit@10: 0.8400
  MRR:    0.6523

Evaluating: LLM v2: High-Precision Keyword Queries




  Hit@5:  0.8100
  Hit@10: 0.8300
  MRR:    0.6710

‚úì Saved FactIR summary to /content/drive/MyDrive/NLP/factir_ranking_results.csv

üèÜ BEST DECOMPOSITION METHOD: LLM v2: Numeric & Temporal (Anchored)
   MRR: 0.6772
   Saved baseline per-claim results to /content/drive/MyDrive/NLP/factir_baseline_perclaim.csv
   Saved llm_numeric per-claim results to /content/drive/MyDrive/NLP/factir_llm_numeric_perclaim.csv
   Saved llm_constraints per-claim results to /content/drive/MyDrive/NLP/factir_llm_constraints_perclaim.csv
   Saved llm_keywords per-claim results to /content/drive/MyDrive/NLP/factir_llm_keywords_perclaim.csv


In [None]:
# ===== 5. QUANTEMP RETRIEVAL (LOAD ONLY) =====
print("\n" + "="*80)
print("SECTION 5: QuanTemp Retrieval (Load-only, no recompute)")
print("="*80)

import os, json, pickle

# Load QuanTemp claims
print("\nLoading QuanTemp claims...")
with open("/content/QuanTemp/data/raw_data/train_claims_quantemp.json") as f:
    train_claims = json.load(f)
with open("/content/QuanTemp/data/raw_data/val_claims_quantemp.json") as f:
    val_claims = json.load(f)
with open("/content/QuanTemp/data/raw_data/test_claims_quantemp.json") as f:
    test_claims = json.load(f)

print(f"‚úì Train: {len(train_claims)} claims")
print(f"‚úì Val:   {len(val_claims)} claims")
print(f"‚úì Test:  {len(test_claims)} claims")

# Load repo evidence dict (optional; keep if later you want R2 comparisons)
print("\nLoading repo bm25_top_100_claimdecomp.json (optional R2)...")
with open("/content/QuanTemp/data/bm25_scored_evidence/bm25_top_100_claimdecomp.json") as f:
    repo_bm25_data = json.load(f)
repo_evidence_dict = {item["query_id"]: item["docs"] for item in repo_bm25_data}
print(f"‚úì Repo evidence entries: {len(repo_evidence_dict)} query_ids")

# Load your already computed retrieval results (baseline + decomposed)
RETRIEVAL_CACHE = os.path.join(EXPERIMENT_DIR, "retrieval_results.pkl")
print(f"\nLoading retrieval cache: {RETRIEVAL_CACHE}")
with open(RETRIEVAL_CACHE, "rb") as f:
    retrieval_results = pickle.load(f)

print("‚úì Loaded retrieval_results keys:", list(retrieval_results.keys()))
print("‚úì baseline sizes:", {k: len(v) for k, v in retrieval_results["baseline"].items()})
print("‚úì decomposed sizes:", {k: len(v) for k, v in retrieval_results["decomposed"].items()})

# Quick evidence coverage stats (no filtering)
def coverage(d, n):
    return sum(len(d.get(i, [])) > 0 for i in range(n)), n

for split_name, claims in [("train", train_claims), ("val", val_claims), ("test", test_claims)]:
    n = len(claims)
    c0, _ = coverage(retrieval_results["baseline"][split_name], n)
    c1, _ = coverage(retrieval_results["decomposed"][split_name], n)
    print(f"\nCoverage {split_name.upper()}:")
    print(f"  R0 baseline:   {c0}/{n}")
    print(f"  R1 decomposed: {c1}/{n}")



SECTION 5: QuanTemp Retrieval (Load-only, no recompute)

Loading QuanTemp claims...
‚úì Train: 9935 claims
‚úì Val:   3084 claims
‚úì Test:  2495 claims

Loading repo bm25_top_100_claimdecomp.json (optional R2)...
‚úì Repo evidence entries: 2495 query_ids

Loading retrieval cache: /content/drive/MyDrive/NLP/retrieval_results.pkl
‚úì Loaded retrieval_results keys: ['baseline', 'decomposed', 'repo']
‚úì baseline sizes: {'train': 9935, 'val': 3084, 'test': 2495}
‚úì decomposed sizes: {'train': 9935, 'val': 3084, 'test': 2495}

Coverage TRAIN:
  R0 baseline:   9935/9935
  R1 decomposed: 9726/9935

Coverage VAL:
  R0 baseline:   3084/3084
  R1 decomposed: 3009/3084

Coverage TEST:
  R0 baseline:   2495/2495
  R1 decomposed: 2445/2495


In [None]:
# ===== 6. BUILD TRAINING PAIRS (DECOMPOSED ONLY, TOP-3, NO FILTER) =====
print("\n" + "="*80)
print("SECTION 6: Building Training Pairs (R1 decomposed only, top-3, no filtering)")
print("="*80)

from tqdm import tqdm

def normalize_label(label):
    label_lower = label.lower()
    if "support" in label_lower or "true" in label_lower or "correct" in label_lower:
        return 0
    elif "refute" in label_lower or "false" in label_lower or "pants" in label_lower:
        return 1
    else:
        return 2

def create_training_examples_all(claims_list, evidence_dict, top_k=3, max_evidence_chars=512):
    """
    Build (evidence, claim, label) pairs for ALL claim indices [0..len(claims_list)-1].
    Uses top_k evidence snippets per claim.
    """
    examples = []
    for idx, claim_obj in enumerate(tqdm(claims_list, desc="Creating pairs")):
        claim_text = claim_obj["claim"]
        label = normalize_label(claim_obj["label"])

        evidences = evidence_dict.get(idx, [])
        if not evidences:
            continue

        for ev in evidences[:top_k]:
            ev = ev.strip()
            if len(ev) < 20:
                continue
            examples.append({
                "claim": claim_text,
                "evidence": ev[:max_evidence_chars],
                "label": label
            })
    return examples

print("\nBuilding TrainSet-B (R1 decomposed retrieval)...")
trainset_b = create_training_examples_all(
    train_claims,
    retrieval_results["decomposed"]["train"],
    top_k=3
)
valset_b = create_training_examples_all(
    val_claims,
    retrieval_results["decomposed"]["val"],
    top_k=3
)

print(f"‚úì TrainSet-B: {len(trainset_b)} examples")
print(f"‚úì ValSet-B:   {len(valset_b)} examples")

# (Optional) show label distribution
from collections import Counter
def print_label_dist(examples, name):
    labels = [ex["label"] for ex in examples]
    dist = Counter(labels)
    total = len(labels) if len(labels) else 1
    print(f"\n{name} Label Distribution:")
    print(f"  SUPPORTS (0): {dist[0]:5d} ({dist[0]/total*100:5.1f}%)")
    print(f"  REFUTES  (1): {dist[1]:5d} ({dist[1]/total*100:5.1f}%)")
    print(f"  NEI      (2): {dist[2]:5d} ({dist[2]/total*100:5.1f}%)")

print_label_dist(trainset_b, "TrainSet-B (R1)")
print_label_dist(valset_b,   "ValSet-B (R1)")



SECTION 6: Building Training Pairs (R1 decomposed only, top-3, no filtering)

Building TrainSet-B (R1 decomposed retrieval)...


Creating pairs: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9935/9935 [00:00<00:00, 46222.03it/s]
Creating pairs: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3084/3084 [00:00<00:00, 201687.59it/s]

‚úì TrainSet-B: 29178 examples
‚úì ValSet-B:   9027 examples

TrainSet-B (R1) Label Distribution:
  SUPPORTS (0):  5379 ( 18.4%)
  REFUTES  (1): 16917 ( 58.0%)
  NEI      (2):  6882 ( 23.6%)

ValSet-B (R1) Label Distribution:
  SUPPORTS (0):  1803 ( 20.0%)
  REFUTES  (1):  5238 ( 58.0%)
  NEI      (2):  1986 ( 22.0%)





In [None]:
# ===== 7. FINE-TUNE TWO VERIFIERS (A=R0, B=R1) =====
print("\n" + "="*80)
print("SECTION 7: Fine-tuning Verifiers (A=baseline, B=decomposed)")
print("="*80)

import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import TrainingArguments, Trainer
from datasets import Dataset
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import f1_score, accuracy_score
from torch import nn

VERIFIER_MODEL = "roberta-base"

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)

    acc = accuracy_score(labels, preds)
    macro_f1 = f1_score(labels, preds, average="macro", zero_division=0)
    per_class = f1_score(labels, preds, average=None, zero_division=0)  # [supports, refutes, nei]

    return {
        "accuracy": acc,
        "macro_f1": macro_f1,
        "supports_f1": per_class[0],
        "refutes_f1": per_class[1],
        "nei_f1": per_class[2],
    }

class WeightedTrainer(Trainer):
    def __init__(self, *args, class_weights=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.class_weights = class_weights

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        weight = torch.tensor(self.class_weights, dtype=torch.float32).to(logits.device)
        loss = nn.CrossEntropyLoss(weight=weight)(logits, labels)
        return (loss, outputs) if return_outputs else loss

def tokenize_function(batch, tokenizer):
    return tokenizer(
        batch["evidence"],
        batch["claim"],
        truncation=True,
        padding="max_length",
        max_length=256
    )

def build_hf_dataset(examples, tokenizer):
    ds = Dataset.from_dict({
        "evidence": [ex["evidence"] for ex in examples],
        "claim":    [ex["claim"] for ex in examples],
        "label":    [ex["label"] for ex in examples],
    })
    ds = ds.map(lambda x: tokenize_function(x, tokenizer), batched=True, remove_columns=["evidence", "claim"])
    return ds

import os
from transformers.trainer_utils import get_last_checkpoint

def train_one_verifier(tag, train_examples, val_examples, out_dir):
    print("\n" + "-"*80)
    print(f"{tag}")
    print(f"Output dir: {out_dir}")
    print("-"*80)

    os.makedirs(out_dir, exist_ok=True)

    tokenizer = AutoTokenizer.from_pretrained(VERIFIER_MODEL)
    model = AutoModelForSequenceClassification.from_pretrained(
        VERIFIER_MODEL,
        num_labels=3,
        problem_type="single_label_classification"
    )

    # class weights from THIS training set
    train_labels = [ex["label"] for ex in train_examples]
    class_weights = compute_class_weight(
        class_weight="balanced",
        classes=np.array([0, 1, 2]),
        y=train_labels
    )
    print(f"{tag} class_weights:", class_weights)

    train_ds = build_hf_dataset(train_examples, tokenizer)
    val_ds   = build_hf_dataset(val_examples, tokenizer)

    args = TrainingArguments(
        output_dir=out_dir,
        overwrite_output_dir=False,     # important for resume
        eval_strategy="steps",
        eval_steps=100,
        save_strategy="steps",
        save_steps=100,
        save_total_limit=3,             # keep last 3 checkpoints
        learning_rate=1e-5,
        per_device_train_batch_size=4,
        per_device_eval_batch_size=8,
        gradient_accumulation_steps=4,
        num_train_epochs=3,
        weight_decay=0.01,
        warmup_steps=200,
        logging_steps=50,
        load_best_model_at_end=True,
        metric_for_best_model="macro_f1",
        greater_is_better=True,
        fp16=True,
        report_to="none"
    )

    trainer = WeightedTrainer(
        model=model,
        args=args,
        train_dataset=train_ds,
        eval_dataset=val_ds,
        processing_class=tokenizer,
        compute_metrics=compute_metrics,
        class_weights=class_weights
    )

    # Auto-resume if checkpoint exists
    last_ckpt = get_last_checkpoint(out_dir)
    if last_ckpt is not None:
        print(f"[*] Resuming from checkpoint: {last_ckpt}")
        trainer.train(resume_from_checkpoint=last_ckpt)
    else:
        print("[*] No checkpoint found. Starting fresh training.")
        trainer.train()

    model.save_pretrained(out_dir)
    tokenizer.save_pretrained(out_dir)
    print(f"‚úì Saved final model to {out_dir}")

    return out_dir


# verifier_a_path = train_one_verifier(
#     tag="Verifier-A (R0 baseline evidence)",
#     train_examples=trainset_a,
#     val_examples=valset_a,
#     out_dir=os.path.join(EXPERIMENT_DIR, "verifier_a_r0")
# )

verifier_b_path = train_one_verifier(
    tag="Verifier-B (R1 decomposed evidence)",
    train_examples=trainset_b,
    val_examples=valset_b,
    out_dir=os.path.join(EXPERIMENT_DIR, "verifier_b_r1")
)



SECTION 7: Fine-tuning Verifiers (A=baseline, B=decomposed)

--------------------------------------------------------------------------------
Verifier-B (R1 decomposed evidence)
Output dir: /content/drive/MyDrive/NLP/verifier_b_r1
--------------------------------------------------------------------------------


Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Verifier-B (R1 decomposed evidence) class_weights: [1.80814278 0.57492463 1.41325196]


Map:   0%|          | 0/29178 [00:00<?, ? examples/s]

Map:   0%|          | 0/9027 [00:00<?, ? examples/s]

[*] Resuming from checkpoint: /content/drive/MyDrive/NLP/verifier_b_r1/checkpoint-300


ValueError: Can't find a valid checkpoint at /content/drive/MyDrive/NLP/verifier_b_r1/checkpoint-300

In [None]:
# ===== 8. DOWNSTREAM EVALUATION =====
print("\n" + "="*80)
print("SECTION 8: Downstream Evaluation on QuanTemp Test")
print("="*80)

from transformers import AutoTokenizer, AutoModelForSequenceClassification
from sklearn.metrics import f1_score, accuracy_score
import numpy as np
import torch
from tqdm import tqdm

# Load both verifiers
print("\nLoading trained verifiers...")

# verifier_a = AutoModelForSequenceClassification.from_pretrained(verifier_a_path).to(device).eval()
verifier_b = AutoModelForSequenceClassification.from_pretrained(verifier_b_path).to(device).eval()

# You can load either tokenizer path; safest is to load the same base model tokenizer
# tokenizer = AutoTokenizer.from_pretrained(verifier_a_path)

print("‚úì Verifiers loaded")

# ---- Evidence dicts for each retrieval condition (TEST split) ----
evidence_r0 = retrieval_results["baseline"]["test"]    # YOUR baseline retrieval
evidence_r1 = retrieval_results["decomposed"]["test"]  # YOUR decomposition retrieval
evidence_r2 = repo_evidence_dict                       # REPO claim-decomp retrieval

print("\nEvidence coverage on TEST:")
print(f"  R0 baseline:   {sum(len(v)>0 for v in evidence_r0.values())}/{len(test_claims)}")
print(f"  R1 decomposed: {sum(len(v)>0 for v in evidence_r1.values())}/{len(test_claims)}")
print(f"  R2 repo:       {sum(len(v)>0 for v in evidence_r2.values())}/{len(test_claims)}")

# Evaluation function (top-5, average probs)
def evaluate_verifier(model, tokenizer, claims, evidence_dict, retrieval_name, top_n=5):
    predictions, true_labels, confidences = [], [], []

    for idx, claim_obj in enumerate(tqdm(claims, desc=f"Evaluating {retrieval_name}")):
        claim_text = claim_obj["claim"]
        true_label = normalize_label(claim_obj["label"])

        evidences = evidence_dict.get(idx, [])
        if not evidences:
            predictions.append(2)   # NEI
            true_labels.append(true_label)
            confidences.append(1.0)
            continue

        all_probs = []
        for ev in evidences[:top_n]:
            ev = ev.strip()
            if len(ev) < 20:
                continue

            inputs = tokenizer(
                ev[:512],
                claim_text,
                truncation=True,
                max_length=256,
                return_tensors="pt"
            ).to(device)

            with torch.no_grad():
                logits = model(**inputs).logits
                probs = torch.softmax(logits, dim=1)[0].detach().cpu().numpy()
                all_probs.append(probs)

        if not all_probs:
            predictions.append(2)   # NEI
            true_labels.append(true_label)
            confidences.append(1.0)
            continue

        avg_probs = np.mean(all_probs, axis=0)
        pred = int(np.argmax(avg_probs))
        conf = float(avg_probs.max())

        predictions.append(pred)
        true_labels.append(true_label)
        confidences.append(conf)

    macro_f1 = f1_score(true_labels, predictions, average="macro", zero_division=0)
    accuracy = accuracy_score(true_labels, predictions)
    per_class_f1 = f1_score(true_labels, predictions, average=None, zero_division=0)

    return {
        "retrieval": retrieval_name,
        "accuracy": accuracy,
        "macro_f1": macro_f1,
        "supports_f1": per_class_f1[0],
        "refutes_f1": per_class_f1[1],
        "nei_f1": per_class_f1[2],
        "predictions": predictions,
        "true_labels": true_labels,
        "confidences": confidences,
    }

results_matrix = []

def run_combo(model, model_name, evidence_dict, retrieval_name):
    res = evaluate_verifier(model, tokenizer, test_claims, evidence_dict, retrieval_name=retrieval_name, top_n=5)
    res["verifier"] = model_name
    results_matrix.append(res)
    print(f"{model_name} √ó {retrieval_name}: Macro-F1={res['macro_f1']:.4f}, Acc={res['accuracy']:.4f}")
    return res

print("\n" + "-"*80)
print("Evaluating Verifier-A (trained on R0 evidence)")
print("-"*80)
run_combo(verifier_a, "Verifier-A (R0-trained)", evidence_r0, "R0_Baseline")
run_combo(verifier_a, "Verifier-A (R0-trained)", evidence_r1, f"R1_{BEST_DECOMP_KEY}")
run_combo(verifier_a, "Verifier-A (R0-trained)", evidence_r2, "R2_RepoClaimDecomp")

print("\n" + "-"*80)
print("Evaluating Verifier-B (trained on R1 evidence)")
print("-"*80)
run_combo(verifier_b, "Verifier-B (R1-trained)", evidence_r0, "R0_Baseline")
run_combo(verifier_b, "Verifier-B (R1-trained)", evidence_r1, f"R1_{BEST_DECOMP_KEY}")
run_combo(verifier_b, "Verifier-B (R1-trained)", evidence_r2, "R2_RepoClaimDecomp")


In [None]:
# ===== 9. SAVE RESULTS =====
print("\n" + "="*80)
print("SECTION 9: Saving Results & Analysis")
print("="*80)

# Create summary table
summary_rows = []
for result in results_matrix:
    summary_rows.append({
        'verifier': result['verifier'],
        'retrieval': result['retrieval'],
        'accuracy': result['accuracy'],
        'macro_f1': result['macro_f1'],
        'supports_f1': result['supports_f1'],
        'refutes_f1': result['refutes_f1'],
        'nei_f1': result['nei_f1']
    })

summary_df = pd.DataFrame(summary_rows)
summary_path = os.path.join(EXPERIMENT_DIR, "downstream_results_summary.csv")
summary_df.to_csv(summary_path, index=False)

print(f"\n‚úì Saved summary to {summary_path}")
print("\nDOWNSTREAM RESULTS SUMMARY:")
print(summary_df.to_string(index=False))

# Find best configuration
best_config = summary_df.loc[summary_df['macro_f1'].idxmax()]
print(f"\nüèÜ BEST CONFIGURATION:")
print(f"   {best_config['verifier']} √ó {best_config['retrieval']}")
print(f"   Macro-F1: {best_config['macro_f1']:.4f} ({best_config['macro_f1']*100:.2f}%)")
print(f"   Accuracy: {best_config['accuracy']:.4f} ({best_config['accuracy']*100:.2f}%)")

# Save detailed predictions for each configuration
for result in results_matrix:
    pred_rows = []
    for i, (claim, true, pred, conf) in enumerate(zip(
        result['claims'],
        result['true_labels'],
        result['predictions'],
        result['confidences']
    )):
        pred_rows.append({
            'claim_id': i,
            'claim': claim,
            'true_label': true,
            'pred_label': pred,
            'true_str': ['SUPPORTS', 'REFUTES', 'NEI'][true],
            'pred_str': ['SUPPORTS', 'REFUTES', 'NEI'][pred],
            'confidence': conf,
            'correct': true == pred
        })

    pred_df = pd.DataFrame(pred_rows)
    pred_path = os.path.join(
        EXPERIMENT_DIR,
        f"predictions_{result['verifier'].replace(' ', '_').replace('(', '').replace(')', '')}_{result['retrieval']}.csv"
    )
    pred_df.to_csv(pred_path, index=False)
    print(f"   Saved predictions: {os.path.basename(pred_path)}")

# Generate qualitative examples
print("\nGenerating qualitative examples...")

qualitative_path = os.path.join(EXPERIMENT_DIR, "qualitative_examples.txt")

with open(qualitative_path, 'w', encoding='utf-8') as f:
    f.write("="*80 + "\n")
    f.write("QUALITATIVE ANALYSIS: Decomposition Impact\n")
    f.write("="*80 + "\n\n")

    f.write(f"Best Decomposition Method: {DECOMPOSITION_METHODS[BEST_DECOMP_KEY]['name']}\n")
    f.write(f"FactIR MRR: {best_method['mrr']:.4f}\n\n")

    f.write("="*80 + "\n")
    f.write("DOWNSTREAM PERFORMANCE COMPARISON\n")
    f.write("="*80 + "\n\n")

    # Compare baseline vs decomposition
    baseline_f1 = summary_df[summary_df['retrieval'] == 'R0_Baseline']['macro_f1'].values[0]
    decomp_f1 = summary_df[summary_df['retrieval'].str.contains('R1')]['macro_f1'].max()
    repo_f1 = summary_df[summary_df['retrieval'] == 'R2_RepoClaimDecomp']['macro_f1'].values[0]

    f.write(f"Baseline (R0):         Macro-F1 = {baseline_f1:.4f}\n")
    f.write(f"Our Decomp (R1):       Macro-F1 = {decomp_f1:.4f}\n")
    f.write(f"Repo Decomp (R2):      Macro-F1 = {repo_f1:.4f}\n\n")

    if decomp_f1 > baseline_f1:
        improvement = ((decomp_f1 - baseline_f1) / baseline_f1) * 100
        f.write(f"‚úì Our decomposition IMPROVED performance by {improvement:.2f}%\n\n")
    else:
        f.write(f"‚úó Our decomposition did not improve over baseline\n\n")

    if decomp_f1 > repo_f1:
        f.write(f"‚úì Our decomposition OUTPERFORMED repo decomposition\n\n")
    else:
        f.write(f"‚úó Repo decomposition performed better\n\n")

    f.write("="*80 + "\n")
    f.write("KEY FINDINGS\n")
    f.write("="*80 + "\n\n")

    f.write(f"1. Best FactIR decomposition: {BEST_DECOMP_KEY}\n")
    f.write(f"   - Achieved MRR of {best_method['mrr']:.4f}\n")
    f.write(f"   - Hit@5: {best_method['hit@5']:.4f}\n\n")

    f.write(f"2. Best downstream configuration:\n")
    f.write(f"   - {best_config['verifier']} √ó {best_config['retrieval']}\n")
    f.write(f"   - Macro-F1: {best_config['macro_f1']:.4f}\n\n")

    f.write(f"3. Training approach impact:\n")
    verifier_a_best = summary_df[summary_df['verifier'].str.contains('A')]['macro_f1'].max()
    verifier_b_best = summary_df[summary_df['verifier'].str.contains('B')]['macro_f1'].max()
    if verifier_b_best > verifier_a_best:
        f.write(f"   ‚úì Decomp-trained verifier performed better ({verifier_b_best:.4f} vs {verifier_a_best:.4f})\n\n")
    else:
        f.write(f"   - Baseline-trained verifier performed better ({verifier_a_best:.4f} vs {verifier_b_best:.4f})\n\n")

print(f"‚úì Saved qualitative analysis to {qualitative_path}")

# ===== 10. FINAL SUMMARY =====
print("\n" + "="*80)
print("EXPERIMENT COMPLETE!")
print("="*80)

print(f"\nüìÅ All results saved to: {EXPERIMENT_DIR}")
print("\nGenerated files:")
print(f"  1. factir_ranking_results.csv")
print(f"  2. downstream_results_summary.csv")
print(f"  3. training_pair_stats.json")
print(f"  4. qualitative_examples.txt")
print(f"  5. Prediction CSVs for each configuration")
print(f"  6. Trained verifier checkpoints")

print("\n" + "="*80)
print("NEXT STEPS:")
print("="*80)
print("\n1. Provide evidence corpus JSON for true retrieval experiments")
print("2. Update CORPUS_PATH variable to enable custom BM25 retrieval")
print("3. Review qualitative_examples.txt for insights")
print("4. Compare results with your friend's baseline")
print("5. Experiment with other decomposition prompts if needed")

print("\n‚úì Pipeline execution complete!")
print("\nTo download results:")
print(f"  from google.colab import files")
print(f"  files.download('{summary_path}')")
print(f"  files.download('{qualitative_path}')")