In [None]:
import os
import torch
from torchvision import datasets, transforms
from torch.utils.data import Subset

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Define the persistent directory in Google Drive
persistent_dir = "/content/drive/MyDrive/federated_learning/client_datasets"
os.makedirs(persistent_dir, exist_ok=True)

# Function to create extreme class imbalance for clients
def create_extreme_imbalance(dataset, num_clients=10, class_allocation=None):
    """
    Create extreme class imbalance by assigning specific classes to each client.

    Args:
        dataset: Full dataset (e.g., MNIST).
        num_clients: Number of clients.
        class_allocation: List of class lists assigned to each client.
                          Example: [[0, 1], [2, 3], ..., [8, 9]]
                          If None, classes are evenly split among clients.

    Returns:
        client_datasets: List of datasets for each client.
    """
    if class_allocation is None:
        # Divide classes equally among clients if not specified
        classes_per_client = len(dataset.classes) // num_clients
        class_allocation = [list(range(i * classes_per_client, (i + 1) * classes_per_client))
                            for i in range(num_clients)]

    # Group indices by class
    class_indices = {i: [] for i in range(len(dataset.classes))}
    for idx, (_, label) in enumerate(dataset):
        class_indices[label].append(idx)

    # Assign class-specific indices to each client
    client_indices = []
    for client_classes in class_allocation:
        client_data = []
        for _class in client_classes:
            client_data.extend(class_indices[_class])
        client_indices.append(client_data)

    # Create client-specific datasets
    client_datasets = [Subset(dataset, indices) for indices in client_indices]
    return client_datasets

# Load MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1] for compatibility with GANs
])
full_dataset = datasets.MNIST(root='data', train=True, download=True, transform=transform)

# Define class allocation for extreme class imbalance
# Example: Each client gets two unique classes
class_allocation = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]] * 2  # Total of 10 clients

# Create imbalanced datasets for clients
client_datasets = create_extreme_imbalance(full_dataset, num_clients=10, class_allocation=class_allocation)

# Verify the class distribution for each client
for i, client_dataset in enumerate(client_datasets):
    labels = [full_dataset.targets[idx] for idx in client_dataset.indices]
    print(f"Client {i+1}: Classes {set(labels)} | Total Samples: {len(labels)}")

# Save client-specific datasets to files in Google Drive
for i, client_dataset in enumerate(client_datasets):
    # Prepare data and labels
    data = [full_dataset.data[idx] for idx in client_dataset.indices]
    labels = [full_dataset.targets[idx] for idx in client_dataset.indices]

    # Save data and labels as a dictionary
    save_path = f"{persistent_dir}/client_{i+1}_dataset.pt"
    torch.save({'data': torch.stack(data), 'labels': torch.tensor(labels)}, save_path)
    print(f"Saved dataset for Client {i+1} at {save_path}")


Mounted at /content/drive
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 34.9MB/s]


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 1.74MB/s]

Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz





Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 13.8MB/s]


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 4.01MB/s]

Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw






Client 1: Classes {tensor(1), tensor(1), tensor(1), tensor(1), tensor(0), tensor(1), tensor(1), tensor(1), tensor(1), tensor(0), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(0), tensor(1), tensor(1), tensor(0), tensor(1), tensor(0), tensor(1), tensor(1), tensor(0), tensor(1), tensor(1), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(1), tensor(1), tensor(1), tensor(1), te

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import os

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Define the persistent directory in Google Drive
persistent_dir = "/content/drive/MyDrive/federated_learning"
generators_dir = f"{persistent_dir}/client_generators"
os.makedirs(generators_dir, exist_ok=True)

# Step 1: Define Generator and Discriminator models
class Generator(nn.Module):
    def __init__(self, noise_dim=100, output_dim=28*28):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(noise_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, output_dim),
            nn.Tanh()
        )

    def forward(self, z):
        return self.model(z)

