import

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms

load MNIST dataset

In [2]:
# Create a transform to convert the images to tensors and normalize them
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

# Create a training dataset using the MNIST dataset, with the transform applied
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# Create a testing dataset using the MNIST dataset, with the transform applied
test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


100.0%


Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz


100.0%


Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz


100.0%


Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100.0%

Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw






define model

In [3]:
class SimpleNN(nn.Module):
    # Initialize the SimpleNN class
    def __init__(self):
        # Call the parent class's constructor
        super(SimpleNN, self).__init__()
        # Create a flatten layer to flatten the input tensor
        self.flatten = nn.Flatten()
        # Create a fully connected layer with 128 neurons and an input size of 28 * 28
        self.fc1 = nn.Linear(28 * 28, 128)
        # Create a ReLU activation function
        self.relu = nn.ReLU()
        # Create a fully connected layer with 10 neurons and an input size of 128
        self.fc2 = nn.Linear(128, 10)
        # Create a softmax activation function
        self.softmax = nn.Softmax(dim=1)

    # Define the forward pass of the network
    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.softmax(x)
        return x

Simulate Federated Learning Client

In [4]:
# Define the number of clients
num_clients = 5
# Calculate the size of data for each client
client_data_size = len(train_data) // num_clients
clients = []

# Loop through the number of clients
for i in range(num_clients):
    client_indices = list(range(i * client_data_size, (i + 1) * client_data_size))
    x_client = torch.stack([train_data[idx][0] for idx in client_indices])
    y_client = torch.tensor([train_data[idx][1] for idx in client_indices])
    clients.append(DataLoader(TensorDataset(x_client, y_client), batch_size=32, shuffle=True))

The process of federated learning

In [5]:
# Define a global model
global_model = SimpleNN()

# Define a function to average the weights of the clients
def federated_avg(weights):
    # Calculate the average of the weights for each layer
    avg_weights = [torch.mean(torch.stack([client_weights[layer] for client_weights in weights]), dim=0) 
                   for layer in range(len(weights[0]))]
    return avg_weights

# Define the number of rounds for federated learning
num_rounds = 5
# Define the loss function
criterion = nn.CrossEntropyLoss()

# Loop through the number of rounds
for round_num in range(num_rounds):
    print(f'Federated Learning Round {round_num + 1}')
    client_weights = []
    
    # Loop through the clients
    for client_data in clients:
        model = SimpleNN()
        model.load_state_dict(global_model.state_dict())
        optimizer = optim.Adam(model.parameters(), lr=0.001)
        model.train()
        for x_client, y_client in client_data:
            optimizer.zero_grad()
            outputs = model(x_client)
            loss = criterion(outputs, y_client)
            loss.backward()
            optimizer.step()
        
        client_weights.append([param.data.clone() for param in model.parameters()])
    
    # Aggregate weight
    new_weights = federated_avg(client_weights)
    # Update the global model weights
    for i, param in enumerate(global_model.parameters()):
        param.data = new_weights[i]

Federated Learning Round 1
Federated Learning Round 2
Federated Learning Round 3
Federated Learning Round 4
Federated Learning Round 5


Evaluate the global model on the test set

In [6]:
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)
global_model.eval()
correct = 0
total = 0

with torch.no_grad():
    for x_test, y_test in test_loader:
        outputs = global_model(x_test)
        _, predicted = torch.max(outputs.data, 1)
        total += y_test.size(0)
        correct += (predicted == y_test).sum().item()

print(f'Accuracy on the test set: {100 * correct / total}%')

Accuracy on the test set: 93.23%


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms

# 定义数据预处理
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
# 加载训练集和测试集
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# 定义简单的神经网络模型
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.softmax(x)
        return x

# 定义客户端数量和每个客户端的数据量
num_clients = 5
client_data_size = len(train_data) // num_clients
clients = []

# 将训练集划分为每个客户端的数据
# 遍历客户端数量
for i in range(num_clients):
    # 计算每个客户端的数据索引范围
    client_indices = list(range(i * client_data_size, (i + 1) * client_data_size))
    # 根据索引范围获取每个客户端的数据
    x_client = torch.stack([train_data[idx][0] for idx in client_indices])
    y_client = torch.tensor([train_data[idx][1] for idx in client_indices])
    # 将每个客户端的数据封装成DataLoader，并添加到clients列表中
    clients.append(DataLoader(TensorDataset(x_client, y_client), batch_size=32, shuffle=True))

# 初始化全局模型
global_model = SimpleNN()

# 定义联邦平均函数
def federated_avg(weights):
    # 对每个层的权重进行平均
    avg_weights = [torch.mean(torch.stack([client_weights[layer] for client_weights in weights]), dim=0) 
                   for layer in range(len(weights[0]))]
    # 返回平均后的权重
    return avg_weights

# 定义训练轮数和损失函数
num_rounds = 5
criterion = nn.CrossEntropyLoss()

# 开始联邦学习
for round_num in range(num_rounds):
    # 开始联邦学习第round_num轮
    print(f'Federated Learning Round {round_num + 1}')
    client_weights = []
    
    # 对每个客户端进行训练
    for client_data in clients:
        model = SimpleNN()
        # 加载全局模型参数
        model.load_state_dict(global_model.state_dict())
        optimizer = optim.Adam(model.parameters(), lr=0.001)
        
        model.train()
        # 对每个客户端数据进行训练
        for x_client, y_client in client_data:
            optimizer.zero_grad()
            outputs = model(x_client)
            loss = criterion(outputs, y_client)
            loss.backward()
            optimizer.step()
        
        # 将每个客户端的模型参数添加到client_weights列表中
        client_weights.append([param.data.clone() for param in model.parameters()])
    
    # 聚合权重
    new_weights = federated_avg(client_weights)
    # 更新全局模型参数
    for i, param in enumerate(global_model.parameters()):
        param.data = new_weights[i]

# 在测试集上评估全局模型
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)
global_model.eval()
correct = 0
total = 0

# 不计算梯度，用于测试模型
with torch.no_grad():
    # 遍历测试集
    for x_test, y_test in test_loader:
        # 使用全局模型进行预测
        outputs = global_model(x_test)
        # 获取预测结果
        _, predicted = torch.max(outputs.data, 1)
        # 统计测试集大小
        total += y_test.size(0)
        # 统计预测正确的数量
        correct += (predicted == y_test).sum().item()
print(f'Accuracy on the test set: {100 * correct / total}%')