# Notebook Summary — Federated LoRA + Central RAG Evaluation

This notebook sets up a federated fine-tuning experiment for a telecom QA model (LLaMA-2 + LoRA) and evaluates pre-federation vs post-federation adapters with and without central RAG. It includes environment setup, data preparation for two clients, Flower-based FL simulation, adapter materialization, and a standardized evaluation harness.

What this notebook does

Environment & Paths

Sets cache dirs, CUDA settings, and verifies key packages/versions and GPU availability.

Defines all project paths: base backbone, pre-FL adapters, output dirs, FAISS index & chunk store, client datasets, and results files.

Sanity Checks

Ensures required directories/files exist (base model, FAISS index, RAG chunks, checkpoints, results folders).

Client Data Preparation

Splits each client’s full packaged file into:

Train: packaged LLaMA-style {"text": "<s>[INST]...[/INST] ans </s>"} for FL.

Holdout: clean QA pairs {"question": "...", "answer": "..."} for evaluation.

Supports two input formats (packaged or instruction-style), with robust regex parsers.

RAG Stack (Central)

Loads FAISS index and serialized chunks; retrieves top candidates via SentenceTransformer embeddings and CrossEncoder reranking.

Slices candidates into high-overlap windows (lexical + TF-IDF), then builds a fusion prompt with source tags.

Clean, span-like prediction post-processing.

Model/Pipeline Loader

Loads full model or PEFT adapters on top of the base backbone.

Optional 4-bit quantization + offloading; ensures tokenizer padding and stable generation.

Two answer modes: with RAG (retrieval + fusion prompt) and no RAG (direct QA prompt).

(Optional) Client-side Preprocessing

Sliding-window context reduction to keep answers in-span.

Semantic similarity check (context vs answer) to retain only consistent rows.

Federated Learning (Flower)

Defines a LoRA client (NumPyClient) using HF Trainer (4-bit, grad checkpointing, cosine LR).

Runs FedAvg over 2 clients for 3 rounds (sequential on 1 GPU), saving aggregated parameters each round.

Custom strategy persists aggregated LoRA weights (.npz).

Materialize Federated Adapters

Rebuilds the PEFT model skeleton and writes merged post-FL adapters to disk (with tokenizer).

Quick Sanity Check

Asks a test question via RAG using pre-FL vs post-FL adapters.

Evaluation Harness

Implements lightweight EM, F1, ROUGE-L, BLEU-1.

Evaluates CentralFT (pre-FL) vs FederatedFT (post-FL), both no_rag and with_rag, across:

Optional RAG test set

Client A holdout

Client B holdout

Writes per-example details (DETAILS_JSONL) and a summary CSV with dataset-level averages and runtime.

In [1]:
import os

paths = {
    "HF_HOME": "/mnt/data/.cache/huggingface",
    "TRANSFORMERS_CACHE": "/mnt/data/.cache/huggingface/transformers",
    "HF_DATASETS_CACHE": "/mnt/data/.cache/huggingface/datasets",
    "TORCH_HOME": "/mnt/data/.cache/torch",
    "PIP_CACHE_DIR": "/mnt/data/.cache/pip",
}
for k, v in paths.items():
    os.environ[k] = v
    os.makedirs(v, exist_ok=True)

print({k: os.environ[k] for k in paths})

os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

{'HF_HOME': '/mnt/data/.cache/huggingface', 'TRANSFORMERS_CACHE': '/mnt/data/.cache/huggingface/transformers', 'HF_DATASETS_CACHE': '/mnt/data/.cache/huggingface/datasets', 'TORCH_HOME': '/mnt/data/.cache/torch', 'PIP_CACHE_DIR': '/mnt/data/.cache/pip'}


In [2]:
pkgs = {
    "transformers": "transformers",
    "peft": "peft",
    "bitsandbytes": "bitsandbytes",
    "datasets": "datasets",
    "faiss-cpu": "faiss",
    "sentence-transformers": "sentence_transformers",
    "scikit-learn": "sklearn",
    "accelerate": "accelerate",
}

from importlib.metadata import version, PackageNotFoundError

for dist, mod in pkgs.items():
    try:
        __import__(mod)
        print(f"{dist:22} {version(dist)}")
    except Exception:
        print(f"{dist:22} (not installed)")



transformers           4.53.1
peft                   0.9.0
bitsandbytes           0.46.1
datasets               3.6.0
faiss-cpu              1.11.0.post1
sentence-transformers  5.0.0
scikit-learn           1.6.1
accelerate             1.8.1