class Discriminator(nn.Module):
    def __init__(self, input_dim=28*28):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

# Step 2: Function to train GAN for one client
def train_gan(client_data, client_labels, noise_dim=100, epochs=50, batch_size=32):
    # Prepare data loader
    client_data = client_data.view(client_data.size(0), -1) / 127.5 - 1.0  # Normalize to [-1, 1]
    dataset = torch.utils.data.TensorDataset(client_data, client_labels)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Initialize models
    generator = Generator(noise_dim=noise_dim).to(device)
    discriminator = Discriminator(input_dim=28*28).to(device)

    # Loss and optimizers
    criterion = nn.BCELoss()
    optimizer_g = optim.Adam(generator.parameters(), lr=0.0002)
    optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002)

    # Train GAN
    for epoch in range(epochs):
        for real_images, _ in dataloader:
            real_images = real_images.to(device)
            batch_size = real_images.size(0)

            # Create labels
            real_labels = torch.ones(batch_size, 1).to(device)
            fake_labels = torch.zeros(batch_size, 1).to(device)

            # Train Discriminator
            noise = torch.randn(batch_size, noise_dim).to(device)
            fake_images = generator(noise)
            d_loss_real = criterion(discriminator(real_images), real_labels)
            d_loss_fake = criterion(discriminator(fake_images.detach()), fake_labels)
            d_loss = d_loss_real + d_loss_fake
            optimizer_d.zero_grad()
            d_loss.backward()
            optimizer_d.step()

            # Train Generator
            noise = torch.randn(batch_size, noise_dim).to(device)
            fake_images = generator(noise)
            g_loss = criterion(discriminator(fake_images), real_labels)
            optimizer_g.zero_grad()
            g_loss.backward()
            optimizer_g.step()

        print(f"Epoch [{epoch+1}/{epochs}], Loss D: {d_loss.item():.4f}, Loss G: {g_loss.item():.4f}")

    return generator

# Step 3: Train GAN for each client and save generators
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_clients = 10  # Total number of clients
noise_dim = 100

# Iterate over each client's dataset
for client_id in range(1, num_clients + 1):
    # Load client-specific dataset
    dataset_path = f"{persistent_dir}/client_datasets/client_{client_id}_dataset.pt"
    loaded_data = torch.load(dataset_path)
    client_data, client_labels = loaded_data['data'], loaded_data['labels']

    print(f"Training GAN for Client {client_id} with {len(client_data)} samples...")

    # Train GAN for the client
    generator = train_gan(client_data, client_labels, noise_dim=noise_dim, epochs=50, batch_size=32)

    # Save the trained generator
    save_path = f"{generators_dir}/client_{client_id}_generator.pth"
    torch.save(generator.state_dict(), save_path)
    print(f"Saved generator for Client {client_id} at {save_path}")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Training GAN for Client 1 with 12665 samples...


  loaded_data = torch.load(dataset_path)


Epoch [1/50], Loss D: 0.0329, Loss G: 6.6600
Epoch [2/50], Loss D: 0.0074, Loss G: 8.7498
Epoch [3/50], Loss D: 0.0991, Loss G: 3.5064
Epoch [4/50], Loss D: 0.0387, Loss G: 5.1723
Epoch [5/50], Loss D: 0.0243, Loss G: 5.5675
Epoch [6/50], Loss D: 0.0368, Loss G: 5.1898
Epoch [7/50], Loss D: 0.0140, Loss G: 6.8468
Epoch [8/50], Loss D: 0.0179, Loss G: 8.0490
Epoch [9/50], Loss D: 0.1490, Loss G: 5.1546
Epoch [10/50], Loss D: 0.1821, Loss G: 3.5806
Epoch [11/50], Loss D: 0.0166, Loss G: 7.7702
Epoch [12/50], Loss D: 0.0309, Loss G: 6.7905
Epoch [13/50], Loss D: 0.6781, Loss G: 7.0820
Epoch [14/50], Loss D: 0.1219, Loss G: 5.3418
Epoch [15/50], Loss D: 0.4751, Loss G: 6.0305
Epoch [16/50], Loss D: 0.0600, Loss G: 7.6682
Epoch [17/50], Loss D: 0.1647, Loss G: 5.2827
Epoch [18/50], Loss D: 0.0846, Loss G: 6.9573
Epoch [19/50], Loss D: 0.2906, Loss G: 4.7508
Epoch [20/50], Loss D: 0.0974, Loss G: 5.9505
Epoch [21/50], Loss D: 0.1689, Loss G: 4.2737
Epoch [22/50], Loss D: 0.0654, Loss G: 5.75

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import os

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Define the persistent directory in Google Drive
persistent_dir = "/content/drive/MyDrive/federated_learning"
synthetic_data_dir = f"{persistent_dir}/synthetic_data"
os.makedirs(synthetic_data_dir, exist_ok=True)

