In [3]:
# !pip install scanpy
import os
os.environ["SCIPY_ARRAY_API"] = "1"

import gdown
import numpy as np
import pandas as pd
import anndata as ad
from sklearn.neural_network import MLPClassifier
from sklearn.feature_selection import SelectKBest, f_classif
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from imblearn.over_sampling import SMOTE
from scipy.sparse import issparse
import matplotlib.pyplot as plt
import seaborn as sns
import random
import torch
import torch.nn as nn
import lightgbm as lgb
import joblib
from sklearn.ensemble import RandomForestClassifier


# Config
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

os.environ['PYTHONHASHSEED'] = str(SEED)

# file_id = "110eYMgseyD32YIS9xOMbOpJ76wnDXahR"
# gdown.download(f"https://drive.google.com/uc?id={file_id}", output="TCGA_BRCA_RNA_with_TinX.h5ad", quiet=False)


Using device: cpu


## Split Dataset

In [None]:
import scanpy as sc
import pandas as pd

# === Paths ===
adata_path = "/Users/xin/Desktop/DATA5703/5703TCGA_BRCA/RNA/TCGA_BRCA_RNA_with_TinX.h5ad"
test_csv_path = "/Users/xin/Desktop/DATA5703/TCGA-DNA-RNA-IMAGE-stage-classifier/RNA/test_metadata_THENEWEST - 28.csv"
output_train_path = "/Users/xin/Desktop/DATA5703/TCGA-DNA-RNA-IMAGE-stage-classifier/RNA/RNA_train.h5ad"
output_test_path = "/Users/xin/Desktop/DATA5703/TCGA-DNA-RNA-IMAGE-stage-classifier/RNA/RNA_test.h5ad"

# === Label mappings ===
label_map = {
    "Stage I": 0,
    "Stage II": 1,
    "Stage III": 2,
    "Stage IV": 3,
}
stage_map = {
    "Stage1": "Stage I",
    "Stage2": "Stage II",
    "Stage3": "Stage III",
    "Stage4": "Stage IV",
}

# === Load .h5ad data ===
adata = sc.read_h5ad(adata_path)
adata.obs["patient_id"] = adata.obs["patient_id"].astype(str)

# === Load test_metadata.csv and fix label format ===
test_df = pd.read_csv(test_csv_path)
test_df["patient_id"] = test_df["patient_id"].astype(str)
test_df["label"] = test_df["label"].str.strip()
test_df["stage"] = test_df["label"].map(stage_map)  # Convert e.g. "Stage4" → "Stage IV"

# === 🔍 Check patient ID consistency ===
csv_patient_ids = set(test_df["patient_id"])
adata_patient_ids = set(adata.obs["patient_id"])
missing_in_adata = csv_patient_ids - adata_patient_ids
if missing_in_adata:
    print("The following patient_id(s) exist in test_metadata.csv but were not found in .h5ad:")
    print(missing_in_adata)
else:
    print("All patient_id(s) in test_metadata.csv are present in the .h5ad dataset.")

# === 1. Extract test set by patient ID ===
test_patients = set(test_df["patient_id"])
is_test = adata.obs["patient_id"].isin(test_patients)
adata_test = adata[is_test].copy()

# De-duplicate: keep only one sample per patient_id
adata_test = adata_test[adata_test.obs.groupby("patient_id").head(1).index]

# Assign correct stage labels from test_metadata.csv
patient_to_stage = dict(zip(test_df["patient_id"], test_df["stage"]))
adata_test.obs["stage"] = adata_test.obs["patient_id"].map(patient_to_stage)

# === 🔍 Check for unmapped test samples ===
unmapped = adata_test.obs[adata_test.obs["stage"].isna()]
if not unmapped.empty:
    print("The following patient_id(s) were found in .h5ad but failed to map a stage label:")
    print(unmapped["patient_id"].tolist())
else:
    print("All test samples successfully mapped to stage labels.")

# === 2. The rest are used as training set ===
adata_train = adata[~is_test].copy()

# === Save output files ===
adata_train.write(output_train_path)
adata_test.write(output_test_path)

# === Final summary ===
print("Training and test sets saved:")
print("Test samples:", adata_test.shape[0], "→", output_test_path)
print("Train samples:", adata_train.shape[0], "→", output_train_path)
print("Test label distribution:")
print(adata_test.obs["stage"].value_counts())