In [3]:
import os, shutil, subprocess, torch
print("CUDA:", torch.cuda.is_available(), "| GPU:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "-")
print("HF_HOME:", os.environ.get("HF_HOME"))
print("Kernel path:", os.__file__)
print(subprocess.check_output("df -h | sed -n '1p; /\\/\\|\\/mnt\\/data/p'", shell=True).decode())
import flwr, ray
print("flwr:", flwr.__version__, "| ray:", ray.__version__)

CUDA: True | GPU: NVIDIA A10G
HF_HOME: /mnt/data/.cache/huggingface
Kernel path: /usr/lib64/python3.9/os.py
Filesystem        Size  Used Avail Use% Mounted on
devtmpfs          4.0M     0  4.0M   0% /dev
tmpfs              16G     0   16G   0% /dev/shm
tmpfs             6.2G  8.6M  6.2G   1% /run
/dev/nvme0n1p1     60G   60G  238M 100% /
tmpfs              16G  2.1M   16G   1% /tmp
/dev/nvme1n1      148G   62G   80G  44% /mnt/data
/dev/nvme0n1p128   10M  1.3M  8.7M  13% /boot/efi
tmpfs             3.1G  8.0K  3.1G   1% /run/user/0
tmpfs             3.1G  8.0K  3.1G   1% /run/user/1000

flwr: 1.8.0 | ray: 2.6.3


In [4]:
BASE_DIR = "/mnt/data/federated_qa_jupyter"  
PRE_FL_MODEL_PATH = "/mnt/data/llama2_qa_lora_output5/final"   
POST_FL_MODEL_PATH = f"{BASE_DIR}/model/federated_merged_model"  

# Central RAG
FAISS_INDEX_PATH = "/mnt/data/RAG/3gpp_index.faiss"
CHUNKS_PKL_PATH  = "/mnt/data/RAG/3gpp_chunks.pkl"

# FULL packaged files per client
CLIENT1_FULL_PACKAGED = f"{BASE_DIR}/data/client1/client1_qa.jsonl"
CLIENT2_FULL_PACKAGED = f"{BASE_DIR}/data/client2/client2_qa.jsonl"

# Train outputs
CLIENT1_TRAIN = f"{BASE_DIR}/data/client1/client1_train.jsonl"
CLIENT2_TRAIN = f"{BASE_DIR}/data/client2/client2_train.jsonl"

# Holdout outputs (QA format)
CLIENT_A_HOLDOUT = f"{BASE_DIR}/data/test/clientA_holdout.jsonl"
CLIENT_B_HOLDOUT = f"{BASE_DIR}/data/test/clientB_holdout.jsonl"

# Optional third test set
RAG_TEST_PATH    = f"{BASE_DIR}/data/test/federated_test_set.jsonl"

# Results
SUMMARY_CSV   = f"{BASE_DIR}/results/federated_eval_summary.csv"
DETAILS_JSONL = f"{BASE_DIR}/results/federated_eval_details.jsonl"

# Flower checkpoints/logs
CKPT_CLIENT1 = f"{BASE_DIR}/checkpoints/client1"
CKPT_CLIENT2 = f"{BASE_DIR}/checkpoints/client2"
CHECKPOINTS_DIR = f"{BASE_DIR}/checkpoints"

# Base backbone
BASE_BACKBONE = "/mnt/data/llama2-model"

# RAG retrieval knobs
TOP_K = 6
MAX_WINDOWS = 5
MAX_NEW_TOKENS = 160

# sanity checks
import os, sys
def _must_dir(p, what="dir"):
    if not os.path.isdir(p):
        raise FileNotFoundError(f"Please create {what}: {p}")

def _must_file(p, what="file"):
    if not os.path.exists(p):
        raise FileNotFoundError(f"Missing {what}: {p}")

_must_dir(BASE_DIR, "BASE_DIR")
_must_file(FAISS_INDEX_PATH, "FAISS_INDEX_PATH")
_must_file(CHUNKS_PKL_PATH, "CHUNKS_PKL_PATH")
_must_dir(CHECKPOINTS_DIR, "CHECKPOINTS_DIR")
_must_dir(BASE_BACKBONE, "BASE_BACKBONE")
_must_dir(PRE_FL_MODEL_PATH, "PRE_FL_MODEL_PATH (adapters dir)")


for p in [CLIENT1_TRAIN, CLIENT2_TRAIN, CLIENT_A_HOLDOUT, CLIENT_B_HOLDOUT,
          SUMMARY_CSV, DETAILS_JSONL, POST_FL_MODEL_PATH]:
    parent = os.path.dirname(p)
    if parent and not os.path.isdir(parent):
        raise FileNotFoundError(f"Create parent folder first: {parent}")

In [5]:
#  Split each client's full file into:
#     - train (packaged {"text": "<s>[INST] ... [/INST] ans </s>"}) -> CLIENT1_TRAIN / CLIENT2_TRAIN
#     - holdout (QA {"question": "...", "answer": "..."})          -> CLIENT_A_HOLDOUT / CLIENT_B_HOLDOUT

import json, random, os, re

PACK_SYS = "You are a precise assistant. Extract the exact answer span from the context."

def _read_jsonl(path):
    data = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            if line.strip():
                data.append(json.loads(line))
    return data

def _write_jsonl(path, rows):
    with open(path, "w", encoding="utf-8") as f:
        for r in rows:
            f.write(json.dumps(r) + "\n")

# parsers

def _parse_from_packaged_text(txt: str):
    """Packaged: '<s>[INST] ... Context: ... Question: ... [/INST] ANSWER </s>'"""
    q = a = c = None
    try:
        prompt_part, answer_part = txt.split("[/INST]", 1)
        a = answer_part.replace("</s>", "").strip()
        # context
        m_ctx = re.search(r"Context:\s*(.*)", prompt_part, flags=re.IGNORECASE|re.DOTALL)
        if m_ctx:
            # stop before 'Question:' if present
            c = m_ctx.group(1)
            c = re.split(r"\n\s*Question:", c, flags=re.IGNORECASE, maxsplit=1)[0].strip()
        # question
        m_q = re.search(r"Question:\s*(.*)", prompt_part, flags=re.IGNORECASE|re.DOTALL)
        if m_q:
            q = m_q.group(1).strip()
            q = re.split(r"\n\s*(?:Answer:|</s>)", q, flags=re.IGNORECASE, maxsplit=1)[0].strip()
    except:
        pass
    return c, q, a

def _parse_from_instruct(entry: dict):
    """
    Instruct-style:
    {"instruction": "...", "input": "### Task: ...\n### Context:\n...\n\n### Question:\nquestion: ...\n\n### Answer:", "output": "..."}
    """
    if not isinstance(entry.get("input", None), str) or not isinstance(entry.get("output", None), str):
        return None, None, None
    inp = entry["input"]
    # context
    m_ctx = re.search(r"###\s*Context:\s*(.*?)\n\s*###\s*Question:", inp, flags=re.IGNORECASE|re.DOTALL)
    # question
    m_q = re.search(r"###\s*Question:\s*(.*?)(?:\n\s*###\s*Answer:|\Z)", inp, flags=re.IGNORECASE|re.DOTALL)
    c = m_ctx.group(1).strip() if m_ctx else None
    q = m_q.group(1).strip() if m_q else None
    if q:
        q = re.sub(r"^\s*question:\s*", "", q, flags=re.IGNORECASE).strip()
    a = entry["output"].strip()
    return c, q, a

def _make_packaged_text(context: str, question: str, answer: str) -> str:
    return (
        f"<s>[INST] <<SYS>>\n{PACK_SYS}\n<</SYS>>\n\n"
        f"Context: {context}\n\n"
        f"Question: {question}\nAnswer: [/INST] {answer}</s>"
    )

# main splitter

def split_client_file(
    full_path: str,
    train_out_path: str,
    holdout_out_path: str,
    test_size: float = 0.2,
    seed: int = 42
):
    rows = _read_jsonl(full_path)
    if not rows:
        raise ValueError(f"No lines found in {full_path}")

    packaged_texts = []   # [{"text": "..."}]
    qa_pairs = []         # [{"question": "...", "answer": "..."}] (same index as packaged_texts)

    for r in rows:
        txt = None
        c = q = a = None

        if "text" in r:  # already packaged
            txt = r["text"]
            c, q, a = _parse_from_packaged_text(txt)
            # if packaged text didn’t parse, we’ll still use it for training; skip QA
        elif {"instruction","input","output"} <= set(r.keys()):
            c, q, a = _parse_from_instruct(r)
            if c and q and a:
                txt = _make_packaged_text(c, q, a)

        if txt:
            packaged_texts.append({"text": txt})
            qa_pairs.append({"question": q, "answer": a} if (q and a) else None)

    if not packaged_texts:
        raise ValueError(f"No usable examples in {full_path}")

    # shuffle/split indices
    idx = list(range(len(packaged_texts)))
    random.Random(seed).shuffle(idx)
    n_hold = max(1, int(round(len(idx) * test_size)))
    hold_idx = set(idx[:n_hold])
    train_idx = [i for i in idx if i not in hold_idx]

    # write train
    _write_jsonl(train_out_path, [packaged_texts[i] for i in train_idx])

    # write holdout
    holdout = [qa_pairs[i] for i in sorted(hold_idx) if qa_pairs[i] is not None]
    if not holdout:
        raise ValueError("Holdout split produced 0 parsable QA pairs; check input formatting.")
    _write_jsonl(holdout_out_path, holdout)

    print(f"✅ Split {full_path}")
    print(f"   Train:   {len(train_idx)} → {train_out_path}")
    print(f"   Holdout: {len(holdout)}  → {holdout_out_path} (from {n_hold} selected)")

#  Run splits for both clients
split_client_file(
    CLIENT1_FULL_PACKAGED,  
    CLIENT1_TRAIN,          # packaged output for FL training
    CLIENT_A_HOLDOUT,       # QA output for eval
    test_size=0.2, seed=42
)

split_client_file(
    CLIENT2_FULL_PACKAGED,
    CLIENT2_TRAIN,
    CLIENT_B_HOLDOUT,
    test_size=0.2, seed=42
)

✅ Split /mnt/data/federated_qa_jupyter/data/client1/client1_qa.jsonl
   Train:   800 → /mnt/data/federated_qa_jupyter/data/client1/client1_train.jsonl
   Holdout: 200  → /mnt/data/federated_qa_jupyter/data/test/clientA_holdout.jsonl (from 200 selected)
✅ Split /mnt/data/federated_qa_jupyter/data/client2/client2_qa.jsonl
   Train:   800 → /mnt/data/federated_qa_jupyter/data/client2/client2_train.jsonl
   Holdout: 200  → /mnt/data/federated_qa_jupyter/data/test/clientB_holdout.jsonl (from 200 selected)


In [6]:
import json, re, time, math, os
from pathlib import Path
from collections import defaultdict

import numpy as np
import torch

# Transformers / datasets
from datasets import Dataset
from transformers import (
    AutoTokenizer, AutoModelForCausalLM, pipeline,
    TrainingArguments, Trainer, BitsAndBytesConfig,
    DataCollatorForLanguageModeling
)

# PEFT
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel

# RAG stack
import faiss, pickle
from sentence_transformers import SentenceTransformer, CrossEncoder, util
from sklearn.feature_extraction.text import TfidfVectorizer
from difflib import SequenceMatcher

# NLTK stopwords (once)
import nltk
try:
    from nltk.corpus import stopwords
    _ = stopwords.words("english")
except LookupError:
    nltk.download("stopwords")
    from nltk.corpus import stopwords

In [15]:
#  Load FAISS + chunks once
index = faiss.read_index(FAISS_INDEX_PATH)
with open(CHUNKS_PKL_PATH, "rb") as f:
    DOCUMENTS = pickle.load(f)

EMBED_MODEL = SentenceTransformer("all-MiniLM-L6-v2")
RERANKER    = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
STOPWORDS   = set(stopwords.words("english"))

def _normalize_text(t):
    return re.sub(r'\W+', ' ', t.lower())

def is_similar(a, b, th=0.75):
    return SequenceMatcher(None, a, b).ratio() >= th

def retrieve_with_rerank(query, top_k=5):
    qv = EMBED_MODEL.encode(query, normalize_embeddings=True)
    qv = np.array(qv, dtype="float32")[None, :]
    D, I = index.search(qv, top_k * 2)
    cand = [DOCUMENTS[i] for i in I[0]]
    pairs = [(query, c["content"]) for c in cand]
    scores = RERANKER.predict(pairs)
    reranked = sorted(zip(scores, cand), key=lambda x: x[0], reverse=True)[:top_k]
    return [c for _, c in reranked]

def truncate_and_filter_chunks(chunks, query, window_size=150, stride=75, max_windows=5):
    def lexical_overlap(q, span):
        q_tok = set(_normalize_text(q).split()) - STOPWORDS
        s_tok = set(_normalize_text(span).split()) - STOPWORDS
        return len(q_tok & s_tok) / (len(q_tok | s_tok) + 1e-5)

    def tfidf_score(q, span):
        vec = TfidfVectorizer().fit([q, span])
        X = vec.transform([q, span])
        return (X[0] @ X[1].T).A[0][0]

    scored = []
    for ch in chunks:
        words = ch["content"].split()
        for i in range(0, len(words), stride):
            w = words[i:i+window_size]
            if len(w) < 30: 
                continue
            span = " ".join(w)
            s = 0.6 * lexical_overlap(query, span) + 0.4 * tfidf_score(query, span)
            scored.append({"content": span, "score": s, "source": ch.get("source","unknown")})
    return sorted(scored, key=lambda x: x["score"], reverse=True)[:max_windows]

def build_fusion_prompt(context_chunks, question):
    SYS = ("You are a precise assistant. Extract the exact answer span from the context. "
           "Do not paraphrase or add anything. Copy the exact text from context.")
    ctx_lines = []
    for ch in context_chunks:
        src = ch.get("source", "unknown").split("/")[-1]
        ctx_lines.append(f"[Source: {src}]\n-----\n{ch['content'].strip()}")
    fused = "\n\n".join(ctx_lines)
    user = (f"Context:\n{fused}\n\n"
            f"Question: {question}\n"
            f"Answer from the context only:")
    return f"<s>[INST] <<SYS>>\n{SYS}\n<</SYS>>\n\n{user} [/INST]"

def clean_prediction(raw_text):
    ans = raw_text.split("[/INST]")[-1].strip()
    ans = re.sub(r"[^\w\s\-.,:/()]", "", ans)
    tok = ans.split()
    for i in range(1, len(tok)//2):
        if tok[:i] == tok[i:2*i]:
            ans = " ".join(tok[:i]); break
    m = re.search(r"[.?!]", ans)
    if m: ans = ans[:m.end()]
    return ans.strip()

def _is_adapter_dir(path: str) -> bool:
    return os.path.isfile(os.path.join(path, "adapter_config.json"))

def load_pipeline(
    model_path: str,
    base_backbone: str = BASE_BACKBONE,
    use_4bit: bool = True,
    offload_dir: str = "/mnt/data/offload",
):
    os.makedirs(offload_dir, exist_ok=True)

    # tokenizer
    tok_src = model_path if os.path.exists(os.path.join(model_path, "tokenizer.json")) else base_backbone
    tok = AutoTokenizer.from_pretrained(tok_src)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token

    quant = None
    if use_4bit:
        quant = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
            bnb_4bit_quant_type="nf4",
        )

    if _is_adapter_dir(model_path):
        # Load base, then attach adapters
        base = AutoModelForCausalLM.from_pretrained(
            base_backbone,
            torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
            quantization_config=quant,
            device_map="auto",
            low_cpu_mem_usage=True,
            offload_folder=offload_dir,   # HF offload for base
        )
        model = PeftModel.from_pretrained(
            base,
            model_path,
            device_map="auto",
            offload_dir=offload_dir,      # Accelerate offload for PEFT
        )
    else:
        # Full fine-tuned model dir
        model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
            quantization_config=quant,
            device_map="auto",
            low_cpu_mem_usage=True,
            offload_folder=offload_dir,
        )

    pipe = pipeline("text-generation", model=model, tokenizer=tok)  # no device arg with Accelerate
    tok.padding_side = "left"  # good for decoder-only LMs
    pipe.model.config.pad_token_id = tok.pad_token_id or tok.eos_token_id
    return pipe

