In [None]:

# 🏥 Federated Learning Tutorial: Malaria Diagnosis (Women in AI Nigeria Workshop)
# Facilitator: Dr. Sakinat Folorunso

# 📦 Step 1: Install Libraries
!pip install -q flwr[simulation] pandas matplotlib scikit-learn tensorflow ipywidgets

import pandas as pd
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import flwr as fl
import ipywidgets as widgets
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from IPython.display import display
import itertools

# 📊 Step 2: Load and Inspect the Malaria Dataset
df = pd.read_csv('/mnt/data/mock_malaria_dataset.csv')
print(f"Total records: {len(df)}")
print("Clients (Clinics):", df.client_id.unique())
df.head()

# 📈 Step 3: Exploratory Data Analysis (EDA)
feature_cols = ['age','temperature','parasite_density','rbc_count','wbc_count','headache','fever']
label_col   = 'malaria_positive'

df.info()
df[feature_cols + [label_col]].describe().round(2)

print("Records per client (clinic):\n", df.client_id.value_counts(), "\n")
print("Overall malaria_positive distribution (%):")
print(df[label_col].value_counts(normalize=True).mul(100).round(1), "%\n")

for col in ['age','temperature','parasite_density']:
    plt.figure()
    plt.hist(df[col], bins=20)
    plt.title(f"Distribution of {col}")
    plt.xlabel(col)
    plt.ylabel("Count")
    plt.show()

# ⚙️ Step 4: Data Preprocessing - Split per Client
client_ids = df.client_id.unique().tolist()
client_train_data, client_test_data = {}, {}

for cid in client_ids:
    df_c = df[df.client_id == cid]
    X = df_c[feature_cols].values.astype(np.float32)
    y = df_c[label_col].values.astype(np.int32)
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
    client_train_data[cid] = (X_train, y_train)
    client_test_data[cid] = (X_test, y_test)

print("Clients prepared:", client_ids)

# 🏗️ Step 5: Build Malaria Diagnosis Model (Interactive Architecture)
num_layers = widgets.IntSlider(value=2, min=1, max=4, description='Hidden Layers:')
layer_units = widgets.Text(value='32,16', description='Units per Layer:')
activation_fn = widgets.Dropdown(options=['relu', 'tanh', 'sigmoid'], value='relu', description='Activation:')
display(num_layers, layer_units, activation_fn)

def build_malaria_model(input_shape):
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.InputLayer(input_shape=input_shape))
    units_list = [int(u.strip()) for u in layer_units.value.split(',') if u.strip().isdigit()]
    for i in range(min(num_layers.value, len(units_list))):
        model.add(tf.keras.layers.Dense(units_list[i], activation=activation_fn.value))
    model.add(tf.keras.layers.Dense(1, activation='sigmoid'))
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    return model

model = build_malaria_model(input_shape=(len(feature_cols),))
model.summary()

# 🤝 Step 6: Setup Federated Learning Client
class MalariaClient(fl.client.NumPyClient):
    def __init__(self, cid):
        self.client_label = client_ids[int(cid)]
        self.model = build_malaria_model(input_shape=(len(feature_cols),))
        self.X_train, self.y_train = client_train_data[self.client_label]
        self.X_test, self.y_test = client_test_data[self.client_label]

    def get_parameters(self):
        return self.model.get_weights()

    def fit(self, parameters, config):
        self.model.set_weights(parameters)
        self.model.fit(self.X_train, self.y_train, epochs=1, batch_size=16, verbose=0)
        return self.model.get_weights(), len(self.X_train), {}

    def evaluate(self, parameters, config):
        self.model.set_weights(parameters)
        loss, acc = self.model.evaluate(self.X_test, self.y_test, verbose=0)
        return loss, len(self.X_test), {"accuracy": acc}

# 🚀 Step 7: Run Federated Learning Simulation (Interactive Clients)
client_slider = widgets.IntSlider(value=3, min=1, max=10, description='Simulated Clients:')
display(client_slider)

num_simulated_clients = client_slider.value

strategy = fl.server.strategy.FedAvg()

history, final_weights = fl.simulation.start_simulation(
    client_fn=lambda cid: MalariaClient(cid),
    num_clients=num_simulated_clients,
    config=fl.server.ServerConfig(num_rounds=5),
    strategy=strategy,
    return_weights=True,
)

# 📊 Step 8: Plot Global Model Accuracy
rounds = list(range(1, len(history["test_accuracy"]) + 1))
plt.plot(rounds, history["test_accuracy"], marker="o")
plt.title("Global Test Accuracy per Round")
plt.xlabel("Round")
plt.ylabel("Accuracy")
plt.xticks(rounds)
plt.grid(True)
plt.show()

# 🧮 Step 9: Evaluate Final Model
global_model = build_malaria_model(input_shape=(len(feature_cols),))
global_model.set_weights(final_weights)

y_true, y_pred = [], []
for cid in client_ids:
    X_test, y_test = client_test_data[cid]
    preds = (global_model.predict(X_test) > 0.5).astype(int).flatten()
    y_true.extend(y_test.tolist())
    y_pred.extend(preds.tolist())

print("Classification Report:\n")
print(classification_report(np.array(y_true), np.array(y_pred), target_names=["Negative", "Positive"]))

cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(5,4))
plt.imshow(cm, interpolation="nearest")
plt.title("Confusion Matrix")
plt.colorbar()
classes = ["Negative", "Positive"]
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes)
plt.yticks(tick_marks, classes)
plt.ylabel("True Label")
plt.xlabel("Predicted Label")

thresh = cm.max() / 2
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
    plt.text(j, i, cm[i, j], ha="center", va="center", color="white" if cm[i, j] > thresh else "black")

plt.tight_layout()
plt.show()
