In [3]:
import json
from collections import defaultdict
from typing import Dict, List, Any

import flwr as fl
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, f1_score

# ============================================================
# Utility Functions
# ============================================================

def _to_numpy(x):
    """Convert to numpy array"""
    return np.asarray(x, dtype=float)

def encode_labels(y: List[Any]):
    """Encode string labels to integer IDs 0..C-1"""
    uniq = sorted(list({str(v) for v in y}))
    str2id = {s: i for i, s in enumerate(uniq)}
    y_id = np.array([str2id[str(v)] for v in y], dtype=int)
    id2str = {i: s for s, i in str2id.items()}
    return y_id, str2id, id2str

def fuse_softmax(prob_list: List[np.ndarray], weights: List[float] | None = None) -> np.ndarray:
    """Fuse multiple modality softmax probability vectors via weighted mean"""
    if len(prob_list) == 1:
        return prob_list[0]
    P = np.vstack(prob_list)
    if weights is None:
        w = np.ones((P.shape[0], 1))
    else:
        w = _to_numpy(weights).reshape(-1, 1)
    P = (P * w).sum(axis=0) / w.sum()
    P = np.clip(P, 1e-9, 1.0)
    P = P / P.sum()
    return P


# ============================================================
# Custom Federated Strategy: Prediction + Fusion Only
# ============================================================

class PredictAndFuseStrategy(fl.server.strategy.FedAvg):
    """
    Custom federated strategy for prediction only.
    Clients return softmax probabilities, and the server fuses them by patient_id.
    """

    def __init__(self, patient_ids: List[str], label_ids: np.ndarray, id2label: Dict[int, str], fusion: str = "mean"):
        super().__init__(
            fraction_fit=0.0,          # Disable training
            min_fit_clients=0,
            fraction_evaluate=1.0,     # All clients perform evaluate (i.e., predict)
            min_evaluate_clients=1,
            min_available_clients=1,
        )
        self.patient_ids = patient_ids
        self.labels = label_ids
        self.id2label = id2label
        self.fusion = fusion
        self.buffer: Dict[str, Dict[str, np.ndarray]] = defaultdict(dict)

    def evaluate(self, server_round: int, parameters):
        """No global model evaluation"""
        return None

    def configure_evaluate(self, server_round: int, parameters, client_manager):
        """Broadcast prediction task to all clients"""
        config = {"task": "predict", "round": server_round}
        evaluate_ins = fl.server.client_proxy.EvaluateIns(parameters, config)
        clients = list(client_manager.all().values())
        return [(client, evaluate_ins) for client in clients]

    def aggregate_evaluate(self, server_round: int, results, failures):
        """Collect client predictions and perform fusion"""
        for client_proxy, evaluate_res in results:
            metrics = evaluate_res.metrics or {}
            preds_blob = metrics.get("preds_json", b"")
            preds_json = preds_blob.decode("utf-8") if isinstance(preds_blob, bytes) else preds_blob
            if not preds_json:
                continue
            rows = json.loads(preds_json)

            for r in rows:
                pid = str(r["patient_id"])
                probs = _to_numpy(r["probs"])
                modality = str(r.get("modality", "unknown"))
                self.buffer[pid][modality] = probs

        # ===== Evaluate partial results so far =====
        y_true = self.labels
        y_pred_partial = []
        has_pred = []

        for pid in self.patient_ids:
            modal_dict = self.buffer.get(pid, {})
            if not modal_dict:
                has_pred.append(False)
                y_pred_partial.append(-1)
                continue
            probs_list = list(modal_dict.values())
            fused = fuse_softmax(probs_list)
            y_pred_partial.append(int(np.argmax(fused)))
            has_pred.append(True)

        idx = [i for i, ok in enumerate(has_pred) if ok]
        metrics = {}
        if idx:
            yt = y_true[idx]
            yp = np.array(y_pred_partial)[idx]
            acc = float(accuracy_score(yt, yp))
            mf1 = float(f1_score(yt, yp, average="macro"))
            metrics = {"acc_partial": acc, "macro_f1_partial": mf1, "n_pred": len(idx)}
            print(f"[Round {server_round}] ✅ Fused {len(idx)}/{len(y_true)} samples, acc={acc:.4f}, macro_f1={mf1:.4f}")
        else:
            print(f"[Round {server_round}] ⚠ No client predictions received yet")

        return 0.0, metrics

    def export_final(self, out_csv: str):
        """Export fused predictions for each patient"""
        records = []
        for pid in self.patient_ids:
            modal_dict = self.buffer.get(pid, {})
            row = {"patient_id": pid}
            for m, p in modal_dict.items():
                row[f"probs_{m}"] = json.dumps(p.tolist())
                row[f"pred_{m}"] = int(np.argmax(p))
            if modal_dict:
                fused = fuse_softmax(list(modal_dict.values()))
                row["probs_fused"] = json.dumps(fused.tolist())
                row["pred_fused"] = int(np.argmax(fused))
            records.append(row)
        df = pd.DataFrame(records)
        df.to_csv(out_csv, index=False, encoding="utf-8-sig")
        print(f"✅ Final fused prediction results saved to: {out_csv}")
        return df