def answer_with_rag(qa_pipe, question, top_k=TOP_K, max_windows=MAX_WINDOWS):
    initial = retrieve_with_rerank(question, top_k=top_k)
    final   = truncate_and_filter_chunks(initial, question, max_windows=max_windows)
    prompt  = build_fusion_prompt(final, question)
    out = qa_pipe(
        prompt,
        max_new_tokens=MAX_NEW_TOKENS, do_sample=False,
        eos_token_id=qa_pipe.tokenizer.eos_token_id,
        pad_token_id=qa_pipe.tokenizer.eos_token_id,
        repetition_penalty=1.05,
    )[0]["generated_text"]
    pred = clean_prediction(out)
    return pred, final

def answer_without_rag(qa_pipe, question):
    sys = "You are a precise telecom expert. Be concise; if unsure, say 'I don't know'."
    user = f"Question: {question}\nAnswer:"
    prompt = f"<s>[INST] <<SYS>>\n{sys}\n<</SYS>>\n\n{user} [/INST]"
    out = qa_pipe(
        prompt,
        max_new_tokens=MAX_NEW_TOKENS, do_sample=False,
        eos_token_id=qa_pipe.tokenizer.eos_token_id,
        pad_token_id=qa_pipe.tokenizer.eos_token_id,
        repetition_penalty=1.05,
    )[0]["generated_text"]
    pred = out.split("[/INST]")[-1].strip()
    m = re.search(r"[.?!]", pred)
    if m: pred = pred[:m.end()]
    return pred

In [17]:
def load_pipeline(
    model_path: str,
    base_backbone: str = BASE_BACKBONE,
    use_4bit: bool = True,
    offload_dir: str = "/mnt/data/offload",
):
    os.makedirs(offload_dir, exist_ok=True)

    # tokenizer
    tok_src = model_path if os.path.exists(os.path.join(model_path, "tokenizer.json")) else base_backbone
    tok = AutoTokenizer.from_pretrained(tok_src)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token

    quant = None
    if use_4bit:
        quant = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
            bnb_4bit_quant_type="nf4",
        )

    if _is_adapter_dir(model_path):
        # Load base, then attach adapters
        base = AutoModelForCausalLM.from_pretrained(
            base_backbone,
            torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
            quantization_config=quant,
            device_map="auto",
            low_cpu_mem_usage=True,
            offload_folder=offload_dir,   # HF offload for base
        )
        model = PeftModel.from_pretrained(
            base,
            model_path,
            device_map="auto",
            offload_dir=offload_dir,      # Accelerate offload for PEFT
        )
    else:
        # Full fine-tuned model dir
        model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
            quantization_config=quant,
            device_map="auto",
            low_cpu_mem_usage=True,
            offload_folder=offload_dir,
        )

    pipe = pipeline("text-generation", model=model, tokenizer=tok)  # no device arg with Accelerate
    tok.padding_side = "left"  # good for decoder-only LMs
    pipe.model.config.pad_token_id = tok.pad_token_id or tok.eos_token_id
    return pipe

In [8]:
MAX_TOKEN_LENGTH = 2048
SIM_THRESHOLD = 0.6

def select_relevant_chunks(context: str, answer: str, window_size=150, stride=100):
    words = context.split()
    for start in range(0, len(words), stride):
        end = start + window_size
        chunk = " ".join(words[start:end])
        if answer in chunk:
            return chunk
        if end >= len(words):
            break
    return None

def _build_prompt(context: str, question: str, answer: str) -> str:
    SYS = ("You are a precise assistant. Extract the exact answer span from the context. "
           "Do not paraphrase or add info. The answer must appear exactly in the context.")
    user = f"Context: {context}\n\nQuestion: {question}\nAnswer:"
    return f"<s>[INST] <<SYS>>\n{SYS}\n<</SYS>>\n\n{user} [/INST] {answer}</s>"

def _extract_context_and_answer(text):
    try:
        prompt_part, answer = text.split("[/INST]", 1)
        answer = answer.strip().replace("</s>","")
        lines = prompt_part.splitlines()
        ctx_lines, inside_ctx = [], False
        for l in lines:
            s = l.strip()
            if s.startswith("Context:"):
                inside_ctx = True
                ctx_lines.append(s.replace("Context:","").strip())
                continue
            if s.startswith("Question:"):
                break
            if inside_ctx:
                ctx_lines.append(s)
        return " ".join(ctx_lines), answer
    except:
        return "", ""

def _clean_answer_entry(entry):
    try:
        text = entry.get("text","")
        pp, ap = text.split("[/INST]", 1)
        ca = ap.strip().replace("</s>","")
        return f"{pp}[/INST] {ca}</s>"
    except:
        return None

def preprocess_client_dataset(input_path: str, output_path: str, model_path: str):
    print(f"🔄 Preprocessing: {input_path}")
    tokenizer = AutoTokenizer.from_pretrained(model_path) 
    encoder   = SentenceTransformer("all-MiniLM-L6-v2")

    reformat, total, filtered = [], 0, 0
    with open(input_path, "r", encoding="utf-8") as f:
        lines = [l for l in f if l.strip()]
    for line in lines:
        total += 1
        try:
            entry = json.loads(line)
            original_text = entry["text"]
            prompt_part, answer = original_text.split("[/INST]", 1)
            answer = answer.strip().replace("</s>","")
            if not answer: 
                continue

            lines2 = prompt_part.splitlines()
            ctx_lines, question = [], ""
            inside_ctx, inside_q = False, False
            for l in lines2:
                s = l.strip()
                if s.startswith("Context:"):
                    inside_ctx, inside_q = True, False
                    ctx_lines.append(s.replace("Context:","").strip()); continue
                elif s.startswith("Question:"):
                    inside_q, inside_ctx = True, False
                    question = s.replace("Question:","").strip(); continue
                if inside_ctx:
                    ctx_lines.append(l)
                elif inside_q and not question:
                    question = s
            full_ctx = "\n".join(ctx_lines).strip()
            if not full_ctx or not question:
                continue

            tmp_prompt = _build_prompt(full_ctx, question, answer)
            input_ids = tokenizer(tmp_prompt)["input_ids"]
            final_ctx = full_ctx
            if len(input_ids) > MAX_TOKEN_LENGTH:
                short_ctx = select_relevant_chunks(full_ctx, answer)
                if short_ctx and answer in short_ctx:
                    final_ctx = short_ctx
                else:
                    filtered += 1; continue

            final_prompt = _build_prompt(final_ctx, question, answer)
            ctx, ans = _extract_context_and_answer(final_prompt)
            if not ctx or not ans: 
                continue

            sim = util.cos_sim(
                encoder.encode(ctx, convert_to_tensor=True),
                encoder.encode(ans, convert_to_tensor=True)
            ).item()
            if sim >= SIM_THRESHOLD:
                cleaned = _clean_answer_entry({"text": final_prompt})
                if cleaned:
                    reformat.append({"text": cleaned})
        except Exception as e:
            continue

    with open(output_path, "w", encoding="utf-8") as f:
        for e in reformat:
            f.write(json.dumps(e) + "\n")
    print(f"✅ Saved cleaned data → {output_path} | kept {len(reformat)}/{total}, filtered {filtered}")