✅ 训练集和测试集已保存,训练集路径: /Users/xin/Desktop/DATA5703/TCGA-DNA-RNA-IMAGE-stage-classifier/RNA/RNA_train.h5ad 测试集路径: /Users/xin/Desktop/DATA5703/TCGA-DNA-RNA-IMAGE-stage-classifier/RNA/RNA_test.h5ad


## Train

In [None]:
import numpy as np
import pandas as pd
import anndata as ad
import joblib
from scipy.sparse import issparse
from sklearn.model_selection import train_test_split
from sklearn.feature_selection import SelectKBest, f_classif
from sklearn.metrics import classification_report, confusion_matrix
from imblearn.over_sampling import SMOTE
import lightgbm as lgb

# === 路径设置 ===
train_h5ad_path = "/Users/xin/Desktop/DATA5703/TCGA-DNA-RNA-IMAGE-stage-classifier/RNA/RNA_train.h5ad"
model_path = "/Users/xin/Desktop/DATA5703/TCGA-DNA-RNA-IMAGE-stage-classifier/RNA/RNA_model.pkl"
selector_path = "/Users/xin/Desktop/DATA5703/TCGA-DNA-RNA-IMAGE-stage-classifier/RNA/RNA_selector.pkl"

# === 1. 加载训练集 ===
adata = ad.read_h5ad(train_h5ad_path)
X = adata.X.toarray() if issparse(adata.X) else adata.X
y_raw = adata.obs["stage"].values

label_map = {"Stage I": 0, "Stage II": 1, "Stage III": 2, "Stage IV": 3}
label_names = list(label_map.keys())
y = np.array([label_map.get(s, 3) for s in y_raw])

# === 2. 拆分训练 / 验证集（保持类别分布）===
X_train, X_val, y_train, y_val = train_test_split(
    X, y, test_size=0.2, stratify=y, random_state=42
)

# === 3. 特征选择 ===
selector = SelectKBest(score_func=f_classif, k=500)
X_train_sel = selector.fit_transform(X_train, y_train)
X_val_sel = selector.transform(X_val)

# 保存 selector
joblib.dump(selector, selector_path)

# === 4. SMOTE 增强训练集 ===
smote = SMOTE(random_state=42)
X_resampled, y_resampled = smote.fit_resample(X_train_sel, y_train)

# === 5. 训练 LightGBM 模型 ===
model = lgb.LGBMClassifier(
    n_estimators=300,
    class_weight="balanced",
    random_state=42
)
# model = RandomForestClassifier(
#     n_estimators=100,
#     max_depth=10,
#     min_samples_split=10,
#     min_samples_leaf=5,
#     class_weight="balanced",
#     random_state=42
# )

model.fit(X_resampled, y_resampled)

# === 6. 验证集评估 ===
y_val_pred = model.predict(X_val_sel)
print("📘 Validation Classification Report:\n", classification_report(y_val, y_val_pred, target_names=label_names))
print("📘 Validation Confusion Matrix:\n", pd.DataFrame(confusion_matrix(y_val, y_val_pred), index=label_names, columns=label_names))

# === 7. 保存模型 ===
joblib.dump(model, model_path)
print(f"✅ 模型已保存到：{model_path}")

📘 Train Classification Report:
               precision    recall  f1-score   support

     Stage I       0.90      1.00      0.95       197
    Stage II       0.99      0.98      0.98       701
   Stage III       1.00      0.95      0.97       276
    Stage IV       1.00      1.00      1.00        46

    accuracy                           0.98      1220
   macro avg       0.97      0.98      0.98      1220
weighted avg       0.98      0.98      0.98      1220

📘 Train Confusion Matrix:
            Stage I  Stage II  Stage III  Stage IV
Stage I        197         0          0         0
Stage II        16       685          0         0
Stage III        6         8        262         0
Stage IV         0         0          0        46
✅ 模型已保存到：/Users/xin/Desktop/DATA5703/TCGA-DNA-RNA-IMAGE-stage-classifier/RNA/RNA_model.pkl


## Client

In [None]:
import numpy as np
import pandas as pd
import anndata as ad
import joblib
import json
from scipy.sparse import issparse
from sklearn.metrics import classification_report, confusion_matrix
import flwr as fl

