In [1]:
import json, time
from collections import defaultdict
from typing import Dict, List, Any

import flwr as fl
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, f1_score

# ---------- 工具函数 ----------
def _to_numpy(x):
    return np.asarray(x, dtype=float)

def encode_labels(y: List[Any]):
    uniq = sorted(list({str(v) for v in y}))
    str2id = {s: i for i, s in enumerate(uniq)}
    y_id = np.array([str2id[str(v)] for v in y], dtype=int)
    return y_id, str2id, {i: s for s, i in str2id.items()}

def fuse_softmax(
    prob_list: List[np.ndarray],
    weights: List[float] | None = None,
    mode: str = "mean",
    eps: float = 1e-12,
) -> np.ndarray:
    P = np.vstack(prob_list).astype(float)        # [M, C]
    M, C = P.shape
    P = np.clip(P, eps, 1.0)
    P = P / P.sum(axis=1, keepdims=True)

    if mode == "mean":
        out = P.mean(axis=0)

    elif mode == "weighted":
        w = np.ones(M) if weights is None else np.asarray(weights, dtype=float)
        w = np.clip(w, eps, None)
        out = (P * w[:, None]).sum(axis=0) / w.sum()

    elif mode == "poe":  # 几何平均 / Log Opinion Pool
        w = np.ones(M) if weights is None else np.asarray(weights, dtype=float)
        logp = np.log(P) * w[:, None]
        out = np.exp(logp.sum(axis=0)); out /= out.sum()

    elif mode == "logits":  # logits 平均
        w = np.ones(M) if weights is None else np.asarray(weights, dtype=float)
        z = (np.log(P) * w[:, None]).sum(axis=0)
        z = z - z.max()
        out = np.exp(z); out /= out.sum()

    elif mode == "entropy":  # 熵权
        H = -(P * np.log(P)).sum(axis=1)
        Hn = H / np.log(C)
        w = 1.0 - Hn
        if weights is not None:
            w = w * np.asarray(weights, float)
        w = np.clip(w, eps, None)
        out = (P * w[:, None]).sum(axis=0) / w.sum()

    elif mode == "maxconf":  # 选择最自信模态
        m = int(np.argmax(P.max(axis=1)))
        out = P[m]

    elif mode == "borda":    # 排序打分
        scores = np.zeros(C, dtype=float)
        for m in range(M):
            order = np.argsort(-P[m])
            for rank, cls in enumerate(order):
                scores[cls] += (C - rank)
        out = scores / scores.sum()

    else:
        out = P.mean(axis=0)

    out = np.clip(out, eps, 1.0)
    out = out / out.sum()
    return out

# 新增：根据模态字典做融合的小助手
def fuse_softmax_by_modalities(
    modal_dict: Dict[str, np.ndarray],
    required_modalities: List[str],
    modality_weights: Dict[str, float] | None,
    mode: str = "mean",
) -> np.ndarray:
    probs_list, weights = [], []
    for m in required_modalities:
        if m not in modal_dict:
            raise ValueError(f"Missing modality '{m}' for fusion")
        probs_list.append(modal_dict[m])
        if modality_weights is not None and m in modality_weights:
            weights.append(float(modality_weights[m]))
    if mode == "weighted" or (mode in {"entropy","logits","poe"} and weights):
        return fuse_softmax(probs_list, weights=weights if weights else None, mode=mode)
    else:
        return fuse_softmax(probs_list, weights=None, mode=mode)

# ====== 严格同步：等待 DNA / RNA / WSI 三个客户端 ======
REQUIRED_MODALITIES = ["DNA", "RNA", "WSI"]   # 如只要两模态：改为 ["DNA","RNA"]
EXPECTED_CLIENTS = len(REQUIRED_MODALITIES)

# —— 加权搜索顺序（与离线脚本一致）——
WEIGHT_SEARCH_ORDER = ["RNA", "DNA", "WSI"]

