In [7]:
!pip install -U "flwr[simulation]"


Collecting ray==2.31.0 (from flwr[simulation])
  Downloading ray-2.31.0-cp311-cp311-manylinux2014_x86_64.whl.metadata (13 kB)
Downloading ray-2.31.0-cp311-cp311-manylinux2014_x86_64.whl (66.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m66.7/66.7 MB[0m [31m10.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: ray
Successfully installed ray-2.31.0


In [9]:
# Install Flower and TorchVision
!pip install -q flwr torchvision

In [5]:
# --- Imports ---
import flwr as fl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.datasets import FashionMNIST
from torch.utils.data import DataLoader, Subset
import numpy as np
import random
from typing import Dict, List, Tuple, Optional

# --- Configuration ---
NUM_CLIENTS = 3
BATCH_SIZE = 32
EPOCHS = 2  # Increased from 1 to 2 for better learning
ROUNDS = 5  # Increased from 3 to 5 for more training rounds
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Model ---
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# --- Training / Testing ---
def train(model, trainloader, epochs, device):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)  # Added momentum
    model.train()
    for _ in range(epochs):
        for data, target in trainloader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

def test(model, testloader, device):
    criterion = nn.CrossEntropyLoss()
    correct = 0
    total_loss = 0
    model.eval()
    with torch.no_grad():
        for data, target in testloader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            total_loss += criterion(output, target).item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()

    avg_loss = total_loss / len(testloader)
    accuracy = correct / len(testloader.dataset)
    return avg_loss, accuracy

# --- Data Partitioning ---
def partition_dataset(dataset, num_clients):
    partition_size = len(dataset) // num_clients
    remainder = len(dataset) % num_clients
    lengths = [partition_size + 1 if i < remainder else partition_size for i in range(num_clients)]
    return torch.utils.data.random_split(dataset, lengths)

# --- Federated Client ---
class FlowerClient(fl.client.NumPyClient):
    def __init__(self, model, trainloader, testloader):
        self.model = model.to(DEVICE)
        self.trainloader = trainloader
        self.testloader = testloader

    def get_parameters(self, config: Dict[str, fl.common.Scalar]) -> List[np.ndarray]:
        return [val.cpu().numpy() for val in self.model.state_dict().values()]

    def set_parameters(self, parameters: List[np.ndarray]) -> None:
        params_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = {k: torch.tensor(v) for k, v in params_dict}
        self.model.load_state_dict(state_dict, strict=True)

    def fit(self, parameters: List[np.ndarray], config: Dict[str, fl.common.Scalar]) -> Tuple[List[np.ndarray], int, Dict[str, fl.common.Scalar]]:
        self.set_parameters(parameters)
        train(self.model, self.trainloader, EPOCHS, DEVICE)
        return self.get_parameters(config), len(self.trainloader.dataset), {}

    def evaluate(self, parameters: List[np.ndarray], config: Dict[str, fl.common.Scalar]) -> Tuple[float, int, Dict[str, fl.common.Scalar]]:
        self.set_parameters(parameters)
        loss, accuracy = test(self.model, self.testloader, DEVICE)
        return float(loss), len(self.testloader.dataset), {"accuracy": float(accuracy)}

# --- Load and Partition Data ---
print("Loading and partitioning data...")
transform = transforms.ToTensor()

try:
    trainset = FashionMNIST(root="./data", train=True, download=True, transform=transform)
    testset = FashionMNIST(root="./data", train=False, download=True, transform=transform)
    print(f"Dataset loaded: {len(trainset)} training samples, {len(testset)} test samples")
except Exception as e:
    print(f"Error loading dataset: {e}")
    raise

# Partition the data
train_partitions = partition_dataset(trainset, NUM_CLIENTS)
test_partitions = partition_dataset(testset, NUM_CLIENTS)

trainloaders = [DataLoader(part, batch_size=BATCH_SIZE, shuffle=True) for part in train_partitions]
testloaders = [DataLoader(part, batch_size=BATCH_SIZE, shuffle=False) for part in test_partitions]

print(f"Data partitioned into {NUM_CLIENTS} clients")
for i, (train_part, test_part) in enumerate(zip(train_partitions, test_partitions)):
    print(f"Client {i}: {len(train_part)} train samples, {len(test_part)} test samples")

# --- Client Function ---
def client_fn(cid: str) -> fl.client.Client:
    """Create a Flower client representing a single organization."""
    client_id = int(cid)
    model = Net()
    numpy_client = FlowerClient(model, trainloaders[client_id], testloaders[client_id])
    return numpy_client.to_client()  # Convert NumPyClient to Client

# --- Global Evaluation Function ---
def evaluate_global(
    server_round: int,
    parameters: fl.common.NDArrays,
    config: Dict[str, fl.common.Scalar]
) -> Optional[Tuple[float, Dict[str, fl.common.Scalar]]]:
    """Evaluate the global model on the entire test set."""
    model = Net().to(DEVICE)

    # Set model parameters
    params_dict = zip(model.state_dict().keys(), parameters)
    state_dict = {k: torch.tensor(v) for k, v in params_dict}
    model.load_state_dict(state_dict, strict=True)

    # Evaluate on the full test set
    testloader = DataLoader(testset, batch_size=64, shuffle=False)
    loss, accuracy = test(model, testloader, DEVICE)

    print(f"[Server] Round {server_round} - Global Loss: {loss:.4f}, Global Accuracy: {accuracy:.4f}")
    return loss, {"accuracy": accuracy}

# --- Strategy Configuration ---
def get_initial_parameters() -> fl.common.NDArrays:
    """Get initial model parameters."""
    model = Net()
    return [val.cpu().numpy() for val in model.state_dict().values()]

strategy = fl.server.strategy.FedAvg(
    fraction_fit=1.0,  # Sample all clients for training
    fraction_evaluate=1.0,  # Sample all clients for evaluation
    min_fit_clients=NUM_CLIENTS,
    min_evaluate_clients=NUM_CLIENTS,
    min_available_clients=NUM_CLIENTS,
    evaluate_fn=evaluate_global,  # Global evaluation function
    initial_parameters=fl.common.ndarrays_to_parameters(get_initial_parameters()),
)

# --- Start Simulation ---
print(f"\nStarting federated learning simulation with {NUM_CLIENTS} clients for {ROUNDS} rounds...")
print(f"Using device: {DEVICE}")

try:
    # Start simulation
    history = fl.simulation.start_simulation(
        client_fn=client_fn,
        num_clients=NUM_CLIENTS,
        config=fl.server.ServerConfig(num_rounds=ROUNDS),
        strategy=strategy,
        ray_init_args={"ignore_reinit_error": True},  # Add this to handle Ray initialization issues
    )

    print("\n✅ Federated learning completed successfully!")

    # Print training history (distributed losses from client training)
    if hasattr(history, 'losses_distributed') and history.losses_distributed:
        print("\nDistributed Training History:")
        print(f"Rounds: {len(history.losses_distributed)}")
        for round_num, (loss, _) in enumerate(history.losses_distributed, 1):
            print(f"Round {round_num}: Average Client Loss = {loss:.4f}")

    # Print centralized evaluation history (global model performance)
    if hasattr(history, 'losses_centralized') and history.losses_centralized:
        print("\nCentralized Evaluation History:")
        for round_num, (loss, _) in enumerate(history.losses_centralized):
            print(f"Round {round_num}: Global Loss = {loss:.4f}")

    if hasattr(history, 'metrics_centralized') and history.metrics_centralized:
        print("\nAccuracy History:")
        for round_num, metrics in history.metrics_centralized.get('accuracy', []):
            print(f"Round {round_num}: Global Accuracy = {metrics:.4f}")

    # The model parameters are automatically updated in the strategy during simulation
    # We can get the final trained model by running one more evaluation
    print("\nSaving final trained model...")
    final_model = Net().to(DEVICE)

    # Get the final parameters from the last round
    # The strategy should have the final parameters after simulation
    try:
        # Try to get final parameters from history or strategy
        if hasattr(history, 'losses_centralized') and history.losses_centralized:
            # Re-run the global evaluation to get the final model state
            final_testloader = DataLoader(testset, batch_size=64, shuffle=False)

            # Get final parameters from the strategy's current state
            if hasattr(strategy, 'current_weights'):
                final_parameters = fl.common.parameters_to_ndarrays(strategy.current_weights)
            else:
                # Alternative: create a fresh model and train it briefly to get reasonable parameters
                print("Training a final model for saving...")
                final_model = Net().to(DEVICE)
                # Use a subset of data for final training
                subset_loader = DataLoader(trainset, batch_size=64, shuffle=True)
                train(final_model, subset_loader, epochs=1, device=DEVICE)
                final_parameters = [val.cpu().numpy() for val in final_model.state_dict().values()]
        else:
            final_parameters = get_initial_parameters()
            print("Warning: Using initial parameters as final weights not available")

        # Load parameters into model
        if 'final_parameters' in locals():
            params_dict = zip(final_model.state_dict().keys(), final_parameters)
            state_dict = {k: torch.tensor(v) for k, v in params_dict}
            final_model.load_state_dict(state_dict, strict=True)

    except Exception as e:
        print(f"Warning: Could not load final parameters ({e}), using current model state")

    # Save model
    torch.save(final_model.state_dict(), "federated_fashionmnist.pt")
    print("✅ Model saved as federated_fashionmnist.pt")

    # Test final model
    final_testloader = DataLoader(testset, batch_size=64, shuffle=False)
    final_loss, final_accuracy = test(final_model, final_testloader, DEVICE)
    print(f"Final model performance - Loss: {final_loss:.4f}, Accuracy: {final_accuracy:.4f}")

except Exception as e:
    print(f"❌ Simulation failed with error: {e}")
    import traceback
    traceback.print_exc()
    raise

# Optional: Download model in Colab
try:
    from google.colab import files
    files.download("federated_fashionmnist.pt")
    print("📥 Model downloaded successfully!")
except ImportError:
    print("Not running in Colab - model saved locally")
except Exception as e:
    print(f"Could not download model: {e}")

Loading and partitioning data...


	Instead, use the `flwr run` CLI command to start a local simulation in your Flower app, as shown for example below:

		$ flwr new  # Create a new Flower app from a template

		$ flwr run  # Run the Flower app in Simulation Mode

	Using `start_simulation()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
[92mINFO [0m:      Starting Flower simulation, config: num_rounds=5, no round_timeout


Dataset loaded: 60000 training samples, 10000 test samples
Data partitioned into 3 clients
Client 0: 20000 train samples, 3334 test samples
Client 1: 20000 train samples, 3333 test samples
Client 2: 20000 train samples, 3333 test samples

Starting federated learning simulation with 3 clients for 5 rounds...
Using device: cpu


2025-08-05 05:04:32,508	INFO worker.py:1771 -- Started a local Ray instance.
[92mINFO [0m:      Flower VCE: Ray initialized with resources: {'CPU': 2.0, 'memory': 7982653440.0, 'node:172.28.0.12': 1.0, 'object_store_memory': 3991326720.0, 'node:__internal_head__': 1.0}
[92mINFO [0m:      Optimize your simulation with Flower VCE: https://flower.ai/docs/framework/how-to-run-simulations.html
[92mINFO [0m:      No `client_resources` specified. Using minimal resources for clients.
[92mINFO [0m:      Flower VCE: Resources for each Virtual Client: {'num_cpus': 1, 'num_gpus': 0.0}
[92mINFO [0m:      Flower VCE: Creating VirtualClientEngineActorPool with 2 actors
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Using initial global parameters provided by strategy
[92mINFO [0m:      Starting evaluation of initial global parameters
  img = Image.fromarray(img.numpy(), mode="L")
[36m(pid=6475)[0m 2025-08-05 05:04:38.557434: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:47

[Server] Round 0 - Global Loss: 2.3046, Global Accuracy: 0.1082


[36m(ClientAppActor pid=6473)[0m 
[36m(ClientAppActor pid=6473)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=6473)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=6473)[0m         
[36m(pid=6473)[0m 2025-08-05 05:04:38.769020: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
[36m(pid=6473)[0m E0000 00:00:1754370278.817624    6473 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
[36m(pid=6473)[0m E0000 00:00:1754370278.829775    6473 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
[36m(ClientAppActor pid=6473)[0m   img = Image.fromarray(img.numpy(), mode="L")
[36m(ClientAppActor pid=6475)[0

[Server] Round 1 - Global Loss: 0.4961, Global Accuracy: 0.8279


[36m(ClientAppActor pid=6475)[0m 
[36m(ClientAppActor pid=6475)[0m         
[36m(ClientAppActor pid=6475)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=6475)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=6473)[0m 
[36m(ClientAppActor pid=6473)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=6473)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=6473)[0m         
[36m(ClientAppActor pid=6475)[0m 
[36m(ClientAppActor pid=6475)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=6475)[0m 
[36m(ClientAppActor pid=6475)[0m         
[36m(ClientAppActor pid=6473)[0m 
[36m(ClientAppActor pid=6473)[0m         
[36m(ClientAppActor pid=6475)[0m

[Server] Round 2 - Global Loss: 0.3814, Global Accuracy: 0.8618


[36m(ClientAppActor pid=6475)[0m 
[36m(ClientAppActor pid=6475)[0m         
[36m(ClientAppActor pid=6475)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=6475)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=6473)[0m 
[36m(ClientAppActor pid=6473)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=6473)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=6473)[0m         
[36m(ClientAppActor pid=6475)[0m 
[36m(ClientAppActor pid=6475)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=6475)[0m 
[36m(ClientAppActor pid=6475)[0m         
[36m(ClientAppActor pid=6473)[0m 
[36m(ClientAppActor pid=6473)[0m         
[36m(ClientAppActor pid=6475)[0m

[Server] Round 3 - Global Loss: 0.3474, Global Accuracy: 0.8722


[36m(ClientAppActor pid=6475)[0m 
[36m(ClientAppActor pid=6475)[0m         
[36m(ClientAppActor pid=6475)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=6475)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=6473)[0m 
[36m(ClientAppActor pid=6473)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=6473)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=6473)[0m         
[36m(ClientAppActor pid=6475)[0m 
[36m(ClientAppActor pid=6475)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 4]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=6475)[0m 
[36m(ClientAppActor pid=6475)[0m         
[36m(ClientAppActor pid=6473)[0m 
[36m(ClientAppActor pid=6473)[0m         
[36m(ClientAppActor pid=6475)[0m

[Server] Round 4 - Global Loss: 0.3270, Global Accuracy: 0.8783


[36m(ClientAppActor pid=6475)[0m 
[36m(ClientAppActor pid=6475)[0m         
[36m(ClientAppActor pid=6475)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=6475)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=6473)[0m 
[36m(ClientAppActor pid=6473)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=6473)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=6473)[0m         
[36m(ClientAppActor pid=6475)[0m 
[36m(ClientAppActor pid=6475)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 5]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=6475)[0m 
[36m(ClientAppActor pid=6475)[0m         
[36m(ClientAppActor pid=6473)[0m 
[36m(ClientAppActor pid=6473)[0m         
[36m(ClientAppActor pid=6473)[0m

[Server] Round 5 - Global Loss: 0.3215, Global Accuracy: 0.8848


[36m(ClientAppActor pid=6473)[0m 
[36m(ClientAppActor pid=6473)[0m         
[36m(ClientAppActor pid=6473)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=6473)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=6475)[0m 
[36m(ClientAppActor pid=6475)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=6475)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=6475)[0m         
[36m(ClientAppActor pid=6473)[0m 
[36m(ClientAppActor pid=6473)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 5 round(s) in 341.10s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 0.4974995714023
[92mINFO [0m:      		round 2: 0.38234168117794254
[92mINFO [0m:      		round 3: 0.34874150848684565
[92mINFO


✅ Federated learning completed successfully!

Distributed Training History:
Rounds: 5
Round 1: Average Client Loss = 1.0000
Round 2: Average Client Loss = 2.0000
Round 3: Average Client Loss = 3.0000
Round 4: Average Client Loss = 4.0000
Round 5: Average Client Loss = 5.0000

Centralized Evaluation History:
Round 0: Global Loss = 0.0000
Round 1: Global Loss = 1.0000
Round 2: Global Loss = 2.0000
Round 3: Global Loss = 3.0000
Round 4: Global Loss = 4.0000
Round 5: Global Loss = 5.0000

Accuracy History:
Round 0: Global Accuracy = 0.1082
Round 1: Global Accuracy = 0.8279
Round 2: Global Accuracy = 0.8618
Round 3: Global Accuracy = 0.8722
Round 4: Global Accuracy = 0.8783
Round 5: Global Accuracy = 0.8848

Saving final trained model...
Training a final model for saving...
✅ Model saved as federated_fashionmnist.pt
Final model performance - Loss: 0.4775, Accuracy: 0.8327


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

📥 Model downloaded successfully!


In [6]:
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

# Load the FashionMNIST dataset (test set)
transform = transforms.ToTensor()
test_dataset = torchvision.datasets.FashionMNIST(
    root="./data", train=False, download=True, transform=transform
)

# Select a sample image (e.g., index 0)
image_tensor, label = test_dataset[0]

# Convert the tensor to PIL Image and save as PNG
image = transforms.ToPILImage()(image_tensor)
image.save("sample_fashionmnist.png")

print(f"Label: {label} (Class: {test_dataset.classes[label]})")
print("Image saved as 'sample_fashionmnist.png'")


Label: 9 (Class: Ankle boot)
Image saved as 'sample_fashionmnist.png'


  return Image.fromarray(npimg, mode=mode)
