# Install and Import Dependencies


In [None]:
!pip install pennylane

In [None]:
import os
import time
import numpy as np
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import pennylane as qml
from tqdm.auto import tqdm
from typing import Union, Tuple

# Set random seeds for reproducibility
torch.manual_seed(0)
np.random.seed(0)

# Utility Functions


In [None]:
def aggregate_weights(client_weights, client_sizes):
    global_weights = {}
    total_size = sum(client_sizes)

    # Initialize with zeros
    for key in client_weights[0].keys():
        global_weights[key] = torch.zeros_like(client_weights[0][key])

    # Accumulate weighted updates
    for idx, (weights, size) in enumerate(zip(client_weights, client_sizes)):
        weight_factor = size / total_size
        for key in weights.keys():
            global_weights[key] += weights[key] * weight_factor

    return global_weights


def federated_learning():
    # Load datasets
    trainloaders, _, testloader = load_datasets(
        num_clients=CONFIG["num_clients"],
        batch_size=CONFIG["batch_size"],
        resize=CONFIG["resize"],
        seed=CONFIG["seed"],
        num_workers=CONFIG["num_workers"],
        splitter=CONFIG["splitter"],
        dataset=CONFIG["dataset"],
        data_path=CONFIG["data_path"],
        data_path_val=CONFIG["data_path_val"],
    )

    # Get client dataset sizes
    client_sizes = [len(loader.dataset) for loader in trainloaders]

    # Initialize global model and loss function
    global_model = QNNModel().to(CONFIG["device"])
    criterion = nn.CrossEntropyLoss()

    # Initial evaluation
    init_loss, init_acc, _, _, _ = test(
        global_model, testloader, criterion, CONFIG["device"]
    )
    print(f"\nInitial Global Model - Loss: {init_loss:.4f}, Accuracy: {init_acc:.2f}%")

    # Federated learning loop
    for round in range(CONFIG["num_rounds"]):
        print(f"\n=== Federated Round {round+1}/{CONFIG['num_rounds']} ===")

        # Client training
        client_weights = []
        for client_id in range(CONFIG["num_clients"]):
            print(f"\n--- Client {client_id+1} Training ---")
            model = QNNModel().to(CONFIG["device"])
            model.load_state_dict(global_model.state_dict())

            weights = train(
                model,
                trainloaders[client_id],
                epochs=CONFIG["local_epochs"],
                lr=CONFIG["learning_rate"],
                device=CONFIG["device"],
            )
            client_weights.append(weights)

        # Aggregate weights
        global_weights = aggregate_weights(client_weights, client_sizes)
        global_model.load_state_dict(global_weights)

        # Global evaluation
        test_loss, test_acc, _, _, _ = test(
            global_model, testloader, criterion, CONFIG["device"]
        )
        print(
            f"\nGlobal Model Performance, Round: {round+1} - Loss: {test_loss:.4f}, Accuracy: {test_acc:.2f}%"
        )

# Data setup


In [None]:
NORMALIZE_DICT = {
    "fashion_mnist": dict(mean=(0.2860,), std=(0.3530,)),
}


def split_data_client(dataset, num_clients, seed):
    partition_size = len(dataset) // num_clients
    lengths = [partition_size] * (num_clients - 1)
    lengths += [len(dataset) - sum(lengths)]
    ds = random_split(dataset, lengths, torch.Generator().manual_seed(seed))
    return ds


