In [3]:
# !pip install scanpy
import os
os.environ["SCIPY_ARRAY_API"] = "1"

import gdown
import numpy as np
import pandas as pd
import anndata as ad
from sklearn.neural_network import MLPClassifier
from sklearn.feature_selection import SelectKBest, f_classif
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from imblearn.over_sampling import SMOTE
from scipy.sparse import issparse
import matplotlib.pyplot as plt
import seaborn as sns
import random
import torch
import torch.nn as nn
import lightgbm as lgb
import joblib
from sklearn.ensemble import RandomForestClassifier


# Config
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

os.environ['PYTHONHASHSEED'] = str(SEED)

# file_id = "110eYMgseyD32YIS9xOMbOpJ76wnDXahR"
# gdown.download(f"https://drive.google.com/uc?id={file_id}", output="TCGA_BRCA_RNA_with_TinX.h5ad", quiet=False)


Using device: cpu


## Split Dataset

In [None]:
import scanpy as sc
import pandas as pd

# === 路径设置 ===
adata_path = "/Users/xin/Desktop/DATA5703/TCGA-DNA-RNA-IMAGE-stage-classifier/RNA/TCGA_BRCA_RNA_with_TinX.h5ad"
# adata_path = "/Users/xin/Desktop/DATA5703/5703TCGA_BRCA/RNA/TCGA_BRCA_RNA_with_TinX.h5ad"
test_csv_path = "/Users/xin/Desktop/DATA5703/TCGA-DNA-RNA-IMAGE-stage-classifier/WSI/test_metadata.csv"
output_train_path = "/Users/xin/Desktop/DATA5703/TCGA-DNA-RNA-IMAGE-stage-classifier/RNA/RNA_train.h5ad"
output_test_path = "/Users/xin/Desktop/DATA5703/TCGA-DNA-RNA-IMAGE-stage-classifier/RNA/RNA_test.h5ad"

# === 加载数据 ===
adata = sc.read_h5ad(adata_path)
test_df = pd.read_csv(test_csv_path)
test_patients = set(test_df["patient_id"].astype(str).tolist())
adata.obs["patient_id"] = adata.obs["patient_id"].astype(str)

# === 分离数据 ===
is_test = adata.obs["patient_id"].isin(test_patients)
adata_test = adata[is_test].copy()
adata_train = adata[~is_test].copy()

# 去重（同一个 patient_id 只取一个）
adata_test = adata_test[adata_test.obs.groupby("patient_id").head(1).index]

# === 保存 ===
adata_train.write(output_train_path)
adata_test.write(output_test_path)

print("✅ 训练集和测试集已保存,训练集路径:", output_train_path, "测试集路径:", output_test_path)

✅ 训练集和测试集已保存,训练集路径: /Users/xin/Desktop/DATA5703/TCGA-DNA-RNA-IMAGE-stage-classifier/RNA/RNA_train.h5ad 测试集路径: /Users/xin/Desktop/DATA5703/TCGA-DNA-RNA-IMAGE-stage-classifier/RNA/RNA_test.h5ad


## Train

In [5]:
# === 路径设置 ===
train_h5ad_path = "/Users/xin/Desktop/DATA5703/TCGA-DNA-RNA-IMAGE-stage-classifier/RNA/RNA_train.h5ad"
model_path = "/Users/xin/Desktop/DATA5703/TCGA-DNA-RNA-IMAGE-stage-classifier/RNA/RNA_model.pkl"
selector_path = "/Users/xin/Desktop/DATA5703/TCGA-DNA-RNA-IMAGE-stage-classifier/RNA/RNA_selector.pkl"

# === 1. 加载训练集 ===
adata = ad.read_h5ad(train_h5ad_path)
X = adata.X.toarray() if issparse(adata.X) else adata.X
y_raw = adata.obs["stage"].values

label_map = {"Stage I": 0, "Stage II": 1, "Stage III": 2, "Stage IV": 3}
label_names = list(label_map.keys())
y = np.array([label_map.get(s, 3) for s in y_raw])

# === 2. 特征选择 ===
selector = SelectKBest(score_func=f_classif, k=500)
X_sel = selector.fit_transform(X, y)

# 保存 selector
joblib.dump(selector, selector_path)

# === 3. SMOTE 增强 ===
smote = SMOTE(random_state=42)
X_resampled, y_resampled = smote.fit_resample(X_sel, y)

# === 4. 训练 RandomForest ===
model = RandomForestClassifier(
    n_estimators=100,
    max_depth=10,
    min_samples_split=10,
    min_samples_leaf=5,
    class_weight="balanced",
    random_state=42
)

# model = lgb.LGBMClassifier(
#     n_estimators=300,
#     class_weight="balanced",
#     random_state=42
# )

model.fit(X_resampled, y_resampled)

# === 5. 训练集评估 ===
y_pred = model.predict(X_sel)
print("📘 Train Classification Report:\n", classification_report(y, y_pred, target_names=label_names))
print("📘 Train Confusion Matrix:\n", pd.DataFrame(confusion_matrix(y, y_pred), index=label_names, columns=label_names))

# === 6. 保存模型 ===
joblib.dump(model, model_path)
print(f"✅ 模型已保存到：{model_path}")