In [9]:
import sys, torch
print("Python:", sys.version)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))

import flwr as fl, ray
print("flwr:", fl.__version__, "| ray:", ray.__version__)
ray.init(ignore_reinit_error=True, include_dashboard=False)
ray.shutdown()

Python: 3.9.22 (main, Apr 29 2025, 00:00:00) 
[GCC 11.5.0 20240719 (Red Hat 11.5.0-5)]
CUDA available: True
GPU: NVIDIA A10G
flwr: 1.8.0 | ray: 2.6.3


2025-08-10 16:15:29,683	INFO worker.py:1621 -- Started a local Ray instance.


In [10]:
import flwr as fl

def get_trainer(client_id: str, dataset_path: str, adapter_init_path: str, output_dir: str, epochs: int = 1):
    # load data
    with open(dataset_path, "r", encoding="utf-8") as f:
        data = [json.loads(l) for l in f if l.strip()]
    ds = Dataset.from_list(data).shuffle(seed=42)
    split = ds.train_test_split(test_size=0.10, seed=42)
    val_test = split["test"].train_test_split(test_size=0.5, seed=42)
    train_ds = split["train"]; eval_ds = val_test["train"]

    # tokenizer
    tok_src = adapter_init_path if os.path.exists(os.path.join(adapter_init_path, "tokenizer.json")) else BASE_BACKBONE
    tokenizer = AutoTokenizer.from_pretrained(tok_src)
    tokenizer.pad_token = tokenizer.eos_token

    def tok(ex): return tokenizer(ex["text"], truncation=True, max_length=2048)
    train_ds = train_ds.map(tok, batched=True, remove_columns=["text"])
    eval_ds  = eval_ds.map(tok,  batched=True, remove_columns=["text"])

    collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False, pad_to_multiple_of=64)

    bnb = BitsAndBytesConfig(
        load_in_4bit=True, bnb_4bit_use_double_quant=True,
        bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4"
    )
    base = AutoModelForCausalLM.from_pretrained(
        BASE_BACKBONE, device_map="auto", quantization_config=bnb, torch_dtype=torch.bfloat16
    )
    base = prepare_model_for_kbit_training(base)
    base.gradient_checkpointing_enable()
    base.config.use_cache = False

    model = PeftModel.from_pretrained(base, adapter_init_path, is_trainable=True)

    args = TrainingArguments(
        output_dir=output_dir, num_train_epochs=epochs,
        per_device_train_batch_size=2, per_device_eval_batch_size=2,
        gradient_accumulation_steps=8, eval_strategy="epoch", save_strategy="no",
        learning_rate=1e-5, lr_scheduler_type="cosine",
        logging_dir=f"{output_dir}/logs", logging_steps=25,
        bf16=True, report_to="none", remove_unused_columns=False,
        dataloader_num_workers=2, group_by_length=True, optim="paged_adamw_32bit",
        max_grad_norm=1.0, warmup_ratio=0.03,
    )

    trainer = Trainer(
        model=model, args=args,
        train_dataset=train_ds, eval_dataset=eval_ds,
        tokenizer=tokenizer, data_collator=collator
    )
    return trainer, model, tokenizer

class LoraClient(fl.client.NumPyClient):
    def __init__(self, client_id, dataset_path, adapter_init_path, output_dir):
        self.client_id = client_id
        self.dataset_path = dataset_path
        self.adapter_init_path = adapter_init_path
        self.output_dir = output_dir
        self.trainer, self.model, self.tokenizer = get_trainer(
            client_id, dataset_path, adapter_init_path, output_dir
        )

    def get_parameters(self, config=None):
        return [p.detach().cpu().numpy() for _, p in self.model.named_parameters() if p.requires_grad]

    def set_parameters(self, parameters):
        params = dict(self.model.named_parameters())
        i = 0
        for name, p in params.items():
            if p.requires_grad:
                arr = torch.tensor(parameters[i])
                p.data = arr.to(p.device, dtype=p.dtype)
                i += 1

    def fit(self, parameters, config=None):
        self.set_parameters(parameters)
        self.trainer.train()
        return self.get_parameters(), len(self.trainer.train_dataset), {}

    def evaluate(self, parameters, config=None):
        self.set_parameters(parameters)
        metrics = self.trainer.evaluate()
        loss = float(metrics.get("eval_loss", 0.0))
        n = len(self.trainer.eval_dataset)
        return loss, n, {"eval_runtime": float(metrics.get("eval_runtime", 0.0))}

# caching to reuse clients across rounds/messages
CLIENT_CACHE = {}

def get_client_fn(which: int):
    def client_fn(cid: str):
        key = f"{which}-{cid}"
        if key in CLIENT_CACHE:
            return CLIENT_CACHE[key]
        ds  = CLIENT1_TRAIN if which == 1 else CLIENT2_TRAIN
        out = CKPT_CLIENT1  if which == 1 else CKPT_CLIENT2
        print(f"🚀 Launching client{which}")
        cli = LoraClient(
            client_id=f"client{which}",
            dataset_path=ds,
            adapter_init_path=PRE_FL_MODEL_PATH,
            output_dir=out,
        ).to_client()  # silence deprecation
        CLIENT_CACHE[key] = cli
        return cli
    return client_fn

In [11]:
from flwr.common import parameters_to_ndarrays
import numpy as np

if not os.path.isdir(CHECKPOINTS_DIR):
    raise FileNotFoundError(f"Please create {CHECKPOINTS_DIR} before running FL")

class SaveFedAvg(fl.server.strategy.FedAvg):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.latest_parameters = None  # list of ndarrays (LoRA trainables)

    def aggregate_fit(self, rnd, results, failures):
        aggregated, metrics = super().aggregate_fit(rnd, results, failures)
        if aggregated is not None:
            nds = parameters_to_ndarrays(aggregated)
            self.latest_parameters = nds
            np.savez(f"{CHECKPOINTS_DIR}/aggregated_round{rnd}.npz", *nds)
            print(f"💾 Saved aggregated parameters (round {rnd}) into {CHECKPOINTS_DIR}")
        return aggregated, metrics

NUM_CLIENTS = 2
strategy = SaveFedAvg(
    fraction_fit=1.0,
    min_fit_clients=NUM_CLIENTS,
    min_available_clients=NUM_CLIENTS,
    on_fit_config_fn=lambda rnd: {"round": rnd}
)

hist = fl.simulation.start_simulation(
    client_fn=lambda cid: get_client_fn(1 if cid == "0" else 2)(cid),
    num_clients=2,
    config=fl.server.ServerConfig(num_rounds=3),
    strategy=strategy,
    client_resources={"num_cpus": 2, "num_gpus": 1.0},   # <- sequential on 1 GPU
)

