In [2]:
# 只在当前内核缺这些包时执行一次；已有就略过
%pip install -q "flwr==1.7.0" "protobuf>=4.21,<5" "numpy==1.26.4" "pandas==2.2.2" "scikit-learn==1.3.2" "torch>=2.0,<3"


Note: you may need to restart the kernel to use updated packages.


In [1]:
import os, json, time
from collections import defaultdict
from typing import Dict, List, Any, Optional, Tuple

import numpy as np
import pandas as pd
import flwr as fl
import torch
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, f1_score

# ---------- 基础工具 ----------
def encode_labels(y: List[Any]):
    uniq = sorted(list({str(v) for v in y}))
    str2id = {s: i for i, s in enumerate(uniq)}
    y_id = np.array([str2id[str(v)] for v in y], dtype=int)
    return y_id, str2id, {i: s for s, i in str2id.items()}

# ---------- 单样本融合 ----------
def fuse_softmax(prob_list: List[np.ndarray], weights: Optional[List[float]]=None,
                 mode: str="mean", eps: float=1e-12) -> np.ndarray:
    P = np.vstack(prob_list).astype(float)  # [M,C]
    P = np.clip(P, eps, 1.0)
    P = P / P.sum(axis=1, keepdims=True)

    if mode == "mean":
        out = P.mean(axis=0)
    elif mode == "weighted":
        w = np.ones(P.shape[0]) if weights is None else np.asarray(weights, float)
        w = np.clip(w, eps, None)
        out = (P * w[:, None]).sum(axis=0) / w.sum()
    elif mode == "poe":  # 几何平均
        w = np.ones(P.shape[0]) if weights is None else np.asarray(weights, float)
        out = np.exp((np.log(P) * w[:, None]).sum(axis=0)); out /= out.sum()
    elif mode == "logits":  # logits 平均
        w = np.ones(P.shape[0]) if weights is None else np.asarray(weights, float)
        z = (np.log(P) * w[:, None]).sum(axis=0); z -= z.max()
        out = np.exp(z); out /= out.sum()
    elif mode == "entropy":  # 熵权
        C = P.shape[1]
        H = -(P * np.log(P)).sum(axis=1); Hn = H / np.log(C)
        w = 1.0 - Hn
        if weights is not None: w = w * np.asarray(weights, float)
        w = np.clip(w, eps, None)
        out = (P * w[:, None]).sum(axis=0) / w.sum()
    elif mode == "maxconf":
        out = P[np.argmax(P.max(axis=1))]
    elif mode == "borda":
        C = P.shape[1]
        scores = np.zeros(C)
        for m in range(P.shape[0]):
            order = np.argsort(-P[m])
            for rank, cls in enumerate(order):
                scores[cls] += (C - rank)
        out = scores / scores.sum()
    else:
        out = P.mean(axis=0)

    out = np.clip(out, eps, 1.0); out /= out.sum()
    return out

def fuse_softmax_classwise(prob_list: List[np.ndarray], W: np.ndarray, eps: float=1e-12) -> np.ndarray:
    """按类别权重 W[M,C] 融合"""
    P = np.vstack(prob_list).astype(float)
    P = np.clip(P, eps, 1.0); P /= P.sum(axis=1, keepdims=True)
    W = np.asarray(W, float); W = np.clip(W, eps, None)
    W = W / W.sum(axis=0, keepdims=True)
    out = (P * W).sum(axis=0)
    out = np.clip(out, eps, 1.0); out /= out.sum()
    return out

def fuse_softmax_by_modalities(modal_dict: Dict[str, np.ndarray],
                               required_modalities: List[str],
                               mode: str="mean",
                               weights_vec: Optional[np.ndarray]=None,   # [M]
                               weights_mc: Optional[np.ndarray]=None):  # [M,C]
    probs_list = [modal_dict[m] for m in required_modalities]
    if weights_mc is not None:
        return fuse_softmax_classwise(probs_list, weights_mc)
    if mode == "weighted" and weights_vec is not None:
        return fuse_softmax(probs_list, weights=weights_vec, mode="weighted")
    return fuse_softmax(probs_list, None, mode=mode)

