In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import torch.quantization

# Set device (use CUDA if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# MNIST dataset with normalization
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root="./data", train=False, transform=transform, download=True)

# Data loaders
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)


In [2]:
N = 20

# Define the transformation to normalize MNIST data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load MNIST training dataset
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# Load MNIST testing dataset
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Combine datasets
combined_dataset = torch.utils.data.ConcatDataset([train_dataset, test_dataset])

import random

def split_dataset(combined_dataset, N):
    # Total number of samples in the dataset
    M = len(combined_dataset)
    indices = list(range(M))
    split_size = M // N
    
    # Shuffle indices to ensure randomness in splitting
    random.shuffle(indices)
    
    # Split indices into N parts
    user_data = [indices[i * split_size:(i + 1) * split_size] for i in range(N)]
    
    # Create subsets for each user
    user_datasets = [torch.utils.data.Subset(combined_dataset, user_data[i]) for i in range(N)]
    
    return user_datasets


def split_train_test(user_dataset, test_ratio=0.2):
    # Total number of samples
    M = len(user_dataset)
    test_size = int(M * test_ratio)
    train_size = M - test_size
    
    # Split the dataset into training and testing sets
    train_subset, test_subset = random_split(user_dataset, [train_size, test_size])
    
    return train_subset, test_subset


user_datasets = split_dataset(combined_dataset, N)

batch_size = 64  # Adjust batch size as needed

# Split user-specific dataset into training and testing sets
user_train_loaders = []
user_test_loaders = []
for user_dataset in user_datasets:
    train_data, test_data = split_train_test(user_dataset)
    user_train_loaders.append(DataLoader(train_data, batch_size=batch_size, shuffle=True))
    user_test_loaders.append(DataLoader(test_data, batch_size=batch_size, shuffle=False))
    
def aggregate_updates(local_updates):
    # A naive method to aggregate model weights
    new_state_dict = {}
    for key in local_updates[0].keys():
        new_state_dict[key] = torch.mean(torch.stack([update[key] for update in local_updates]), dim=0)
    return new_state_dict

In [3]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.quant = torch.ao.quantization.QuantStub()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.relu1 = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = self.quant(x)
        x = self.pool(self.relu1(self.conv1(x)))
        x = self.pool(self.relu2(self.conv2(x)))
        x = x.reshape(-1, 64 * 7 * 7)
        x = self.relu3(self.fc1(x))
        x = self.fc2(x)
        return x


In [4]:
global_model = Net().to(device)

client_models = [Net().to(device) for i in range(N)]

for model in client_models + [global_model]:
    # Fuse layers (for better optimization during quantization)
    model.fuse_model = lambda: torch.quantization.fuse_modules(model, [["conv1", "relu1"], ["conv2", "relu2"]])
    model.fuse_model()
    
    # Specify quantization configuration
    model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
    
    # Prepare for QAT
    model = torch.quantization.prepare_qat(model)






In [5]:
client_models[0].load_state_dict(global_model.state_dict())

<All keys matched successfully>

In [6]:
avg_test_accs = []
def federated_train(global_model, client_models, user_train_loaders, user_test_loaders, epochs):
    for epoch in range(epochs):
        local_updates = []
        test_accs = []
        for i in range(N):
            # Copy the global model for each user
            client_models[i].load_state_dict(global_model.state_dict())
            optimizer = torch.optim.SGD(client_models[i].parameters(), lr=0.01)  # Example with SGD optimizer

            
            # Train the local model on the user's training data
            for data, target in user_train_loaders[i]:
                optimizer.zero_grad()
                output = client_models[i](data)
                loss = torch.nn.functional.cross_entropy(output, target)
                loss.backward()
                optimizer.step()
            
            # Save the local model state dict
            local_updates.append(client_models[i].state_dict())
            
            # Evaluate on the test set of the user
            correct = 0
            total = 0
            with torch.no_grad():
                for data, target in user_test_loaders[i]:
                    output = client_models[i](data)
                    _, predicted = torch.max(output.data, 1)
                    total += target.size(0)
                    correct += (predicted == target).sum().item()
            accuracy = 100 * correct / total
            print(f"User {i+1} - Epoch {epoch+1}/{epochs} - Test Accuracy: {accuracy:.2f}%")
            test_accs.append(accuracy)
        
        # Aggregate updates from all users
        aggregated_model_state = aggregate_updates(local_updates)
        global_model.load_state_dict(aggregated_model_state)
        print(f"Epoch {epoch+1}/{epochs} completed.")
        avg_test_accs.append(sum(test_accs) / len(test_accs))

# Example usage
#global_model = Net().to(device)
federated_train(global_model, client_models, user_train_loaders, user_test_loaders, epochs=10)

User 1 - Epoch 1/10 - Test Accuracy: 25.71%
User 2 - Epoch 1/10 - Test Accuracy: 28.71%
User 3 - Epoch 1/10 - Test Accuracy: 12.71%
User 4 - Epoch 1/10 - Test Accuracy: 9.86%
User 5 - Epoch 1/10 - Test Accuracy: 23.29%
User 6 - Epoch 1/10 - Test Accuracy: 15.00%
User 7 - Epoch 1/10 - Test Accuracy: 15.71%
User 8 - Epoch 1/10 - Test Accuracy: 20.86%
User 9 - Epoch 1/10 - Test Accuracy: 19.29%
User 10 - Epoch 1/10 - Test Accuracy: 17.57%
User 11 - Epoch 1/10 - Test Accuracy: 26.14%
User 12 - Epoch 1/10 - Test Accuracy: 14.29%
User 13 - Epoch 1/10 - Test Accuracy: 10.71%
User 14 - Epoch 1/10 - Test Accuracy: 10.86%
User 15 - Epoch 1/10 - Test Accuracy: 30.00%
User 16 - Epoch 1/10 - Test Accuracy: 17.86%
User 17 - Epoch 1/10 - Test Accuracy: 21.57%
User 18 - Epoch 1/10 - Test Accuracy: 34.57%
User 19 - Epoch 1/10 - Test Accuracy: 26.71%
User 20 - Epoch 1/10 - Test Accuracy: 21.57%
Epoch 1/10 completed.
User 1 - Epoch 2/10 - Test Accuracy: 56.00%
User 2 - Epoch 2/10 - Test Accuracy: 39.00%


In [None]:
model_quantized = torch.quantization.convert(model_prepared)


In [None]:
# Evaluate quantized model
model_quantized.eval()
correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model_quantized(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Accuracy of the quantized model on the test set: {100 * correct / total:.2f}%")


In [None]:
sum([p.numel() for p in model_quantized.parameters()])

In [None]:
[p.numel() for p in model_prepared.parameters()]

In [None]:
def print_model_size(mdl):
    torch.save(mdl.state_dict(), "tmp.pt")
    print("%.2f MB" %(os.path.getsize("tmp.pt")/1e6))
    os.remove('tmp.pt')