In [None]:
import numpy as np
import pandas as pd
import anndata as ad
import joblib
import json
from scipy.sparse import issparse
from sklearn.metrics import classification_report, confusion_matrix
import flwr as fl
from autogluon.tabular import TabularPredictor

#== Parameter Settings==
test_h5ad_path   = "../Train/RNA_test.h5ad"
scaler_path = "../Train/RNA_scaler.pkl"
selector_path = "../Train/RNA_selector_kbest.pkl"
autogluon_model_path = "../Train/autogluon_rna_model"

SERVER_ADDRESS   = "192.168.0.8:8080"
MODALITY         = "RNA"
WEIGHT           = 0.6

label_map = {"Stage I": 0, "Stage II": 1, "Stage III": 2, "Stage IV": 3}
label_names = list(label_map.keys())
int_to_stage = {v: k for k, v in label_map.items()}

class RNAClient(fl.client.NumPyClient):
    def __init__(self, test_h5ad_path, scaler_path, selector_path, model_path, modality, weight):
        self.modality = modality
        self.weight = weight
        self.rows = []
        self._load_and_predict(test_h5ad_path, scaler_path, selector_path, model_path)

    def _load_and_predict(self, h5ad_path, scaler_path, selector_path, model_path):
        # 1. Load test data
        adata = ad.read_h5ad(h5ad_path)
        X = adata.X.toarray() if issparse(adata.X) else adata.X
        y_raw = adata.obs["stage"].values
        pids = adata.obs["patient_id"].astype(str).values
        y_true = np.array([label_map.get(s, 3) for s in y_raw])

         # 2. Load scaler & selector
        scaler = joblib.load(scaler_path)
        selector = joblib.load(selector_path)

        X_scaled = scaler.transform(X)
        X_selected = selector.transform(X_scaled)

        # 3. Wrap as DataFrame for AutoGluon
        df = pd.DataFrame(X_selected)

        # 4. Load AutoGluon model and predict
        predictor = TabularPredictor.load(model_path)

        y_pred = predictor.predict(df)
        y_prob = predictor.predict_proba(df)

        # AutoGluon returns string labels ("0", "1", ...) → map to int
        y_pred_int = y_pred.astype(int).values

        print("Server_RNA_test Classification Report:")
        print(classification_report(y_true, y_pred_int, target_names=label_names, digits=4))
        print("Server_RNA_test Confusion Matrix:")
        print(pd.DataFrame(confusion_matrix(y_true, y_pred_int), index=label_names, columns=label_names))

        # 5. Format to JSON
        for i, probs in enumerate(y_prob.values):
            self.rows.append({
                "patient_id": pids[i],
                "probs": probs.tolist(),
                "modality": self.modality,
                "weight": self.weight
            })

        print(f"\n{len(self.rows)} predictions have been generated.")

    def get_parameters(self, config): return []
    def fit(self, parameters, config): return [], 0, {}
    def evaluate(self, parameters, config):
        task = config.get("task", "")
        metrics = {}
        if task == "predict":
            print(f"\n📤 RNA client uploads {len(self.rows)} predictions.")
            metrics = {
                "preds_json": json.dumps(self.rows).encode("utf-8")
            }
        return 0.0, len(self.rows), metrics

#== Start client==
client = RNAClient(test_h5ad_path, scaler_path, selector_path, autogluon_model_path, MODALITY, WEIGHT)

fl.client.start_numpy_client(server_address=SERVER_ADDRESS, client=client)

  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
	Instead, use `flwr.client.start_client()` by ensuring you first call the `.to_client()` method as shown below: 
	flwr.client.start_client(
		server_address='<IP>:<PORT>',
		client=FlowerClient().to_client(), # <-- where FlowerClient is of type flwr.client.NumPyClient object
	)
	Using `start_numpy_client()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
	Instead, use the `flower-supernode` CLI command to start a SuperNode as shown below:

		$ flower-supernode --insecure --superlink='<IP>:<PORT>'

	To view all available options, run:

		$ flower-supernode --help

	Using `start_client()` is deprecated.

            This is a deprecated feature. It will be removed
         

Server_RNA_test Classification Report:
              precision    recall  f1-score   support

     Stage I     0.5714    0.6667    0.6154         6
    Stage II     0.7143    0.7143    0.7143        14
   Stage III     0.7143    0.7143    0.7143         7
    Stage IV     0.0000    0.0000    0.0000         1

    accuracy                         0.6786        28
   macro avg     0.5000    0.5238    0.5110        28
weighted avg     0.6582    0.6786    0.6676        28

Server_RNA_test Confusion Matrix:
           Stage I  Stage II  Stage III  Stage IV
Stage I          4         2          0         0
Stage II         3        10          1         0
Stage III        0         2          5         0
Stage IV         0         0          1         0

28 predictions have been generated.
