In [None]:
import os

base = "/mnt/data/federated_qa_jupyter"
folders = [
    "data/client1", "data/client2", "data/test", "data/central_rag_index",
    "model/base_model", "model/client1", "model/client2", "model/federated_merged_model",
    "checkpoints/client1", "checkpoints/client2", "results", "utils"
]

for f in folders:
    os.makedirs(os.path.join(base, f), exist_ok=True)

print("✅ Folder structure created.")

In [None]:
#  answer_utils.py
import re
import faiss
import pickle
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from sentence_transformers import SentenceTransformer, CrossEncoder
from sklearn.feature_extraction.text import TfidfVectorizer
from nltk.corpus import stopwords
from difflib import SequenceMatcher

# Load FAISS index and chunk metadata
index = faiss.read_index("/mnt/data/RAG/3gpp_index.faiss")
with open("/mnt/data/RAG/3gpp_chunks.pkl", "rb") as f:
    documents = pickle.load(f)

embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")

def is_similar(a, b, threshold=0.75):
    return SequenceMatcher(None, a, b).ratio() >= threshold

def truncate_and_filter_chunks(chunks, query, window_size=150, stride=75, max_windows=5):
    STOPWORDS = set(stopwords.words("english"))

    def normalize(text):
        return re.sub(r'\W+', ' ', text.lower())

    def lexical_overlap(query, span):
        q_tokens = set(normalize(query).split()) - STOPWORDS
        c_tokens = set(normalize(span).split()) - STOPWORDS
        return len(q_tokens & c_tokens) / (len(q_tokens | c_tokens) + 1e-5)

    def tfidf_score(query, span):
        vec = TfidfVectorizer().fit([query, span])
        X = vec.transform([query, span])
        return (X[0] @ X[1].T).A[0][0]

    scored_spans = []
    for chunk in chunks:
        words = chunk["content"].split()
        for i in range(0, len(words), stride):
            span_words = words[i:i + window_size]
            if len(span_words) < 30:
                continue
            span = " ".join(span_words)
            score = 0.6 * lexical_overlap(query, span) + 0.4 * tfidf_score(query, span)
            scored_spans.append({
                "content": span,
                "score": score,
                "source": chunk.get("source", "unknown")
            })

    return sorted(scored_spans, key=lambda x: x["score"], reverse=True)[:max_windows]

def retrieve_with_rerank(query, top_k=5):
    query_vec = embedding_model.encode(query, normalize_embeddings=True)
    query_vec = np.array(query_vec).reshape(1, -1).astype("float32")
    D, I = index.search(query_vec, top_k * 2)
    initial_results = [documents[i] for i in I[0]]
    pairs = [(query, doc["content"]) for doc in initial_results]
    scores = reranker.predict(pairs)
    reranked = sorted(zip(scores, initial_results), key=lambda x: x[0], reverse=True)[:top_k]
    return [doc for _, doc in reranked]

def build_fusion_prompt(context_chunks, question):
    SYSTEM_PROMPT = (
        "You are a precise assistant. Extract the exact answer span from the context. "
        "Do not paraphrase, summarize, or add extra information. "
        "The answer must appear exactly in the context. "
        "If the context lists multiple conditions, actions, or branches, include them all as written. "
        "Do not summarize or paraphrase — copy the exact text from the context, line by line."
    )
    context_lines = []
    for chunk in context_chunks:
        source = chunk.get("source", "unknown").split("/")[-1]
        context_lines.append(f"[Source: {source}]\n-----\n{chunk['content'].strip()}")
    fused_context = "\n\n".join(context_lines)
    user_prompt = (
        f"Context:\n{fused_context}\n\n"
        f"Question: {question}\n"
        f"Answer from the context only:"
    )
    return f"<s>[INST] <<SYS>>\n{SYSTEM_PROMPT}\n<</SYS>>\n\n{user_prompt} [/INST]"