[92mINFO [0m:      Starting Flower simulation, config: num_rounds=3, no round_timeout
2025-08-10 16:15:37,572	INFO worker.py:1621 -- Started a local Ray instance.
[92mINFO [0m:      Flower VCE: Ray initialized with resources: {'CPU': 8.0, 'memory': 17116424603.0, 'node:__internal_head__': 1.0, 'object_store_memory': 8558212300.0, 'accelerator_type:A10G': 1.0, 'node:172.31.36.122': 1.0, 'GPU': 1.0}
[92mINFO [0m:      Optimize your simulation with Flower VCE: https://flower.ai/docs/framework/how-to-run-simulations.html
[92mINFO [0m:      Flower VCE: Resources for each Virtual Client: {'num_cpus': 2, 'num_gpus': 1.0}
[92mINFO [0m:      Flower VCE: Creating VirtualClientEngineActorPool with 1 actors
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client


[2m[36m(ClientAppActor pid=10689)[0m 🚀 Launching client2


Map:   0%|          | 0/720 [00:00<?, ? examples/s]
Map: 100%|██████████| 720/720 [00:00<00:00, 5543.12 examples/s]
Map: 100%|██████████| 40/40 [00:00<00:00, 3996.95 examples/s]
Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]
Loading checkpoint shards:  50%|█████     | 1/2 [00:08<00:08,  8.83s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:11<00:00,  5.88s/it]
[2m[36m(ClientAppActor pid=10689)[0m No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.
[92mINFO [0m:      Received initial parameters from one random client
[92mINFO [0m:      Evaluating initial global parameters
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 2)


[2m[36m(ClientAppActor pid=10689)[0m 🚀 Launching client2


Map:   0%|          | 0/720 [00:00<?, ? examples/s]
Map: 100%|██████████| 720/720 [00:00<00:00, 5529.54 examples/s]
Map:   0%|          | 0/40 [00:00<?, ? examples/s]
Map: 100%|██████████| 40/40 [00:00<00:00, 3745.92 examples/s]
Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]
Loading checkpoint shards:  50%|█████     | 1/2 [00:08<00:08,  8.81s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:11<00:00,  5.86s/it]
[2m[36m(ClientAppActor pid=10689)[0m No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.
  0%|          | 0/45 [00:00<?, ?it/s]m 
  2%|▏         | 1/45 [00:08<06:13,  8.49s/it]
  4%|▍         | 2/45 [00:16<05:48,  8.12s/it]
  7%|▋         | 3/45 [00:23<05:20,  7.63s/it]
  9%|▉         | 4/45 [00:29<04:50,  7.10s/it]
 11%|█         | 5/45 [00:35<04:3

[2m[36m(ClientAppActor pid=10689)[0m {'loss': 1.1551, 'grad_norm': 1.42552649974823, 'learning_rate': 4.817389884711706e-06, 'epoch': 0.56}


 58%|█████▊    | 26/45 [02:54<02:08,  6.76s/it]
 60%|██████    | 27/45 [03:00<01:59,  6.61s/it]
 62%|██████▏   | 28/45 [03:07<01:50,  6.51s/it]
 64%|██████▍   | 29/45 [03:13<01:43,  6.44s/it]
 67%|██████▋   | 30/45 [03:19<01:35,  6.39s/it]
 69%|██████▉   | 31/45 [03:26<01:28,  6.36s/it]
 71%|███████   | 32/45 [03:32<01:22,  6.33s/it]
 73%|███████▎  | 33/45 [03:38<01:15,  6.32s/it]
 76%|███████▌  | 34/45 [03:46<01:14,  6.77s/it]
 78%|███████▊  | 35/45 [03:53<01:09,  6.97s/it]
 80%|████████  | 36/45 [04:00<01:00,  6.76s/it]
 82%|████████▏ | 37/45 [04:06<00:52,  6.62s/it]
 84%|████████▍ | 38/45 [04:12<00:45,  6.51s/it]
 87%|████████▋ | 39/45 [04:18<00:38,  6.44s/it]
 89%|████████▉ | 40/45 [04:25<00:31,  6.39s/it]
 91%|█████████ | 41/45 [04:31<00:25,  6.36s/it]
 93%|█████████▎| 42/45 [04:37<00:18,  6.33s/it]
 96%|█████████▌| 43/45 [04:44<00:12,  6.32s/it]
 98%|█████████▊| 44/45 [04:50<00:06,  6.30s/it]
100%|██████████| 45/45 [04:56<00:00,  6.37s/it]
[2m[36m(ClientAppActor pid=10689)[0m 

[2m[36m(ClientAppActor pid=10689)[0m {'eval_loss': 0.9707896113395691, 'eval_runtime': 6.5965, 'eval_samples_per_second': 6.064, 'eval_steps_per_second': 3.032, 'epoch': 1.0}
[2m[36m(ClientAppActor pid=10689)[0m {'train_runtime': 303.4904, 'train_samples_per_second': 2.372, 'train_steps_per_second': 0.148, 'train_loss': 1.0397278679741753, 'epoch': 1.0}
[2m[36m(ClientAppActor pid=10689)[0m 🚀 Launching client1


Map:   0%|          | 0/720 [00:00<?, ? examples/s]
Map: 100%|██████████| 720/720 [00:00<00:00, 5604.22 examples/s]
Map:   0%|          | 0/40 [00:00<?, ? examples/s]
Map: 100%|██████████| 40/40 [00:00<00:00, 3850.02 examples/s]
Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]
Loading checkpoint shards:  50%|█████     | 1/2 [00:08<00:08,  8.73s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:11<00:00,  5.82s/it]
[2m[36m(ClientAppActor pid=10689)[0m No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.
  0%|          | 0/45 [00:00<?, ?it/s]m 
  2%|▏         | 1/45 [00:08<05:58,  8.15s/it]
  4%|▍         | 2/45 [00:15<05:42,  7.96s/it]
  7%|▋         | 3/45 [00:22<05:13,  7.46s/it]
  9%|▉         | 4/45 [00:29<04:46,  7.00s/it]
 11%|█         | 5/45 [00:35<04:2

[2m[36m(ClientAppActor pid=10689)[0m {'loss': 1.1495, 'grad_norm': 1.427578091621399, 'learning_rate': 4.817389884711706e-06, 'epoch': 0.56}


 58%|█████▊    | 26/45 [02:52<02:07,  6.72s/it]
 60%|██████    | 27/45 [02:59<01:58,  6.59s/it]
 62%|██████▏   | 28/45 [03:05<01:50,  6.50s/it]
 64%|██████▍   | 29/45 [03:11<01:42,  6.43s/it]
 67%|██████▋   | 30/45 [03:17<01:35,  6.39s/it]
 69%|██████▉   | 31/45 [03:24<01:28,  6.36s/it]
 71%|███████   | 32/45 [03:30<01:22,  6.33s/it]
 73%|███████▎  | 33/45 [03:36<01:15,  6.32s/it]
 76%|███████▌  | 34/45 [03:44<01:14,  6.77s/it]
 78%|███████▊  | 35/45 [03:52<01:10,  7.09s/it]
 80%|████████  | 36/45 [03:59<01:03,  7.02s/it]
 82%|████████▏ | 37/45 [04:05<00:54,  6.80s/it]
 84%|████████▍ | 38/45 [04:12<00:46,  6.69s/it]
 87%|████████▋ | 39/45 [04:18<00:39,  6.57s/it]
 89%|████████▉ | 40/45 [04:24<00:32,  6.48s/it]
 91%|█████████ | 41/45 [04:30<00:25,  6.42s/it]
 93%|█████████▎| 42/45 [04:37<00:19,  6.38s/it]
 96%|█████████▌| 43/45 [04:43<00:12,  6.35s/it]
 98%|█████████▊| 44/45 [04:49<00:06,  6.33s/it]
100%|██████████| 45/45 [04:56<00:00,  6.45s/it]
[2m[36m(ClientAppActor pid=10689)[0m 

[2m[36m(ClientAppActor pid=10689)[0m {'eval_loss': 0.9191396832466125, 'eval_runtime': 6.5895, 'eval_samples_per_second': 6.07, 'eval_steps_per_second': 3.035, 'epoch': 1.0}
[2m[36m(ClientAppActor pid=10689)[0m {'train_runtime': 303.0594, 'train_samples_per_second': 2.376, 'train_steps_per_second': 0.148, 'train_loss': 1.0343484666612413, 'epoch': 1.0}


[92mINFO [0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


💾 Saved aggregated parameters (round 1) into /mnt/data/federated_qa_jupyter/checkpoints
[2m[36m(ClientAppActor pid=10689)[0m 🚀 Launching client1


Map:   0%|          | 0/720 [00:00<?, ? examples/s]
Map: 100%|██████████| 720/720 [00:00<00:00, 5240.19 examples/s]
Map: 100%|██████████| 40/40 [00:00<00:00, 3844.63 examples/s]
Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]
Loading checkpoint shards:  50%|█████     | 1/2 [00:08<00:08,  8.76s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:11<00:00,  5.84s/it]
[2m[36m(ClientAppActor pid=10689)[0m No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.
  0%|          | 0/20 [00:00<?, ?it/s]m 
 10%|█         | 2/20 [00:00<00:03,  5.63it/s]
 15%|█▌        | 3/20 [00:00<00:04,  4.20it/s]
 20%|██        | 4/20 [00:00<00:04,  3.80it/s]
 25%|██▌       | 5/20 [00:01<00:04,  3.60it/s]
 30%|███       | 6/20 [00:01<00:04,  3.36it/s]
 35%|███▌      | 7/20 [00:01<00:04,  

[2m[36m(ClientAppActor pid=10689)[0m 🚀 Launching client2


Map:   0%|          | 0/720 [00:00<?, ? examples/s]
Map: 100%|██████████| 720/720 [00:00<00:00, 5230.43 examples/s]
Map:   0%|          | 0/40 [00:00<?, ? examples/s]
Map: 100%|██████████| 40/40 [00:00<00:00, 3780.95 examples/s]
Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]
Loading checkpoint shards:  50%|█████     | 1/2 [00:08<00:08,  8.81s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:11<00:00,  5.89s/it]
[2m[36m(ClientAppActor pid=10689)[0m No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.
  0%|          | 0/20 [00:00<?, ?it/s]m 
 10%|█         | 2/20 [00:00<00:02,  6.00it/s]
 15%|█▌        | 3/20 [00:00<00:03,  4.46it/s]
 20%|██        | 4/20 [00:00<00:04,  3.93it/s]
 25%|██▌       | 5/20 [00:01<00:04,  3.68it/s]
 30%|███       | 6/20 [00:01<00:0

[2m[36m(ClientAppActor pid=10689)[0m 🚀 Launching client2


Map:   0%|          | 0/720 [00:00<?, ? examples/s]
Map: 100%|██████████| 720/720 [00:00<00:00, 5204.83 examples/s]
Map: 100%|██████████| 40/40 [00:00<00:00, 3876.35 examples/s]
Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]
Loading checkpoint shards:  50%|█████     | 1/2 [00:08<00:08,  8.67s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:11<00:00,  5.80s/it]
[2m[36m(ClientAppActor pid=10689)[0m No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.
  0%|          | 0/45 [00:00<?, ?it/s]m 
  2%|▏         | 1/45 [00:08<05:58,  8.16s/it]
  4%|▍         | 2/45 [00:15<05:42,  7.96s/it]
  7%|▋         | 3/45 [00:23<05:17,  7.55s/it]
  9%|▉         | 4/45 [00:29<04:48,  7.05s/it]
 11%|█         | 5/45 [00:35<04:30,  6.77s/it]
 13%|█▎        | 6/45 [00:41<04:17,  

[2m[36m(ClientAppActor pid=10689)[0m {'loss': 0.92, 'grad_norm': 1.202350378036499, 'learning_rate': 4.817389884711706e-06, 'epoch': 0.56}


 58%|█████▊    | 26/45 [02:54<02:08,  6.76s/it]
 60%|██████    | 27/45 [03:00<01:59,  6.61s/it]
 62%|██████▏   | 28/45 [03:06<01:50,  6.51s/it]
 64%|██████▍   | 29/45 [03:13<01:43,  6.44s/it]
 67%|██████▋   | 30/45 [03:19<01:35,  6.39s/it]
 69%|██████▉   | 31/45 [03:25<01:29,  6.36s/it]
 71%|███████   | 32/45 [03:32<01:22,  6.33s/it]
 73%|███████▎  | 33/45 [03:38<01:15,  6.32s/it]
 76%|███████▌  | 34/45 [03:46<01:14,  6.77s/it]
 78%|███████▊  | 35/45 [03:53<01:09,  6.97s/it]
 80%|████████  | 36/45 [03:59<01:00,  6.76s/it]
 82%|████████▏ | 37/45 [04:06<00:52,  6.62s/it]
 84%|████████▍ | 38/45 [04:12<00:45,  6.51s/it]
 87%|████████▋ | 39/45 [04:18<00:38,  6.44s/it]
 89%|████████▉ | 40/45 [04:24<00:31,  6.39s/it]
 91%|█████████ | 41/45 [04:31<00:25,  6.36s/it]
 93%|█████████▎| 42/45 [04:37<00:19,  6.33s/it]
 96%|█████████▌| 43/45 [04:43<00:12,  6.32s/it]
 98%|█████████▊| 44/45 [04:50<00:06,  6.30s/it]
100%|██████████| 45/45 [04:56<00:00,  6.37s/it]
[2m[36m(ClientAppActor pid=10689)[0m 

[2m[36m(ClientAppActor pid=10689)[0m {'eval_loss': 0.9362310171127319, 'eval_runtime': 6.603, 'eval_samples_per_second': 6.058, 'eval_steps_per_second': 3.029, 'epoch': 1.0}
[2m[36m(ClientAppActor pid=10689)[0m {'train_runtime': 303.1975, 'train_samples_per_second': 2.375, 'train_steps_per_second': 0.148, 'train_loss': 0.8896586524115668, 'epoch': 1.0}
[2m[36m(ClientAppActor pid=10689)[0m 🚀 Launching client1


Map:   0%|          | 0/720 [00:00<?, ? examples/s]
Map: 100%|██████████| 720/720 [00:00<00:00, 5034.64 examples/s]
Map: 100%|██████████| 40/40 [00:00<00:00, 3765.42 examples/s]
Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]
Loading checkpoint shards:  50%|█████     | 1/2 [00:08<00:08,  8.63s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:11<00:00,  5.78s/it]
[2m[36m(ClientAppActor pid=10689)[0m No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.
  0%|          | 0/45 [00:00<?, ?it/s]m 
  2%|▏         | 1/45 [00:08<05:59,  8.16s/it]
  4%|▍         | 2/45 [00:15<05:42,  7.97s/it]
  7%|▋         | 3/45 [00:22<05:13,  7.46s/it]
  9%|▉         | 4/45 [00:29<04:46,  6.99s/it]
 11%|█         | 5/45 [00:35<04:29,  6.73s/it]
 13%|█▎        | 6/45 [00:41<04:16,  

[2m[36m(ClientAppActor pid=10689)[0m {'loss': 0.9161, 'grad_norm': 1.4552488327026367, 'learning_rate': 4.817389884711706e-06, 'epoch': 0.56}


 58%|█████▊    | 26/45 [02:52<02:07,  6.72s/it]
 60%|██████    | 27/45 [02:59<01:59,  6.64s/it]
 62%|██████▏   | 28/45 [03:05<01:51,  6.53s/it]
 64%|██████▍   | 29/45 [03:11<01:43,  6.45s/it]
 67%|██████▋   | 30/45 [03:18<01:36,  6.40s/it]
 69%|██████▉   | 31/45 [03:24<01:29,  6.36s/it]
 71%|███████   | 32/45 [03:30<01:22,  6.34s/it]
 73%|███████▎  | 33/45 [03:36<01:15,  6.32s/it]
 76%|███████▌  | 34/45 [03:44<01:14,  6.77s/it]
 78%|███████▊  | 35/45 [03:52<01:10,  7.09s/it]
 80%|████████  | 36/45 [03:59<01:03,  7.02s/it]
 82%|████████▏ | 37/45 [04:05<00:54,  6.80s/it]
 84%|████████▍ | 38/45 [04:11<00:46,  6.64s/it]
 87%|████████▋ | 39/45 [04:18<00:39,  6.53s/it]
 89%|████████▉ | 40/45 [04:24<00:32,  6.45s/it]
 91%|█████████ | 41/45 [04:30<00:25,  6.41s/it]
 93%|█████████▎| 42/45 [04:37<00:19,  6.37s/it]
 96%|█████████▌| 43/45 [04:43<00:12,  6.34s/it]
 98%|█████████▊| 44/45 [04:49<00:06,  6.32s/it]
100%|██████████| 45/45 [04:56<00:00,  6.44s/it]
[2m[36m(ClientAppActor pid=10689)[0m 

[2m[36m(ClientAppActor pid=10689)[0m {'eval_loss': 0.8832899332046509, 'eval_runtime': 6.5877, 'eval_samples_per_second': 6.072, 'eval_steps_per_second': 3.036, 'epoch': 1.0}
[2m[36m(ClientAppActor pid=10689)[0m {'train_runtime': 302.8943, 'train_samples_per_second': 2.377, 'train_steps_per_second': 0.149, 'train_loss': 0.8851483662923177, 'epoch': 1.0}


[92mINFO [0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


💾 Saved aggregated parameters (round 2) into /mnt/data/federated_qa_jupyter/checkpoints
[2m[36m(ClientAppActor pid=10689)[0m 🚀 Launching client2


Map:   0%|          | 0/720 [00:00<?, ? examples/s]
Map: 100%|██████████| 720/720 [00:00<00:00, 5132.31 examples/s]
Map: 100%|██████████| 40/40 [00:00<00:00, 3688.35 examples/s]
Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]
Loading checkpoint shards:  50%|█████     | 1/2 [00:08<00:08,  8.68s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:11<00:00,  5.80s/it]
[2m[36m(ClientAppActor pid=10689)[0m No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.
  0%|          | 0/20 [00:00<?, ?it/s]m 
 10%|█         | 2/20 [00:00<00:02,  6.01it/s]
 15%|█▌        | 3/20 [00:00<00:03,  4.46it/s]
 20%|██        | 4/20 [00:00<00:04,  3.94it/s]
 25%|██▌       | 5/20 [00:01<00:04,  3.68it/s]
 30%|███       | 6/20 [00:01<00:04,  3.40it/s]
 35%|███▌      | 7/20 [00:01<00:03,  

[2m[36m(ClientAppActor pid=10689)[0m 🚀 Launching client1


Map:   0%|          | 0/720 [00:00<?, ? examples/s]
Map: 100%|██████████| 720/720 [00:00<00:00, 4972.21 examples/s]
Map: 100%|██████████| 40/40 [00:00<00:00, 3760.02 examples/s]
Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]
Loading checkpoint shards:  50%|█████     | 1/2 [00:07<00:07,  7.11s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:09<00:00,  4.74s/it]
[2m[36m(ClientAppActor pid=10689)[0m No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.
  0%|          | 0/20 [00:00<?, ?it/s]m 
 10%|█         | 2/20 [00:00<00:03,  5.63it/s]
 15%|█▌        | 3/20 [00:00<00:04,  4.20it/s]
 20%|██        | 4/20 [00:00<00:04,  3.77it/s]
 25%|██▌       | 5/20 [00:01<00:04,  3.59it/s]
 30%|███       | 6/20 [00:01<00:04,  3.35it/s]
 35%|███▌      | 7/20 [00:01<00:04,  

[2m[36m(ClientAppActor pid=10689)[0m 🚀 Launching client2


Map:   0%|          | 0/720 [00:00<?, ? examples/s]
Map: 100%|██████████| 720/720 [00:00<00:00, 4917.18 examples/s]
Map: 100%|██████████| 40/40 [00:00<00:00, 3682.36 examples/s]
Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]
Loading checkpoint shards:  50%|█████     | 1/2 [00:07<00:07,  7.11s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:09<00:00,  4.75s/it]
[2m[36m(ClientAppActor pid=10689)[0m No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.
  0%|          | 0/45 [00:00<?, ?it/s]m 
  2%|▏         | 1/45 [00:08<05:59,  8.16s/it]
  4%|▍         | 2/45 [00:15<05:42,  7.97s/it]
  7%|▋         | 3/45 [00:23<05:17,  7.55s/it]
  9%|▉         | 4/45 [00:29<04:49,  7.05s/it]
 11%|█         | 5/45 [00:35<04:30,  6.77s/it]
 13%|█▎        | 6/45 [00:41<04:17,  

[2m[36m(ClientAppActor pid=10689)[0m {'loss': 0.8903, 'grad_norm': 1.118640422821045, 'learning_rate': 4.817389884711706e-06, 'epoch': 0.56}


 58%|█████▊    | 26/45 [02:54<02:08,  6.76s/it]
 60%|██████    | 27/45 [03:00<01:59,  6.62s/it]
 62%|██████▏   | 28/45 [03:06<01:50,  6.52s/it]
 64%|██████▍   | 29/45 [03:13<01:43,  6.44s/it]
 67%|██████▋   | 30/45 [03:19<01:35,  6.40s/it]
 69%|██████▉   | 31/45 [03:25<01:29,  6.36s/it]
 71%|███████   | 32/45 [03:32<01:22,  6.34s/it]
 73%|███████▎  | 33/45 [03:38<01:15,  6.32s/it]
 76%|███████▌  | 34/45 [03:46<01:14,  6.77s/it]
 78%|███████▊  | 35/45 [03:53<01:09,  6.97s/it]
 80%|████████  | 36/45 [03:59<01:00,  6.77s/it]
 82%|████████▏ | 37/45 [04:06<00:52,  6.62s/it]
 84%|████████▍ | 38/45 [04:12<00:45,  6.52s/it]
 87%|████████▋ | 39/45 [04:18<00:38,  6.45s/it]
 89%|████████▉ | 40/45 [04:25<00:31,  6.40s/it]
 91%|█████████ | 41/45 [04:31<00:25,  6.36s/it]
 93%|█████████▎| 42/45 [04:37<00:19,  6.34s/it]
 96%|█████████▌| 43/45 [04:43<00:12,  6.32s/it]
 98%|█████████▊| 44/45 [04:50<00:06,  6.32s/it]
100%|██████████| 45/45 [04:56<00:00,  6.38s/it]
[2m[36m(ClientAppActor pid=10689)[0m 

[2m[36m(ClientAppActor pid=10689)[0m {'eval_loss': 0.9256883859634399, 'eval_runtime': 6.5915, 'eval_samples_per_second': 6.068, 'eval_steps_per_second': 3.034, 'epoch': 1.0}
[2m[36m(ClientAppActor pid=10689)[0m {'train_runtime': 303.3479, 'train_samples_per_second': 2.374, 'train_steps_per_second': 0.148, 'train_loss': 0.8656343248155381, 'epoch': 1.0}
[2m[36m(ClientAppActor pid=10689)[0m 🚀 Launching client1


Map:   0%|          | 0/720 [00:00<?, ? examples/s]
Map: 100%|██████████| 720/720 [00:00<00:00, 5055.19 examples/s]
Map: 100%|██████████| 40/40 [00:00<00:00, 3878.50 examples/s]
Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]
Loading checkpoint shards:  50%|█████     | 1/2 [00:08<00:08,  8.73s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:11<00:00,  5.83s/it]
[2m[36m(ClientAppActor pid=10689)[0m No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.
  0%|          | 0/45 [00:00<?, ?it/s]m 
  2%|▏         | 1/45 [00:08<05:59,  8.16s/it]
  4%|▍         | 2/45 [00:15<05:42,  7.97s/it]
  7%|▋         | 3/45 [00:22<05:13,  7.46s/it]
  9%|▉         | 4/45 [00:29<04:46,  6.99s/it]
 11%|█         | 5/45 [00:35<04:29,  6.74s/it]
 13%|█▎        | 6/45 [00:41<04:16,  

[2m[36m(ClientAppActor pid=10689)[0m {'loss': 0.8869, 'grad_norm': 1.3227615356445312, 'learning_rate': 4.817389884711706e-06, 'epoch': 0.56}


 58%|█████▊    | 26/45 [02:52<02:07,  6.73s/it]
 60%|██████    | 27/45 [02:59<01:58,  6.59s/it]
 62%|██████▏   | 28/45 [03:05<01:50,  6.50s/it]
 64%|██████▍   | 29/45 [03:11<01:42,  6.43s/it]
 67%|██████▋   | 30/45 [03:18<01:35,  6.38s/it]
 69%|██████▉   | 31/45 [03:24<01:28,  6.35s/it]
 71%|███████   | 32/45 [03:30<01:22,  6.33s/it]
 73%|███████▎  | 33/45 [03:36<01:15,  6.31s/it]
 76%|███████▌  | 34/45 [03:44<01:14,  6.77s/it]
 78%|███████▊  | 35/45 [03:52<01:10,  7.09s/it]
 80%|████████  | 36/45 [03:59<01:03,  7.02s/it]
 82%|████████▏ | 37/45 [04:05<00:54,  6.80s/it]
 84%|████████▍ | 38/45 [04:11<00:46,  6.64s/it]
 87%|████████▋ | 39/45 [04:18<00:39,  6.53s/it]
 89%|████████▉ | 40/45 [04:24<00:32,  6.46s/it]
 91%|█████████ | 41/45 [04:30<00:25,  6.40s/it]
 93%|█████████▎| 42/45 [04:37<00:19,  6.36s/it]
 96%|█████████▌| 43/45 [04:43<00:12,  6.34s/it]
 98%|█████████▊| 44/45 [04:49<00:06,  6.32s/it]
100%|██████████| 45/45 [04:56<00:00,  6.44s/it]
[2m[36m(ClientAppActor pid=10689)[0m 

[2m[36m(ClientAppActor pid=10689)[0m {'eval_loss': 0.8680970072746277, 'eval_runtime': 6.593, 'eval_samples_per_second': 6.067, 'eval_steps_per_second': 3.034, 'epoch': 1.0}
[2m[36m(ClientAppActor pid=10689)[0m {'train_runtime': 302.9418, 'train_samples_per_second': 2.377, 'train_steps_per_second': 0.149, 'train_loss': 0.8621364593505859, 'epoch': 1.0}


[92mINFO [0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


💾 Saved aggregated parameters (round 3) into /mnt/data/federated_qa_jupyter/checkpoints
[2m[36m(ClientAppActor pid=10689)[0m 🚀 Launching client2


Map:   0%|          | 0/720 [00:00<?, ? examples/s]
Map: 100%|██████████| 720/720 [00:00<00:00, 5203.18 examples/s]
Map: 100%|██████████| 40/40 [00:00<00:00, 3825.00 examples/s]
Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]
Loading checkpoint shards:  50%|█████     | 1/2 [00:08<00:08,  8.79s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:11<00:00,  5.85s/it]
[2m[36m(ClientAppActor pid=10689)[0m No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.
  0%|          | 0/20 [00:00<?, ?it/s]m 
 10%|█         | 2/20 [00:00<00:02,  6.02it/s]
 15%|█▌        | 3/20 [00:00<00:03,  4.47it/s]
 20%|██        | 4/20 [00:00<00:04,  3.94it/s]
 25%|██▌       | 5/20 [00:01<00:04,  3.68it/s]
 30%|███       | 6/20 [00:01<00:04,  3.41it/s]
 35%|███▌      | 7/20 [00:01<00:03,  

[2m[36m(ClientAppActor pid=10689)[0m 🚀 Launching client1


Map:   0%|          | 0/720 [00:00<?, ? examples/s]
Map: 100%|██████████| 720/720 [00:00<00:00, 5236.48 examples/s]
Map:   0%|          | 0/40 [00:00<?, ? examples/s]
Map: 100%|██████████| 40/40 [00:00<00:00, 3875.81 examples/s]
Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]
Loading checkpoint shards:  50%|█████     | 1/2 [00:08<00:08,  8.87s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:11<00:00,  5.92s/it]
[2m[36m(ClientAppActor pid=10689)[0m No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.
  0%|          | 0/20 [00:00<?, ?it/s]m 
 10%|█         | 2/20 [00:00<00:03,  5.63it/s]
 15%|█▌        | 3/20 [00:00<00:04,  4.20it/s]
 20%|██        | 4/20 [00:00<00:04,  3.80it/s]
 25%|██▌       | 5/20 [00:01<00:04,  3.61it/s]
 30%|███       | 6/20 [00:01<00:0

In [12]:
# Build the same PEFT skeleton, load aggregated ndarrays into it, and save adapters.
def materialize_federated_adapters(agg_ndarrays, save_dir):
    if not os.path.isdir(save_dir):
        raise FileNotFoundError(f"Create this folder first, then re-run: {save_dir}")

    base = AutoModelForCausalLM.from_pretrained(
        BASE_BACKBONE,
        torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
        device_map="cpu"
    )
    peft_model = PeftModel.from_pretrained(base, PRE_FL_MODEL_PATH, is_trainable=True)

    i = 0
    for name, p in peft_model.named_parameters():
        if p.requires_grad:
            p.data = torch.tensor(agg_ndarrays[i]).to(dtype=p.dtype)
            i += 1
    assert i == len(agg_ndarrays), "Mismatch: loaded vs aggregated length"

    peft_model.save_pretrained(save_dir)

    tok_src = PRE_FL_MODEL_PATH if os.path.exists(os.path.join(PRE_FL_MODEL_PATH, "tokenizer.json")) else BASE_BACKBONE
    AutoTokenizer.from_pretrained(tok_src).save_pretrained(save_dir)
    print(f"✅ Saved FL adapters → {save_dir}")

if getattr(strategy, "latest_parameters", None) is None:
    raise RuntimeError("No aggregated parameters found. Did FL complete?")
materialize_federated_adapters(strategy.latest_parameters, POST_FL_MODEL_PATH)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



✅ Saved FL adapters → /mnt/data/federated_qa_jupyter/model/federated_merged_model


In [18]:
q = "In 5G NR, what does PUCCH carry and in which release was format 2 updated?"

pre_pipe = load_pipeline(PRE_FL_MODEL_PATH, use_4bit=True, offload_dir="/mnt/data/offload_pre")
ans1 = answer_with_rag(pre_pipe, q)[0]
print("CentralFT + RAG:", ans1)
del pre_pipe; torch.cuda.empty_cache()

post_pipe = load_pipeline(POST_FL_MODEL_PATH, use_4bit=True, offload_dir="/mnt/data/offload_post")
ans2 = answer_with_rag(post_pipe, q)[0]
print("FederatedFT + RAG:", ans2)
del post_pipe; torch.cuda.empty_cache()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Device set to use cuda:0
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


CentralFT + RAG: PUCCH carries UCI and in which release was format 2 updated 5G NR, format 2 was updated in Release 15.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Device set to use cuda:0
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


FederatedFT + RAG: PUCCH carries HARQ-ACK information, SR and CSI report(s).


In [20]:
#  metrics (lightweight)
def _norm(s:str)->str:
    s = s.lower().strip()
    s = re.sub(r"\s+"," ",s)
    s = re.sub(r"[^a-z0-9\-.,:/() ]","",s)
    return s

def EM(pred, ref):  return 1.0 if _norm(pred)==_norm(ref) else 0.0

def F1(pred, ref):
    p = _norm(pred).split(); r = _norm(ref).split()
    if not p and not r: return 1.0
    if not p or not r:  return 0.0
    rc = {}
    for t in r: rc[t]=rc.get(t,0)+1
    common=0
    for t in p:
        if rc.get(t,0)>0: common+=1; rc[t]-=1
    if common==0: return 0.0
    prec=common/len(p); rec=common/len(r)
    return 2*prec*rec/(prec+rec)

def _lcs(a,b):
    A=_norm(a).split(); B=_norm(b).split()
    dp=[[0]*(len(B)+1) for _ in range(len(A)+1)]
    for i in range(1,len(A)+1):
        for j in range(1,len(B)+1):
            dp[i][j] = dp[i-1][j-1]+1 if A[i-1]==B[j-1] else max(dp[i-1][j], dp[i][j-1])
    return dp[-1][-1], len(A), len(B)

def ROUGE_L(pred,ref):
    l,m,n=_lcs(pred,ref)
    if m==0 or n==0 or l==0: return 0.0
    prec=l/m; rec=l/n
    beta2 = 1.2**2
    return (1+beta2)*prec*rec/(rec+beta2*prec)

def BLEU1(pred,ref):
    P=_norm(pred).split(); R=_norm(ref).split()
    if not P or not R: return 0.0
    rc={}; 
    for t in R: rc[t]=rc.get(t,0)+1
    match=0; used={}
    for t in P:
        c=used.get(t,0)
        if c<rc.get(t,0):
            match+=1; used[t]=c+1
    prec=match/len(P)
    bp=1.0 if len(P)>len(R) else math.exp(1-len(R)/max(1,len(P)))
    return bp*prec

def metrics(pred, ref):
    return {"EM":EM(pred,ref), "F1":F1(pred,ref), "ROUGE_L":ROUGE_L(pred,ref), "BLEU1":BLEU1(pred,ref)}

def load_jsonl(p):
    data=[]
    with open(p,"r",encoding="utf-8") as f:
        for l in f:
            if l.strip(): data.append(json.loads(l))
    return data

# check results dir exists (no auto-creation)
_results_dir = os.path.dirname(SUMMARY_CSV)
if not os.path.isdir(_results_dir):
    raise FileNotFoundError(f"Create results folder first: {_results_dir}")

#  datasets
datasets=[]
if os.path.exists(RAG_TEST_PATH):     datasets.append(("R15_16_RAG100", load_jsonl(RAG_TEST_PATH)))
if os.path.exists(CLIENT_A_HOLDOUT):  datasets.append(("ClientA_holdout", load_jsonl(CLIENT_A_HOLDOUT)))
if os.path.exists(CLIENT_B_HOLDOUT):  datasets.append(("ClientB_holdout", load_jsonl(CLIENT_B_HOLDOUT)))
assert datasets, "No test datasets found—check your paths."

import gc, json, os, re, math, time, numpy as np, torch

rows = []
with open(DETAILS_JSONL, "w", encoding="utf-8") as fp:
    for (ds_name, ds) in datasets:
        for tag, path, offdir in [
            ("CentralFT",   PRE_FL_MODEL_PATH,  "/mnt/data/offload_pre"),
            ("FederatedFT", POST_FL_MODEL_PATH, "/mnt/data/offload_post"),
        ]:
            pipe = load_pipeline(path, use_4bit=True, offload_dir=offdir)
            try:
                for mode in ["no_rag", "with_rag"]:
                    print(f"▶️ {tag} | {mode} | {ds_name}  (N={len(ds)})")
                    agg = {"EM": [], "F1": [], "ROUGE_L": [], "BLEU1": []}
                    t0 = time.time()

                    for ex in ds:
                        q = ex["question"].strip()
                        ref = ex["answer"].strip()
                        try:
                            if mode == "with_rag":
                                out = answer_with_rag(pipe, q)
                                pred = out[0] if isinstance(out, (list, tuple)) else out
                            else:
                                pred = answer_without_rag(pipe, q)
                        except Exception:
                            pred = ""

                        m = metrics(pred, ref)
                        for k, v in m.items():
                            agg[k].append(v)

                        fp.write(json.dumps({
                            "model": tag, "rag_mode": mode, "dataset": ds_name,
                            "question": q, "reference": ref, "prediction": pred, "metrics": m
                        }) + "\n")

                    rows.append({
                        "model": tag, "rag_mode": mode, "dataset": ds_name, "N": len(ds),
                        "EM": float(np.mean(agg["EM"]) if agg["EM"] else 0.0),
                        "F1": float(np.mean(agg["F1"]) if agg["F1"] else 0.0),
                        "ROUGE_L": float(np.mean(agg["ROUGE_L"]) if agg["ROUGE_L"] else 0.0),
                        "BLEU1": float(np.mean(agg["BLEU1"]) if agg["BLEU1"] else 0.0),
                        "secs": round(time.time() - t0, 2),
                    })
            finally:
                del pipe
                gc.collect()
                torch.cuda.empty_cache()

with open(SUMMARY_CSV, "w", encoding="utf-8") as f:
    f.write("model,rag_mode,dataset,N,EM,F1,ROUGE_L,BLEU1,secs\n")
    for r in rows:
        f.write("{model},{rag_mode},{dataset},{N},{EM:.4f},{F1:.4f},{ROUGE_L:.4f},{BLEU1:.4f},{secs}\n".format(**r))

print("✅ Summary →", SUMMARY_CSV)
print("✅ Details →", DETAILS_JSONL)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Device set to use cuda:0
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


▶️ CentralFT | no_rag | ClientA_holdout  (N=200)


The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.

▶️ CentralFT | with_rag | ClientA_holdout  (N=200)


The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Device set to use cuda:0
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


▶️ FederatedFT | no_rag | ClientA_holdout  (N=200)


The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.

▶️ FederatedFT | with_rag | ClientA_holdout  (N=200)


The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Device set to use cuda:0
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


▶️ CentralFT | no_rag | ClientB_holdout  (N=200)


The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.

▶️ CentralFT | with_rag | ClientB_holdout  (N=200)


The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Device set to use cuda:0
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


▶️ FederatedFT | no_rag | ClientB_holdout  (N=200)


The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.

▶️ FederatedFT | with_rag | ClientB_holdout  (N=200)


The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.

✅ Summary → /mnt/data/federated_qa_jupyter/results/federated_eval_summary.csv
✅ Details → /mnt/data/federated_qa_jupyter/results/federated_eval_details.jsonl
