EfficientNet

In [8]:
# ==============================
# Improved Federated Learning with EfficientNet & Mesh Topology
# Precision-Weighted Aggregation + Class Balancing
# ==============================

import os
import pandas as pd
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms as T
from PIL import Image
from efficientnet_pytorch import EfficientNet
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import numpy as np

# ======== Dataset Definition ==========
class OcularDataset(Dataset):
    def __init__(self, dataframe, image_dir, transform=None):
        self.df = dataframe.reset_index(drop=True)
        self.image_dir = image_dir
        self.transform = transform if transform else T.Compose([
            T.Resize((224, 224)),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406],
                        [0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.image_dir, row['filename'])
        img = Image.open(img_path).convert("RGB")
        labels = torch.tensor(row["labels"], dtype=torch.float32)
        return self.transform(img), labels

# ======== Model Definition ==========
def get_model():
    model = EfficientNet.from_pretrained("efficientnet-b0")
    model._fc = nn.Sequential(
        nn.Dropout(0.5),
        nn.Linear(model._fc.in_features, 8)
    )
    return model

# ======== Paths & Data Load ==========
base_path = "C:/Users/aitoo/OneDrive/Desktop/FYP-Code/ODIR-5K/ODIR-5K"
train_images_dir = os.path.join(base_path, "Training Images")
data_file = os.path.join(base_path, "data.xlsx")

df = pd.read_excel(data_file)
label_cols = ["N", "D", "G", "C", "A", "H", "M", "O"]

# Expand to Left & Right Fundus images
rows = []
for _, r in df.iterrows():
    rows.append({"filename": r["Left-Fundus"], "labels": r[label_cols].values.astype(int).tolist()})
    rows.append({"filename": r["Right-Fundus"], "labels": r[label_cols].values.astype(int).tolist()})
df_images = pd.DataFrame(rows)

# ======== Compute pos_weight ==========
label_array = np.array(df_images["labels"].tolist())
pos_freq = np.sum(label_array, axis=0)
neg_freq = len(label_array) - pos_freq
pos_weight_values = neg_freq / (pos_freq + 1e-6)

# ======== Split for Federated Clients ==========
total_len = len(df_images)
split_size = total_len // 5
splits = [split_size] * 5
splits[-1] += total_len - sum(splits)
clients_dfs = random_split(df_images, splits)

# ======== Config ==========
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EPOCHS = 20
ROUNDS = 30
lr = 3e-4
threshold = 0.5

# ======== Weighted Aggregation ==========
def aggregate_weights(weights_list, precision_scores):
    total = sum(precision_scores) + 1e-8
    avg_weights = []
    for weights in zip(*weights_list):
        avg = sum(p * w for w, p in zip(weights, precision_scores)) / total
        avg_weights.append(avg)
    return avg_weights

# ======== Mesh Topology ==========
neighbors = {i: [j for j in range(5) if j != i] for i in range(5)}

# ======== Initialize Clients ==========
clients = []
for idx in range(5):
    df_split = clients_dfs[idx].dataset.iloc[clients_dfs[idx].indices]
    dataset = OcularDataset(df_split, train_images_dir)
    split = int(0.8 * len(dataset))
    train_ds, val_ds = random_split(dataset, [split, len(dataset) - split])
    train_loader = DataLoader(train_ds, batch_size=16, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=16)
    clients.append({
        "model": get_model().to(device),
        "train_loader": train_loader,
        "val_loader": val_loader
    })

# ======== Training Loop ==========
logs = [pd.DataFrame(columns=['Round', 'Loss', 'Accuracy', 'Precision', 'Recall', 'F1']) for _ in range(5)]
pos_weight = torch.tensor(pos_weight_values, dtype=torch.float32).to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

for rnd in range(1, ROUNDS + 1):
    print(f"\n🔁 Round {rnd} ========================")
    precision_scores = []
    weights_list = []

    for idx, client in enumerate(clients):
        model = client['model']
        model.train()
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)

        for _ in range(EPOCHS):
            for images, labels in client['train_loader']:
                images, labels = images.to(device), labels.to(device)
                optimizer.zero_grad()
                loss = criterion(model(images), labels)
                loss.backward()
                optimizer.step()

        # Evaluation
        model.eval()
        y_true, y_pred = [], []
        total_loss = 0
        batch_count = 0
        with torch.no_grad():
            for images, labels in client['val_loader']:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss_val = criterion(outputs, labels).item()
                total_loss += loss_val
                batch_count += 1
                preds = torch.sigmoid(outputs).cpu().numpy()
                y_true.extend(labels.cpu().numpy())
                y_pred.extend((preds > threshold).astype(int))

        avg_loss = total_loss / batch_count if batch_count > 0 else 0
        acc = accuracy_score(y_true, y_pred)
        prec = precision_score(y_true, y_pred, average="samples", zero_division=0)
        rec = recall_score(y_true, y_pred, average="samples", zero_division=0)
        f1 = f1_score(y_true, y_pred, average="samples", zero_division=0)

        print(f"Client {idx+1}: Loss={avg_loss:.4f} | Acc={acc:.4f} | Prec={prec:.4f} | Rec={rec:.4f} | F1={f1:.4f}")

        logs[idx] = pd.concat([logs[idx], pd.DataFrame([[rnd, avg_loss, acc, prec, rec, f1]],
                                                       columns=logs[idx].columns)], ignore_index=True)
        logs[idx].to_csv(f"C:/Users/aitoo/OneDrive/Desktop/FYP-Code/first_raw/DFL_Client_{idx+1}.csv", index=False)

        # Store weights and precision
        precision_scores.append(prec if prec > 0 else 1e-4)
        weights_list.append([val.cpu().detach().numpy() for val in model.state_dict().values()])

    # ======== Aggregation ==========
    for i in range(5):
        neighbor_indices = neighbors[i] + [i]
        neighbor_weights = [weights_list[j] for j in neighbor_indices]
        neighbor_precisions = [precision_scores[j] for j in neighbor_indices]
        new_weights = aggregate_weights(neighbor_weights, neighbor_precisions)
        state_dict = clients[i]['model'].state_dict()
        new_state = {k: torch.tensor(v).to(device) for k, v in zip(state_dict.keys(), new_weights)}
        clients[i]['model'].load_state_dict(new_state)