def clean_prediction(raw_text):
    answer = raw_text.split("[/INST]")[-1].strip()
    answer = re.sub(r"[^\w\s\-.,:/()]", "", answer)
    answer = re.sub(r'(\b.+?:)(\s*\1)+', r'\1', answer)
    tokens = answer.split()
    for i in range(1, len(tokens) // 2):
        if tokens[:i] == tokens[i:2*i]:
            answer = " ".join(tokens[:i])
            break
    sentence_end = re.search(r'[.?!]', answer)
    if sentence_end:
        answer = answer[:sentence_end.end()]
    return answer.strip()

def answer_with_fusion_cross_rag_truncated(question, top_k=6, max_windows=5, verbose=False):
    initial_chunks = retrieve_with_rerank(question, top_k=top_k)
    final_chunks = truncate_and_filter_chunks(initial_chunks, question, max_windows=max_windows)
    prompt = build_fusion_prompt(final_chunks, question)

    model_path = "./model/federated_merged_model"
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16).to("cuda")
    qa_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)

    output = qa_pipeline(
        prompt,
        max_new_tokens=160,
        do_sample=False,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id
    )[0]["generated_text"]

    answer = clean_prediction(output)

    all_context = " ".join([c["content"] for c in final_chunks])
    if not any(is_similar(answer.lower(), c["content"].lower()) for c in final_chunks):
        print("🚨 WARNING: Approximate match for answer not found in final context. Review answer relevance.")

    if verbose:
        print("📌 Prompt (truncated):\n", prompt[:500], "...\n")
        print("🧾 Raw Output:\n", output)
        print("✅ Final Answer:\n", answer)
        for i, chunk in enumerate(final_chunks):
            print(f"\n--- Context {i+1} ---\n{chunk['content'][:300]}...\n")

    return answer, final_chunks

In [None]:
# trainer.py
import torch
import json
from pathlib import Path
from datasets import Dataset
from transformers import (
    AutoTokenizer, AutoModelForCausalLM,
    TrainingArguments, Trainer, BitsAndBytesConfig,
    DataCollatorForLanguageModeling
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

def load_jsonl(path):
    with open(path, "r", encoding="utf-8") as f:
        return [json.loads(line.strip()) for line in f if line.strip()]

def train_model(dataset_path, base_model_path, output_path, epochs=3):
    # Load data
    data = load_jsonl(dataset_path)
    dataset = Dataset.from_list(data).shuffle(seed=42)
    split = dataset.train_test_split(test_size=0.10, seed=42)
    val_test = split["test"].train_test_split(test_size=0.5, seed=42)
    train_dataset = split["train"]
    val_dataset = val_test["train"]

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained("/mnt/data/llama2-model")
    tokenizer.pad_token = tokenizer.eos_token

    def tokenize(example):
        return tokenizer(
            example["text"],
            truncation=True,
            max_length=2048
        )

    train_dataset = train_dataset.map(tokenize, batched=True, num_proc=2, remove_columns=["text"])
    val_dataset = val_dataset.map(tokenize, batched=True, num_proc=2, remove_columns=["text"])

    # Data Collator
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False,
        pad_to_multiple_of=64
    )

    # Load Model with LoRA
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_quant_type="nf4"
    )

    base_model = AutoModelForCausalLM.from_pretrained(
        base_model_path,
        device_map="auto",
        quantization_config=bnb_config,
        torch_dtype=torch.bfloat16
    )

    base_model = prepare_model_for_kbit_training(base_model)
    base_model.gradient_checkpointing_enable()
    base_model.config.use_cache = False

    lora_config = LoraConfig(
        r=32,
        lora_alpha=64,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM"
    )

    model = get_peft_model(base_model, lora_config)

    # Training Args
    args = TrainingArguments(
        output_dir=output_path,
        num_train_epochs=epochs,
        per_device_train_batch_size=2,
        per_device_eval_batch_size=2,
        gradient_accumulation_steps=8,
        eval_strategy="epoch",
        save_strategy="epoch",
        save_total_limit=1,
        learning_rate=1e-5,
        lr_scheduler_type="cosine",
        logging_dir=f"{output_path}/logs",
        logging_steps=25,
        bf16=True,
        report_to="none",
        remove_unused_columns=False,
        dataloader_num_workers=2,
        group_by_length=True,
        optim="paged_adamw_32bit",
        max_grad_norm=1,
        warmup_ratio=0.03
    )

    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        tokenizer=tokenizer,
        data_collator=data_collator
    )

    trainer.train()
    trainer.save_model(f"{output_path}/final")
    tokenizer.save_pretrained(f"{output_path}/final")

    return trainer, model, tokenizer