# ---------- 监督式权重学习 ----------
def learn_global_weights_supervised(P: np.ndarray, y: np.ndarray,
                                    lr: float=0.1, steps: int=800, seed: int=42) -> Tuple[np.ndarray, np.ndarray]:
    """
    P: [N,M,C], y:[N] -> w:[M], fused:[N,C]
    """
    torch.manual_seed(seed)
    N, M, C = P.shape
    P_t = torch.tensor(P, dtype=torch.float32)
    y_t = torch.tensor(y, dtype=torch.long)
    alpha = torch.zeros(M, dtype=torch.float32, requires_grad=True)
    opt = torch.optim.Adam([alpha], lr=lr)
    best, best_alpha = 1e9, None
    for t in range(steps):
        w = torch.softmax(alpha, dim=0)
        fused = torch.einsum('m,nmc->nc', w, P_t).clamp(1e-12, 1.0)
        loss = F.nll_loss(torch.log(fused), y_t)
        opt.zero_grad(); loss.backward(); opt.step()
        if loss.item() < best:
            best, best_alpha = loss.item(), alpha.detach().clone()
    with torch.no_grad():
        w = torch.softmax(best_alpha, dim=0).cpu().numpy()
        fused = np.einsum('m,nmc->nc', w, P)
        fused = np.clip(fused, 1e-12, 1.0); fused /= fused.sum(axis=1, keepdims=True)
    return w, fused

def learn_classwise_weights_supervised(P: np.ndarray, y: np.ndarray,
                                       lr: float=0.1, steps: int=1000, seed: int=0) -> Tuple[np.ndarray, np.ndarray]:
    """
    P: [N,M,C], y:[N] -> W:[M,C], fused:[N,C]
    """
    torch.manual_seed(seed)
    N, M, C = P.shape
    P_t = torch.tensor(P, dtype=torch.float32)
    y_t = torch.tensor(y, dtype=torch.long)
    alpha = torch.zeros(M, C, dtype=torch.float32, requires_grad=True)
    opt = torch.optim.Adam([alpha], lr=lr)
    best, best_alpha = 1e9, None
    for t in range(steps):
        W = torch.softmax(alpha, dim=0)          # 列归一
        fused = torch.einsum('mc,nmc->nc', W, P_t).clamp(1e-12, 1.0)
        loss = F.nll_loss(torch.log(fused), y_t)
        opt.zero_grad(); loss.backward(); opt.step()
        if loss.item() < best:
            best, best_alpha = loss.item(), alpha.detach().clone()
    with torch.no_grad():
        W = torch.softmax(best_alpha, dim=0).cpu().numpy()
        fused = np.einsum('mc,nmc->nc', W, P)
        fused = np.clip(fused, 1e-12, 1.0); fused /= fused.sum(axis=1, keepdims=True)
    return W, fused


In [2]:
# 仅定义；不启动
REQUIRED_MODALITIES = ["DNA", "RNA", "WSI"]
EXPECTED_CLIENTS = len(REQUIRED_MODALITIES)

