In [2]:
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import pennylane as qml
from torch.utils.data import DataLoader, random_split

import torchvision
from torchvision import transforms, datasets
from torchvision.transforms import ToTensor
import torch.optim as optim
from pennylane import numpy as np

In [9]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print(device)

cuda


In [11]:
#######################
# Define the quantum circuit using PennyLane
n_qubits = 5
#dev = qml.device("default.qubit", wires=n_qubits)

if use_cuda is True:
    dev = qml.device("lightning.gpu", wires=n_qubits)
else:
    dev = qml.device("default.qubit", wires=n_qubits)
print('dev: ',dev)
@qml.qnode(dev)
def qnode(inputs, weights):
    qml.AngleEmbedding(inputs, wires=range(n_qubits))
    qml.BasicEntanglerLayers(weights, wires=range(n_qubits))
    return [qml.expval(qml.PauliZ(wires=i)) for i in range(n_qubits)]

# Define the QLayer
n_layers = 3
weight_shapes = {"weights": (n_layers, n_qubits)}


# ==== The model architecture ====
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # Convolutional layer 1 with 1 input channels (for greyscale images), 16 output channels, and 5x5 kernel
        self.conv1 = nn.Conv2d(1, 16, 5, stride=1, padding=2)
        # Batch normalization after convolutional layer 1
        self.bn1 = nn.BatchNorm2d(16)
        # Max pooling layer with a 2x2 window
        self.pool = nn.MaxPool2d(2, 2)
        # Convolutional layer 2 with 16 input channels (from the previous layer), 32 output channels, and 5x5 kernel
        self.conv2 = nn.Conv2d(16, 32, 5, stride=1, padding=2)
        # Batch normalization after convolutional layer 2
        self.bn2 = nn.BatchNorm2d(32)
        # Quantum layer
        self.qlayer1 = qml.qnn.TorchLayer(qnode, weight_shapes)
        self.qlayer2 = qml.qnn.TorchLayer(qnode, weight_shapes)
        self.qlayer3 = qml.qnn.TorchLayer(qnode, weight_shapes)
        self.qlayer4 = qml.qnn.TorchLayer(qnode, weight_shapes)
        # Fully connected layers
        self.fc1 = nn.Linear(32 * 7 * 7, 120)
        self.fc2 = nn.Linear(120, 20)
        self.fc3 = nn.Linear(20, 10)

    def forward(self, x):
        # Propagate the input through the CNN layers
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        # Flatten the output from the convolutional layers
        x = x.view(-1, 32 * 7 * 7)
        # Pass the output to the quantum layer
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x_1, x_2, x_3, x_4 = torch.split(x, 5, dim=1)
        x_1 = self.qlayer1(x_1)
        x_2 = self.qlayer2(x_2)
        x_3 = self.qlayer3(x_3)
        x_4 = self.qlayer4(x_4)
        x = torch.cat([x_1, x_2, x_3, x_4], axis=1)
        x = self.fc3(x)
        return x

dev:  <lightning.gpu device (wires=5) at 0x76eae2d93810>


In [12]:
# Download and load MNIST dataset

train_data = datasets.MNIST(
    root = 'data',
    train = True,
    transform = ToTensor(),
    download = True,
)
test_data = datasets.MNIST(
    root = 'data',
    train = False,
    transform = ToTensor()
)

In [13]:
# ==== Federated learning setup ====
def train_local(model, train_loader, criterion, optimizer, epochs=1):
    model.train()
    for _ in range(epochs):
        for data, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(data)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
    return model.state_dict()

def federated_avg(state_dicts):
    """Average model weights across clients."""
    avg_dict = copy.deepcopy(state_dicts[0])
    for key in avg_dict.keys():
        for i in range(1, len(state_dicts)):
            avg_dict[key] += state_dicts[i][key]
        avg_dict[key] = avg_dict[key] / len(state_dicts)
    return avg_dict

In [14]:
# ==== Dataset split for multiple clients ====
num_clients = 3
client_loaders = []

dataset = train_data  # Your dataset here
client_size = len(dataset) // num_clients
client_sets = random_split(dataset, [client_size] * num_clients)

for c in range(num_clients):
    loader = DataLoader(client_sets[c], batch_size=4, shuffle=True)
    client_loaders.append(loader)

In [15]:
# ==== Federated loop ====
global_model = Net()
criterion = nn.CrossEntropyLoss()

num_rounds = 2
local_epochs = 2
for round_idx in range(num_rounds):
    print(f"--- Round {round_idx+1} ---")
    local_states = []

    for c in range(num_clients):
        local_model = copy.deepcopy(global_model)
        optimizer = torch.optim.SGD(local_model.parameters(), lr=0.001, momentum=0.9)
        state_dict = train_local(local_model, client_loaders[c], criterion, optimizer, epochs=local_epochs)
        local_states.append(state_dict)

    # Aggregate weights
    global_state = federated_avg(local_states)
    global_model.load_state_dict(global_state)

--- Round 1 ---
--- Round 2 ---


In [16]:

# ==== Evaluation ====
val_loader = DataLoader(test_data, batch_size=4, shuffle=False)
correct = 0
total = 0
global_model.eval()
with torch.no_grad():
    for images, labels in val_loader:
        outputs = global_model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Global Model Accuracy: {100 * correct / total:.2f}%')


Global Model Accuracy: 98.99%


None