In [None]:
# data_utils.py
import json
from pathlib import Path
from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer, util
from tqdm import tqdm
import numpy as np

MAX_TOKEN_LENGTH = 2048
SIM_THRESHOLD = 0.6
SYSTEM_PROMPT = (
    "You are a precise assistant. Extract the exact answer span from the context. "
    "Do not paraphrase, summarize, or add extra information. "
    "The answer must appear exactly in the context."
)

def select_relevant_chunks(context: str, answer: str, window_size=150, stride=100):
    words = context.split()
    for start in range(0, len(words), stride):
        end = start + window_size
        chunk = " ".join(words[start:end])
        if answer in chunk:
            return chunk
        if end >= len(words):
            break
    return None

def build_prompt(context: str, question: str, answer: str) -> str:
    user_prompt = (
        f"Context: {context}\n\n"
        f"Question: {question}\n"
        f"Answer:"
    )
    return f"<s>[INST] <<SYS>>\n{SYSTEM_PROMPT}\n<</SYS>>\n\n{user_prompt} [/INST] {answer}</s>"

def extract_context_and_answer(text):
    try:
        prompt_part, answer = text.split("[/INST]", 1)
        answer = answer.strip().replace("</s>", "")
        lines = prompt_part.splitlines()

        context_lines = []
        inside_context = False
        for line in lines:
            if line.strip().startswith("Context:"):
                inside_context = True
                context_lines.append(line.replace("Context:", "").strip())
                continue
            if line.strip().startswith("Question:"):
                break
            if inside_context:
                context_lines.append(line.strip())

        context = " ".join(context_lines)
        return context, answer
    except:
        return "", ""

def clean_answer(entry):
    try:
        text = entry.get("text", "")
        prompt_part, answer_part = text.split("[/INST]", 1)
        clean_answer = answer_part.strip().replace("</s>", "")
        return f"{prompt_part}[/INST] {clean_answer}</s>"
    except:
        return None