class PredictAndFuseStrategy(fl.server.strategy.FedAvg):
    def __init__(self,
                 patient_ids: List[str],
                 label_ids: np.ndarray,
                 id2label: Dict[int, str],
                 fusion: str = "weighted",
                 modality_weights: Dict[str, float] | None = None,
                 enable_weight_search: bool = True,
                 search_step: float = 0.1,
                 search_objective: str = "accuracy"   # ← 改为 accuracy
                 ):
        super().__init__(  # 同前
        )
        self.patient_ids = patient_ids
        self.labels = label_ids
        self.id2label = id2label
        self.fusion = fusion
        self.modality_weights = modality_weights
        self.buffer: Dict[str, Dict[str, np.ndarray]] = defaultdict(dict)

        self.enable_weight_search = enable_weight_search
        self.search_step = float(search_step)
        assert search_objective in ("accuracy", "f1_weighted", "f1_macro")
        self.search_objective = search_objective

        self.best_weights: Dict[str, float] | None = None
        self.best_score: float | None = None
        
    def evaluate(self, server_round: int, parameters):
        return None  # 不做全局参数评估

    def configure_evaluate(self, server_round: int, parameters, client_manager):
        target = EXPECTED_CLIENTS
        while True:
            online = client_manager.all()
            if len(online) >= target:
                break
            print(f"[Round {server_round}] waiting clients {len(online)}/{target} ...", end="\r")
            time.sleep(0.5)
        print(f"\n[Round {server_round}] READY: online clients = {list(online.keys())}")
        cfg = {"task": "predict", "round": server_round}
        evaluate_ins = fl.server.client_proxy.EvaluateIns(parameters, cfg)
        return [(c, evaluate_ins) for c in online.values()]

    def _complete_indices(self):
        req = set(REQUIRED_MODALITIES)
        idx = []
        for i, pid in enumerate(self.patient_ids):
            modal = self.buffer.get(pid, {})
            if req.issubset(set(modal.keys())):
                idx.append(i)
        return idx

    def _fuse_pred_with_weights(self, pid: str, weights_dict: Dict[str, float]) -> np.ndarray:
        # 使用 "weighted" 融合，并按字典映射权重（顺序不敏感）
        return fuse_softmax_by_modalities(
            self.buffer[pid],
            REQUIRED_MODALITIES,
            modality_weights=weights_dict,
            mode="weighted",
        )

    def _search_best_weights(self, idx: List[int]) -> Dict[str, float] | None:
        """在 complete 样本上网格搜索 (RNA, DNA, WSI) 权重，目标为 weighted-F1 或 macro-F1。"""
        if not idx:
            return None

        weight_options = np.arange(0.0, 1.0 + 1e-9, self.search_step)
        best_score = -1.0
        best_w = {"RNA": 1/3, "DNA": 1/3, "WSI": 1/3}

        # 收集 y_true
        yt = self.labels[idx]

        # 穷举网格：与离线脚本一致（RNA, DNA, 剩余→WSI）
        for w_rna in weight_options:
            for w_dna in weight_options:
                if w_rna + w_dna > 1.0:
                    continue
                w_wsi = 1.0 - w_rna - w_dna
                weights_dict = {"RNA": float(w_rna), "DNA": float(w_dna), "WSI": float(w_wsi)}

                # 融合所有 complete 病人
                y_pred = []
                for i in idx:
                    pid = self.patient_ids[i]
                    P = self._fuse_pred_with_weights(pid, weights_dict)
                    y_pred.append(int(np.argmax(P)))
                yp = np.asarray(y_pred, dtype=int)

                if self.search_objective == "f1_weighted":
                    score = f1_score(yt, yp, average="weighted")
                else:
                    score = f1_score(yt, yp, average="macro")

                if score > best_score:
                    best_score = float(score)
                    best_w = weights_dict

        self.best_weights = best_w
        self.best_score = best_score
        print(f"[Search] Best Weights (RNA, DNA, WSI) = ({best_w['RNA']:.2f}, {best_w['DNA']:.2f}, {best_w['WSI']:.2f}); "
              f"best {self.search_objective} = {best_score:.4f}")
        return best_w

    def aggregate_evaluate(self, server_round: int, results, failures):
        seen_modalities = set()
        for client_proxy, evaluate_res in results:
            metrics = evaluate_res.metrics or {}
            blob = metrics.get("preds_json", b"")
            js = blob.decode("utf-8") if isinstance(blob, bytes) else blob
            if not js:
                continue
            for r in json.loads(js):
                pid = str(r["patient_id"])
                probs = np.asarray(r["probs"], dtype=float)
                modality = str(r.get("modality", "unknown"))
                self.buffer[pid][modality] = probs
                seen_modalities.add(modality)

        # 只在所有必需模态都齐的病人上融合与评估
        idx = self._complete_indices()

        if not idx:
            print(f"[Round {server_round}] no patient has all modalities yet; received={sorted(seen_modalities)}")
            return 0.0, {"n_complete": 0}

        # —— 可选：权重搜索（仅在 fusion='weighted' 时启用）——
        if self.fusion == "weighted" and self.enable_weight_search:
            best_w = self._search_best_weights(idx)
            if best_w is not None:
                # 将最优权重写回，用于后续导出/评估
                self.modality_weights = best_w

        # 在（搜索后的）当前策略与权重下评估
        yt = self.labels[idx]
        y_pred = []
        for i in idx:
            pid = self.patient_ids[i]
            if self.fusion == "weighted":
                P = fuse_softmax_by_modalities(
                    self.buffer[pid], REQUIRED_MODALITIES, self.modality_weights, mode="weighted"
                )
            else:
                P = fuse_softmax_by_modalities(
                    self.buffer[pid], REQUIRED_MODALITIES, self.modality_weights, mode=self.fusion
                )
            y_pred.append(int(np.argmax(P)))
        yp = np.asarray(y_pred, dtype=int)

        acc = float(accuracy_score(yt, yp))
        mf1 = float(f1_score(yt, yp, average="macro"))
        wf1 = float(f1_score(yt, yp, average="weighted"))

        print(f"[Round {server_round}] acc_full={acc:.4f}, macro_f1_full={mf1:.4f}, weighted_f1_full={wf1:.4f}, "
              f"n_complete={len(idx)}/{len(self.labels)} (need {REQUIRED_MODALITIES}, fusion={self.fusion})")

        # 返回指标，包含最优权重（如已搜索）
        out_metrics = {
            "acc_complete": acc,
            "macro_f1_complete": mf1,
            "weighted_f1_complete": wf1,
            "n_complete": len(idx),
            "fusion": self.fusion,
        }
        if self.best_weights is not None:
            out_metrics.update({
                "best_weights_rna": float(self.best_weights["RNA"]),
                "best_weights_dna": float(self.best_weights["DNA"]),
                "best_weights_wsi": float(self.best_weights["WSI"]),
                f"best_{self.search_objective}": float(self.best_score if self.best_score is not None else -1.0),
            })
        return 0.0, out_metrics

    def export_final(self, out_csv: str):
        req = set(REQUIRED_MODALITIES)
        rows = []
        for pid in self.patient_ids:
            modal = self.buffer.get(pid, {})
            row = {"patient_id": pid}
            for m, p in modal.items():
                row[f"probs_{m}"] = json.dumps(p.tolist(), ensure_ascii=False)
                row[f"pred_{m}"] = int(np.argmax(p))
            if req.issubset(set(modal.keys())):
                # 使用（可能已搜索得到的）当前权重与策略导出融合结果
                if self.fusion == "weighted":
                    P = fuse_softmax_by_modalities(
                        modal, REQUIRED_MODALITIES, self.modality_weights, mode="weighted"
                    )
                else:
                    P = fuse_softmax_by_modalities(
                        modal, REQUIRED_MODALITIES, self.modality_weights, mode=self.fusion
                    )
                row["probs_fused"] = json.dumps(P.tolist(), ensure_ascii=False)
                row["pred_fused"] = int(np.argmax(P))
            rows.append(row)
        df = pd.DataFrame(rows)
        df.to_csv(out_csv, index=False, encoding="utf-8-sig")
        return df