In [4]:

# ============================================================
# 启动服务器
# ============================================================

if __name__ == "__main__":
    # 你需要的测试元数据文件（包含 patient_id, label）
    test_csv = r"C:\Users\mxjli\Desktop\test_metadata_THENEWEST - 28.csv"
    n_classes = 4
    fusion = "mean"
    export_csv = "predictions_fused.csv"
    rounds = 1

    meta = pd.read_csv(test_csv)
    assert "patient_id" in meta.columns and "label" in meta.columns

    pids = meta["patient_id"].astype(str).tolist()
    y_raw = meta["label"].tolist()
    y_id, str2id, id2str = encode_labels(y_raw)

    print(f"[INFO] Loaded {len(pids)} test patients with labels.")
    strategy = PredictAndFuseStrategy(patient_ids=pids, label_ids=y_id, id2label=id2str, fusion=fusion)

    # 启动 Flower 服务器
    fl.server.start_server(
        server_address="0.0.0.0:8090",
        config=fl.server.ServerConfig(num_rounds=rounds),
        strategy=strategy,
    )

    # 导出最终结果
    final_df = strategy.export_final(export_csv)
    print("✅ Final predictions saved to:", export_csv)
    print(final_df.head())


	Instead, use the `flower-superlink` CLI command to start a SuperLink as shown below:

		$ flower-superlink --insecure

	To view usage and all available options, run:

		$ flower-superlink --help

	Using `start_server()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
[92mINFO [0m:      Starting Flower server, config: num_rounds=1, no round_timeout
[92mINFO [0m:      Flower ECE: gRPC server running (1 rounds), SSL is disabled
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client


[INFO] Loaded 28 test patients with labels.


[92mINFO [0m:      Received initial parameters from one random client
[92mINFO [0m:      Starting evaluation of initial global parameters
[92mINFO [0m:      Evaluation returned no results (`None`)
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: no clients selected, cancel
[92mINFO [0m:      configure_evaluate: strategy sampled 1 clients (out of 1)
[92mINFO [0m:      aggregate_evaluate: received 1 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 1 round(s) in 0.08s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 0.0
[92mINFO [0m:      	History (metrics, distributed, evaluate):
[92mINFO [0m:      	{'acc_partial': [(1, 0.35714285714285715)],
[92mINFO [0m:      	 'macro_f1_partial': [(1, 0.29166666666666663)],
[92mINFO [0m:      	 'n_pred': [(1, 28)]}
[92mINFO [0m:      


[Round 1] ✅ Fused 28/28 samples, acc=0.3571, macro_f1=0.2917
✅ Final fused prediction results saved to: predictions_fused.csv
✅ Final predictions saved to: predictions_fused.csv
     patient_id                                          probs_WSI  pred_WSI  \
0  TCGA-A2-A3XZ  [0.48621126371500306, 0.35949420481930233, 0.1...         0   
1  TCGA-A7-A426  [0.16658036898221074, 0.6485370507533245, 0.13...         1   
2  TCGA-A2-A04N  [0.36131723604752164, 0.2906667768920461, 0.16...         0   
3  TCGA-A8-A09X  [0.08855836624591445, 0.11588980128127209, 0.2...         3   
4  TCGA-PL-A8LX  [0.22262549905644102, 0.17420840362947843, 0.2...         3   

                                         probs_fused  pred_fused  
0  [0.48621126371500306, 0.35949420481930233, 0.1...           0  
1  [0.16658036898221074, 0.6485370507533245, 0.13...           1  
2  [0.36131723604752164, 0.2906667768920461, 0.16...           0  
3  [0.08855836624591445, 0.11588980128127209, 0.2...           3  
4  [0.