def load_datasets(
    num_clients: int,
    batch_size: int,
    resize: int,
    seed: int,
    num_workers: int,
    splitter=10,
    dataset="fashion_mnist",
    data_path="./data/",
    data_path_val="",
):
    list_transforms = [
        transforms.ToTensor(),
        transforms.Normalize(**NORMALIZE_DICT[dataset]),
    ]

    # Resize images for non-CIFAR datasets
    if dataset in ["caltech101", "stanfordcars"] and resize is not None:
        list_transforms = [transforms.Resize((resize, resize))] + list_transforms
    elif dataset == "svhn":
        list_transforms = [
            transforms.Resize((32, 32))
        ] + list_transforms  # SVHN images are 32x32
    # No resize for Fashion-MNIST (keep 28x28)

    transformer = transforms.Compose(list_transforms)

    try:
        if dataset == "caltech101":
            trainset = datasets.ImageFolder(
                data_path + "caltech101/train", transform=transformer
            )
            testset = datasets.ImageFolder(
                data_path + "caltech101/test", transform=transformer
            )
        elif dataset == "stanfordcars":
            trainset = datasets.StanfordCars(
                data_path + "stanfordcars",
                split="train",
                download=True,
                transform=transformer,
            )
            testset = datasets.StanfordCars(
                data_path + "stanfordcars",
                split="test",
                download=True,
                transform=transformer,
            )
        elif dataset == "fashion_mnist":
            trainset = datasets.FashionMNIST(
                data_path + "fashion_mnist",
                train=True,
                download=True,
                transform=transformer,
            )
            testset = datasets.FashionMNIST(
                data_path + "fashion_mnist",
                train=False,
                download=True,
                transform=transformer,
            )
        else:
            raise ValueError(f"Unsupported dataset: {dataset}")
    except Exception as e:
        print(f"Failed to load dataset: {e}")
        raise

    # Split data into clients
    datasets_train = split_data_client(trainset, num_clients, seed)

    # Handle validation data
    if data_path_val:
        valset = datasets.ImageFolder(data_path_val, transform=transformer)
        datasets_val = split_data_client(valset, num_clients, seed)
    else:
        datasets_val = None

    # Create dataloaders
    trainloaders = []
    valloaders = []
    for i in range(num_clients):
        if data_path_val:
            trainloaders.append(
                DataLoader(datasets_train[i], batch_size=batch_size, shuffle=True)
            )
            valloaders.append(DataLoader(datasets_val[i], batch_size=batch_size))
        else:
            len_val = int(len(datasets_train[i]) * splitter / 100)
            len_train = len(datasets_train[i]) - len_val
            lengths = [len_train, len_val]
            ds_train, ds_val = random_split(
                datasets_train[i], lengths, torch.Generator().manual_seed(seed)
            )
            trainloaders.append(
                DataLoader(ds_train, batch_size=batch_size, shuffle=True)
            )
            valloaders.append(DataLoader(ds_val, batch_size=batch_size))

    testloader = DataLoader(testset, batch_size=batch_size)
    return trainloaders, valloaders, testloader

# Training and testing functions


In [None]:
def train(model, trainloader, epochs, lr, device):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    model.train()

    for epoch in range(epochs):
        epoch_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in trainloader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        epoch_acc = 100 * correct / total
        print(
            f"Epoch {epoch+1}/{epochs} - Loss: {epoch_loss/len(trainloader):.4f}, Acc: {epoch_acc:.2f}%"
        )

    return model.state_dict()


def test(
    model: torch.nn.Module,
    dataloader: torch.utils.data.DataLoader,
    loss_fn: Union[torch.nn.Module, Tuple],
    device: torch.device,
):
    model.eval()
    test_loss, test_acc = 0, 0
    y_pred = []
    y_true = []
    y_proba = []
    softmax = nn.Softmax(dim=1)

    with torch.inference_mode():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            output = model(images)

            probas_output = softmax(output)
            y_proba.extend(probas_output.detach().cpu().numpy())

            loss = loss_fn(output, labels)
            test_loss += loss.item()

            labels_np = labels.data.cpu().numpy()
            y_true.extend(labels_np)
            preds = np.argmax(output.detach().cpu().numpy(), axis=1)
            y_pred.extend(preds)

            acc = (preds == labels_np).mean()
            test_acc += acc

    y_proba = np.array(y_proba)
    test_loss = test_loss / len(dataloader)
    test_acc = test_acc / len(dataloader)
    return test_loss, test_acc * 100, y_pred, y_true, y_proba

# Main experiment setup


In [None]:
CONFIG = {
    "dataset": "fashion_mnist",
    "num_clients": 10,
    "num_rounds": 20,
    "local_epochs": 10,
    "batch_size": 32,
    "resize": 32,
    "seed": 42,
    "num_workers": 0,
    "splitter": 10,
    "data_path": "./data/",
    "data_path_val": "",
    "num_classes": 10,
    "n_qubits": 6,
    "n_layers": 6,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "learning_rate": 1e-3,
}

dev = qml.device("default.qubit", wires=CONFIG["n_qubits"])


@qml.qnode(dev, interface="torch")
def quantum_net(inputs, weights):
    qml.AngleEmbedding(inputs, wires=range(CONFIG["n_qubits"]))
    qml.BasicEntanglerLayers(weights, wires=range(CONFIG["n_qubits"]))
    return [qml.expval(qml.PauliZ(i)) for i in range(CONFIG["n_qubits"])]


class QNNModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Flatten(),
            nn.Linear(256 * 3 * 3, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, CONFIG["n_qubits"]),
        )

        self.qnn = qml.qnn.TorchLayer(
            quantum_net, {"weights": (CONFIG["n_layers"], CONFIG["n_qubits"])}
        )
        self.fc = nn.Linear(CONFIG["n_qubits"], CONFIG["num_classes"])

    def forward(self, x):
        x = self.conv_layers(x)
        x = self.qnn(x)
        return self.fc(x)

# Main Experiment


In [None]:
print(f"Using device: {CONFIG['device']}")
print("Starting federated learning...")
start_time = time.time()
federated_learning()
print(f"Total training time: {time.time()-start_time:.2f} seconds")