In [1]:
import os
import json
import time
import base64
import pathlib
from typing import Dict, List, Optional, Tuple

import requests
import numpy as np
import pandas as pd
import tensorflow as tf

In [2]:
# --------------------- CONFIG ---------------------
OWNER = "Triss11"
REPO  = "FL"
REF   = "main"
GLOBAL_MODEL_DIR = "M.Tech_Dissertation/Server/global_model"

In [3]:
# Local data roots
FNN_USERS_DIR  = pathlib.Path("/Users/sohinikar/FL/M.Tech_Dissertation/Client/data/FNN_BC_test_data")
LSTM_USERS_DIR = pathlib.Path("/Users/sohinikar/FL/M.Tech_Dissertation/Client/data/LSTM_BC_test_data")

# Output (weights + params)
FNN_OUT_ROOT  = pathlib.Path("/Users/sohinikar/FL/M.Tech_Dissertation/Client/client_params/FNN_BC")
LSTM_OUT_ROOT = pathlib.Path("/Users/sohinikar/FL/M.Tech_Dissertation/Client/client_params/LSTM_BC")

In [5]:
# Training hyperparams (adjust as you like)
EPOCHS = 5
BATCH_SIZE = 64
OPTIMIZER = "adam"
LOSS_FNN  = "binary_crossentropy"   # change to categorical_crossentropy if multiclass
LOSS_LSTM = "binary_crossentropy"
METRICS   = ["accuracy"]

# File naming conventions for model selection (we’ll try to guess)
FNN_HINTS  = ("fnn", "dense", "ffn", "mlp", "FNN", "Dense", "FFN", "MLP")
LSTM_HINTS = ("lstm", "rnn", "LSTM", "RNN")

In [6]:
def gh_session() -> requests.Session:
    s = requests.Session()
    token = os.environ.get("GITHUB_TOKEN")
    if token:
        s.headers.update({"Authorization": f"Bearer {token}"})
    s.headers.update({"Accept": "application/vnd.github+json"})
    return s

In [7]:
def list_github_dir(session: requests.Session, owner: str, repo: str, path: str, ref: str="main") -> List[Dict]:
    url = f"https://api.github.com/repos/{owner}/{repo}/contents/{path}?ref={ref}"
    r = session.get(url)
    if r.status_code != 200:
        raise RuntimeError(f"GitHub list failed: {r.status_code} {r.text}")
    return r.json()

In [9]:
def choose_model_files(entries: List[Dict]) -> Tuple[Optional[Dict], Optional[Dict]]:
    """
    Pick one FNN-like and one LSTM-like file from a list of contents entries.
    Preference: filenames containing hints. Fallback: first .keras/.h5 found.
    """
    model_files = [e for e in entries if e.get("type") == "file" and any(e["name"].endswith(ext) for ext in (".keras", ".h5"))]

    def pick_with_hints(hints):
        for e in model_files:
            name = e["name"]
            if any(h in name for h in hints):
                return e
        return None

    fnn = pick_with_hints(FNN_HINTS)
    lstm = pick_with_hints(LSTM_HINTS)

    # Fallbacks if not found via hints
    if fnn is None and model_files:
        fnn = model_files[0]
    if lstm is None and len(model_files) > 1:
        # pick something different than fnn if possible
        for e in model_files:
            if e is not fnn:
                lstm = e
                break
    elif lstm is None and model_files:
        lstm = model_files[0]

    return fnn, lstm

def download_file(session: requests.Session, entry: Dict, dest: pathlib.Path) -> pathlib.Path:
    """
    Download a file using 'download_url' from the GitHub contents API.
    """
    url = entry.get("download_url")
    if not url:
        raise RuntimeError(f"No download_url for {entry.get('name')}")
    dest.parent.mkdir(parents=True, exist_ok=True)
    r = session.get(url, stream=True)
    if r.status_code != 200:
        raise RuntimeError(f"Download failed for {entry.get('name')}: {r.status_code} {r.text}")
    with open(dest, "wb") as f:
        for chunk in r.iter_content(chunk_size=1024 * 1024):
            if chunk:
                f.write(chunk)
    return dest

def load_user_dirs(root: pathlib.Path) -> List[pathlib.Path]:
    """
    Return sorted list of user_* directories that contain X.csv and y.csv.
    """
    if not root.exists():
        return []
    users = []
    for p in sorted(root.glob("user_*")):
        if (p / "X.csv").exists() and (p / "y.csv").exists():
            users.append(p)
    return users

