In [1]:
import json
from collections import defaultdict
from typing import Dict, List, Any

import flwr as fl
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, f1_score

# ---------- 工具函数 ----------

def _to_numpy(x):
    return np.asarray(x, dtype=float)

# 将字符串标签映射为 0..C-1，并返回映射
def encode_labels(y: List[Any]):
    uniq = sorted(list({str(v) for v in y}))
    str2id = {s: i for i, s in enumerate(uniq)}
    y_id = np.array([str2id[str(v)] for v in y], dtype=int)
    return y_id, str2id, {i: s for s, i in str2id.items()}

# 融合 softmax（按模态加权平均）
def fuse_softmax(prob_list: List[np.ndarray], weights: List[float] | None = None) -> np.ndarray:
    if len(prob_list) == 1:
        return prob_list[0]
    P = np.vstack(prob_list)
    if weights is None:
        w = np.ones((P.shape[0], 1))
    else:
        w = _to_numpy(weights).reshape(-1, 1)
    P = (P * w).sum(axis=0) / w.sum()
    P = np.clip(P, 1e-9, 1.0)
    P = P / P.sum()
    return P
# ====== 严格同步：等待 DNA / RNA / WSI 三个客户端 ======
import time

REQUIRED_MODALITIES = ["DNA", "RNA", "WSI"]
EXPECTED_CLIENTS = len(REQUIRED_MODALITIES)  # = 3

class PredictAndFuseStrategy(fl.server.strategy.FedAvg):
    def __init__(self,
                 patient_ids: List[str],
                 label_ids: np.ndarray,
                 id2label: Dict[int, str],
                 fusion: str = "mean",
                 modality_weights: Dict[str, float] | None = None):
        super().__init__(
            fraction_fit=0.0, min_fit_clients=0,
            fraction_evaluate=1.0,
            min_evaluate_clients=EXPECTED_CLIENTS,   # 必须等到 3 个客户端在线
            min_available_clients=EXPECTED_CLIENTS,
        )
        self.patient_ids = patient_ids
        self.labels = label_ids
        self.id2label = id2label
        self.fusion = fusion
        self.modality_weights = modality_weights
        self.buffer: Dict[str, Dict[str, np.ndarray]] = defaultdict(dict)

    def evaluate(self, server_round: int, parameters):
        return None  # 不做全局参数评估

    def configure_evaluate(self, server_round: int, parameters, client_manager):
        # 阻塞等待直到 >=3 个客户端在线
        target = EXPECTED_CLIENTS
        while True:
            online = client_manager.all()
            if len(online) >= target:
                break
            print(f"[Round {server_round}] waiting clients {len(online)}/{target} ...", end="\r")
            time.sleep(0.5)
        print(f"\n[Round {server_round}] READY: online clients = {list(online.keys())}")
        cfg = {"task": "predict", "round": server_round}
        evaluate_ins = fl.server.client_proxy.EvaluateIns(parameters, cfg)
        return [(c, evaluate_ins) for c in online.values()]

    def aggregate_evaluate(self, server_round: int, results, failures):
        seen_modalities = set()
        for client_proxy, evaluate_res in results:
            metrics = evaluate_res.metrics or {}
            blob = metrics.get("preds_json", b"")
            js = blob.decode("utf-8") if isinstance(blob, bytes) else blob
            if not js:
                continue
            for r in json.loads(js):
                pid = str(r["patient_id"])
                probs = np.asarray(r["probs"], dtype=float)
                modality = str(r.get("modality", "unknown"))
                self.buffer[pid][modality] = probs
                seen_modalities.add(modality)

        # 只在 DNA+RNA+WSI 都到齐的病人上融合与评估
        req = set(REQUIRED_MODALITIES)
        y_pred, mask = [], []
        for pid in self.patient_ids:
            modal = self.buffer.get(pid, {})
            if req.issubset(set(modal.keys())):
                if self.fusion == "mean":
                    P = fuse_softmax_by_modalities(modal, REQUIRED_MODALITIES, self.modality_weights)
                else:
                    # 目前仅实现 mean；可扩展 logit/几何平均等
                    P = fuse_softmax_by_modalities(modal, REQUIRED_MODALITIES, self.modality_weights)
                y_pred.append(int(np.argmax(P))); mask.append(True)
            else:
                y_pred.append(-1); mask.append(False)

        idx = [i for i, ok in enumerate(mask) if ok]
        if idx:
            yt = self.labels[idx]; yp = np.array(y_pred)[idx]
            acc = float(accuracy_score(yt, yp))
            mf1 = float(f1_score(yt, yp, average="macro"))
            print(f"[Round {server_round}] acc_full={acc:.4f}, macro_f1_full={mf1:.4f}, "
                  f"n_complete={len(idx)}/{len(self.labels)} (need {REQUIRED_MODALITIES})")
            return 0.0, {"acc_complete": acc, "macro_f1_complete": mf1, "n_complete": len(idx)}
        else:
            print(f"[Round {server_round}] no patient has all modalities yet; received={sorted(seen_modalities)}")
            return 0.0, {"n_complete": 0}

    def export_final(self, out_csv: str):
        req = set(REQUIRED_MODALITIES)
        rows = []
        for pid in self.patient_ids:
            modal = self.buffer.get(pid, {})
            row = {"patient_id": pid}
            # 展开每个模态
            for m, p in modal.items():
                row[f"probs_{m}"] = json.dumps(p.tolist(), ensure_ascii=False)
                row[f"pred_{m}"] = int(np.argmax(p))
            # 三模态融合
            if req.issubset(set(modal.keys())):
                P = fuse_softmax_by_modalities(modal, REQUIRED_MODALITIES, self.modality_weights)
                row["probs_fused"] = json.dumps(P.tolist(), ensure_ascii=False)
                row["pred_fused"] = int(np.argmax(P))
            rows.append(row)
        df = pd.DataFrame(rows)
        df.to_csv(out_csv, index=False, encoding="utf-8-sig")
        return df

In [None]:
test_csv = "testdata.csv"
export_csv = "predictions_fused.csv"

meta = pd.read_csv(test_csv)
assert "patient_id" in meta.columns and "label" in meta.columns
pids = meta["patient_id"].astype(str).tolist()
y_id, _, id2str = encode_labels(meta["label"].tolist())

strategy = PredictAndFuseStrategy(patient_ids=pids, label_ids=y_id, id2label=id2str, fusion="mean")

fl.server.start_server(
    server_address="0.0.0.0:8080",
    config=fl.server.ServerConfig(num_rounds=1),  # 只跑 1 轮；不到 3 个客户端就会一直等
    strategy=strategy,
)

final_df = strategy.export_final(export_csv)
print("Final predictions saved to:", export_csv)
print(final_df.head())

	Instead, use the `flower-superlink` CLI command to start a SuperLink as shown below:

		$ flower-superlink --insecure

	To view usage and all available options, run:

		$ flower-superlink --help

	Using `start_server()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
[92mINFO [0m:      Starting Flower server, config: num_rounds=1, no round_timeout
[92mINFO [0m:      Flower ECE: gRPC server running (1 rounds), SSL is disabled
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client