def preprocess_client_dataset(input_path: Path, output_path: Path, model_path: str):
    print(f"🔄 Preprocessing: {input_path}")
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    encoder = SentenceTransformer("all-MiniLM-L6-v2")

    reformatted_entries = []
    total_count = 0
    filtered_out_count = 0

    with open(input_path, "r", encoding="utf-8") as f:
        num_lines = sum(1 for _ in f)

    with open(input_path, "r", encoding="utf-8") as f:
        for line in tqdm(f, total=num_lines, desc="Reformatting"):
            total_count += 1
            try:
                entry = json.loads(line)
                original_text = entry["text"]

                prompt_part, answer = original_text.split("[/INST]", 1)
                answer = answer.strip().replace("</s>", "")
                if not answer:
                    continue

                lines = prompt_part.splitlines()
                context_lines, question = [], ""
                inside_context, inside_question = False, False

                for l in lines:
                    stripped = l.strip()
                    if stripped.startswith("Context:"):
                        inside_context = True
                        inside_question = False
                        context_lines.append(stripped.replace("Context:", "").strip())
                        continue
                    elif stripped.startswith("Question:"):
                        inside_question = True
                        inside_context = False
                        question = stripped.replace("Question:", "").strip()
                        continue
                    if inside_context:
                        context_lines.append(l)
                    elif inside_question and not question:
                        question = stripped

                full_context = "\n".join(context_lines).strip()
                if not full_context or not question:
                    continue

                temp_prompt = build_prompt(full_context, question, answer)
                input_ids = tokenizer(temp_prompt)["input_ids"]

                final_context = full_context
                if len(input_ids) > MAX_TOKEN_LENGTH:
                    short_context = select_relevant_chunks(full_context, answer)
                    if short_context and answer in short_context:
                        final_context = short_context
                    else:
                        filtered_out_count += 1
                        continue

                final_prompt = build_prompt(final_context, question, answer)
                context, ans = extract_context_and_answer(final_prompt)

                if not context or not ans:
                    continue

                context_emb = encoder.encode(context, convert_to_tensor=True)
                answer_emb = encoder.encode(ans, convert_to_tensor=True)
                similarity = util.cos_sim(context_emb, answer_emb).item()

                if similarity >= SIM_THRESHOLD:
                    cleaned = clean_answer({"text": final_prompt})
                    if cleaned:
                        reformatted_entries.append({"text": cleaned})

            except Exception as e:
                print(f"Skipping malformed line {total_count}: {e}")
                continue

    with open(output_path, "w", encoding="utf-8") as f:
        for e in reformatted_entries:
            f.write(json.dumps(e) + "\n")

    print(f"✅ Saved cleaned client data to {output_path}")
    print(f"📊 Total processed: {total_count}, Kept: {len(reformatted_entries)}, Filtered: {filtered_out_count}")

In [None]:
# client.py
import flwr as fl
import torch
from utils.trainer import get_trainer
import os


class LoraClient(fl.client.NumPyClient):
    def __init__(self, client_id, dataset_path, model_path, output_dir):
        self.client_id = client_id
        self.dataset_path = dataset_path
        self.model_path = model_path
        self.output_dir = output_dir

        self.trainer, self.model, self.tokenizer = get_trainer(
            client_id, dataset_path, model_path, output_dir
        )

    def get_parameters(self, config=None):
        return [val.cpu().numpy() for _, val in self.model.named_parameters() if val.requires_grad]

    def set_parameters(self, parameters):
        params_dict = dict(self.model.named_parameters())
        for (name, param), new_val in zip(params_dict.items(), parameters):
            if param.requires_grad:
                param.data = torch.tensor(new_val).to(param.device)

    def fit(self, parameters, config=None):
        print(f"🚀 [{self.client_id}] Starting training round...")
        self.set_parameters(parameters)
        self.trainer.train()
        return self.get_parameters(), len(self.trainer.train_dataset), {}

    def evaluate(self, parameters, config=None):
        return 0.0, 0, {}  # Optional: skip evaluation during simulation


def get_client_fn(client_id):
    def client_fn(cid):
        print(f"🚀 Launching client {client_id} ...")

        dataset_path = f"/mnt/data/federated_qa_jupyter/data/client{client_id}/client{client_id}_qa.jsonl"
        
        # ✅ LoRA-finetuned model (read-only)
        model_path = "/mnt/data/llama2_qa_lora_output5"

        # ✅ Save new adapters and logs here
        output_dir = f"/mnt/data/federated_qa_jupyter/checkpoints/client{client_id}"
        os.makedirs(output_dir, exist_ok=True)

        return LoraClient(
            client_id=f"client{client_id}",
            dataset_path=dataset_path,
            model_path=model_path,
        )
    return client_fn

get_client_fn(1)

In [None]:
from pathlib import Path
from data_utils import preprocess_client_dataset

base_model_path = "/mnt/data/llama2-model"  
input_client1 = Path("/mnt/data/client_raw/client1_raw.jsonl")
input_client2 = Path("/mnt/data/client_raw/client2_raw.jsonl")