def train_one_user(model_path: pathlib.Path, user_dir: pathlib.Path, out_dir: pathlib.Path,
                   is_lstm: bool, loss: str) -> Dict:
    """
    Load model, train on user data, save weights + params.json to out_dir.
    """
    # Load data
    X = pd.read_csv(user_dir / "X.csv").to_numpy(dtype=np.float32)
    y = pd.read_csv(user_dir / "y.csv")["label"].to_numpy().astype(np.float32).reshape(-1)

    if is_lstm:
        # reshape to (N, 1, D)
        X = X.reshape((X.shape[0], 1, X.shape[1])).astype(np.float32)

    # Load model
    model = tf.keras.models.load_model(str(model_path))

    # Compile fresh (weights may not include optimizer state)
    model.compile(optimizer=OPTIMIZER, loss=loss, metrics=METRICS)

    # Train (no validation, as requested)
    start = time.time()
    hist = model.fit(X, y, epochs=EPOCHS, batch_size=BATCH_SIZE, verbose=0)
    elapsed = time.time() - start

    # Prepare output
    out_dir.mkdir(parents=True, exist_ok=True)
    weights_path = out_dir / f"{user_dir.name}.weights.h5"
    model.save_weights(str(weights_path))

    # Capture final metrics from history
    final_metrics = {}
    for k, v in hist.history.items():
        final_metrics[f"final_{k}"] = float(v[-1])

    params = {
        "user": user_dir.name,
        "model_file": model_path.name,
        "samples": int(X.shape[0]),
        "features": int(X.shape[-1]),
        "is_lstm": is_lstm,
        "epochs": EPOCHS,
        "batch_size": BATCH_SIZE,
        "optimizer": OPTIMIZER,
        "loss": loss,
        "metrics": METRICS,
        "train_seconds": round(elapsed, 3),
        **final_metrics
    }
    with open(out_dir / "params.json", "w") as f:
        json.dump(params, f, indent=2)

    return {"weights": str(weights_path), "params": params}

def main():
    session = gh_session()

    # 1) List models in the global_model folder
    print("🔎 Listing global model directory on GitHub…")
    entries = list_github_dir(session, OWNER, REPO, GLOBAL_MODEL_DIR, REF)

    # 2) Choose FNN and LSTM model files
    fnn_entry, lstm_entry = choose_model_files(entries)
    if not fnn_entry or not lstm_entry:
        raise RuntimeError("Could not locate suitable .keras/.h5 model files in the global_model folder.")

    # 3) Download models locally (cache under ./_downloaded_models)
    cache_dir = pathlib.Path("./_downloaded_models")
    fnn_local  = download_file(session, fnn_entry,  cache_dir / fnn_entry["name"])
    lstm_local = download_file(session, lstm_entry, cache_dir / lstm_entry["name"])
    print(f"⬇️  Downloaded FNN model:  {fnn_local}")
    print(f"⬇️  Downloaded LSTM model: {lstm_local}")

    # 4) Enumerate users
    fnn_users  = load_user_dirs(FNN_USERS_DIR)
    lstm_users = load_user_dirs(LSTM_USERS_DIR)

    if not fnn_users:
        print(f"⚠️  No valid FNN user folders found under: {FNN_USERS_DIR}")
    if not lstm_users:
        print(f"⚠️  No valid LSTM user folders found under: {LSTM_USERS_DIR}")

    # 5) Train FNN users
    print(f"🏃 Training FNN users ({len(fnn_users)}) …")
    for udir in fnn_users:
        out_dir = FNN_OUT_ROOT / udir.name
        try:
            result = train_one_user(fnn_local, udir, out_dir, is_lstm=False, loss=LOSS_FNN)
            print(f"✅ {udir.name}: saved -> {result['weights']}")
        except Exception as e:
            print(f"❌ {udir.name}: {e}")

    # 6) Train LSTM users
    print(f"🏃 Training LSTM users ({len(lstm_users)}) …")
    for udir in lstm_users:
        out_dir = LSTM_OUT_ROOT / udir.name
        try:
            result = train_one_user(lstm_local, udir, out_dir, is_lstm=True, loss=LOSS_LSTM)
            print(f"✅ {udir.name}: saved -> {result['weights']}")
        except Exception as e:
            print(f"❌ {udir.name}: {e}")

    print("🎉 Done.")

if __name__ == "__main__":
    main()

🔎 Listing global model directory on GitHub…
⬇️  Downloaded FNN model:  _downloaded_models/global_FNN_model.keras
⬇️  Downloaded LSTM model: _downloaded_models/global_lstm_model.keras
🏃 Training FNN users (128) …
✅ user_001: saved -> /Users/sohinikar/FL/M.Tech_Dissertation/Client/client_params/FNN_BC/user_001/user_001.weights.h5
✅ user_002: saved -> /Users/sohinikar/FL/M.Tech_Dissertation/Client/client_params/FNN_BC/user_002/user_002.weights.h5
✅ user_003: saved -> /Users/sohinikar/FL/M.Tech_Dissertation/Client/client_params/FNN_BC/user_003/user_003.weights.h5
✅ user_004: saved -> /Users/sohinikar/FL/M.Tech_Dissertation/Client/client_params/FNN_BC/user_004/user_004.weights.h5
✅ user_005: saved -> /Users/sohinikar/FL/M.Tech_Dissertation/Client/client_params/FNN_BC/user_005/user_005.weights.h5
✅ user_006: saved -> /Users/sohinikar/FL/M.Tech_Dissertation/Client/client_params/FNN_BC/user_006/user_006.weights.h5
✅ user_007: saved -> /Users/sohinikar/FL/M.Tech_Dissertation/Client/client_param