class PredictAndTrainFuseStrategy(fl.server.strategy.FedAvg):
    """
    - 收集各客户端 evaluate 返回的 preds_json（含 patient_id / probs / modality）
    - 对齐成 [N,M,C] 之后，基于真实标签监督学习融合权重（global / classwise）
    - 融合评估并保存结果
    """
    def __init__(self,
                 patient_ids: List[str],
                 labels: List[Any],
                 learn_mode: str = "global",      # "global" 或 "classwise"
                 fusion_base: str = "mean",       # 学权重失败时的兜底策略
                 learn_lr: float = 0.1,
                 learn_steps: int = 800,
                 out_dir: str = "server_outputs"):
        super().__init__(
            fraction_fit=0.0, min_fit_clients=0,
            fraction_evaluate=1.0,
            min_evaluate_clients=EXPECTED_CLIENTS,
            min_available_clients=EXPECTED_CLIENTS,
        )
        self.patient_ids = [str(x) for x in patient_ids]
        y_id, str2id, id2str = encode_labels(labels)
        self.labels = y_id
        self.id2label = id2str
        self.learn_mode = learn_mode
        self.fusion_base = fusion_base
        self.learn_lr = learn_lr
        self.learn_steps = learn_steps
        self.out_dir = out_dir
        os.makedirs(self.out_dir, exist_ok=True)

        self.buffer: Dict[str, Dict[str, np.ndarray]] = defaultdict(dict)
        self.num_classes: Optional[int] = None
        self.learned_w: Optional[np.ndarray] = None
        self.learned_W: Optional[np.ndarray] = None

    # 服务器不做“全局参数评估”
    def evaluate(self, server_round: int, parameters):
        return None

    # 下发 evaluate 指令来收集各客户端的 preds_json
    def configure_evaluate(self, server_round: int, parameters, client_manager):
        target = EXPECTED_CLIENTS
        while True:
            online = client_manager.all()
            if len(online) >= target:
                break
            print(f"[Round {server_round}] waiting clients {len(online)}/{target} ...", end="\r")
            time.sleep(0.5)
        print(f"\n[Round {server_round}] READY: online={list(online.keys())}")
        cfg = {"task": "predict", "round": server_round}
        evaluate_ins = fl.server.client_proxy.EvaluateIns(parameters, cfg)
        return [(c, evaluate_ins) for c in online.values()]

    def _assemble_PM(self) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], List[int]]:
        req = set(REQUIRED_MODALITIES); rows, keep = [], []
        for i, pid in enumerate(self.patient_ids):
            modal = self.buffer.get(pid, {})
            if req.issubset(set(modal.keys())):
                rows.append(np.vstack([modal[m] for m in REQUIRED_MODALITIES])[None, ...])
                keep.append(i)
        if not rows:
            return None, None, []
        P = np.concatenate(rows, axis=0)          # [N_keep, M, C]
        y = self.labels[np.array(keep, int)]      # [N_keep]
        return P, y, keep

    def aggregate_evaluate(self, server_round: int, results, failures):
        seen_modalities = set()
        # 收集各客户端上传的 softmax
        for client_proxy, evaluate_res in results:
            metrics = evaluate_res.metrics or {}
            blob = metrics.get("preds_json", b"")
            js = blob.decode("utf-8") if isinstance(blob, bytes) else blob
            if not js: 
                continue
            for r in json.loads(js):
                pid = str(r["patient_id"])
                probs = np.asarray(r["probs"], float)
                if self.num_classes is None:
                    self.num_classes = probs.shape[-1]
                modality = str(r.get("modality", "unknown"))
                self.buffer[pid][modality] = probs
                seen_modalities.add(modality)

        P, y, keep_idx = self._assemble_PM()
        if P is None:
            print(f"[Round {server_round}] no complete sample; received={sorted(seen_modalities)}")
            return 0.0, {"n_complete": 0}

        N_keep, M, C = P.shape
        print(f"[Round {server_round}] complete={N_keep}/{len(self.patient_ids)}  M={M}, C={C}")

        # —— 学习最优权重 ——
        if self.learn_mode == "global":
            w, fused = learn_global_weights_supervised(P, y, lr=self.learn_lr, steps=self.learn_steps)
            self.learned_w, self.learned_W = w, None
            np.save(os.path.join(self.out_dir, "weights_global.npy"), w)
            print(f"[Learn] w={np.round(w,4)} (sum={w.sum():.4f})")
        elif self.learn_mode == "classwise":
            W, fused = learn_classwise_weights_supervised(P, y, lr=self.learn_lr, steps=self.learn_steps)
            self.learned_W, self.learned_w = W, None
            np.save(os.path.join(self.out_dir, "weights_classwise.npy"), W)
            print(f"[Learn] W.shape={W.shape} (col-sum≈1)")
        else:
            # 不学习，仅基线融合（不推荐）
            fused = np.array([
                fuse_softmax_by_modalities(self.buffer[self.patient_ids[i]], REQUIRED_MODALITIES, mode=self.fusion_base)
                for i in keep_idx
            ])

        y_pred = fused.argmax(axis=1)
        acc = float(accuracy_score(y, y_pred))
        mf1 = float(f1_score(y, y_pred, average="macro"))
        np.save(os.path.join(self.out_dir, "fused_valid_probs.npy"), fused)
        np.save(os.path.join(self.out_dir, "y_valid.npy"), y)
        print(f"[Eval] acc={acc:.4f}  macro_f1={mf1:.4f}")

        return 0.0, {"n_complete": int(N_keep), "acc": acc, "macro_f1": mf1}

    # 导出 CSV（用学到的权重对所有齐全样本输出最终融合结果）
    def export_final(self, out_csv: str):
        req = set(REQUIRED_MODALITIES); rows = []
        for pid in self.patient_ids:
            modal = self.buffer.get(pid, {})
            row = {"patient_id": pid}
            for m, p in modal.items():
                row[f"pred_{m}"] = int(np.argmax(p))
                row[f"probs_{m}"] = json.dumps(p.tolist(), ensure_ascii=False)
            if req.issubset(set(modal.keys())):
                if self.learned_W is not None:
                    P = fuse_softmax_by_modalities(modal, REQUIRED_MODALITIES, mode="classwise", weights_mc=self.learned_W)
                elif self.learned_w is not None:
                    P = fuse_softmax_by_modalities(modal, REQUIRED_MODALITIES, mode="weighted", weights_vec=self.learned_w)
                else:
                    P = fuse_softmax_by_modalities(modal, REQUIRED_MODALITIES, mode=self.fusion_base)
                row["pred_fused"] = int(np.argmax(P))
                row["probs_fused"] = json.dumps(P.tolist(), ensure_ascii=False)
            rows.append(row)
        df = pd.DataFrame(rows)
        df.to_csv(out_csv, index=False, encoding="utf-8-sig")
        print(f"[Export] wrote {out_csv} ({len(df)} rows)")
        return df