output_client1 = Path("/mnt/data/federated_qa_jupyter/data/client1/client1_qa.jsonl")
output_client2 = Path("/mnt/data/federated_qa_jupyter/data/client2/client2_qa.jsonl")

# Format datasets for training
preprocess_client_dataset(input_client1, output_client1, base_model_path)
preprocess_client_dataset(input_client2, output_client2, base_model_path)

In [None]:
import flwr as fl

NUM_CLIENTS = 2

strategy = fl.server.strategy.FedAvg(
    fraction_fit=1.0,
    min_fit_clients=NUM_CLIENTS,
    min_available_clients=NUM_CLIENTS,
    on_fit_config_fn=lambda rnd: {"round": rnd}
)

fl.simulation.start_simulation(
    client_fn=get_client_fn,
    num_clients=NUM_CLIENTS,
    config=fl.server.ServerConfig(num_rounds=3),
    strategy=strategy,
    client_resources={"num_cpus": 2, "num_gpus": 1}
)

In [None]:
# federated_eval.py
from pathlib import Path
import json, re, time, math
from collections import defaultdict

import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

# Reuse your retrieval + prompt utilities (no model inside)
from answer_utils import (
    retrieve_with_rerank,
    truncate_and_filter_chunks,
    build_fusion_prompt,
    clean_prediction,
)

# Models
PRE_FL_MODEL_PATH = Path("/mnt/data/llama2_qa_lora_output5/final")  #  pre-FL fine-tuned model dir
POST_FL_MODEL_PATH = Path("/mnt/data/federated_qa_jupyter/model/federated_merged_model")  # merged FL model

# Datasets (expected format: {"question": "...", "answer": "..."} per line)
RAG_TEST_PATH = Path("/mnt/data/federated_qa_jupyter/data/test/federated_test_set.jsonl")
CLIENT_A_HOLDOUT = Path("/mnt/data/federated_qa_jupyter/data/test/clientA_holdout.jsonl")
CLIENT_B_HOLDOUT = Path("/mnt/data/federated_qa_jupyter/data/test/clientB_holdout.jsonl")

# Outputs
RESULTS_DIR = Path("/mnt/data/federated_qa_jupyter/results")
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
SUMMARY_CSV = RESULTS_DIR / "federated_eval_summary.csv"
DETAILS_JSONL = RESULTS_DIR / "federated_eval_details.jsonl"

# Retrieval windows used by your fusion pipeline
TOP_K = 6
MAX_WINDOWS = 5


# METRICS
def _normalize(s: str) -> str:
    s = s.lower().strip()
    s = re.sub(r"\s+", " ", s)
    s = re.sub(r"[^a-z0-9\-.,:/() ]", "", s)
    return s

def exact_match(pred, ref):
    return 1.0 if _normalize(pred) == _normalize(ref) else 0.0

def f1_score(pred, ref):
    p_tok = _normalize(pred).split()
    r_tok = _normalize(ref).split()
    if not p_tok and not r_tok:
        return 1.0
    if not p_tok or not r_tok:
        return 0.0
    common = 0
    r_counts = {}
    for t in r_tok: r_counts[t] = r_counts.get(t, 0) + 1
    for t in p_tok:
        if r_counts.get(t, 0) > 0:
            common += 1
            r_counts[t] -= 1
    if common == 0:
        return 0.0
    precision = common / len(p_tok)
    recall = common / len(r_tok)
    return 2 * precision * recall / (precision + recall)

def lcs_length(a, b):
    a, b = _normalize(a), _normalize(b)
    A, B = a.split(), b.split()
    dp = [[0]*(len(B)+1) for _ in range(len(A)+1)]
    for i in range(1, len(A)+1):
        for j in range(1, len(B)+1):
            if A[i-1] == B[j-1]:
                dp[i][j] = dp[i-1][j-1] + 1
            else:
                dp[i][j] = max(dp[i-1][j], dp[i][j-1])
    return dp[-1][-1], len(A), len(B)

