In [1]:
import json
import time
import pathlib
from typing import List, Tuple, Optional

import numpy as np
import tensorflow as tf

In [2]:
# -------------------- CONFIG --------------------
FNN_CLIENT_ROOT  = pathlib.Path("/Users/sohinikar/FL/M.Tech_Dissertation/Client/client_params/FNN_BC")
LSTM_CLIENT_ROOT = pathlib.Path("/Users/sohinikar/FL/M.Tech_Dissertation/Client/client_params/LSTM_BC")

FNN_GLOBAL_PATH  = pathlib.Path("/Users/sohinikar/FL/M.Tech_Dissertation/Server/global_model/global_FNN_model.keras")
LSTM_GLOBAL_PATH = pathlib.Path("/Users/sohinikar/FL/M.Tech_Dissertation/Server/global_model/global_lstm_model.keras")

# Weight by client sample counts found in params.json ("samples")
WEIGHT_BY_SAMPLES = True

# Save updated copies alongside the originals (timestamped)
TS = time.strftime("%Y%m%d_%H%M%S")
FNN_OUT_MODEL  = FNN_GLOBAL_PATH.with_name(f"{FNN_GLOBAL_PATH.stem}_AGG_{TS}.keras")
FNN_OUT_WEIGHTS = FNN_GLOBAL_PATH.with_name(f"{FNN_GLOBAL_PATH.stem}_AGG_{TS}.weights.h5")
LSTM_OUT_MODEL = LSTM_GLOBAL_PATH.with_name(f"{LSTM_GLOBAL_PATH.stem}_AGG_{TS}.keras")
LSTM_OUT_WEIGHTS = LSTM_GLOBAL_PATH.with_name(f"{LSTM_GLOBAL_PATH.stem}_AGG_{TS}.weights.h5")

In [3]:
def find_clients(root: pathlib.Path) -> List[pathlib.Path]:
    """Return sorted user_* dirs that contain at least one *.weights.h5."""
    if not root.exists():
        return []
    clients = []
    for udir in sorted(root.glob("user_*")):
        if list(udir.glob("*.weights.h5")):
            clients.append(udir)
    return clients

def load_client_sample_count(udir: pathlib.Path) -> int:
    """Read samples from params.json if present; else default to 1."""
    p = udir / "params.json"
    if p.exists():
        try:
            with open(p, "r") as f:
                j = json.load(f)
            val = int(j.get("samples", 1))
            return max(val, 1)
        except Exception:
            return 1
    return 1

def pick_weights_file(udir: pathlib.Path) -> Optional[pathlib.Path]:
    """Pick the first *.weights.h5 in a user folder."""
    files = sorted(udir.glob("*.weights.h5"))
    return files[0] if files else None

def fedavg_aggregate(global_model_path: pathlib.Path,
                     clients_root: pathlib.Path,
                     out_model_path: pathlib.Path,
                     out_weights_path: pathlib.Path) -> Tuple[int, int]:
    """
    Perform FedAvg on all client weights under clients_root to update the global model.
    Returns: (num_clients_used, total_weight)
    """
    if not global_model_path.exists():
        raise FileNotFoundError(f"Global model not found: {global_model_path}")

    # Load a base model (for shapes)
    base_model = tf.keras.models.load_model(str(global_model_path))
    base_weights = base_model.get_weights()
    if not base_weights:
        raise RuntimeError("Global model has no weights (did you save an unbuilt model?).")

    # Accumulators
    acc = [np.zeros_like(w, dtype=np.float64) for w in base_weights]
    total_w = 0

    clients = find_clients(clients_root)
    print(f"Found {len(clients)} client(s) in {clients_root}")

    used = 0
    for udir in clients:
        wfile = pick_weights_file(udir)
        if wfile is None:
            print(f"  - {udir.name}: no *.weights.h5, skipping")
            continue

        weight = load_client_sample_count(udir) if WEIGHT_BY_SAMPLES else 1

        # Load a fresh model instance each time to avoid mutation issues
        try:
            m = tf.keras.models.load_model(str(global_model_path))
            m.load_weights(str(wfile))
            c_weights = m.get_weights()
        except Exception as e:
            print(f"  - {udir.name}: failed to load weights ({wfile.name}): {e}, skipping")
            continue

        # Shape check
        if len(c_weights) != len(acc) or any(cw.shape != bw.shape for cw, bw in zip(c_weights, base_weights)):
            print(f"  - {udir.name}: weight shapes mismatch, skipping")
            continue

        # Accumulate
        for i in range(len(acc)):
            acc[i] += weight * c_weights[i].astype(np.float64)
        total_w += weight
        used += 1
        print(f"  ✓ {udir.name}: included (weight={weight})")

    if used == 0 or total_w == 0:
        raise RuntimeError(f"No compatible client weights found in {clients_root}")

    # Compute weighted average and update base_model
    avg = [(acc[i] / float(total_w)).astype(base_weights[i].dtype) for i in range(len(acc))]
    base_model.set_weights(avg)

    # Save updated global model + weights
    out_model_path.parent.mkdir(parents=True, exist_ok=True)
    base_model.save(str(out_model_path))
    base_model.save_weights(str(out_weights_path))  # must end with ".weights.h5"

    print(f"\n✅ Aggregated {used} client(s) (total weight={total_w}) into:")
    print(f"   Model  : {out_model_path}")
    print(f"   Weights: {out_weights_path}\n")

    return used, total_w