# Step 1: Define the Generator model (same as used during training)
class Generator(nn.Module):
    def __init__(self, noise_dim=100, output_dim=28*28):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(noise_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, output_dim),
            nn.Tanh()
        )

    def forward(self, z):
        return self.model(z)

# Step 2: Function to generate synthetic data using a saved generator
def generate_synthetic_data(generator_path, num_samples=1000, noise_dim=100):
    # Load the trained generator
    generator = Generator(noise_dim=noise_dim).to(device)
    generator.load_state_dict(torch.load(generator_path))
    generator.eval()  # Set generator to evaluation mode

    # Generate synthetic data
    with torch.no_grad():
        noise = torch.randn(num_samples, noise_dim).to(device)
        synthetic_data = generator(noise)
    return synthetic_data.cpu()

# Step 3: Generate and save synthetic data
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_clients = 10  # Total number of clients
noise_dim = 100
num_samples_per_client = 1000  # Number of synthetic samples per client

synthetic_data_list = []
synthetic_labels_list = []

for client_id in range(1, num_clients + 1):
    generator_path = f"{persistent_dir}/client_generators/client_{client_id}_generator.pth"
    print(f"Generating synthetic data for Client {client_id} using {generator_path}...")

    # Generate synthetic data
    synthetic_data = generate_synthetic_data(generator_path, num_samples=num_samples_per_client, noise_dim=noise_dim)

    # Assign labels (optional: simulate class distribution)
    labels = torch.randint(0, 10, (num_samples_per_client,))  # Random labels for synthetic data

    synthetic_data_list.append(synthetic_data)
    synthetic_labels_list.append(labels)

# Merge synthetic data and labels from all clients
merged_synthetic_data = torch.cat(synthetic_data_list, dim=0)
merged_synthetic_labels = torch.cat(synthetic_labels_list, dim=0)
print(f"Merged synthetic dataset: {merged_synthetic_data.shape}, Labels: {merged_synthetic_labels.shape}")

# Save the merged synthetic dataset
synthetic_data_path = f"{synthetic_data_dir}/merged_synthetic_dataset.pt"
torch.save({'data': merged_synthetic_data, 'labels': merged_synthetic_labels}, synthetic_data_path)
print(f"Saved merged synthetic dataset at {synthetic_data_path}")

# Step 4: Load saved synthetic data for training
loaded_data = torch.load(synthetic_data_path)
merged_synthetic_data, merged_synthetic_labels = loaded_data['data'], loaded_data['labels']

# Step 5: Create DataLoader for merged synthetic data
synthetic_dataset = TensorDataset(merged_synthetic_data, merged_synthetic_labels)
synthetic_loader = DataLoader(synthetic_dataset, batch_size=64, shuffle=True)