def rouge_l(pred, ref):
    lcs, m, n = lcs_length(pred, ref)
    if m == 0 or n == 0 or lcs == 0:
        return 0.0
    prec = lcs / m
    rec = lcs / n
    if prec + rec == 0:
        return 0.0
    beta2 = 1.2**2
    return (1 + beta2) * prec * rec / (rec + beta2 * prec)

def bleu1(pred, ref):
    # simple BLEU-1 with brevity penalty
    p_tokens = _normalize(pred).split()
    r_tokens = _normalize(ref).split()
    if not p_tokens or not r_tokens:
        return 0.0
    ref_counts = {}
    for t in r_tokens: ref_counts[t] = ref_counts.get(t, 0) + 1
    match = 0
    used = {}
    for t in p_tokens:
        c = used.get(t, 0)
        if c < ref_counts.get(t, 0):
            match += 1
            used[t] = c + 1
    precision = match / len(p_tokens)
    # brevity penalty
    bp = 1.0 if len(p_tokens) > len(r_tokens) else math.exp(1 - len(r_tokens)/max(1, len(p_tokens)))
    return bp * precision

def compute_metrics(pred, ref):
    return {
        "EM": exact_match(pred, ref),
        "F1": f1_score(pred, ref),
        "ROUGE_L": rouge_l(pred, ref),
        "BLEU1": bleu1(pred, ref),
    }


# DATA HELPERS
def load_jsonl(path: Path):
    data = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            if line.strip():
                data.append(json.loads(line))
    return data


# INFERENCE HELPERS
def load_pipeline(model_path: Path):
    tok = AutoTokenizer.from_pretrained(model_path)
    # decoder-only best practice
    tok.pad_token = tok.eos_token
    tok.padding_side = "left"
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
        device_map="auto"
    )
    return pipeline("text-generation", model=model, tokenizer=tok)

def answer_with_rag(qa_pipe, question, top_k=TOP_K, max_windows=MAX_WINDOWS):
    # Retrieve & fuse context with your existing utilities
    initial_chunks = retrieve_with_rerank(question, top_k=top_k)
    final_chunks = truncate_and_filter_chunks(initial_chunks, question, max_windows=max_windows)
    prompt = build_fusion_prompt(final_chunks, question)

    out = qa_pipe(
        prompt,
        max_new_tokens=160,
        do_sample=False,
        eos_token_id=qa_pipe.tokenizer.eos_token_id,
        pad_token_id=qa_pipe.tokenizer.eos_token_id,
        repetition_penalty=1.05,
    )[0]["generated_text"]

    pred = clean_prediction(out)
    return pred, final_chunks

def answer_without_rag(qa_pipe, question):
    # Short, model-knowledge-only prompt
    sys = (
        "You are a precise telecom expert. Answer concisely and factually. "
        "If unsure, say 'I don't know'."
    )
    user = f"Question: {question}\nAnswer:"
    prompt = f"<s>[INST] <<SYS>>\n{sys}\n<</SYS>>\n\n{user} [/INST]"
    out = qa_pipe(
        prompt,
        max_new_tokens=160,
        do_sample=False,
        eos_token_id=qa_pipe.tokenizer.eos_token_id,
        pad_token_id=qa_pipe.tokenizer.eos_token_id,
        repetition_penalty=1.05,
    )[0]["generated_text"]
    # Extract everything after [/INST]
    pred = out.split("[/INST]")[-1].strip()
    # Trim at first sentence end to reduce rambling
    m = re.search(r"[.?!]", pred)
    if m:
        pred = pred[: m.end()]
    return pred