In [2]:
# ===== 启动区（直接粘在文件底部） =====
test_csv = "test_metadata_THENEWEST - 28.csv"
export_csv = "predictions_fused.csv"

# 1) 等待三模态
REQUIRED_MODALITIES = ["DNA", "RNA", "WSI"]
EXPECTED_CLIENTS = len(REQUIRED_MODALITIES)

# 2) 读取测试集并编码标签
meta = pd.read_csv(test_csv)
assert "patient_id" in meta.columns and "label" in meta.columns
pids = meta["patient_id"].astype(str).tolist()
y_id, str2id, id2str = encode_labels(meta["label"].tolist())

# 3) 切换融合策略 & 网格搜索设置
strategy = PredictAndFuseStrategy(
    patient_ids=pids,
    label_ids=y_id,
    id2label=id2str,
    fusion="weighted",            # ← 只在 weighted 下会触发权重搜索
    modality_weights=None,        # ← 不预设则由搜索得到
    enable_weight_search=True,    # ← 开关
    search_step=0.1,              # ← 与离线脚本一致；可改 0.05/0.02 做二次微调
    search_objective="f1_weighted"  # ← 与离线脚本一致；可改 "f1_macro"
)

# 4) 启动服务器（必须等 3 个客户端上线）
fl.server.start_server(
    server_address="0.0.0.0:8080",
    config=fl.server.ServerConfig(num_rounds=1),
    strategy=strategy,
)

# 5) 导出融合结果（使用已搜索出的最优权重）
final_df = strategy.export_final(export_csv)
print("Final predictions saved to:", export_csv)
print("Best weights used:", strategy.modality_weights)
print(final_df.head())