# Step 6: Define and train an ML model on synthetic data
class SimpleNN(nn.Module):
    def __init__(self, input_dim=28*28, num_classes=10):
        super(SimpleNN, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        return self.fc(x)

# Initialize the ML model
ml_model = SimpleNN().to(device)
optimizer = torch.optim.SGD(ml_model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

# Train the ML model
num_epochs = 10
for epoch in range(num_epochs):
    ml_model.train()
    for data, target in synthetic_loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = ml_model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

# Step 7: Save the trained ML model
trained_model_path = f"{synthetic_data_dir}/trained_ml_model.pth"
torch.save(ml_model.state_dict(), trained_model_path)
print(f"Saved trained ML model at {trained_model_path}")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Generating synthetic data for Client 1 using /content/drive/MyDrive/federated_learning/client_generators/client_1_generator.pth...
Generating synthetic data for Client 2 using /content/drive/MyDrive/federated_learning/client_generators/client_2_generator.pth...
Generating synthetic data for Client 3 using /content/drive/MyDrive/federated_learning/client_generators/client_3_generator.pth...
Generating synthetic data for Client 4 using /content/drive/MyDrive/federated_learning/client_generators/client_4_generator.pth...

  generator.load_state_dict(torch.load(generator_path))



Generating synthetic data for Client 5 using /content/drive/MyDrive/federated_learning/client_generators/client_5_generator.pth...
Generating synthetic data for Client 6 using /content/drive/MyDrive/federated_learning/client_generators/client_6_generator.pth...
Generating synthetic data for Client 7 using /content/drive/MyDrive/federated_learning/client_generators/client_7_generator.pth...
Generating synthetic data for Client 8 using /content/drive/MyDrive/federated_learning/client_generators/client_8_generator.pth...
Generating synthetic data for Client 9 using /content/drive/MyDrive/federated_learning/client_generators/client_9_generator.pth...
Generating synthetic data for Client 10 using /content/drive/MyDrive/federated_learning/client_generators/client_10_generator.pth...
Merged synthetic dataset: torch.Size([10000, 784]), Labels: torch.Size([10000])
Saved merged synthetic dataset at /content/drive/MyDrive/federated_learning/synthetic_data/merged_synthetic_dataset.pt


  loaded_data = torch.load(synthetic_data_path)


Epoch [1/10], Loss: 2.2849
Epoch [2/10], Loss: 2.3313
Epoch [3/10], Loss: 2.3170
Epoch [4/10], Loss: 2.2888
Epoch [5/10], Loss: 2.2675
Epoch [6/10], Loss: 2.3111
Epoch [7/10], Loss: 2.2499
Epoch [8/10], Loss: 2.2727
Epoch [9/10], Loss: 2.2855
Epoch [10/10], Loss: 2.2789
Saved trained ML model at /content/drive/MyDrive/federated_learning/synthetic_data/trained_ml_model.pth


In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Define the persistent directory in Google Drive
persistent_dir = "/content/drive/MyDrive/federated_learning"
client_models_dir = f"{persistent_dir}/client_models"
os.makedirs(client_models_dir, exist_ok=True)

# Step 1: Define the model architecture (same as used for training on synthetic data)
class SimpleNN(nn.Module):
    def __init__(self, input_dim=28*28, num_classes=10):
        super(SimpleNN, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        return self.fc(x)

# Step 2: Load the trained model parameters (from the server-side training)
trained_model_path = f"{persistent_dir}/synthetic_data/trained_ml_model.pth"
global_model = SimpleNN().to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
global_model.load_state_dict(torch.load(trained_model_path))
print(f"Loaded global model from {trained_model_path}")

# Step 3: Function to train a model on client-specific data
def train_client_model(client_id, client_data, client_labels, global_model, num_epochs=5, batch_size=32, lr=0.01):
    # Prepare client-specific DataLoader
    client_data = client_data.view(client_data.size(0), -1) / 127.5 - 1.0  # Normalize to [-1, 1]
    client_dataset = torch.utils.data.TensorDataset(client_data, client_labels)
    client_loader = DataLoader(client_dataset, batch_size=batch_size, shuffle=True)

    # Initialize model with global parameters
    client_model = SimpleNN().to(device)
    client_model.load_state_dict(global_model.state_dict())  # Load global model parameters

    # Define optimizer and loss function
    optimizer = optim.SGD(client_model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    # Train the model on the client's dataset
    for epoch in range(num_epochs):
        client_model.train()
        for data, target in client_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = client_model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
        print(f"Client {client_id}: Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

    return client_model

# Step 4: Train and save models for each client
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_clients = 10  # Total number of clients
client_epochs = 5  # Number of epochs to train on client data

for client_id in range(1, num_clients + 1):
    # Load client-specific dataset
    dataset_path = f"{persistent_dir}/client_datasets/client_{client_id}_dataset.pt"
    loaded_data = torch.load(dataset_path)
    client_data, client_labels = loaded_data['data'], loaded_data['labels']

    print(f"Training model for Client {client_id} with {len(client_data)} samples...")

    # Train the client's model using the global model
    updated_client_model = train_client_model(client_id, client_data, client_labels, global_model,
                                              num_epochs=client_epochs)

    # Save the updated model for the client
    save_path = f"{client_models_dir}/client_{client_id}_model.pth"
    torch.save(updated_client_model.state_dict(), save_path)
    print(f"Saved updated model for Client {client_id} at {save_path}")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Loaded global model from /content/drive/MyDrive/federated_learning/synthetic_data/trained_ml_model.pth
Training model for Client 1 with 12665 samples...


  global_model.load_state_dict(torch.load(trained_model_path))
  loaded_data = torch.load(dataset_path)


Client 1: Epoch [1/5], Loss: 0.0032
Client 1: Epoch [2/5], Loss: 0.0119
Client 1: Epoch [3/5], Loss: 0.0161
Client 1: Epoch [4/5], Loss: 0.0045
Client 1: Epoch [5/5], Loss: 0.0036
Saved updated model for Client 1 at /content/drive/MyDrive/federated_learning/client_models/client_1_model.pth
Training model for Client 2 with 12089 samples...
Client 2: Epoch [1/5], Loss: 0.0206
Client 2: Epoch [2/5], Loss: 0.1053
Client 2: Epoch [3/5], Loss: 0.0187
Client 2: Epoch [4/5], Loss: 0.0840
Client 2: Epoch [5/5], Loss: 0.0901
Saved updated model for Client 2 at /content/drive/MyDrive/federated_learning/client_models/client_2_model.pth
Training model for Client 3 with 11263 samples...
Client 3: Epoch [1/5], Loss: 0.0592
Client 3: Epoch [2/5], Loss: 0.0238
Client 3: Epoch [3/5], Loss: 0.0056
Client 3: Epoch [4/5], Loss: 0.0079
Client 3: Epoch [5/5], Loss: 0.0105
Saved updated model for Client 3 at /content/drive/MyDrive/federated_learning/client_models/client_3_model.pth
Training model for Client 4

In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.nn as nn
import os

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Define the persistent directory in Google Drive
persistent_dir = "/content/drive/MyDrive/federated_learning"
client_models_dir = f"{persistent_dir}/client_models"
aggregated_model_path = f"{persistent_dir}/aggregated_model.pth"

# Step 1: Define the model architecture (same as used earlier)
class SimpleNN(nn.Module):
    def __init__(self, input_dim=28*28, num_classes=10):
        super(SimpleNN, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        return self.fc(x)

# Step 2: Federated Averaging function
def federated_averaging(client_models, global_model):
    """
    Perform Federated Averaging on client models to update the global model.

    Args:
        client_models: List of state_dicts of client models.
        global_model: The global model to update.

    Returns:
        None (updates the global_model in place).
    """
    global_state_dict = global_model.state_dict()

    # Initialize the global model's parameters with zeros
    for key in global_state_dict:
        global_state_dict[key] = torch.zeros_like(global_state_dict[key])

    # Sum up parameters from all client models
    for client_state_dict in client_models:
        for key in client_state_dict:
            global_state_dict[key] += client_state_dict[key]

    # Average the parameters
    num_clients = len(client_models)
    for key in global_state_dict:
        global_state_dict[key] /= num_clients

    # Update the global model
    global_model.load_state_dict(global_state_dict)

# Step 3: Aggregate client models
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
global_model = SimpleNN().to(device)  # Initialize the global model
client_model_paths = [os.path.join(client_models_dir, f) for f in os.listdir(client_models_dir) if f.endswith('.pth')]

# Load client models
client_models = []
for client_model_path in client_model_paths:
    client_model = SimpleNN().to(device)
    client_model.load_state_dict(torch.load(client_model_path))
    client_models.append(client_model.state_dict())

# Perform Federated Averaging
federated_averaging(client_models, global_model)
print("Aggregated model using Federated Averaging.")

# Step 4: Test the aggregated model on the original MNIST dataset
# Load MNIST test dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
mnist_test_dataset = datasets.MNIST(root='data', train=False, download=True, transform=transform)
test_loader = DataLoader(mnist_test_dataset, batch_size=64, shuffle=False)

# Define the test function
def test_model(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.view(data.size(0), -1).to(device), target.to(device)
            output = model(data)
            _, predicted = torch.max(output, dim=1)
            correct += (predicted == target).sum().item()
            total += target.size(0)
    accuracy = correct / total * 100
    return accuracy

# Evaluate the aggregated model
accuracy = test_model(global_model, test_loader, device)
print(f"Accuracy of the aggregated model on the MNIST test dataset: {accuracy:.2f}%")

# Step 5: Save the aggregated model persistently
torch.save(global_model.state_dict(), aggregated_model_path)
print(f"Saved aggregated model at {aggregated_model_path}")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Aggregated model using Federated Averaging.


  client_model.load_state_dict(torch.load(client_model_path))


Accuracy of the aggregated model on the MNIST test dataset: 38.46%
Saved aggregated model at /content/drive/MyDrive/federated_learning/aggregated_model.pth


In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.nn as nn
import os

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Define the persistent directory in Google Drive
persistent_dir = "/content/drive/MyDrive/federated_learning"
client_datasets_dir = f"{persistent_dir}/client_datasets"
aggregated_model_path = f"{persistent_dir}/aggregated_model.pth"

# Ensure the persistent directory exists
os.makedirs(persistent_dir, exist_ok=True)

# Step 1: Define the model architecture
class SimpleNN(nn.Module):
    def __init__(self, input_dim=28*28, num_classes=10):
        super(SimpleNN, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        return self.fc(x)

# Step 2: Federated Averaging function
def federated_averaging(client_models, global_model):
    """
    Perform Federated Averaging to aggregate client models into the global model.
    """
    global_state_dict = global_model.state_dict()
    for key in global_state_dict:
        global_state_dict[key] = torch.zeros_like(global_state_dict[key])

    for client_state_dict in client_models:
        for key in client_state_dict:
            global_state_dict[key] += client_state_dict[key]

    num_clients = len(client_models)
    for key in global_state_dict:
        global_state_dict[key] /= num_clients

    global_model.load_state_dict(global_state_dict)

# Step 3: Train client model function
def train_client_model(client_id, client_data, client_labels, global_model, num_epochs=5, batch_size=32, lr=0.01):
    """
    Train a local client model using its dataset and the current global model parameters.
    """
    client_data = client_data.view(client_data.size(0), -1) / 127.5 - 1.0
    client_dataset = torch.utils.data.TensorDataset(client_data, client_labels)
    client_loader = DataLoader(client_dataset, batch_size=batch_size, shuffle=True)

    client_model = SimpleNN().to(device)
    client_model.load_state_dict(global_model.state_dict())

    optimizer = torch.optim.SGD(client_model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        client_model.train()
        for data, target in client_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = client_model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

    return client_model

# Step 4: Test the aggregated model
def test_model(model, test_loader, device):
    """
    Evaluate the global model on the MNIST test dataset.
    """
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.view(data.size(0), -1).to(device), target.to(device)
            output = model(data)
            _, predicted = torch.max(output, dim=1)
            correct += (predicted == target).sum().item()
            total += target.size(0)
    accuracy = correct / total * 100
    return accuracy

# Step 5: Perform iterative federated learning
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load the initial global model
global_model = SimpleNN().to(device)
global_model.load_state_dict(torch.load(aggregated_model_path))
print(f"Loaded initial global model from {aggregated_model_path}")

client_epochs = 5
num_clients = 10
iteration = 0
accuracy = 0
max_iterations = 150  # Set maximum iterations

# Prepare MNIST test dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
mnist_test_dataset = datasets.MNIST(root='data', train=False, download=True, transform=transform)
test_loader = DataLoader(mnist_test_dataset, batch_size=64, shuffle=False)

# Iterate until accuracy reaches 86% or iteration count reaches 150
while accuracy < 86 and iteration < max_iterations:
    iteration += 1
    print(f"\n--- Iteration {iteration} ---")

    client_models = []
    for client_id in range(1, num_clients + 1):
        dataset_path = f"{client_datasets_dir}/client_{client_id}_dataset.pt"
        loaded_data = torch.load(dataset_path)
        client_data, client_labels = loaded_data['data'], loaded_data['labels']

        print(f"Training model for Client {client_id}...")
        updated_client_model = train_client_model(client_id, client_data, client_labels, global_model, num_epochs=client_epochs)
        client_models.append(updated_client_model.state_dict())

    # Perform Federated Averaging
    federated_averaging(client_models, global_model)
    print(f"Aggregated global model at iteration {iteration}.")

    # Evaluate the aggregated model
    accuracy = test_model(global_model, test_loader, device)
    print(f"Accuracy of the aggregated model on the MNIST test dataset: {accuracy:.2f}%")

    # Save the aggregated model every 10 iterations
    if iteration % 10 == 0:
        interval_model_path = f"{persistent_dir}/aggregated_model_iter{iteration}.pth"
        torch.save(global_model.state_dict(), interval_model_path)
        print(f"Saved aggregated model at iteration {iteration} to {interval_model_path}")

# Save the final aggregated model
final_model_path = f"{persistent_dir}/final_aggregated_model.pth"
torch.save(global_model.state_dict(), final_model_path)
print(f"Saved final aggregated model at {final_model_path}")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


  global_model.load_state_dict(torch.load(aggregated_model_path))


Loaded initial global model from /content/drive/MyDrive/federated_learning/aggregated_model.pth
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:01<00:00, 5.24MB/s]


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 155kB/s]


Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:01<00:00, 1.46MB/s]


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 2.90MB/s]
  loaded_data = torch.load(dataset_path)


Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw


--- Iteration 1 ---
Training model for Client 1...
Training model for Client 2...
Training model for Client 3...
Training model for Client 4...
Training model for Client 5...
Training model for Client 6...
Training model for Client 7...
Training model for Client 8...
Training model for Client 9...
Training model for Client 10...
Aggregated global model at iteration 1.
Accuracy of the aggregated model on the MNIST test dataset: 60.35%

--- Iteration 2 ---
Training model for Client 1...
Training model for Client 2...
Training model for Client 3...
Training model for Client 4...
Training model for Client 5...
Training model for Client 6...
Training model for Client 7...
Training model for Client 8...
Training model for Client 9...
Training model for Client 10...
Aggregated global model at iteration 2.
Accuracy of the aggregated model on the MNIST test dataset: 65.00%

--- Iteration 3 ---
Training model for Client 1...