In [1]:
import json
from collections import defaultdict
from typing import Dict, List, Any

import flwr as fl
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report

# ---------- 工具函数 ----------

def _to_numpy(x):
    return np.asarray(x, dtype=float)

# 将字符串标签映射为 0..C-1，并返回映射
def encode_labels(y: List[Any]):
    uniq = sorted(list({str(v) for v in y}))
    str2id = {s: i for i, s in enumerate(uniq)}
    y_id = np.array([str2id[str(v)] for v in y], dtype=int)
    return y_id, str2id, {i: s for s, i in str2id.items()}

# 融合 softmax
def fuse_softmax(prob_list: List[np.ndarray], weights: List[float] | None = None) -> np.ndarray:
    if len(prob_list) == 1:
        return prob_list[0]
    P = np.vstack(prob_list)
    if weights is None:
        w = np.ones((P.shape[0], 1))
    else:
        w = _to_numpy(weights).reshape(-1, 1)
    P = (P * w).sum(axis=0) / w.sum()
    P = np.clip(P, 1e-9, 1.0)
    P = P / P.sum()
    return P

# ---------- 自定义策略 ----------

class PredictAndFuseStrategy(fl.server.strategy.FedAvg):
    def __init__(self, patient_ids: List[str], label_ids: np.ndarray, id2label: Dict[int, str], fusion: str = "mean"):
        # 仅进行 evaluate（预测/收集 softmax），不做训练
        super().__init__(
            fraction_fit=0.0,
            min_fit_clients=0,
            fraction_evaluate=1.0,
            min_evaluate_clients=1,
            min_available_clients=1,
        )
        self.patient_ids = patient_ids
        self.labels = label_ids
        self.id2label = id2label
        self.fusion = fusion
        self.buffer: Dict[str, Dict[str, np.ndarray]] = defaultdict(dict)

    # 修正签名：Flower 在启动时调用 evaluate(server_round, parameters)
    def evaluate(self, server_round: int, parameters):
        return None  # 不做全局模型评估

    def configure_evaluate(self, server_round: int, parameters, client_manager):
        # 广播“预测任务”给所有可用客户端
        config = {"task": "predict", "round": server_round}
        evaluate_ins = fl.server.client_proxy.EvaluateIns(parameters, config)
        clients = list(client_manager.all().values())
        return [(client, evaluate_ins) for client in clients]

    def aggregate_evaluate(self, server_round: int, results, failures):
        # 收集每个客户端上传的 JSON（包含 patient_id, probs, modality, 可选 weight）
        for client_proxy, evaluate_res in results:
            metrics = evaluate_res.metrics or {}
            preds_blob = metrics.get("preds_json", b"")
            preds_json = preds_blob.decode("utf-8") if isinstance(preds_blob, bytes) else preds_blob
            if not preds_json:
                continue
            rows = json.loads(preds_json)
            for r in rows:
                pid = str(r["patient_id"])  # 必须
                probs = _to_numpy(r["probs"])  # 长度 = C
                modality = str(r.get("modality", "unknown"))
                self.buffer[pid][modality] = probs

        # 计算当前轮能计算到的部分指标
        y_true = self.labels
        y_pred_single: List[int] = []
        have_pred_flags: List[bool] = []

        for pid in self.patient_ids:
            modal_dict = self.buffer.get(pid, {})
            if not modal_dict:
                have_pred_flags.append(False)
                y_pred_single.append(-1)
                continue
            probs_list = list(modal_dict.values())
            fused = fuse_softmax(probs_list)
            y_pred_single.append(int(np.argmax(fused)))
            have_pred_flags.append(True)

        idx = [i for i, ok in enumerate(have_pred_flags) if ok]
        metrics = {}
        if idx:
            yt = y_true[idx]
            yp = np.array(y_pred_single)[idx]
            acc = float(accuracy_score(yt, yp))
            mf1 = float(f1_score(yt, yp, average="macro"))
            metrics = {"acc_partial": acc, "macro_f1_partial": mf1, "n_pred": len(idx)}
            print(f"[Round {server_round}] acc={acc:.4f}, macro_f1={mf1:.4f}, n_pred={len(idx)}/{len(y_true)}")
        else:
            print(f"[Round {server_round}] no predictions yet.")

        return 0.0, metrics

    def export_final(self, out_csv: str):
        records = []
        for pid in self.patient_ids:
            modal_dict = self.buffer.get(pid, {})
            row = {"patient_id": pid}
            for m, p in modal_dict.items():
                row[f"probs_{m}"] = json.dumps(p.tolist())
                row[f"pred_{m}"] = int(np.argmax(p))
            if modal_dict:
                fused = fuse_softmax(list(modal_dict.values()))
                row["probs_fused"] = json.dumps(fused.tolist())
                row["pred_fused"] = int(np.argmax(fused))
            records.append(row)
        df = pd.DataFrame(records)
        df.to_csv(out_csv, index=False, encoding="utf-8-sig")
        return df




In [3]:
# ---------- 在 Jupyter 中使用 ----------

test_csv = "test_metadata.csv"   # 测试集 CSV 路径
n_classes = 4                    # 类别数
fusion = "mean"                  # 或 "weighted"
export_csv = "predictions_fused.csv"
rounds = 1

meta = pd.read_csv(test_csv)
assert "patient_id" in meta.columns and "label" in meta.columns
pids = meta["patient_id"].astype(str).tolist()
y_raw = meta["label"].tolist()
y_id, str2id, id2str = encode_labels(y_raw)

strategy = PredictAndFuseStrategy(patient_ids=pids, label_ids=y_id, id2label=id2str, fusion=fusion)

fl.server.start_server(
    server_address="0.0.0.0:8080",
    config=fl.server.ServerConfig(num_rounds=rounds),
    strategy=strategy,
)

final_df = strategy.export_final(export_csv)
print("Final predictions saved to:", export_csv)
final_df.head()

	Instead, use the `flower-superlink` CLI command to start a SuperLink as shown below:

		$ flower-superlink --insecure

	To view usage and all available options, run:

		$ flower-superlink --help

	Using `start_server()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
[92mINFO [0m:      Starting Flower server, config: num_rounds=1, no round_timeout
[92mINFO [0m:      Flower ECE: gRPC server running (1 rounds), SSL is disabled
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client
[92mINFO [0m:      Received initial parameters from one random client
[92mINFO [0m:      Starting evaluation of initial global parameters
[92mINFO [0m:      Evaluation returned no results (`None`)
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: no clients selected, cancel
[92mINFO [0m:      configure_evaluate: strategy sampl

[Round 1] acc=0.4000, macro_f1=0.3714, n_pred=10/10
Final predictions saved to: predictions_fused.csv


Unnamed: 0,patient_id,probs_WSI,pred_WSI,probs_fused,pred_fused
0,P001,"[0.23608514137618886, 0.17805322154675982, 0.3...",2,"[0.23608514137618886, 0.17805322154675982, 0.3...",2
1,P002,"[0.4182917036297869, 0.1745685071568715, 0.223...",0,"[0.4182917036297869, 0.1745685071568715, 0.223...",0
2,P003,"[0.21188875458539896, 0.20787244639378866, 0.0...",3,"[0.21188875458539896, 0.20787244639378866, 0.0...",3
3,P004,"[0.5074629584982421, 0.19192542794425588, 0.01...",0,"[0.5074629584982421, 0.19192542794425588, 0.01...",0
4,P005,"[0.20005901394364475, 0.10226203661128096, 0.3...",3,"[0.20005901394364475, 0.10226203661128096, 0.3...",3
