In [1]:
import json
import time
import pathlib
from typing import List, Tuple, Optional

import numpy as np
import tensorflow as tf

In [2]:
# -------------------- CONFIG --------------------
FNN_CLIENT_ROOT  = pathlib.Path("/Users/sohinikar/FL/M.Tech_Dissertation/Client/client_params/FNN_BC")
LSTM_CLIENT_ROOT = pathlib.Path("/Users/sohinikar/FL/M.Tech_Dissertation/Client/client_params/LSTM_BC")

FNN_GLOBAL_PATH  = pathlib.Path("/Users/sohinikar/FL/M.Tech_Dissertation/Server/global_model/global_FNN_model.keras")
LSTM_GLOBAL_PATH = pathlib.Path("/Users/sohinikar/FL/M.Tech_Dissertation/Server/global_model/global_lstm_model.keras")

# Weight by client sample counts found in params.json ("samples")
WEIGHT_BY_SAMPLES = True

# Save updated copies alongside the originals (timestamped)
TS = time.strftime("%Y%m%d_%H%M%S")
FNN_OUT_MODEL  = FNN_GLOBAL_PATH.with_name(f"{FNN_GLOBAL_PATH.stem}_AGG_{TS}.keras")
FNN_OUT_WEIGHTS = FNN_GLOBAL_PATH.with_name(f"{FNN_GLOBAL_PATH.stem}_AGG_{TS}.weights.h5")
LSTM_OUT_MODEL = LSTM_GLOBAL_PATH.with_name(f"{LSTM_GLOBAL_PATH.stem}_AGG_{TS}.keras")
LSTM_OUT_WEIGHTS = LSTM_GLOBAL_PATH.with_name(f"{LSTM_GLOBAL_PATH.stem}_AGG_{TS}.weights.h5")

In [3]:
def find_clients(root: pathlib.Path) -> List[pathlib.Path]:
    """Return sorted user_* dirs that contain at least one *.weights.h5."""
    if not root.exists():
        return []
    clients = []
    for udir in sorted(root.glob("user_*")):
        if list(udir.glob("*.weights.h5")):
            clients.append(udir)
    return clients

def load_client_sample_count(udir: pathlib.Path) -> int:
    """Read samples from params.json if present; else default to 1."""
    p = udir / "params.json"
    if p.exists():
        try:
            with open(p, "r") as f:
                j = json.load(f)
            val = int(j.get("samples", 1))
            return max(val, 1)
        except Exception:
            return 1
    return 1

def pick_weights_file(udir: pathlib.Path) -> Optional[pathlib.Path]:
    """Pick the first *.weights.h5 in a user folder."""
    files = sorted(udir.glob("*.weights.h5"))
    return files[0] if files else None

def fedavg_aggregate(global_model_path: pathlib.Path,
                     clients_root: pathlib.Path,
                     out_model_path: pathlib.Path,
                     out_weights_path: pathlib.Path) -> Tuple[int, int]:
    """
    Perform FedAvg on all client weights under clients_root to update the global model.
    Returns: (num_clients_used, total_weight)
    """
    if not global_model_path.exists():
        raise FileNotFoundError(f"Global model not found: {global_model_path}")

    # Load a base model (for shapes)
    base_model = tf.keras.models.load_model(str(global_model_path))
    base_weights = base_model.get_weights()
    if not base_weights:
        raise RuntimeError("Global model has no weights (did you save an unbuilt model?).")

    # Accumulators
    acc = [np.zeros_like(w, dtype=np.float64) for w in base_weights]
    total_w = 0

    clients = find_clients(clients_root)
    print(f"Found {len(clients)} client(s) in {clients_root}")

    used = 0
    for udir in clients:
        wfile = pick_weights_file(udir)
        if wfile is None:
            print(f"  - {udir.name}: no *.weights.h5, skipping")
            continue

        weight = load_client_sample_count(udir) if WEIGHT_BY_SAMPLES else 1

        # Load a fresh model instance each time to avoid mutation issues
        try:
            m = tf.keras.models.load_model(str(global_model_path))
            m.load_weights(str(wfile))
            c_weights = m.get_weights()
        except Exception as e:
            print(f"  - {udir.name}: failed to load weights ({wfile.name}): {e}, skipping")
            continue

        # Shape check
        if len(c_weights) != len(acc) or any(cw.shape != bw.shape for cw, bw in zip(c_weights, base_weights)):
            print(f"  - {udir.name}: weight shapes mismatch, skipping")
            continue

        # Accumulate
        for i in range(len(acc)):
            acc[i] += weight * c_weights[i].astype(np.float64)
        total_w += weight
        used += 1
        print(f"  ✓ {udir.name}: included (weight={weight})")

    if used == 0 or total_w == 0:
        raise RuntimeError(f"No compatible client weights found in {clients_root}")

    # Compute weighted average and update base_model
    avg = [(acc[i] / float(total_w)).astype(base_weights[i].dtype) for i in range(len(acc))]
    base_model.set_weights(avg)

    # Save updated global model + weights
    out_model_path.parent.mkdir(parents=True, exist_ok=True)
    base_model.save(str(out_model_path))
    base_model.save_weights(str(out_weights_path))  # must end with ".weights.h5"

    print(f"\n✅ Aggregated {used} client(s) (total weight={total_w}) into:")
    print(f"   Model  : {out_model_path}")
    print(f"   Weights: {out_weights_path}\n")

    return used, total_w

def main():
    # Aggregate FNN
    try:
        print("=== Aggregating FNN global model ===")
        fedavg_aggregate(FNN_GLOBAL_PATH, FNN_CLIENT_ROOT, FNN_OUT_MODEL, FNN_OUT_WEIGHTS)
    except Exception as e:
        print(f"❌ FNN aggregation failed: {e}")

    # Aggregate LSTM
    try:
        print("=== Aggregating LSTM global model ===")
        fedavg_aggregate(LSTM_GLOBAL_PATH, LSTM_CLIENT_ROOT, LSTM_OUT_MODEL, LSTM_OUT_WEIGHTS)
    except Exception as e:
        print(f"❌ LSTM aggregation failed: {e}")

if __name__ == "__main__":
    main()


=== Aggregating FNN global model ===
Found 128 client(s) in /Users/sohinikar/FL/M.Tech_Dissertation/Client/client_params/FNN_BC
  ✓ user_001: included (weight=46)
  ✓ user_002: included (weight=46)
  ✓ user_003: included (weight=46)
  ✓ user_004: included (weight=46)
  ✓ user_005: included (weight=46)
  ✓ user_006: included (weight=46)
  ✓ user_007: included (weight=46)
  ✓ user_008: included (weight=46)
  ✓ user_009: included (weight=46)
  ✓ user_010: included (weight=46)
  ✓ user_011: included (weight=46)
  ✓ user_012: included (weight=46)
  ✓ user_013: included (weight=46)
  ✓ user_014: included (weight=46)
  ✓ user_015: included (weight=46)
  ✓ user_016: included (weight=46)
  ✓ user_017: included (weight=46)
  ✓ user_018: included (weight=46)
  ✓ user_019: included (weight=46)
  ✓ user_020: included (weight=46)
  ✓ user_021: included (weight=46)
  ✓ user_022: included (weight=46)
  ✓ user_023: included (weight=46)
  ✓ user_024: included (weight=46)
  ✓ user_025: included (weight=4