def main():
    # Aggregate FNN
    try:
        print("=== Aggregating FNN global model ===")
        fedavg_aggregate(FNN_GLOBAL_PATH, FNN_CLIENT_ROOT, FNN_OUT_MODEL, FNN_OUT_WEIGHTS)
    except Exception as e:
        print(f"❌ FNN aggregation failed: {e}")

    # Aggregate LSTM
    try:
        print("=== Aggregating LSTM global model ===")
        fedavg_aggregate(LSTM_GLOBAL_PATH, LSTM_CLIENT_ROOT, LSTM_OUT_MODEL, LSTM_OUT_WEIGHTS)
    except Exception as e:
        print(f"❌ LSTM aggregation failed: {e}")

if __name__ == "__main__":
    main()


=== Aggregating FNN global model ===
Found 128 client(s) in /Users/sohinikar/FL/M.Tech_Dissertation/Client/client_params/FNN_BC
  ✓ user_001: included (weight=46)
  ✓ user_002: included (weight=46)
  ✓ user_003: included (weight=46)
  ✓ user_004: included (weight=46)
  ✓ user_005: included (weight=46)
  ✓ user_006: included (weight=46)
  ✓ user_007: included (weight=46)
  ✓ user_008: included (weight=46)
  ✓ user_009: included (weight=46)
  ✓ user_010: included (weight=46)
  ✓ user_011: included (weight=46)
  ✓ user_012: included (weight=46)
  ✓ user_013: included (weight=46)
  ✓ user_014: included (weight=46)
  ✓ user_015: included (weight=46)
  ✓ user_016: included (weight=46)
  ✓ user_017: included (weight=46)
  ✓ user_018: included (weight=46)
  ✓ user_019: included (weight=46)
  ✓ user_020: included (weight=46)
  ✓ user_021: included (weight=46)
  ✓ user_022: included (weight=46)
  ✓ user_023: included (weight=46)
  ✓ user_024: included (weight=46)
  ✓ user_025: included (weight=4

In [5]:
#!/usr/bin/env python3
"""
FedAvg aggregator (Multiclass) that fetches client params from GitHub with
rate-limit-safe strategies:
  - If GITHUB_TOKEN present -> Contents API (high limit)
  - Else -> single tarball download via codeload.github.com (no API limit)

Aggregates (sample-weighted) into local global models:
  - /Users/sohinikar/FL/M.Tech_Dissertation/Server/global_model/global_FNN_MC_model.keras
  - /Users/sohinikar/FL/M.Tech_Dissertation/Server/global_model/global_lstm_MC_model.keras

Client params on GitHub:
  - M.Tech_Dissertation/Client/client_params/FNN_MC/user_*/user_*.weights.h5 + params.json
  - M.Tech_Dissertation/Client/client_params/LSTM_MC/... (optional)

Outputs (timestamped) written next to originals:
  - *_AGG_<ts>.keras and *_AGG_<ts>.weights.h5

Deps: pip install requests tensorflow
"""

import os
import io
import json
import time
import tarfile
import pathlib
from typing import List, Optional, Tuple

import numpy as np
import requests
import tensorflow as tf

# ----------- Repo / Paths -----------
OWNER = "Triss11"
REPO  = "FL"
REF   = "main"

REMOTE_FNN_DIR  = "M.Tech_Dissertation/Client/client_params/FNN_MC"
REMOTE_LSTM_DIR = "M.Tech_Dissertation/Client/client_params/LSTM_MC"  # optional

FNN_GLOBAL_PATH  = pathlib.Path("/Users/sohinikar/FL/M.Tech_Dissertation/Server/global_model/global_FNN_MC_model.keras")
LSTM_GLOBAL_PATH = pathlib.Path("/Users/sohinikar/FL/M.Tech_Dissertation/Server/global_model/global_lstm_MC_model.keras")

TS = time.strftime("%Y%m%d_%H%M%S")
FNN_OUT_MODEL    = FNN_GLOBAL_PATH.with_name(f"{FNN_GLOBAL_PATH.stem}_AGG_{TS}.keras")
FNN_OUT_WEIGHTS  = FNN_GLOBAL_PATH.with_name(f"{FNN_GLOBAL_PATH.stem}_AGG_{TS}.weights.h5")
LSTM_OUT_MODEL   = LSTM_GLOBAL_PATH.with_name(f"{LSTM_GLOBAL_PATH.stem}_AGG_{TS}.keras")
LSTM_OUT_WEIGHTS = LSTM_GLOBAL_PATH.with_name(f"{LSTM_GLOBAL_PATH.stem}_AGG_{TS}.weights.h5")

CACHE_DIR = pathlib.Path("./_client_cache_mc")
WEIGHT_BY_SAMPLES = True

# ----------- Strategy -----------
# "auto": use API if token available, else tarball
# "api" : force API listing/downloading
# "tarball": force tarball extraction
FETCH_STRATEGY = "auto"


# ============ Helpers: HTTP / GitHub ============
def gh_session() -> requests.Session:
    s = requests.Session()
    tok = os.environ.get("GITHUB_TOKEN")
    if tok:
        s.headers.update({"Authorization": f"Bearer {tok}"})
    s.headers.update({"Accept": "application/vnd.github+json"})
    return s

def gh_list_dir(session: requests.Session, path: str) -> List[dict]:
    url = f"https://api.github.com/repos/{OWNER}/{REPO}/contents/{path}?ref={REF}"
    r = session.get(url)
    if r.status_code != 200:
        raise RuntimeError(f"GitHub list failed for {path}: {r.status_code} {r.text}")
    data = r.json()
    if not isinstance(data, list):
        raise RuntimeError(f"{path} is not a directory on GitHub.")
    return data

def gh_get_file(session: requests.Session, path: str) -> dict:
    url = f"https://api.github.com/repos/{OWNER}/{REPO}/contents/{path}?ref={REF}"
    r = session.get(url)
    if r.status_code != 200:
        raise RuntimeError(f"GitHub get failed for {path}: {r.status_code} {r.text}")
    data = r.json()
    if data.get("type") != "file" or not data.get("download_url"):
        raise RuntimeError(f"{path} is not a downloadable file.")
    return data

def download_url_to(session: requests.Session, url: str, dest: pathlib.Path) -> pathlib.Path:
    dest.parent.mkdir(parents=True, exist_ok=True)
    with session.get(url, stream=True) as r:
        r.raise_for_status()
        with open(dest, "wb") as f:
            for chunk in r.iter_content(1024 * 1024):
                if chunk:
                    f.write(chunk)
    return dest


# ============ Strategy A: API (tokened) ============
def fetch_clients_via_api(session: requests.Session, remote_root: str, cache_root: pathlib.Path) -> List[Tuple[pathlib.Path, int]]:
    """
    Downloads user_*/user_*.weights.h5 and params.json using the Contents API.
    Assumes weights are named exactly 'user_xxx.weights.h5'.
    """
    users = gh_list_dir(session, remote_root)
    results = []
    for ent in sorted(users, key=lambda e: e.get("name", "")):
        if ent.get("type") != "dir":
            continue
        user = ent["name"]  # e.g., user_001
        try:
            weight_name = f"{user}.weights.h5"
            weight_path = f"{remote_root}/{user}/{weight_name}"
            params_path = f"{remote_root}/{user}/params.json"

            # Download weights (required)
            w_json = gh_get_file(session, weight_path)
            w_local = cache_root / user / weight_name
            download_url_to(session, w_json["download_url"], w_local)

            # Download params.json (optional)
            samples = 1
            try:
                p_json = gh_get_file(session, params_path)
                p_local = cache_root / user / "params.json"
                download_url_to(session, p_json["download_url"], p_local)
                with open(p_local, "r") as f:
                    j = json.load(f)
                samples = max(int(j.get("samples", 1)), 1)
            except Exception:
                samples = 1

            results.append((w_local, samples))
            print(f"  ✓ {user}: downloaded {weight_name}, samples={samples}")
        except Exception as e:
            print(f"  - {user}: {e} (skipping)")
    return results


# ============ Strategy B: Tarball (no API limit) ============
def download_repo_tarball(session: requests.Session, cache_dir: pathlib.Path) -> pathlib.Path:
    """
    Downloads the repo tar.gz via codeload (not API rate-limited).
    """
    tar_url = f"https://codeload.github.com/{OWNER}/{REPO}/tar.gz/{REF}"
    tar_path = cache_dir / f"{REPO}_{REF}.tar.gz"
    if tar_path.exists() and tar_path.stat().st_size > 0:
        return tar_path
    tar_path.parent.mkdir(parents=True, exist_ok=True)
    with session.get(tar_url, stream=True) as r:
        r.raise_for_status()
        with open(tar_path, "wb") as f:
            for chunk in r.iter_content(1024 * 1024):
                if chunk:
                    f.write(chunk)
    return tar_path

def extract_clients_from_tarball(tar_path: pathlib.Path, remote_root: str, cache_root: pathlib.Path) -> List[Tuple[pathlib.Path, int]]:
    """
    Extract only files under `remote_root` (weights + params.json) to cache_root.
    Returns [(local_weights_path, samples_weight)].
    """
    results = []
    cache_root.mkdir(parents=True, exist_ok=True)
    with tarfile.open(tar_path, "r:gz") as tar:
        # tar member names look like: "<owner>-<repo>-<sha>/{repo_tree...}"
        for m in tar.getmembers():
            if not m.isfile():
                continue
            # match the remote_root path inside the tar
            # ensure "/<remote_root>/" substring to avoid partial matches
            needle = f"/{remote_root}/"
            if needle not in m.name:
                continue
            rel = m.name.split(needle, 1)[1]  # e.g., "user_001/user_001.weights.h5"
            # Save only weights and params.json
            if not (rel.endswith(".weights.h5") or rel.endswith("params.json")):
                continue
            dest = cache_root / rel
            dest.parent.mkdir(parents=True, exist_ok=True)
            with tar.extractfile(m) as src, open(dest, "wb") as out:
                out.write(src.read())

    # Build list and read samples
    # We expect weights at user_xxx/user_xxx.weights.h5 (from earlier training script)
    for user_dir in sorted((cache_root).glob("user_*")):
        if not user_dir.is_dir():
            continue
        weight_file = next(user_dir.glob("*.weights.h5"), None)
        if not weight_file:
            continue
        samples = 1
        p = user_dir / "params.json"
        if p.exists():
            try:
                with open(p, "r") as f:
                    j = json.load(f)
                samples = max(int(j.get("samples", 1)), 1)
            except Exception:
                samples = 1
        results.append((weight_file, samples))
    return results


# ============ FedAvg aggregation ============
def fedavg_aggregate(global_model_path: pathlib.Path,
                     client_weights: List[Tuple[pathlib.Path, int]],
                     out_model_path: pathlib.Path,
                     out_weights_path: pathlib.Path):
    if not global_model_path.exists():
        raise FileNotFoundError(f"Global model not found: {global_model_path}")
    model = tf.keras.models.load_model(str(global_model_path))
    base_w = model.get_weights()
    if not base_w:
        raise RuntimeError("Global model has no weights (unbuilt?).")

    acc = [np.zeros_like(w, dtype=np.float64) for w in base_w]
    total = 0
    used = 0

    for w_path, weight in client_weights:
        try:
            m = tf.keras.models.load_model(str(global_model_path))
            m.load_weights(str(w_path))
            cw = m.get_weights()
        except Exception as e:
            print(f"  - {w_path.parent.name}: failed to load weights: {e} (skip)")
            continue

        if len(cw) != len(base_w) or any(c.shape != b.shape for c, b in zip(cw, base_w)):
            print(f"  - {w_path.parent.name}: weight shape mismatch (skip)")
            continue

        for i in range(len(acc)):
            acc[i] += weight * cw[i].astype(np.float64)
        total += weight
        used += 1

    if used == 0 or total == 0:
        raise RuntimeError("No compatible client weights to aggregate.")

    avg = [(acc[i] / float(total)).astype(base_w[i].dtype) for i in range(len(acc))]
    model.set_weights(avg)

    out_model_path.parent.mkdir(parents=True, exist_ok=True)
    model.save(str(out_model_path))
    model.save_weights(str(out_weights_path))  # Keras 3 requires .weights.h5
    print(f"✅ Aggregated {used} client(s), total weight={total}")
    print(f"   Model  : {out_model_path}")
    print(f"   Weights: {out_weights_path}\n")


# ============ Main ============
def main():
    CACHE_DIR.mkdir(parents=True, exist_ok=True)
    session = gh_session()
    token_present = os.environ.get("GITHUB_TOKEN") is not None

    # Pick strategy
    strategy = FETCH_STRATEGY
    if strategy == "auto":
        strategy = "api" if token_present else "tarball"

    print(f"Fetch strategy: {strategy.upper()} (token={'yes' if token_present else 'no'})")

    # -------- FNN_MC --------
    print("=== FNN_MC: fetching client params ===")
    if strategy == "api":
        fnn_clients = fetch_clients_via_api(session, REMOTE_FNN_DIR, CACHE_DIR / "FNN_MC")
    else:
        tar_path = download_repo_tarball(session, CACHE_DIR)
        fnn_clients = extract_clients_from_tarball(tar_path, REMOTE_FNN_DIR, CACHE_DIR / "FNN_MC")

    if fnn_clients:
        print(f"Aggregating {len(fnn_clients)} FNN_MC client(s)…")
        fedavg_aggregate(FNN_GLOBAL_PATH, fnn_clients, FNN_OUT_MODEL, FNN_OUT_WEIGHTS)
    else:
        print("⚠️  No FNN_MC clients found to aggregate.")

    # -------- LSTM_MC (optional) --------
    print("=== LSTM_MC: fetching client params ===")
    try:
        if strategy == "api":
            lstm_clients = fetch_clients_via_api(session, REMOTE_LSTM_DIR, CACHE_DIR / "LSTM_MC")
        else:
            tar_path = download_repo_tarball(session, CACHE_DIR)
            lstm_clients = extract_clients_from_tarball(tar_path, REMOTE_LSTM_DIR, CACHE_DIR / "LSTM_MC")
    except Exception as e:
        print(f"ℹ️  Skipping LSTM_MC fetch: {e}")
        lstm_clients = []

    if lstm_clients:
        print(f"Aggregating {len(lstm_clients)} LSTM_MC client(s)…")
        fedavg_aggregate(LSTM_GLOBAL_PATH, lstm_clients, LSTM_OUT_MODEL, LSTM_OUT_WEIGHTS)
    else:
        print("ℹ️  No LSTM_MC clients found (or path missing). Skipping.")


if __name__ == "__main__":
    main()


Fetch strategy: TARBALL (token=no)
=== FNN_MC: fetching client params ===
Aggregating 128 FNN_MC client(s)…
✅ Aggregated 128 client(s), total weight=5888
   Model  : /Users/sohinikar/FL/M.Tech_Dissertation/Server/global_model/global_FNN_MC_model_AGG_20250830_125948.keras
   Weights: /Users/sohinikar/FL/M.Tech_Dissertation/Server/global_model/global_FNN_MC_model_AGG_20250830_125948.weights.h5

=== LSTM_MC: fetching client params ===
Aggregating 30 LSTM_MC client(s)…
✅ Aggregated 30 client(s), total weight=1380
   Model  : /Users/sohinikar/FL/M.Tech_Dissertation/Server/global_model/global_lstm_MC_model_AGG_20250830_125948.keras
   Weights: /Users/sohinikar/FL/M.Tech_Dissertation/Server/global_model/global_lstm_MC_model_AGG_20250830_125948.weights.h5