In [3]:
import pandas as pd
from pathlib import Path

# 你的文件名
test_csv   = "test_metadata_THENEWEST - 28.csv"  # 输入
export_csv = "predictions_fused.csv"             # 输出（写到 server_outputs/ 下）

# 读取并做列名归一化（小写）
df = pd.read_csv(test_csv)
df.columns = df.columns.str.lower()

# 智能匹配列名：从常见候选里找到 patient_id / label 列
pid_candidates   = ["patient_id", "patientid", "pid", "id", "case_id", "slide_id"]
label_candidates = ["label", "stage", "y", "target", "class"]

def pick_col(cands, cols):
    for c in cands:
        if c in cols:
            return c
    raise ValueError(f"未在 {list(cols)} 中找到需要的列（候选：{cands}）")

pid_col   = pick_col(pid_candidates,   set(df.columns))
label_col = pick_col(label_candidates, set(df.columns))

# 准备列表
patient_ids = df[pid_col].astype(str).tolist()
labels      = df[label_col].tolist()

print(f"Loaded {len(patient_ids)} rows from {test_csv}")
print("patient_ids sample:", patient_ids[:5])
print("labels sample:", labels[:5])


Loaded 28 rows from test_metadata_THENEWEST - 28.csv
patient_ids sample: ['TCGA-A2-A3XZ', 'TCGA-A7-A426', 'TCGA-A2-A04N', 'TCGA-A8-A09X', 'TCGA-PL-A8LX']
labels sample: ['Stage1', 'Stage3', 'Stage1', 'Stage3', 'Stage4']


In [None]:
LEARN_MODE = "global"   # 或 "classwise"
PORT = "127.0.0.1:8090"   # ← 改这里，别用 8080

strategy = PredictAndTrainFuseStrategy(
    patient_ids=patient_ids,
    labels=labels,
    learn_mode=LEARN_MODE,
    fusion_base="mean",
    learn_lr=0.1,
    learn_steps=800 if LEARN_MODE=="global" else 1000,
    out_dir="server_outputs",
)

print(f"[Server] starting on {PORT} (learn_mode={LEARN_MODE})")
fl.server.start_server(
    server_address=PORT,
    config=fl.server.ServerConfig(num_rounds=1),
    strategy=strategy,
)

strategy.export_final(os.path.join("server_outputs", export_csv))


INFO flwr 2025-10-10 16:23:04,871 | app.py:163 | Starting Flower server, config: ServerConfig(num_rounds=1, round_timeout=None)
INFO flwr 2025-10-10 16:23:04,894 | app.py:176 | Flower ECE: gRPC server running (1 rounds), SSL is disabled
INFO flwr 2025-10-10 16:23:04,894 | server.py:89 | Initializing global parameters
INFO flwr 2025-10-10 16:23:04,895 | server.py:276 | Requesting initial parameters from one random client


[Server] starting on 127.0.0.1:8090 (learn_mode=global)