INFO flwr 2025-10-21 14:36:25,941 | app.py:163 | Starting Flower server, config: ServerConfig(num_rounds=1, round_timeout=None)
INFO flwr 2025-10-21 14:36:25,976 | app.py:176 | Flower ECE: gRPC server running (1 rounds), SSL is disabled
INFO flwr 2025-10-21 14:36:25,976 | server.py:89 | Initializing global parameters
INFO flwr 2025-10-21 14:36:25,976 | server.py:276 | Requesting initial parameters from one random client
INFO flwr 2025-10-21 14:36:36,047 | server.py:280 | Received initial parameters from one random client
INFO flwr 2025-10-21 14:36:36,047 | server.py:91 | Evaluating initial parameters
INFO flwr 2025-10-21 14:36:36,047 | server.py:104 | FL starting
INFO flwr 2025-10-21 14:36:54,084 | server.py:220 | fit_round 1: no clients selected, cancel
DEBUG flwr 2025-10-21 14:36:54,085 | server.py:173 | evaluate_round 1: strategy sampled 3 clients (out of 3)
DEBUG flwr 2025-10-21 14:36:54,182 | server.py:187 | evaluate_round 1 received 3 results and 0 failures



[Round 1] READY: online clients = ['64cc0f4de31146f1b01da0605e960d04', '43834ba089b94cc5b2f2762f3ab4161f', 'ad22a4a78b5240629defbea3431ddf90']


INFO flwr 2025-10-21 14:36:54,329 | server.py:153 | FL finished in 18.27839529999983
INFO flwr 2025-10-21 14:36:54,329 | app.py:226 | app_fit: losses_distributed [(1, 0.0)]
INFO flwr 2025-10-21 14:36:54,329 | app.py:227 | app_fit: metrics_distributed_fit {}
INFO flwr 2025-10-21 14:36:54,330 | app.py:228 | app_fit: metrics_distributed {'acc_complete': [(1, 0.7142857142857143)], 'macro_f1_complete': [(1, 0.5348932676518884)], 'weighted_f1_complete': [(1, 0.7007389162561576)], 'n_complete': [(1, 28)], 'fusion': [(1, 'weighted')], 'best_weights_rna': [(1, 0.7000000000000001)], 'best_weights_dna': [(1, 0.30000000000000004)], 'best_weights_wsi': [(1, -1.1102230246251565e-16)], 'best_f1_weighted': [(1, 0.7007389162561576)]}
INFO flwr 2025-10-21 14:36:54,331 | app.py:229 | app_fit: losses_centralized []
INFO flwr 2025-10-21 14:36:54,331 | app.py:230 | app_fit: metrics_centralized {}


[Search] Best Weights (RNA, DNA, WSI) = (0.70, 0.30, -0.00); best f1_weighted = 0.7007
[Round 1] acc_full=0.7143, macro_f1_full=0.5349, weighted_f1_full=0.7007, n_complete=28/28 (need ['DNA', 'RNA', 'WSI'], fusion=weighted)
Final predictions saved to: predictions_fused.csv
Best weights used: {'RNA': 0.7000000000000001, 'DNA': 0.30000000000000004, 'WSI': -1.1102230246251565e-16}
     patient_id                                          probs_RNA  pred_RNA  \
