import

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms

load MNIST dataset

In [3]:
# Create a transform to convert the images to tensors and normalize them
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

# Create a training dataset using the MNIST dataset, with the transform applied
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# Create a testing dataset using the MNIST dataset, with the transform applied
test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

define model

In [4]:
class SimpleNN(nn.Module):
    # Initialize the SimpleNN class
    def __init__(self):
        # Call the parent class's constructor
        super(SimpleNN, self).__init__()
        # Create a flatten layer to flatten the input tensor
        self.flatten = nn.Flatten()
        # Create a fully connected layer with 128 neurons and an input size of 28 * 28
        self.fc1 = nn.Linear(28 * 28, 128)
        # Create a ReLU activation function
        self.relu = nn.ReLU()
        # Create a fully connected layer with 10 neurons and an input size of 128
        self.fc2 = nn.Linear(128, 10)
        # Create a softmax activation function
        self.softmax = nn.Softmax(dim=1)

    # Define the forward pass of the network
    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.softmax(x)
        return x

Simulate Federated Learning Client

In [5]:
# Define the number of clients
num_clients = 5
# Calculate the size of data for each client
client_data_size = len(train_data) // num_clients
clients = []

# Loop through the number of clients
for i in range(num_clients):
    client_indices = list(range(i * client_data_size, (i + 1) * client_data_size))
    x_client = torch.stack([train_data[idx][0] for idx in client_indices])
    y_client = torch.tensor([train_data[idx][1] for idx in client_indices])
    clients.append(DataLoader(TensorDataset(x_client, y_client), batch_size=32, shuffle=True))

The process of federated learning

In [6]:
# Define a global model
global_model = SimpleNN()

# Define a function to average the weights of the clients
def federated_avg(weights):
    # Calculate the average of the weights for each layer
    avg_weights = [torch.mean(torch.stack([client_weights[layer] for client_weights in weights]), dim=0) 
                   for layer in range(len(weights[0]))]
    return avg_weights

# Define the number of rounds for federated learning
num_rounds = 5
num_epochs = 5
criterion = nn.CrossEntropyLoss()

# Start federated learning
for round_num in range(num_rounds):
    # Start federated learning round number round_num
    print(f'Federated Learning Round {round_num + 1}')
    client_weights = []
    
    # Loop through the clients
    for client_data in clients:
        model = SimpleNN()
        # Load the global model parameters
        model.load_state_dict(global_model.state_dict())
        optimizer = optim.Adam(model.parameters(), lr=0.001)
        
        model.train()
        # Train the model with the client data
        for epoch in range(num_epochs):  # Number of training rounds for each client
            for x_client, y_client in client_data:
                optimizer.zero_grad()
                outputs = model(x_client)
                loss = criterion(outputs, y_client)
                loss.backward()
                optimizer.step()
        
        # Add the client model parameters to the client_weights list
        client_weights.append([param.data.clone() for param in model.parameters()])
    
    # Aggregate weight
    new_weights = federated_avg(client_weights)
    # Update the global model weights
    for i, param in enumerate(global_model.parameters()):
        param.data = new_weights[i]

Federated Learning Round 1
Federated Learning Round 2
Federated Learning Round 3
Federated Learning Round 4
Federated Learning Round 5


Evaluate the global model on the test set

In [7]:
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)
global_model.eval()
correct = 0
total = 0

with torch.no_grad():
    for x_test, y_test in test_loader:
        outputs = global_model(x_test)
        _, predicted = torch.max(outputs.data, 1)
        total += y_test.size(0)
        correct += (predicted == y_test).sum().item()

print(f'Accuracy on the test set: {100 * correct / total}%')

Accuracy on the test set: 96.58%


In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, random_split
from torchvision import datasets, transforms

# Define the data preprocessing
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
# Load the training set
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# Divide the training set into training set, validation set, and test set
train_size = int(0.6 * len(train_data))
val_size = int(0.2 * len(train_data))
test_size = len(train_data) - train_size - val_size
train_data, val_data, test_data = random_split(train_data, [train_size, val_size, test_size])

# Define the simple neural network model
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.softmax(x)
        return x

# Define the number of rounds for federated learning and the loss function
num_rounds = 5
num_epochs = 5  # Number of training rounds for each client
criterion = nn.CrossEntropyLoss()

# Start federated learning
for round_num in range(num_rounds):
    # Start federated learning round number round_num
    print(f'Federated Learning Round {round_num + 1}')
    client_weights = []
    
    # Loop through the clients
    for client_data in clients:
        model = SimpleNN()
        # Load the global model parameters
        model.load_state_dict(global_model.state_dict())
        optimizer = optim.Adam(model.parameters(), lr=0.001)
        
        model.train()
        # Train the model with the client data
        for epoch in range(num_epochs):  # Number of training rounds for each client
            for x_client, y_client in client_data:
                optimizer.zero_grad()
                outputs = model(x_client)
                loss = criterion(outputs, y_client)
                loss.backward()
                optimizer.step()
        
        # Add the client model parameters to the client_weights list
        client_weights.append([param.data.clone() for param in model.parameters()])

    # Aggregate weight
    new_weights = federated_avg(client_weights)
    # Update the global model weights
    for i, param in enumerate(global_model.parameters()):
        param.data = new_weights[i]

    # Evaluate the global model on the validation set
    val_loader = DataLoader(val_data, batch_size=32, shuffle=False)
    global_model.eval()
    correct = 0
    total = 0

    # Do not calculate gradients, used for testing the model
    with torch.no_grad():
        # Loop through the validation set
        for x_val, y_val in val_loader:
            # Use the global model for prediction
            outputs = global_model(x_val)
            # Get the prediction result
            _, predicted = torch.max(outputs.data, 1)
            # Count the size of the validation set
            total += y_val.size(0)
            # Count the number of correct predictions
            correct += (predicted == y_val).sum().item()
    # Print the accuracy
    print(f'Accuracy on the validation set: {100 * correct / total}%')