# ===== Parameter Settings =====
TEST_H5AD_PATH  = "/Users/xin/Desktop/DATA5703/TCGA-DNA-RNA-IMAGE-stage-classifier/RNA/RNA_test.h5ad"
SELECTOR_PATH   = "/Users/xin/Desktop/DATA5703/TCGA-DNA-RNA-IMAGE-stage-classifier/RNA/RNA_selector.pkl"
MODEL_PATH      = "/Users/xin/Desktop/DATA5703/TCGA-DNA-RNA-IMAGE-stage-classifier/RNA/RNA_model.pkl"
# SERVER_ADDRESS  = "127.0.0.1:8080"
SERVER_ADDRESS  = "192.168.0.6:8080"

MODALITY        = "RNA"
WEIGHT          = 0.3

label_map = {"Stage I": 0, "Stage II": 1, "Stage III": 2, "Stage IV": 3}
label_names = list(label_map.keys())

class RNAClient(fl.client.NumPyClient):
    def __init__(self, test_h5ad_path, selector_path, model_path, modality, weight):
        self.modality = modality
        self.weight = weight
        self.rows = []
        self._load_and_predict(test_h5ad_path, selector_path, model_path)

    def _load_and_predict(self, h5ad_path, selector_path, model_path):
        # === 1. Load test data ===
        adata = ad.read_h5ad(h5ad_path)
        X = adata.X.toarray() if issparse(adata.X) else adata.X
        y_raw = adata.obs["stage"].values
        pids = adata.obs["patient_id"].astype(str).values

        y_true = np.array([label_map.get(s, 3) for s in y_raw])

        # === 2. Load selector and model ===
        selector = joblib.load(selector_path)
        model = joblib.load(model_path)

        # === 3. Feature selection + prediction ===
        X_sel = selector.transform(X)
        y_pred = model.predict(X_sel)
        y_prob = model.predict_proba(X_sel)

        # === 4. Output evaluation results ===
        print("Classification Report:")
        print(classification_report(y_true, y_pred, target_names=label_names))

        print("Confusion Matrix:")
        print(pd.DataFrame(confusion_matrix(y_true, y_pred), index=label_names, columns=label_names))

        # === 5. Pack results into JSON structure ===
        for i, prob in enumerate(y_prob):
            self.rows.append({
                "patient_id": pids[i],
                "probs": prob.tolist(),
                "modality": self.modality,
                "weight": self.weight
            })

        print(f"\n{len(self.rows)} predictions have been generated.")

    def get_parameters(self, config):
        return []

    def fit(self, parameters, config):
        return [], 0, {}

    def evaluate(self, parameters, config):
        task = config.get("task", "")
        metrics = {}
        if task == "predict":
            print(f"\nThe RNA client uploads {len(self.rows)} predictions.")
            metrics = {
                "preds_json": json.dumps(self.rows).encode("utf-8")
            }
        return 0.0, len(self.rows), metrics

# ===== Start the client =====
client = RNAClient(TEST_H5AD_PATH, SELECTOR_PATH, MODEL_PATH, MODALITY, WEIGHT)
fl.client.start_numpy_client(server_address=SERVER_ADDRESS, client=client)

  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
	Instead, use `flwr.client.start_client()` by ensuring you first call the `.to_client()` method as shown below: 
	flwr.client.start_client(
		server_address='<IP>:<PORT>',
		client=FlowerClient().to_client(), # <-- where FlowerClient is of type flwr.client.NumPyClient object
	)
	Using `start_numpy_client()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
	Instead, use the `flower-supernode` CLI command to start a SuperNode as shown below:

		$ flower-supernode --insecure --superlink='<IP>:<PORT>'

	To view all available options, run:

		$ flower-supernode --help

	Using `start_client()` is deprecated.

            This is a deprecated feature. It will be removed
         

📊 Classification Report:
              precision    recall  f1-score   support

     Stage I       0.75      0.75      0.75         4
    Stage II       0.50      0.67      0.57         3
   Stage III       0.50      0.50      0.50         2
    Stage IV       0.00      0.00      0.00         1

    accuracy                           0.60        10
   macro avg       0.44      0.48      0.46        10
weighted avg       0.55      0.60      0.57        10

📊 Confusion Matrix:
           Stage I  Stage II  Stage III  Stage IV
Stage I          3         1          0         0
Stage II         1         2          0         0
Stage III        0         1          1         0
Stage IV         0         0          1         0
✅ 已生成 10 条预测
📤 RNA 客户端上传 10 条预测