📘 Train Classification Report:
               precision    recall  f1-score   support

     Stage I       0.90      1.00      0.95       197
    Stage II       0.99      0.98      0.98       701
   Stage III       1.00      0.95      0.97       276
    Stage IV       1.00      1.00      1.00        46

    accuracy                           0.98      1220
   macro avg       0.97      0.98      0.98      1220
weighted avg       0.98      0.98      0.98      1220

📘 Train Confusion Matrix:
            Stage I  Stage II  Stage III  Stage IV
Stage I        197         0          0         0
Stage II        16       685          0         0
Stage III        6         8        262         0
Stage IV         0         0          0        46
✅ 模型已保存到：/Users/xin/Desktop/DATA5703/TCGA-DNA-RNA-IMAGE-stage-classifier/RNA/RNA_model.pkl


## Client

In [6]:
import numpy as np
import pandas as pd
import anndata as ad
import joblib
import json
from scipy.sparse import issparse
from sklearn.metrics import classification_report, confusion_matrix
import flwr as fl

# ===== 参数设置 =====
TEST_H5AD_PATH  = "/Users/xin/Desktop/DATA5703/TCGA-DNA-RNA-IMAGE-stage-classifier/RNA/RNA_test.h5ad"
SELECTOR_PATH   = "/Users/xin/Desktop/DATA5703/TCGA-DNA-RNA-IMAGE-stage-classifier/RNA//RNA_selector.pkl"
MODEL_PATH      = "/Users/xin/Desktop/DATA5703/TCGA-DNA-RNA-IMAGE-stage-classifier/RNA//RNA_model.pkl"
SERVER_ADDRESS  = "127.0.0.1:8080"
MODALITY        = "RNA"
WEIGHT          = 1.0

label_map = {"Stage I": 0, "Stage II": 1, "Stage III": 2, "Stage IV": 3}
label_names = list(label_map.keys())

class RNAClient(fl.client.NumPyClient):
    def __init__(self, test_h5ad_path, selector_path, model_path, modality, weight):
        self.modality = modality
        self.weight = weight
        self.rows = []
        self._load_and_predict(test_h5ad_path, selector_path, model_path)

    def _load_and_predict(self, h5ad_path, selector_path, model_path):
        # === 1. 加载数据 ===
        adata = ad.read_h5ad(h5ad_path)
        X = adata.X.toarray() if issparse(adata.X) else adata.X
        y_raw = adata.obs["stage"].values
        pids = adata.obs["patient_id"].astype(str).values

        y_true = np.array([label_map.get(s, 3) for s in y_raw])

        # === 2. 加载 selector 和模型 ===
        selector = joblib.load(selector_path)
        model = joblib.load(model_path)

        # === 3. 特征选择 + 预测 ===
        X_sel = selector.transform(X)
        y_pred = model.predict(X_sel)
        y_prob = model.predict_proba(X_sel)

        # === 4. 输出评估结果 ===
        print("📊 Classification Report:")
        print(classification_report(y_true, y_pred, target_names=label_names))

        print("📊 Confusion Matrix:")
        print(pd.DataFrame(confusion_matrix(y_true, y_pred), index=label_names, columns=label_names))

        # === 5. 封装为 JSON 结构 ===
        for i, prob in enumerate(y_prob):
            self.rows.append({
                "patient_id": pids[i],
                "probs": prob.tolist(),
                "modality": self.modality,
                "weight": self.weight
            })

        print(f"✅ 已生成 {len(self.rows)} 条预测")

    def get_parameters(self, config):
        return []

    def fit(self, parameters, config):
        return [], 0, {}

    def evaluate(self, parameters, config):
        task = config.get("task", "")
        metrics = {}
        if task == "predict":
            print(f"📤 RNA 客户端上传 {len(self.rows)} 条预测")
            metrics = {
                "preds_json": json.dumps(self.rows).encode("utf-8")
            }
        return 0.0, len(self.rows), metrics

# ===== 启动客户端 =====
client = RNAClient(TEST_H5AD_PATH, SELECTOR_PATH, MODEL_PATH, MODALITY, WEIGHT)
fl.client.start_numpy_client(server_address=SERVER_ADDRESS, client=client)

  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
	Instead, use `flwr.client.start_client()` by ensuring you first call the `.to_client()` method as shown below: 
	flwr.client.start_client(
		server_address='<IP>:<PORT>',
		client=FlowerClient().to_client(), # <-- where FlowerClient is of type flwr.client.NumPyClient object
	)
	Using `start_numpy_client()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
	Instead, use the `flower-supernode` CLI command to start a SuperNode as shown below:

		$ flower-supernode --insecure --superlink='<IP>:<PORT>'

	To view all available options, run:

		$ flower-supernode --help

	Using `start_client()` is deprecated.

            This is a deprecated feature. It will be removed
         

📊 Classification Report:
              precision    recall  f1-score   support

     Stage I       0.75      0.75      0.75         4
    Stage II       0.50      0.67      0.57         3
   Stage III       0.50      0.50      0.50         2
    Stage IV       0.00      0.00      0.00         1

    accuracy                           0.60        10
   macro avg       0.44      0.48      0.46        10
weighted avg       0.55      0.60      0.57        10

📊 Confusion Matrix:
           Stage I  Stage II  Stage III  Stage IV
Stage I          3         1          0         0
Stage II         1         2          0         0
Stage III        0         1          1         0
Stage IV         0         0          1         0
✅ 已生成 10 条预测
📤 RNA 客户端上传 10 条预测