# Evaluate the global model on the test set
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)
global_model.eval()
correct = 0
total = 0

# Do not calculate gradients, used for testing the model
with torch.no_grad():
    # Loop through the test set
    for x_test, y_test in test_loader:
        # Use the global model for prediction
        outputs = global_model(x_test)
        # Get the prediction result
        _, predicted = torch.max(outputs.data, 1)
        # Count the size of the test set
        total += y_test.size(0)
        # Count the number of correct predictions
        correct += (predicted == y_test).sum().item()
# Print the accuracy
print(f'Accuracy on the test set: {100 * correct / total}%')


Federated Learning Round 1
Accuracy on the validation set: 97.43333333333334%
Federated Learning Round 2
Accuracy on the validation set: 97.59166666666667%
Federated Learning Round 3
Accuracy on the validation set: 97.80833333333334%
Federated Learning Round 4
Accuracy on the validation set: 97.95833333333333%
Federated Learning Round 5
Accuracy on the validation set: 98.20833333333333%
Accuracy on the test set: 98.09166666666667%


In [13]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

# Check if CUDA is available and use GPU if possible
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define the data preprocessing
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load the MNIST dataset
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# Divide the dataset into training set, validation set, and test set
train_size = int(0.6 * len(train_data))
val_size = int(0.2 * len(train_data))
test_size = len(train_data) - train_size - val_size
train_data, val_data, test_data = random_split(train_data, [train_size, val_size, test_size])

# Define the simple neural network model
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.softmax(x)
        return x

# Define the federated averaging function
def federated_avg(weights):
    avg_weights = [torch.mean(torch.stack([client_weights[layer] for client_weights in weights]), dim=0)
                   for layer in range(len(weights[0]))]
    return avg_weights

# Initialize global model
global_model = SimpleNN().to(device)

# Define the number of rounds and epochs
num_rounds = 5
num_epochs = 5  # Number of training epochs for each client
criterion = nn.CrossEntropyLoss()

# Create data loaders
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
val_loader = DataLoader(val_data, batch_size=32, shuffle=False)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

# Start federated learning
for round_num in range(num_rounds):
    print(f'Federated Learning Round {round_num + 1}')
    client_weights = []
    
    # Simulate clients' training
    for client_data in [train_loader]:  # Placeholder for multiple clients
        model = SimpleNN().to(device)
        model.load_state_dict(global_model.state_dict())
        optimizer = optim.Adam(model.parameters(), lr=0.001)
        
        model.train()
        # Train the model with client data
        for epoch in range(num_epochs):
            for x_client, y_client in client_data:
                x_client, y_client = x_client.to(device), y_client.to(device)
                optimizer.zero_grad()
                outputs = model(x_client)
                loss = criterion(outputs, y_client)
                loss.backward()
                optimizer.step()
        
        # Add the client model parameters to the client_weights list
        client_weights.append([param.data.clone() for param in model.parameters()])

    # Aggregate weights
    new_weights = federated_avg(client_weights)
    # Update the global model weights
    with torch.no_grad():
        for i, param in enumerate(global_model.parameters()):
            param.data.copy_(new_weights[i])

    # Evaluate the global model on the validation set
    global_model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for x_val, y_val in val_loader:
            x_val, y_val = x_val.to(device), y_val.to(device)
            outputs = global_model(x_val)
            _, predicted = torch.max(outputs.data, 1)
            total += y_val.size(0)
            correct += (predicted == y_val).sum().item()
    print(f'Accuracy on the validation set: {100 * correct / total}%')

# Evaluate the global model on the test set
global_model.eval()
correct = 0
total = 0

with torch.no_grad():
    for x_test, y_test in test_loader:
        x_test, y_test = x_test.to(device), y_test.to(device)
        outputs = global_model(x_test)
        _, predicted = torch.max(outputs.data, 1)
        total += y_test.size(0)
        correct += (predicted == y_test).sum().item()

print(f'Accuracy on the test set: {100 * correct / total}%')

Federated Learning Round 1
Accuracy on the validation set: 93.575%
Federated Learning Round 2
Accuracy on the validation set: 95.28333333333333%
Federated Learning Round 3
Accuracy on the validation set: 95.78333333333333%
Federated Learning Round 4
Accuracy on the validation set: 96.16666666666667%
Federated Learning Round 5
Accuracy on the validation set: 96.24166666666666%
Accuracy on the test set: 96.15%
