In [1]:
import os
import json
import joblib
import numpy as np
import pandas as pd
from pathlib import Path
import flwr as fl

# ===== Configuration =====
SERVER = "127.0.0.1:8090"                     # Server address
FEATURES_CSV = "outputs/features_test.csv"    # Test feature file
MODEL_PATH = "outputs/stage_classifier.pkl"   # Trained RandomWalk model
SCALER_PATH = "outputs/scaler.pkl"            # Scaler
FEATCOLS_PATH = "outputs/feat_cols.pkl"       # Feature column file
MODALITY = "WSI"                              # Modality name
N_CLASSES = 4                                 # Number of classes
WEIGHT = 1.0                                  # Weight for federated aggregation

# ===== Load model, scaler, and feature columns =====
print("[INFO] Loading model, scaler, and feature columns...")
clf = joblib.load(MODEL_PATH)
scaler = joblib.load(SCALER_PATH)
feat_cols = joblib.load(FEATCOLS_PATH)

# ===== Load test features =====
df = pd.read_csv(FEATURES_CSV)
print(f"[INFO] Loaded test set with {len(df)} samples.")

# ===== Check feature completeness =====
if not all(col in df.columns for col in feat_cols):
    missing = [c for c in feat_cols if c not in df.columns]
    raise ValueError(f"❌ Missing feature columns: {missing}")

# Extract and standardize features
X = df[feat_cols].values
X_scaled = scaler.transform(X)

# ===== Extract patient IDs =====
def to_patient_id(path: str) -> str:
    """Extract the first three parts of TCGA ID from file path"""
    stem = Path(path).stem
    return "-".join(stem.split("-")[:3])

if "path" in df.columns:
    ids = [to_patient_id(p) for p in df["path"]]
else:
    ids = [f"sample_{i}" for i in range(len(df))]

print(f"[INFO] Prepared {len(ids)} samples for prediction.")
print("[DEBUG] First 5 patient IDs:", ids[:5])

# ===== Define client class =====
class RandomWalkClient(fl.client.NumPyClient):
    def __init__(self, clf, X, ids, modality, weight):
        self.clf = clf
        self.X = X
        self.ids = ids
        self.modality = modality
        self.weight = weight

    # Required methods for federated learning
    def get_parameters(self, config):
        return []

    def fit(self, parameters, config):
        return [], 0, {}

    def evaluate(self, parameters, config):
        task = config.get("task", "")
        metrics = {}

        if task == "predict":
            print(f"[INFO] Starting prediction for {len(self.ids)} samples...")
            probs_all = self.clf.predict_proba(self.X)
            rows = []

            for pid, probs in zip(self.ids, probs_all):
                probs = np.clip(probs.astype(float), 1e-9, 1.0)
                probs = probs / probs.sum()
                row = {
                    "patient_id": pid,
                    "probs": probs.tolist(),
                    "modality": self.modality,
                    "weight": self.weight
                }
                print(f"[PREDICT] {pid}: predicted class = {np.argmax(probs)}, probs = {probs}")
                rows.append(row)

            # Package prediction results as JSON to send to server
            metrics = {"preds_json": json.dumps(rows)}
            print(f"[INFO] Prediction complete. Sending {len(rows)} results to server.")

        return 0.0, len(self.ids), metrics


# ===== Start client =====
if __name__ == "__main__":
    print(f"[START] Connecting to server {SERVER} ...")
    fl.client.start_numpy_client(
        server_address=SERVER,
        client=RandomWalkClient(clf, X_scaled, ids, MODALITY, WEIGHT)
    )
    print("[DONE] Client execution finished.")


[INFO] Loading model, scaler, and feature columns...


	Instead, use `flwr.client.start_client()` by ensuring you first call the `.to_client()` method as shown below: 
	flwr.client.start_client(
		server_address='<IP>:<PORT>',
		client=FlowerClient().to_client(), # <-- where FlowerClient is of type flwr.client.NumPyClient object
	)
	Using `start_numpy_client()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
	Instead, use the `flower-supernode` CLI command to start a SuperNode as shown below:

		$ flower-supernode --insecure --superlink='<IP>:<PORT>'

	To view all available options, run:

		$ flower-supernode --help

	Using `start_client()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
[92mINFO [0m:      
[92mINFO [0m:      Received: get_parameters message 0a80b467-8ad5-44e7-ba42-8e3cab048f0b
[92mINFO [0m:      Sent reply
[92mINFO [0m:      
[92mINFO [0

[INFO] Loaded test set with 28 samples.
[INFO] Prepared 28 samples for prediction.
[DEBUG] First 5 patient IDs: ['TCGA-A1-A0SN', 'TCGA-A2-A04N', 'TCGA-A2-A04R', 'TCGA-A2-A0T1', 'TCGA-A2-A0T3']
[START] Connecting to server 127.0.0.1:8090 ...
[INFO] Starting prediction for 28 samples...
[PREDICT] TCGA-A1-A0SN: predicted class = 1, probs = [0.4285238  0.43505666 0.11184914 0.02457039]
[PREDICT] TCGA-A2-A04N: predicted class = 0, probs = [0.34291728 0.21507746 0.29402504 0.14798022]
[PREDICT] TCGA-A2-A04R: predicted class = 1, probs = [0.34740799 0.45692116 0.18345198 0.01221886]
[PREDICT] TCGA-A2-A0T1: predicted class = 1, probs = [0.38765068 0.5198211  0.08902231 0.00350591]
[PREDICT] TCGA-A2-A0T3: predicted class = 3, probs = [0.25133899 0.13552774 0.20316999 0.40996328]
[PREDICT] TCGA-A2-A0YM: predicted class = 0, probs = [0.46647353 0.39506375 0.13422931 0.00423341]
[PREDICT] TCGA-A2-A3XZ: predicted class = 0, probs = [0.32766502 0.24240354 0.25819268 0.17173876]
[PREDICT] TCGA-A7-A42