# RUN EVAL
def eval_one_dataset(qa_pipe, dataset, rag_mode, model_tag, dataset_name, details_fp):
    # rag_mode ∈ {"with_rag","no_rag"}
    agg = {"EM": [], "F1": [], "ROUGE_L": [], "BLEU1": []}
    for ex in dataset:
        q = ex["question"].strip()
        ref = ex["answer"].strip()
        try:
            if rag_mode == "with_rag":
                pred, ctx = answer_with_rag(qa_pipe, q)
            else:
                pred = answer_without_rag(qa_pipe, q)
                ctx = None
        except Exception as e:
            print(f"⚠️ Inference error for Q: {q[:80]}... -> {e}")
            pred, ctx = "", None

        m = compute_metrics(pred, ref)
        for k,v in m.items(): agg[k].append(v)

        # write details row
        record = {
            "model": model_tag,
            "rag_mode": rag_mode,
            "dataset": dataset_name,
            "question": q,
            "reference": ref,
            "prediction": pred,
            "metrics": m,
        }
        if ctx:
            record["contexts"] = [{"source": c.get("source",""), "snippet": c["content"][:300]} for c in ctx]
        details_fp.write(json.dumps(record) + "\n")

    # summarize
    return {
        "model": model_tag,
        "rag_mode": rag_mode,
        "dataset": dataset_name,
        "EM": np.mean(agg["EM"]) if agg["EM"] else 0.0,
        "F1": np.mean(agg["F1"]) if agg["F1"] else 0.0,
        "ROUGE_L": np.mean(agg["ROUGE_L"]) if agg["ROUGE_L"] else 0.0,
        "BLEU1": np.mean(agg["BLEU1"]) if agg["BLEU1"] else 0.0,
        "N": len(dataset),
    }

def main():
    # Load data
    datasets = []
    if RAG_TEST_PATH.exists():
        datasets.append(("R15_16_RAG100", load_jsonl(RAG_TEST_PATH)))
    if CLIENT_A_HOLDOUT.exists():
        datasets.append(("ClientA_holdout", load_jsonl(CLIENT_A_HOLDOUT)))
    if CLIENT_B_HOLDOUT.exists():
        datasets.append(("ClientB_holdout", load_jsonl(CLIENT_B_HOLDOUT)))

    if not datasets:
        raise FileNotFoundError("No datasets found. Check the dataset paths.")

    # Load both model pipelines
    print("🔹 Loading PRE-FL pipeline...")
    pre_pipe = load_pipeline(PRE_FL_MODEL_PATH)

    print("🔹 Loading POST-FL pipeline...")
    post_pipe = load_pipeline(POST_FL_MODEL_PATH)

    # Run all combinations
    rows = []
    with open(DETAILS_JSONL, "w", encoding="utf-8") as fp:
        for (ds_name, ds_data) in datasets:
            for rag_mode, pipe, tag in [
                ("no_rag", pre_pipe, "CentralFT"),
                ("with_rag", pre_pipe, "CentralFT"),
                ("no_rag", post_pipe, "FederatedFT"),
                ("with_rag", post_pipe, "FederatedFT"),
            ]:
                print(f"▶️  {tag} | {rag_mode} | {ds_name}  (N={len(ds_data)})")
                t0 = time.time()
                summary = eval_one_dataset(pipe, ds_data, rag_mode, tag, ds_name, fp)
                summary["secs"] = round(time.time() - t0, 2)
                rows.append(summary)

    # Write summary CSV
    with open(SUMMARY_CSV, "w", encoding="utf-8") as f:
        f.write("model,rag_mode,dataset,N,EM,F1,ROUGE_L,BLEU1,secs\n")
        for r in rows:
            f.write("{model},{rag_mode},{dataset},{N},{EM:.4f},{F1:.4f},{ROUGE_L:.4f},{BLEU1:.4f},{secs}\n".format(**r))

    print(f"✅ Done.\nSummary → {SUMMARY_CSV}\nDetails → {DETAILS_JSONL}")

if __name__ == "__main__":
    main()