print("\n✅ Training completed and logs saved.")


Loaded pretrained weights for efficientnet-b0
Loaded pretrained weights for efficientnet-b0
Loaded pretrained weights for efficientnet-b0
Loaded pretrained weights for efficientnet-b0
Loaded pretrained weights for efficientnet-b0

Client 1: Loss=2.2830 | Acc=0.2607 | Prec=0.4149 | Rec=0.4667 | F1=0.4224


  logs[idx] = pd.concat([logs[idx], pd.DataFrame([[rnd, avg_loss, acc, prec, rec, f1]],


Client 2: Loss=2.3929 | Acc=0.2429 | Prec=0.3976 | Rec=0.4750 | F1=0.4157


  logs[idx] = pd.concat([logs[idx], pd.DataFrame([[rnd, avg_loss, acc, prec, rec, f1]],


Client 3: Loss=2.2951 | Acc=0.2679 | Prec=0.4214 | Rec=0.4845 | F1=0.4324


  logs[idx] = pd.concat([logs[idx], pd.DataFrame([[rnd, avg_loss, acc, prec, rec, f1]],


Client 4: Loss=2.9714 | Acc=0.3250 | Prec=0.4625 | Rec=0.4863 | F1=0.4574


  logs[idx] = pd.concat([logs[idx], pd.DataFrame([[rnd, avg_loss, acc, prec, rec, f1]],


Client 5: Loss=3.9005 | Acc=0.2714 | Prec=0.4210 | Rec=0.4923 | F1=0.4344



  logs[idx] = pd.concat([logs[idx], pd.DataFrame([[rnd, avg_loss, acc, prec, rec, f1]],


Client 1: Loss=2.2484 | Acc=0.2536 | Prec=0.4143 | Rec=0.4744 | F1=0.4261
Client 2: Loss=1.4776 | Acc=0.3286 | Prec=0.4905 | Rec=0.5518 | F1=0.5005
Client 3: Loss=2.6031 | Acc=0.2179 | Prec=0.3781 | Rec=0.4583 | F1=0.3960
Client 4: Loss=2.1131 | Acc=0.3036 | Prec=0.4658 | Rec=0.5262 | F1=0.4745


KeyboardInterrupt: 

Swin Transformer

mobilenet_v3_large