0  TCGA-A2-A3XZ  [0.8015482425689697, 0.12688452005386353, 0.05...         0   
1  TCGA-A7-A426  [0.013402990065515041, 0.1585855484008789, 0.8...         2   
2  TCGA-A2-A04N  [0.7289366722106934, 0.165786013007164, 0.0860...         0   
3  TCGA-A8-A09X  [0.008406032808125019, 0.9169505834579468, 0.0...         1   
4  TCGA-PL-A8LX  [0.02609281986951828, 0.1378437727689743, 0.80...         2   

                                           probs_WSI  pred_WSI  \
0  [0.4862112637150032, 0.3594942048193022, 0.118...         0   
1  [0.

In [1]:
# -*- coding: utf-8 -*-
import json, time
from collections import defaultdict
from typing import Dict, List, Any

import numpy as np
import pandas as pd
import flwr as fl
from sklearn.metrics import accuracy_score, f1_score

# Flower 1.x/2.x 通用：EvaluateIns 从 flwr.common 导入最稳妥
from flwr.common import EvaluateIns

# ================= 工具函数 =================

def _to_numpy(x):
    return np.asarray(x, dtype=float)

def encode_labels(y: List[Any]):
    """将原始字符串标签编码为 0..K-1"""
    uniq = sorted(list({str(v) for v in y}))
    str2id = {s: i for i, s in enumerate(uniq)}
    y_id = np.array([str2id[str(v)] for v in y], dtype=int)
    return y_id, str2id, {i: s for s, i in str2id.items()}

def fuse_softmax(
    prob_list: List[np.ndarray],
    weights: List[float] | None = None,
    mode: str = "mean",
    eps: float = 1e-12,
) -> np.ndarray:
    """
    对 M 个模态的类别概率做融合，返回融合后的单个概率向量 [C]
    prob_list: [np.ndarray(M条，每条形如[C]))]
    """
    P = np.vstack(prob_list).astype(float)  # [M, C]
    M, C = P.shape
    P = np.clip(P, eps, 1.0)
    P = P / P.sum(axis=1, keepdims=True)

    if mode == "mean":
        out = P.mean(axis=0)

    elif mode == "weighted":
        w = np.ones(M) if weights is None else np.asarray(weights, dtype=float)
        w = np.clip(w, eps, None)
        out = (P * w[:, None]).sum(axis=0) / w.sum()

    elif mode == "poe":  # 几何平均 / Log Opinion Pool
        w = np.ones(M) if weights is None else np.asarray(weights, dtype=float)
        logp = np.log(P) * w[:, None]
        out = np.exp(logp.sum(axis=0)); out /= out.sum()

    elif mode == "logits":  # logits 平均（等价对 log 概率线性加权再 softmax）
        w = np.ones(M) if weights is None else np.asarray(weights, dtype=float)
        z = (np.log(P) * w[:, None]).sum(axis=0)
        z = z - z.max()
        out = np.exp(z); out /= out.sum()

    elif mode == "entropy":  # 熵权（不确定性越低权重越大）
        H = -(P * np.log(P)).sum(axis=1)
        Hn = H / np.log(C)
        w = 1.0 - Hn
        if weights is not None:
            w = w * np.asarray(weights, float)
        w = np.clip(w, eps, None)
        out = (P * w[:, None]).sum(axis=0) / w.sum()

    elif mode == "maxconf":  # 选择最自信模态
        m = int(np.argmax(P.max(axis=1)))
        out = P[m]

    elif mode == "borda":    # 排序打分
        scores = np.zeros(C, dtype=float)
        for m in range(M):
            order = np.argsort(-P[m])
            for rank, cls in enumerate(order):
                scores[cls] += (C - rank)
        out = scores / scores.sum()

    else:
        out = P.mean(axis=0)

    out = np.clip(out, eps, 1.0)
    out = out / out.sum()
    return out

def fuse_softmax_by_modalities(
    modal_dict: Dict[str, np.ndarray],
    required_modalities: List[str],
    modality_weights: Dict[str, float] | None,
    mode: str = "mean",
) -> np.ndarray:
    """
    将 {modality: probs[C]} 字典按 required_modalities 指定的顺序做融合。
    """
    probs_list, weights = [], []
    for m in required_modalities:
        if m not in modal_dict:
            raise ValueError(f"Missing modality '{m}' for fusion")
        probs_list.append(modal_dict[m])
        if modality_weights is not None and m in modality_weights:
            weights.append(float(modality_weights[m]))
    if mode == "weighted" or (mode in {"entropy", "logits", "poe"} and weights):
        return fuse_softmax(probs_list, weights=weights if weights else None, mode=mode)
    else:
        return fuse_softmax(probs_list, weights=None, mode=mode)

# ================= 全局/配置 =================

# 必需模态（严格同步）
REQUIRED_MODALITIES = ["DNA", "RNA", "WSI"]   # 如只要两模态，改为 ["DNA", "RNA"]
EXPECTED_CLIENTS = len(REQUIRED_MODALITIES)

# ================= Flower 策略（含 Accuracy 网格搜索） =================

class PredictAndFuseStrategy(fl.server.strategy.FedAvg):
    """
    - 等待 REQUIRED_MODALITIES 个客户端全部在线
    - 让客户端以 Evaluate 阶段返回每个 patient 的预测概率 JSON
    - 聚合端：
        * 收集三模态概率
        * 对“模态齐全”的样本集合做权重网格搜索（目标：Accuracy）
        * 用最优权重进行最终评估与导出
    """
    def __init__(
        self,
        patient_ids: List[str],
        label_ids: np.ndarray,
        id2label: Dict[int, str],
        fusion: str = "weighted",
        modality_weights: Dict[str, float] | None = None,
        enable_weight_search: bool = True,
        search_step: float = 0.1,       # 网格步长（0.1 对应 0,0.1,...,1.0）
        search_objective: str = "accuracy",  # 也可切 "f1_weighted"/"f1_macro"
    ):
        super().__init__(
            fraction_fit=0.0, min_fit_clients=0,
            fraction_evaluate=1.0,
            min_evaluate_clients=EXPECTED_CLIENTS,
            min_available_clients=EXPECTED_CLIENTS,
        )
        self.patient_ids = patient_ids
        self.labels = label_ids
        self.id2label = id2label
        self.fusion = fusion
        self.modality_weights = modality_weights
        self.buffer: Dict[str, Dict[str, np.ndarray]] = defaultdict(dict)

        self.enable_weight_search = enable_weight_search
        self.search_step = float(search_step)
        assert search_objective in ("accuracy", "f1_weighted", "f1_macro")
        self.search_objective = search_objective

        self.best_weights: Dict[str, float] | None = None
        self.best_score: float | None = None

    # Flower callbacks --------
    def evaluate(self, server_round: int, parameters):
        # 不做全局参数评估
        return None

    def configure_evaluate(self, server_round: int, parameters, client_manager):
        # 阻塞等待必需的客户端全部在线
        target = EXPECTED_CLIENTS
        while True:
            online = client_manager.all()
            if len(online) >= target:
                break
            print(f"[Round {server_round}] waiting clients {len(online)}/{target} ...", end="\r")
            time.sleep(0.5)
        print(f"\n[Round {server_round}] READY: online clients = {list(online.keys())}")

        # 下发“预测”任务（客户端需在 evaluate() 中返回 preds_json）
        cfg = {"task": "predict", "round": server_round}
        evaluate_ins = EvaluateIns(parameters, cfg)
        return [(c, evaluate_ins) for c in online.values()]

    def aggregate_evaluate(self, server_round: int, results, failures):
        # 收集各客户端的预测概率
        seen_modalities = set()
        for client_proxy, evaluate_res in results:
            metrics = evaluate_res.metrics or {}
            blob = metrics.get("preds_json", b"")
            js = blob.decode("utf-8") if isinstance(blob, bytes) else blob
            if not js:
                continue
            for r in json.loads(js):
                pid = str(r["patient_id"])
                probs = np.asarray(r["probs"], dtype=float)
                modality = str(r.get("modality", "unknown"))
                self.buffer[pid][modality] = probs
                seen_modalities.add(modality)

        # 找到“模态齐全”的样本索引
        idx = self._complete_indices()
        if not idx:
            print(f"[Round {server_round}] no patient has all modalities yet; received={sorted(seen_modalities)}")
            return 0.0, {"n_complete": 0}

        # 权重搜索（仅当 fusion='weighted' 且开启开关）
        if self.fusion == "weighted" and self.enable_weight_search:
            best_w = self._search_best_weights(idx)
            if best_w is not None:
                self.modality_weights = best_w  # 固化最优权重

        # 在当前策略/权重下做评估（返回多种指标供参考）
        yt = self.labels[idx]
        yp = self._predict_indices(idx)

        acc = float(accuracy_score(yt, yp))
        mf1 = float(f1_score(yt, yp, average="macro"))
        wf1 = float(f1_score(yt, yp, average="weighted"))

        print(f"[Round {server_round}] acc_full={acc:.4f}, macro_f1_full={mf1:.4f}, weighted_f1_full={wf1:.4f}, "
              f"n_complete={len(idx)}/{len(self.labels)} (need {REQUIRED_MODALITIES}, fusion={self.fusion})")

        out_metrics = {
            "acc_complete": acc,
            "macro_f1_complete": mf1,
            "weighted_f1_complete": wf1,
            "n_complete": len(idx),
            "fusion": self.fusion,
        }
        if self.best_weights is not None:
            out_metrics.update({
                "best_weights_rna": float(self.best_weights["RNA"]),
                "best_weights_dna": float(self.best_weights["DNA"]),
                "best_weights_wsi": float(self.best_weights["WSI"]),
                f"best_{self.search_objective}": float(self.best_score if self.best_score is not None else -1.0),
            })
        return 0.0, out_metrics

    # ------- 辅助函数 -------
    def _complete_indices(self) -> List[int]:
        req = set(REQUIRED_MODALITIES)
        idx = []
        for i, pid in enumerate(self.patient_ids):
            modal = self.buffer.get(pid, {})
            if req.issubset(set(modal.keys())):
                idx.append(i)
        return idx

    def _fuse_pred_with_weights(self, pid: str, weights_dict: Dict[str, float]) -> np.ndarray:
        return fuse_softmax_by_modalities(
            self.buffer[pid],
            REQUIRED_MODALITIES,
            modality_weights=weights_dict,
            mode="weighted",
        )

    def _predict_indices(self, idx: List[int]) -> np.ndarray:
        """根据当前 fusion & modality_weights，对给定索引样本输出类别预测"""
        y_pred = []
        for i in idx:
            pid = self.patient_ids[i]
            if self.fusion == "weighted":
                P = fuse_softmax_by_modalities(
                    self.buffer[pid], REQUIRED_MODALITIES, self.modality_weights, mode="weighted"
                )
            else:
                P = fuse_softmax_by_modalities(
                    self.buffer[pid], REQUIRED_MODALITIES, self.modality_weights, mode=self.fusion
                )
            y_pred.append(int(np.argmax(P)))
        return np.asarray(y_pred, dtype=int)

    def _search_best_weights(self, idx: List[int]) -> Dict[str, float] | None:
        """在模态齐全的样本上，按 Accuracy（或指定目标）做权重网格搜索。"""
        if not idx:
            return None

        weight_options = np.arange(0.001, 1.0 + 1e-9, self.search_step)
        best_score = -1.0
        best_w = {"RNA": 1/3, "DNA": 1/3, "WSI": 1/3}
        yt = self.labels[idx]

        # 穷举 (w_rna, w_dna)，w_wsi = 1 - w_rna - w_dna
        for w_rna in weight_options:
            for w_dna in weight_options:
                if w_rna + w_dna > 1.0:
                    continue
                w_wsi = 1.0 - w_rna - w_dna
                weights_dict = {"RNA": float(w_rna), "DNA": float(w_dna), "WSI": float(w_wsi)}

                yp = []
                for i in idx:
                    pid = self.patient_ids[i]
                    P = self._fuse_pred_with_weights(pid, weights_dict)
                    yp.append(int(np.argmax(P)))
                yp = np.asarray(yp, dtype=int)

                # 目标指标
                if self.search_objective == "accuracy":
                    score = accuracy_score(yt, yp)
                elif self.search_objective == "f1_weighted":
                    score = f1_score(yt, yp, average="weighted")
                else:
                    score = f1_score(yt, yp, average="macro")

                if score > best_score:
                    best_score = float(score)
                    best_w = weights_dict

        self.best_weights = best_w
        self.best_score = best_score
        print(
            f"[Search] Best Weights (RNA, DNA, WSI) = "
            f"({best_w['RNA']:.2f}, {best_w['DNA']:.2f}, {best_w['WSI']:.2f}); "
            f"best {self.search_objective} = {best_score:.4f}"
        )
        return best_w

    def export_final(self, out_csv: str):
        """导出每个病人的各模态/融合概率与预测"""
        req = set(REQUIRED_MODALITIES)
        rows = []
        for pid in self.patient_ids:
            modal = self.buffer.get(pid, {})
            row = {"patient_id": pid}
            for m, p in modal.items():
                row[f"probs_{m}"] = json.dumps(p.tolist(), ensure_ascii=False)
                row[f"pred_{m}"] = int(np.argmax(p))
            if req.issubset(set(modal.keys())):
                if self.fusion == "weighted":
                    P = fuse_softmax_by_modalities(
                        modal, REQUIRED_MODALITIES, self.modality_weights, mode="weighted"
                    )
                else:
                    P = fuse_softmax_by_modalities(
                        modal, REQUIRED_MODALITIES, self.modality_weights, mode=self.fusion
                    )
                row["probs_fused"] = json.dumps(P.tolist(), ensure_ascii=False)
                row["pred_fused"] = int(np.argmax(P))
            rows.append(row)
        df = pd.DataFrame(rows)
        df.to_csv(out_csv, index=False, encoding="utf-8-sig")
        return df

In [2]:
if __name__ == "__main__":
    # 测试集元数据（必须包含 patient_id 与 label）
    test_csv = "test_metadata_THENEWEST - 28.csv"
    export_csv = "predictions_fused.csv"

    # 读取并编码标签
    meta = pd.read_csv(test_csv)
    assert "patient_id" in meta.columns and "label" in meta.columns
    pids = meta["patient_id"].astype(str).tolist()
    y_id, str2id, id2str = encode_labels(meta["label"].tolist())

    # 构建策略（按 Accuracy 搜索）
    strategy = PredictAndFuseStrategy(
        patient_ids=pids,
        label_ids=y_id,
        id2label=id2str,
        fusion="weighted",            # 只在 "weighted" 下启用权重
        modality_weights=None,        # 不预设，由搜索得到
        enable_weight_search=True,    # 打开搜索
        search_step=0.02,              # 0.1 为粗搜；也可二阶段 0.02 微调
        search_objective="accuracy",  # 以 Accuracy 为目标
    )

    # 启动 Flower 服务器（等待 3 个客户端）
    fl.server.start_server(
        server_address="0.0.0.0:8080",
        config=fl.server.ServerConfig(num_rounds=1),
        strategy=strategy,
    )

    # 导出融合结果（包含 probs_* 与 pred_* 以及融合结果）
    final_df = strategy.export_final(export_csv)
    print("Final predictions saved to:", export_csv)
    print("Best weights used:", strategy.modality_weights)
    print(final_df.head())


INFO flwr 2025-10-21 14:52:58,267 | app.py:163 | Starting Flower server, config: ServerConfig(num_rounds=1, round_timeout=None)
INFO flwr 2025-10-21 14:52:58,304 | app.py:176 | Flower ECE: gRPC server running (1 rounds), SSL is disabled
INFO flwr 2025-10-21 14:52:58,305 | server.py:89 | Initializing global parameters
INFO flwr 2025-10-21 14:52:58,305 | server.py:276 | Requesting initial parameters from one random client
INFO flwr 2025-10-21 14:53:11,575 | server.py:280 | Received initial parameters from one random client
INFO flwr 2025-10-21 14:53:11,575 | server.py:91 | Evaluating initial parameters
INFO flwr 2025-10-21 14:53:11,575 | server.py:104 | FL starting
INFO flwr 2025-10-21 14:53:20,355 | server.py:220 | fit_round 1: no clients selected, cancel
DEBUG flwr 2025-10-21 14:53:20,355 | server.py:173 | evaluate_round 1: strategy sampled 3 clients (out of 3)
DEBUG flwr 2025-10-21 14:53:20,443 | server.py:187 | evaluate_round 1 received 3 results and 0 failures



[Round 1] READY: online clients = ['89b0e0a4a6bb479981b6d183fcfeebc5', 'a65f60003f9543dda97517913f6657b6', 'a981a2edbab1430d98252af0ac58bf6e']


INFO flwr 2025-10-21 14:53:22,370 | server.py:153 | FL finished in 10.79545440000038
INFO flwr 2025-10-21 14:53:22,370 | app.py:226 | app_fit: losses_distributed [(1, 0.0)]
INFO flwr 2025-10-21 14:53:22,370 | app.py:227 | app_fit: metrics_distributed_fit {}
INFO flwr 2025-10-21 14:53:22,376 | app.py:228 | app_fit: metrics_distributed {'acc_complete': [(1, 0.7142857142857143)], 'macro_f1_complete': [(1, 0.5268817204301075)], 'weighted_f1_complete': [(1, 0.6966205837173579)], 'n_complete': [(1, 28)], 'fusion': [(1, 'weighted')], 'best_weights_rna': [(1, 0.361)], 'best_weights_dna': [(1, 0.161)], 'best_weights_wsi': [(1, 0.478)], 'best_accuracy': [(1, 0.7142857142857143)]}
INFO flwr 2025-10-21 14:53:22,376 | app.py:229 | app_fit: losses_centralized []
INFO flwr 2025-10-21 14:53:22,376 | app.py:230 | app_fit: metrics_centralized {}


[Search] Best Weights (RNA, DNA, WSI) = (0.36, 0.16, 0.48); best accuracy = 0.7143
[Round 1] acc_full=0.7143, macro_f1_full=0.5269, weighted_f1_full=0.6966, n_complete=28/28 (need ['DNA', 'RNA', 'WSI'], fusion=weighted)
Final predictions saved to: predictions_fused.csv
Best weights used: {'RNA': 0.361, 'DNA': 0.161, 'WSI': 0.478}
     patient_id                                          probs_WSI  pred_WSI  \
0  TCGA-A2-A3XZ  [0.4862112637150032, 0.3594942048193023, 0.118...         0   
1  TCGA-A7-A426  [0.16658036898221076, 0.6485370507533245, 0.13...         1   
2  TCGA-A2-A04N  [0.3613172360475217, 0.2906667768920461, 0.160...         0   
3  TCGA-A8-A09X  [0.08855836624591444, 0.11588980128127209, 0.2...         3   
4  TCGA-PL-A8LX  [0.22262549905644102, 0.17420840362947845, 0.2...         3   

                                           probs_DNA  pred_DNA  \
0  [0.0008491443586535752, 0.19199661910533905, 0...         2   
1  [0.033237095922231674, 0.9648754000663757, 0.0...   