In [None]:
# ===================================
# Federated Swin Transformer with Flower
# Low-memory version (single client)
# Per-client CSV logging (SAFE)
# ===================================

import os
import pandas as pd
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms as T
from PIL import Image
import timm
import ast
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import flwr as fl

# --- Paths ---
base_path = "C:/Users/aitoo/OneDrive/Desktop/FYP-Code/ODIR-5K/ODIR-5K"
train_images_dir = os.path.join(base_path, "Training Images")
data_file = os.path.join(base_path, "data.xlsx")

# --- Load labels ---
df = pd.read_excel(data_file)
label_cols = ["N", "D", "G", "C", "A", "H", "M", "O"]

rows = []
for _, r in df.iterrows():
    rows.append({"filename": r["Left-Fundus"], "labels": r[label_cols].values.astype(int).tolist()})
    rows.append({"filename": r["Right-Fundus"], "labels": r[label_cols].values.astype(int).tolist()})

df_images = pd.DataFrame(rows)
df_images.to_csv("C:/Users/aitoo/OneDrive/Desktop/FYP-Code/Swin/ocular_labels.csv", index=False)
df_images = pd.read_csv("C:/Users/aitoo/OneDrive/Desktop/FYP-Code/Swin/ocular_labels.csv")
df_images["labels"] = df_images["labels"].apply(ast.literal_eval)

# --- Dataset ---
class OcularDataset(Dataset):
    def __init__(self, dataframe, image_dir, transform=None):
        self.df = dataframe.reset_index(drop=True)
        self.image_dir = image_dir
        self.transform = transform or T.Compose([
            T.Resize((224, 224)),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.image_dir, row["filename"])
        img = Image.open(img_path).convert("RGB")
        labels = torch.tensor(row["labels"], dtype=torch.float32)
        return self.transform(img), labels

# --- Split indices into clients ---
indices = torch.randperm(len(df_images)).tolist()
split_size = len(indices) // 1
splits = [split_size]
splits[-1] += len(indices) - sum(splits)
split_indices = torch.utils.data.random_split(indices, splits)

# --- Model ---
def get_model():
    return timm.create_model(
        "swin_tiny_patch4_window7_224",
        pretrained=True,
        num_classes=8,
        global_pool="avg"
    )

device = torch.device("cpu")  # Force CPU to avoid OOM

# --- Flower Client ---
class FlowerClient(fl.client.NumPyClient):
    def __init__(self, model, train_loader, val_loader, client_id):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([3.0]*8).to(device))
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4)
        self.client_id = client_id

    def get_parameters(self, config):
        return [val.cpu().numpy() for _, val in self.model.state_dict().items()]

    def set_parameters(self, parameters):
        state_dict = self.model.state_dict()
        for k, v in zip(state_dict.keys(), parameters):
            state_dict[k] = torch.tensor(v).to(device)
        self.model.load_state_dict(state_dict)

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        self.model.train()
        for _ in range(1):
            for images, labels in self.train_loader:
                images, labels = images.to(device), labels.to(device)
                self.optimizer.zero_grad()
                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
                loss.backward()
                self.optimizer.step()
        return self.get_parameters({}), len(self.train_loader.dataset), {}

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        self.model.eval()
        y_true, y_pred = [], []
        total_loss = 0.0

        with torch.no_grad():
            for images, labels in self.val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = self.model(images)
                total_loss += self.criterion(outputs, labels).item()
                preds = torch.sigmoid(outputs).cpu().numpy()
                y_true.extend(labels.cpu().numpy())
                y_pred.extend((preds > 0.3).astype(int))

        acc = accuracy_score(y_true, y_pred)
        prec = precision_score(y_true, y_pred, average="samples", zero_division=0)
        rec = recall_score(y_true, y_pred, average="samples", zero_division=0)
        f1 = f1_score(y_true, y_pred, average="samples", zero_division=0)

        print(f"[Client {self.client_id}] Eval - Loss: {total_loss/len(self.val_loader):.4f} | Acc: {acc:.4f}")

        return (
            total_loss / len(self.val_loader),
            len(self.val_loader.dataset),
            {
                "accuracy": acc,
                "precision": prec,
                "recall": rec,
                "f1": f1,
                "client_id": self.client_id
            }
        )

# --- Client factory ---
def client_fn(cid: str):
    idx = int(cid)
    indices = split_indices[idx]
    client_df = df_images.iloc[indices.indices]
    dataset = OcularDataset(client_df, train_images_dir)
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_ds, val_ds = random_split(dataset, [train_size, val_size])

    train_loader = DataLoader(train_ds, batch_size=1, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=1)
    return FlowerClient(get_model(), train_loader, val_loader, cid).to_client()

# --- Safe Custom Strategy ---
class PerClientCSVStrategy(fl.server.strategy.FedAvg):
    def __init__(self, output_dir, **kwargs):
        super().__init__(**kwargs)
        self.output_dir = output_dir
        self.metrics_per_client = {}

    def aggregate_evaluate(self, rnd, results, failures):
        res = super().aggregate_evaluate(rnd, results, failures)
        if res is None:
            aggregated_loss = None
            aggregated_metrics = {}
        else:
            aggregated_loss, aggregated_metrics = res

        # Save metrics per client
        for client_res in results:
            metrics = client_res[2]
            if metrics is None:
                continue
            client_id = metrics.get("client_id", "unknown")
            metrics_row = {
                "round": rnd,
                "loss": client_res[1],
                "accuracy": metrics["accuracy"],
                "precision": metrics["precision"],
                "recall": metrics["recall"],
                "f1": metrics["f1"]
            }
            if client_id not in self.metrics_per_client:
                self.metrics_per_client[client_id] = []
            self.metrics_per_client[client_id].append(metrics_row)

            df = pd.DataFrame(self.metrics_per_client[client_id])
            os.makedirs(self.output_dir, exist_ok=True)
            csv_path = os.path.join(self.output_dir, f"FL_Client_{client_id}.csv")
            df.to_csv(csv_path, index=False)

        return aggregated_loss, aggregated_metrics

# --- Start Simulation ---
strategy = PerClientCSVStrategy(
    output_dir="C:/Users/aitoo/OneDrive/Desktop/FYP-Code/Swin/ClientMetrics",
    fraction_fit=1.0,
    fraction_evaluate=1.0,
    min_fit_clients=1,
    min_evaluate_clients=1,
    min_available_clients=1
)

fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=1,
    client_resources={"num_cpus": 1, "num_gpus": 0},  # CPU only
    strategy=strategy,
    config=fl.server.ServerConfig(num_rounds=3)
)


	Instead, use the `flwr run` CLI command to start a local simulation in your Flower app, as shown for example below:

		$ flwr new  # Create a new Flower app from a template

		$ flwr run  # Run the Flower app in Simulation Mode

	Using `start_simulation()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
[92mINFO [0m:      Starting Flower simulation, config: num_rounds=3, no round_timeout
2025-07-04 16:16:42,050	INFO worker.py:1771 -- Started a local Ray instance.
[92mINFO [0m:      Flower VCE: Ray initialized with resources: {'accelerator_type:G': 1.0, 'node:__internal_head__': 1.0, 'CPU': 24.0, 'object_store_memory': 14757959270.0, 'node:127.0.0.1': 1.0, 'memory': 29515918542.0, 'GPU': 1.0}
[92mINFO [0m:      Optimize your simulation with Flower VCE: https://flower.ai/docs/framework/how-to-run-simulations.html
[92mINFO [0m:      Flower VCE: Resources for each Virtual Client: {'num_cpus': 