In [1]:
!git clone https://github.com/mit-han-lab/torchquantum.git

Cloning into 'torchquantum'...
remote: Enumerating objects: 15656, done.[K
remote: Counting objects: 100% (1801/1801), done.[K
remote: Compressing objects: 100% (436/436), done.[K
remote: Total 15656 (delta 1560), reused 1369 (delta 1364), pack-reused 13855 (from 4)[K
Receiving objects: 100% (15656/15656), 101.77 MiB | 11.82 MiB/s, done.
Resolving deltas: 100% (8900/8900), done.
Updating files: 100% (346/346), done.


In [2]:
cd torchquantum

/content/torchquantum


In [3]:
pip install --editable .

Obtaining file:///content/torchquantum
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting dill==0.3.4 (from torchquantum==0.1.8)
  Downloading dill-0.3.4-py2.py3-none-any.whl.metadata (9.6 kB)
Collecting nbsphinx (from torchquantum==0.1.8)
  Downloading nbsphinx-0.9.6-py3-none-any.whl.metadata (2.1 kB)
Collecting pathos>=0.2.7 (from torchquantum==0.1.8)
  Downloading pathos-0.3.3-py3-none-any.whl.metadata (11 kB)
Collecting pylatexenc>=2.10 (from torchquantum==0.1.8)
  Downloading pylatexenc-2.10.tar.gz (162 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m162.6/162.6 kB[0m [31m8.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting pyscf>=2.0.1 (from torchquantum==0.1.8)
  Downloading pyscf-2.8.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.4 kB)
Collecting qiskit<1.0.0,>=0.39.0 (from torchquantum==0.1.8)
  Downloading qiskit-0.46.3-py3-none-any.whl.metadata (12 kB)
Collecting reco

In [87]:
import os
import sys
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from torch.utils.data import DataLoader, Subset, RandomSampler, SubsetRandomSampler
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision import datasets, transforms
from torchvision.datasets import ImageFolder
from sklearn.metrics import precision_score, recall_score, f1_score
from torch.distributions.bernoulli import Bernoulli
import matplotlib.pyplot as plt
import pandas as pd
from collections import OrderedDict
from copy import deepcopy
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import torchquantum as tq
from torchquantum.layer import U3CU3Layer0, RandomLayer

# Set seeds for reproducibility
random.seed(42)
np.random.seed(42)

n_qubits, n_layers, n_classes = 4, 4, 4


In [88]:
class Circuit_TQ(tq.QuantumModule):
    def __init__(self,
                 n_wires=4,
                 n_layers=4,
                 n_classes=4
                ):
        super().__init__()
        self.tag = 'Quantum'
        self.n_wires  = n_wires
        self.layers = n_layers
        self.q_device = tq.QuantumDevice(n_wires=self.n_wires)
        self.encoder  = tq.StateEncoder()
        self.n_classes = 4
        self.PQC = nn.ModuleList([ U3CU3Layer0({'n_wires': self.n_wires, 'n_blocks': 1, 'n_layers_per_block': 1})   for i in range(n_layers)])
        self.measure = tq.MeasureAll(tq.PauliZ)
        self.state_vectors = []

    def forward(self,x,layers):
        if len(x.shape) == 1:
            x = x.unsqueeze(0)
        self.encoder(self.q_device,x)
        for i in range(layers):
            self.PQC[i](self.q_device)
        output = self.measure(self.q_device)[:,:self.n_classes]
        state_vector = self.q_device.get_states_1d()
        self.state_vectors.append(state_vector)

        return torch.softmax(3.5* output,dim=-1), output.clone()

    def get_logits_and_states(self,x,layers):
        if len(x.shape) == 1:
            x = x.unsqueeze(0)
        self.encoder(self.q_device,x)
        for i in range(layers):
            self.PQC[i](self.q_device)
        output = self.measure(self.q_device)[:,:self.n_classes]
        return output.clone(), self.q_device.get_states_1d()


    def get_q_state(self,x,layers):
        if len(x.shape) == 1:
            x = x.unsqueeze(0)
        self.encoder(self.q_device,x)
        for i in range(layers):
            self.PQC[i](self.q_device)
            # self.PQC2[i](self.q_device)
        return self.q_device.get_states_1d()

In [89]:
# Data Preparation
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
selected_classes = [0, 1, 2, 3]
filtered_indices = [i for i, target in enumerate(dataset.targets) if target in selected_classes]
filtered_targets = [dataset.targets[i] for i in filtered_indices]
train_indices, test_indices = train_test_split(filtered_indices, test_size=0.3, stratify=filtered_targets, random_state=42)
train_data = Subset(dataset, train_indices)
test_data = Subset(dataset, test_indices)

In [90]:
def create_dataloader(dataset, batch_size=70):
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [95]:
# Model Initialization
model = Circuit_TQ(n_qubits, n_layers, n_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=8e-4)
criterion = nn.CrossEntropyLoss()

# DataLoader
train_loader = create_dataloader(train_data)
test_loader = create_dataloader(test_data)

time_list = []
test_accuracy = []
test_loss = []

for epoch in range(5):
    TR_Size = 0
    TR_Corrects = 0
    TR_Loss = []
    TE_Size = 0
    TE_Corrects = 0
    TE_Loss = []

    model.train()
    for x, y in train_loader:
        x = F.avg_pool2d(x, 10).reshape(-1, 4)

        optimizer.zero_grad()
        y_hat, logit = model(x, n_layers)
        loss = criterion(y_hat, F.one_hot(y, n_classes).to(torch.float32))
        loss.backward()
        optimizer.step()

        TR_Loss.append(loss.item())
        _, indices = y_hat.topk(1, dim=1)
        masks = indices.eq(y.view(-1, 1).expand_as(indices))
        TR_Size += y.shape[0]
        TR_Corrects += masks.sum().item()

    TR_Acc = TR_Corrects / TR_Size
    TR_Loss = sum(TR_Loss) / len(TR_Loss)

    model.eval()
    with torch.no_grad():
        for x, y in test_loader:
            x = F.avg_pool2d(x, 10).reshape(-1, 4)

            y_hat, logit = model(x, n_layers)
            loss = criterion(y_hat, F.one_hot(y, n_classes).to(torch.float32))
            TE_Loss.append(loss.item())

            _, indices = y_hat.topk(1, dim=1)
            masks = indices.eq(y.view(-1, 1).expand_as(indices))
            TE_Size += y.shape[0]
            TE_Corrects += masks.sum().item()

    TE_Acc = TE_Corrects / TE_Size
    TE_Loss = sum(TE_Loss) / len(TE_Loss)

    print(f"Epoch {epoch+1}: TR_Acc = {TR_Acc:.4f}, TE_Acc = {TE_Acc:.4f}, Loss = {TE_Loss:.4f}")
    print('Training complete')


Epoch 1: TR_Acc = 0.5998, TE_Acc = 0.6397, Loss = 1.1535
Training complete
Epoch 2: TR_Acc = 0.6497, TE_Acc = 0.6514, Loss = 1.1278
Training complete
Epoch 3: TR_Acc = 0.6625, TE_Acc = 0.6668, Loss = 1.1202
Training complete
Epoch 4: TR_Acc = 0.6720, TE_Acc = 0.6786, Loss = 1.1169
Training complete
Epoch 5: TR_Acc = 0.6777, TE_Acc = 0.6902, Loss = 1.1127
Training complete


In [96]:
# Federated Learning Setup
n_clients = 5
n_epochs = 5
# Adjust dataset split so that the last chunk gets the remainder
split_sizes = [len(train_data) // n_clients] * (n_clients - 1)
split_sizes.append(len(train_data) - sum(split_sizes))  # Remaining samples go to the last client

# Perform the split
clients_data = torch.utils.data.random_split(train_data, split_sizes)
clients_loaders = [create_dataloader(client_data) for client_data in clients_data]
test_loader = create_dataloader(test_data)


In [101]:
def Aggregation(models, n_qubits, n_layers,x):
    empty_model = Circuit_TQ(n_qubits,n_layers)
    with torch.no_grad():
        empty_model(x,n_layers)
        for param, empty_param in zip(models[0].parameters(),empty_model.parameters()):
            empty_param.data = torch.zeros_like(param)
        for model in models:
            for param, empty_param in zip(model.parameters(), empty_model.parameters()):
                empty_param.data += param.data
        for empty_param in empty_model.parameters():
            empty_param.data /= len(models)
    return empty_model

In [102]:
# Initialize Clients
clients_models = [Circuit_TQ(n_qubits, n_layers) for _ in range(n_clients)]
clients_optimizers = [torch.optim.Adam(model.parameters(), lr=8e-4) for model in clients_models]
models = [Circuit_TQ(n_qubits, n_layers, n_classes) for _ in range(n_clients)]
optims = [torch.optim.Adam(model.parameters(), lr=8e-4) for model in models]
criterion = nn.CrossEntropyLoss()


In [None]:
time_list = []
test_accuracy = []
test_loss = []

for epoch in range(5):
    TR_Size = 0
    TR_Corrects = 0
    TR_Loss = []
    TE_Size = 0
    TE_Corrects = 0
    TE_Loss = []

    for i, (model, optim) in enumerate(zip(models, optims)):
        model.train()
        for x, y in clients_loaders[i]:
            x = F.avg_pool2d(x, 10).reshape(-1, 4)  # Adjust input shape for Circuit_TQ

            optim.zero_grad()
            y_hat, logit = model(x, n_layers)
            loss = criterion(y_hat, F.one_hot(y, n_classes).to(torch.float32))
            loss.backward()
            optim.step()

            TR_Loss.append(loss.item())
            _, indices = y_hat.topk(1, dim=1)
            masks = indices.eq(y.view(-1, 1).expand_as(indices))
            TR_Size += y.shape[0]
            TR_Corrects += masks.sum().item()

    TR_Acc = TR_Corrects / TR_Size
    TR_Loss = sum(TR_Loss) / len(TR_Loss)

    # Federated Aggregation Step
    Global_model = Aggregation(models,n_qubits,n_layers, x)
    Global_model.eval()

    with torch.no_grad():
        for x, y in test_loader:
            x = F.avg_pool2d(x, 10).reshape(-1, 4)
            y_hat, logit = Global_model(x, n_layers)
            loss = criterion(y_hat, F.one_hot(y, n_classes).to(torch.float32))
            TE_Loss.append(loss.item())

            _, indices = y_hat.topk(1, dim=1)
            masks = indices.eq(y.view(-1, 1).expand_as(indices))
            TE_Size += y.shape[0]
            TE_Corrects += masks.sum().item()


    TE_Acc = TE_Corrects / TE_Size
    TE_Loss = sum(TE_Loss) / len(TE_Loss)

    print(f"Epoch {epoch+1}: TR_Acc = {TR_Acc:.4f}, TE_Acc = {TE_Acc:.4f}, Loss = {TE_Loss:.4f}")
    print('Training complete')


Epoch 1: TR_Acc = 0.3699, TE_Acc = 0.2929, Loss = 1.3953
Training complete
Epoch 2: TR_Acc = 0.5067, TE_Acc = 0.2934, Loss = 1.3925
Training complete
Epoch 3: TR_Acc = 0.5931, TE_Acc = 0.2945, Loss = 1.3885
Training complete
