In [23]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split, Dataset
import flwr as fl
from collections import OrderedDict
import numpy as np
import matplotlib.pyplot as plt
import os
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

# Device config
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# CONSTANTS (From Pseudocode)
NUM_CLIENTS = 3        # The 3 isolated Farmers
ROUNDS = 5             # Communication Rounds
BATCH_SIZE = 32
IMG_SIZE = 224
DATA_DIR = r"F:\WIDS-5.0\data\plantvillage dataset\color" 
# Helper for Data Splitting
class ApplyTransform(Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform
        
    def __getitem__(self, index):
        x, y = self.subset[index]
        if self.transform:
            x = self.transform(x)
        return x, y
        
    def __len__(self):
        return len(self.subset)

Using device: cuda


In [24]:
# Week 3 Transforms (Augmentation for Training, Clean for Validation)
train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomRotation(20),       # Augmentation
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load Raw Data
print("Loading dataset...")
raw_dataset = datasets.ImageFolder(DATA_DIR)
num_classes = len(raw_dataset.classes)
total_size = len(raw_dataset)

# Split into 3 "Silos"
split_size = total_size // NUM_CLIENTS
lengths = [split_size] * NUM_CLIENTS
lengths[-1] += (total_size % NUM_CLIENTS)

subsets = random_split(raw_dataset, lengths)

# Wrap with Transforms
train_datasets = [ApplyTransform(sub, transform=train_transform) for sub in subsets]
val_dataset = ApplyTransform(subsets[0], transform=val_transform) # Proxy validation set

# Create Loaders
train_loaders = [DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True) for ds in train_datasets]
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

print(f"Data ready. {NUM_CLIENTS} Clients prepared.")

Loading dataset...
Data ready. 3 Clients prepared.


In [25]:
# The Federated Model (ResNet18)
def get_model():
    net = models.resnet18(weights='DEFAULT')
    net.fc = nn.Linear(net.fc.in_features, num_classes)
    return net.to(device)

# Local Training Function
def train(net, trainloader, epochs=1):
    criterion = nn.CrossEntropyLoss()
    # Using Adam to match Week 3 performance
    optimizer = optim.Adam(net.parameters(), lr=1e-4)
    net.train()
    for _ in range(epochs):
        for images, labels in trainloader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            loss = criterion(net(images), labels)
            loss.backward()
            optimizer.step()

# Local Evaluation Function
def test(net, testloader):
    criterion = nn.CrossEntropyLoss()
    loss, correct, total = 0.0, 0, 0
    net.eval()
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return loss / len(testloader), correct / total

In [26]:
class PlantClient(fl.client.NumPyClient):
    def __init__(self, net, trainloader, valloader):
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        # Method get_parameters: Return weights
        return [val.cpu().numpy() for _, val in self.net.state_dict().items()]

    def set_parameters(self, parameters):
        # Helper to load weights
        params_dict = zip(self.net.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        self.net.load_state_dict(state_dict, strict=True)

    def fit(self, parameters, config):
        # Method fit: Download Global -> Train Local -> Upload
        self.set_parameters(parameters)
        train(self.net, self.trainloader, epochs=1)
        return self.get_parameters(config={}), len(self.trainloader.dataset), {}

    def evaluate(self, parameters, config):
        # Method evaluate: Check accuracy
        self.set_parameters(parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader.dataset), {"accuracy": float(accuracy)}

In [27]:
def client_fn(cid: str):
    idx = int(cid)
    net = get_model()
    return PlantClient(net, train_loaders[idx], val_loader).to_client()

print("ðŸš€ Starting Federated Simulation...")
# This runs the training. It will take ~15-20 mins.
history = fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=NUM_CLIENTS,
    config=fl.server.ServerConfig(num_rounds=ROUNDS),
    strategy=fl.server.strategy.FedAvg(),
)

	Instead, use the `flwr run` CLI command to start a local simulation in your Flower app, as shown for example below:

		$ flwr new  # Create a new Flower app from a template

		$ flwr run  # Run the Flower app in Simulation Mode

	Using `start_simulation()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
[92mINFO [0m:      Starting Flower simulation, config: num_rounds=5, no round_timeout


ðŸš€ Starting Federated Simulation...


2026-01-16 19:08:37,226	INFO worker.py:2012 -- Started a local Ray instance.
[92mINFO [0m:      Flower VCE: Ray initialized with resources: {'object_store_memory': 709764710.0, 'GPU': 1.0, 'memory': 1656117658.0, 'node:__internal_head__': 1.0, 'accelerator_type:G': 1.0, 'node:127.0.0.1': 1.0, 'CPU': 12.0}
[92mINFO [0m:      Optimize your simulation with Flower VCE: https://flower.ai/docs/framework/how-to-run-simulations.html
[92mINFO [0m:      No `client_resources` specified. Using minimal resources for clients.
[92mINFO [0m:      Flower VCE: Resources for each Virtual Client: {'num_cpus': 1, 'num_gpus': 0.0}
[92mINFO [0m:      Flower VCE: Creating VirtualClientEngineActorPool with 12 actors
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client
[36m(pid=34556)[0m 2026-01-16 19:08:41.595609: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-poin