In [8]:
# import all needed packages
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, TensorDataset
import numpy as np
import sys
import os
import random

# Get the absolute path of the src directory
src_path = os.path.abspath('../src')

# Add src_path to sys.path
if src_path not in sys.path:
    sys.path.append(src_path)

    
import fl
from fl import client_selection_with_constraints, select_indices
from models import QuantStubModel,calculate_model_size
from quantization import *

# Hyperparameters

In [9]:
# Hyperparameters
hp = {
    "run_id": "test",
    "run_number": "qunatized_mqat_32168_noniid",
    "learning_rate": 1e-1,
    "batch_size": 32,
    "num_clients": 20,
    "num_rounds": 10,
    "num_classes": 10,  # 10 for MNIST
    "classes_per_client": 1,
    "epochs": 3,    # number of epochs to train in each round
    "split": "NONIID",   # ["RANDOM", "NONIID"]
    "quantize": False,
    "averaging_setting": "standard",    # ["standard", "scalar", "kure", "mqat"]
    "lambda_kure": 1e-4,
    "delta": 1e-2,
    "bit_widths": [32],
    "shared_fraction": 0.05,
}

In [None]:
shared_fraction = hp['shared_fraction']


## Obtain Dataset

In [None]:
from torchvision import datasets, transforms

# Define transformations for the dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalize to range [-1, 1]
])

# Download and load MNIST dataset
mnist_full = datasets.MNIST(root="../data", train=True, transform=transform, download=True)
mnist_test = datasets.MNIST(root="../data", train=False, transform=transform, download=True)


# Split training data into training and validation sets
train_size = int(0.8 * len(mnist_full))  # 80% for training
val_size = len(mnist_full) - train_size  # 20% for validation
train_dataset, val_dataset = random_split(mnist_full, [train_size, val_size])

shared_size = int(len(train_dataset) * shared_fraction)
remaining_size = len(train_dataset) - shared_size
train_shared_dataset, train_dataset = random_split(train_dataset, [shared_size, remaining_size])

# Extract indices from Subset objects
train_indices = train_dataset.indices  # List of training indices
train_shared_indices = train_shared_dataset.indices  # List of training indices
val_indices = val_dataset.indices      # List of validation indices

# Create training and validation MNIST datasets
mnist_train = datasets.MNIST(root="../data", train=True, transform=transform, download=False)
mnist_train_shared = datasets.MNIST(root="../data", train=True, transform=transform, download=False)
mnist_val = datasets.MNIST(root="../data", train=True, transform=transform, download=False)

# Filter datasets by indices
mnist_train.data = mnist_train.data[torch.tensor(train_indices)]
mnist_train.targets = mnist_train.targets[torch.tensor(train_indices)]
mnist_train_shared.data = mnist_train_shared.data[torch.tensor(train_shared_indices)]
mnist_train_shared.targets = mnist_train_shared.targets[torch.tensor(train_shared_indices)]
mnist_val.data = mnist_val.data[torch.tensor(val_indices)]
mnist_val.targets = mnist_val.targets[torch.tensor(val_indices)]

# Print dataset sizes
print(f"Train dataset size: {len(mnist_train)}, Train Shared dataset size: {len(mnist_train_shared)}, Validation dataset size: {len(mnist_val)}, Test dataset size: {len(mnist_test)}")

Train dataset size: 45600, Train Shared dataset size: 2400, Validation dataset size: 12000, Test dataset size: 10000


In [11]:
"""
sys.path.append(os.path.abspath('../datasets/non-iid-dataset-for-personalized-federated-learning'))

from dataset.mnist_noniid import get_dataset_mnist_extr_noniid
from collections import Counter

num_users_mnist = 10
nclass_mnist = 1
nsamples_mnist = 10
rate_unbalance_mnist = 1.0

train_dataset_mnist, test_dataset_mnist, user_groups_train_mnist, user_groups_test_mnist = get_dataset_mnist_extr_noniid(num_users_mnist, nclass_mnist, nsamples_mnist, rate_unbalance_mnist)
print(len(user_groups_test_mnist[0]))

user_groups_train_mnist = {
    key: [int(idx) for idx in indices]
    for key, indices in user_groups_train_mnist.items()
}


client_datasets = [
    Subset(train_dataset_mnist, indices) for indices in user_groups_train_mnist.values()
]

for i, client_dataset in enumerate(client_datasets):
    client_loader = torch.utils.data.DataLoader(client_dataset, batch_size=len(client_dataset))
    client_samples, client_labels = next(iter(client_loader))
    label_counts = Counter(client_labels.tolist())
    print(f"Client {i} label distribution: {dict(label_counts)}")
"""

'\nsys.path.append(os.path.abspath(\'../datasets/non-iid-dataset-for-personalized-federated-learning\'))\n\nfrom dataset.mnist_noniid import get_dataset_mnist_extr_noniid\nfrom collections import Counter\n\nnum_users_mnist = 10\nnclass_mnist = 1\nnsamples_mnist = 10\nrate_unbalance_mnist = 1.0\n\ntrain_dataset_mnist, test_dataset_mnist, user_groups_train_mnist, user_groups_test_mnist = get_dataset_mnist_extr_noniid(num_users_mnist, nclass_mnist, nsamples_mnist, rate_unbalance_mnist)\nprint(len(user_groups_test_mnist[0]))\n\nuser_groups_train_mnist = {\n    key: [int(idx) for idx in indices]\n    for key, indices in user_groups_train_mnist.items()\n}\n\n\nclient_datasets = [\n    Subset(train_dataset_mnist, indices) for indices in user_groups_train_mnist.values()\n]\n\nfor i, client_dataset in enumerate(client_datasets):\n    client_loader = torch.utils.data.DataLoader(client_dataset, batch_size=len(client_dataset))\n    client_samples, client_labels = next(iter(client_loader))\n    lab

# Split Dataset: Non-IID / Random

In [12]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split, Subset
import numpy as np
from torch.utils.data import random_split
from collections import Counter

# Number of clients and non-IID split
num_clients = hp["num_clients"]
batch_size = hp["batch_size"]
classes_per_client = hp["classes_per_client"]
num_classes = hp["num_classes"]

train_data_size = len(mnist_train) // num_clients
val_data_size = len(mnist_val) // num_clients

if hp["split"] == "NONIID":
    
    # Create indices for each class
    train_class_indices = {i: np.where(np.array(mnist_train.targets) == i)[0] for i in range(num_classes)}

    train_indices = []
    for client_id in range(num_clients):
        chosen_classes = np.random.choice(num_classes, classes_per_client, replace=False)
        train_client_idx = []
        val_client_idx = []
        for cls in chosen_classes:
            train_cls_size = max(1, len(train_class_indices[cls]) // (num_clients // classes_per_client))
            
            # Adjust for insufficient samples
            train_cls_size = min(train_cls_size, len(train_class_indices[cls]))
            
            train_cls_idx = np.random.choice(train_class_indices[cls], train_cls_size, replace=False)
            train_client_idx.extend(train_cls_idx)
                        
            # Remove assigned indices to avoid duplication
            train_class_indices[cls] = np.setdiff1d(train_class_indices[cls], train_cls_idx)

        train_indices.append(train_client_idx)

    # Create datasets and DataLoaders for each client
    train_dataset = [Subset(mnist_train, indices) for indices in train_indices]
    train_loader = [DataLoader(ds, batch_size=batch_size, shuffle=True) for ds in train_dataset]

else:
    
    # Split the training data into smaller datasets for each client
    train_dataset = random_split(mnist_train, [train_data_size] * num_clients)
    train_loader = [DataLoader(ds, batch_size=batch_size, shuffle=True) for ds in train_dataset]    

# Test DataLoader for evaluation
val_loader = DataLoader(mnist_val, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=False)

print(f"Simulated {num_clients} clients, each with {train_data_size} training samples, and {len(mnist_val)} validation samples")
# Debugging: Output distribution of classes for all clients
for i, client_dataset in enumerate(train_dataset):
    client_loader = torch.utils.data.DataLoader(client_dataset, batch_size=len(client_dataset))
    client_samples, client_labels = next(iter(client_loader))
    label_counts = Counter(client_labels.tolist())
    print(f"Client {i} label distribution: {dict(label_counts)}")

Simulated 20 clients, each with 2280 training samples, and 12000 validation samples
Client 0 label distribution: {5: 205}
Client 1 label distribution: {6: 224}
Client 2 label distribution: {1: 259}
Client 3 label distribution: {5: 195}
Client 4 label distribution: {4: 220}
Client 5 label distribution: {7: 236}
Client 6 label distribution: {6: 213}
Client 7 label distribution: {7: 224}
Client 8 label distribution: {6: 202}
Client 9 label distribution: {6: 192}
Client 10 label distribution: {9: 228}
Client 11 label distribution: {8: 220}
Client 12 label distribution: {6: 182}
Client 13 label distribution: {8: 209}
Client 14 label distribution: {4: 209}
Client 15 label distribution: {3: 233}
Client 16 label distribution: {3: 221}
Client 17 label distribution: {3: 210}
Client 18 label distribution: {4: 199}
Client 19 label distribution: {3: 200}


# Optionally distribute the shared data among clients

In [None]:
alpha = 0.05    # Each client gets partition of the shared data

shared_loader = torch.utils.data.DataLoader(mnist_train_shared, batch_size=len(mnist_train_shared))
shared_samples, shared_labels = next(iter(shared_loader))
num_shared_samples = int(alpha * len(shared_samples))

merged_train_datasets = []
for client_dataset in train_dataset:
    client_loader = torch.utils.data.DataLoader(client_dataset, batch_size=len(client_dataset))
    client_samples, client_labels = next(iter(client_loader))
    
    # Combine shared and local data
    combined_samples = torch.cat([client_samples, shared_samples[:num_shared_samples]])
    combined_labels = torch.cat([client_labels, shared_labels[:num_shared_samples]])
    merged_train_dataset = torch.utils.data.TensorDataset(combined_samples, combined_labels)
    merged_train_datasets.append(merged_train_dataset)

In [None]:
from collections import Counter

# Display the label distribution for each client
for i, client_dataset in enumerate(merged_train_datasets):
    client_loader = torch.utils.data.DataLoader(client_dataset, batch_size=len(client_dataset))
    client_samples, client_labels = next(iter(client_loader))
    label_counts = Counter(client_labels.tolist())
    print(f"Client {i} label distribution: {dict(label_counts)}")

Client 0 label distribution: {3: 943, 1: 1050, 8: 12, 9: 12, 7: 10, 5: 8, 2: 16, 6: 10, 0: 7, 4: 12}
Client 1 label distribution: {0: 909, 9: 928, 1: 19, 8: 12, 3: 14, 7: 10, 5: 8, 2: 16, 6: 10, 4: 12}
Client 2 label distribution: {1: 843, 9: 745, 8: 12, 3: 14, 7: 10, 5: 8, 2: 16, 6: 10, 0: 7, 4: 12}
Client 3 label distribution: {1: 679, 8: 893, 3: 14, 9: 12, 7: 10, 5: 8, 2: 16, 6: 10, 0: 7, 4: 12}
Client 4 label distribution: {2: 922, 0: 729, 1: 19, 8: 12, 3: 14, 9: 12, 7: 10, 5: 8, 6: 10, 4: 12}
Client 5 label distribution: {3: 757, 8: 717, 1: 19, 9: 12, 7: 10, 5: 8, 2: 16, 6: 10, 0: 7, 4: 12}
Client 6 label distribution: {5: 830, 4: 897, 1: 19, 8: 12, 3: 14, 9: 12, 7: 10, 2: 16, 6: 10, 0: 7}
Client 7 label distribution: {3: 609, 5: 666, 1: 19, 8: 12, 9: 12, 7: 10, 2: 16, 6: 10, 0: 7, 4: 12}
Client 8 label distribution: {5: 534, 6: 909, 1: 19, 8: 12, 3: 14, 9: 12, 7: 10, 2: 16, 0: 7, 4: 12}
Client 9 label distribution: {2: 741, 9: 599, 1: 19, 8: 12, 3: 14, 7: 10, 5: 8, 6: 10, 0: 7, 4

In [None]:
train_loader = [DataLoader(ds, batch_size=batch_size, shuffle=True) for ds in merged_train_datasets]

## Set up and initialize the Global Model

In [None]:
# 1. Initialization: Instantiate the global model (server)
global_model = QuantStubModel(q=hp['quantize'])
if(hp["quantize"]):
    global_model.qconfig = torch.quantization.get_default_qat_qconfig("fbgemm")
    torch.quantization.prepare_qat(global_model, inplace=True)

model_size_bytes, model_size_mb = calculate_model_size(global_model)
print(f"Model size averaged_state: {model_size_bytes} bytes ({model_size_mb:.2f} MB)")

Model size averaged_state: 407080 bytes (0.39 MB)


## Client training loop

In [None]:
# set up a training loop that is run on a client for a number of epochs
def train_model(model, train_loader, hp, round_num, epochs=1, q=False, lambda_kure=0.0):

    lr = hp["learning_rate"]
    delta = hp["delta"]
    bit_widths = hp["bit_widths"]
    
    # 3. Distribution: Create a copy of the global model
    local_model = QuantStubModel(q=q)
    if(hp["quantize"]):
        local_model.qconfig = torch.quantization.get_default_qat_qconfig("fbgemm")
        torch.quantization.prepare_qat(local_model, inplace=True)
    local_model.load_state_dict(model.state_dict())

    # Define the loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(local_model.parameters(), lr=hp["learning_rate"])
    
    # Training loop
    local_model.train()
    total_loss = 0  # Initialize total loss
    num_batches = 0  # Initialize batch counter

    for epoch in range(epochs):
        for inputs, labels in train_loader:
            optimizer.zero_grad()

            if hp["averaging_setting"] == 'mqat':
                # Apply Pseudo-Quantization Noise (APQN)
                if delta is not None:
                    for param in model.parameters():
                        param.data = add_pseudo_quantization_noise(param, delta)

                # Apply Multi-Bit Quantization (MQAT)
                if bit_widths is not None:
                    bit_width = random.choice(bit_widths)
                    for param in model.parameters():
                        param.data = quantize_multi_bit(param, bit_width)

            outputs = local_model(inputs)
            loss = criterion(outputs, labels)

            # Kurtosis Regularization
            if hp["averaging_setting"] == 'kure' and lambda_kure > 0:
                for param in local_model.parameters():
                    loss += lambda_kure * kurtosis_regularization(param)

            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            num_batches += 1
    
    # Calculate average training loss
    avg_loss = total_loss / num_batches if num_batches > 0 else 0
    return local_model.state_dict(), avg_loss  # Return updated model parameters and average loss


# Validation function
def validate_model(model, val_loader, round_num):
    model.eval()
    criterion = nn.CrossEntropyLoss()  # Define the loss function
    total_loss = 0
    num_batches = 0

    with torch.no_grad():
        for inputs, labels in val_loader:
            outputs = model(inputs)
            loss = criterion(outputs, labels) #F.cross_entropy(outputs, labels, reduction='sum')  # Compute loss for the batch
            total_loss += loss.item()
            num_batches += 1
    
    avg_loss = total_loss / num_batches if num_batches > 0 else 0
    return avg_loss


# Client Selection

In [None]:
# Example Resource Info (to simulate resource heterogeneity)
client_resources = [{"comp_capacity": random.randint(10, 100), "data_size": random.randint(1, 10)} for _ in range(num_clients)]

## Run Federated Learning Round

In [None]:
from datetime import datetime
import time
import tensorflow as tf
from torch.utils.tensorboard import SummaryWriter


# Generate a unique log directory based on the current time
run_number = hp["run_number"]
log_dir = f"./logs/{hp['run_id']}/run_{run_number}-{datetime.now().strftime('%m-%d-%H-%M')}"
writer = SummaryWriter(log_dir=log_dir)

# Log hyperparameters as text
hyperparams_text = "\n".join([f"{key}: {value}" for key, value in hp.items()])

# Start time measurement
start_time = time.time()

# Federated Learning with FedCS Client Selection
num_rounds = hp["num_rounds"]
epochs = hp["epochs"]
round_deadline = 3  # Example round deadline (in arbitrary time units)

assert num_clients == len(train_loader)
print("GPU available:", torch.cuda.is_available())

# Conduct federated learning rounds
for round_num in range(num_rounds):
    print(f"Round {round_num + 1}")
    
    # 2. Resource Request
    k = random.randint(1, num_clients)
    resource_requested_clients = select_indices(num_clients, k)

    # 2. Client Selection: Collect client updates
    # TODO: Client ressources
    selected_train_clients = client_selection_with_constraints([client_resources[i] for i in resource_requested_clients], round_deadline)
    filtered_train_loaders = [train_loader[i] for i in selected_train_clients]
    #filtered_val_loaders = [val_loader[i] for i in selected_train_clients]
    print(f"-> Resource requested from {len(resource_requested_clients)} clients, {len(selected_train_clients)} clients fulfilled the criteria")

    client_states = []
    round_train_loss = 0  # Initialize round loss
    round_val_loss = 0  # Initialize round loss
    num_batches = 0  # Initialize batch counter

    #for client_train_loader, client_val_loader in zip(filtered_train_loaders, filtered_val_loaders):
    for client_train_loader in filtered_train_loaders:

        # 4. Distribution
        client_state, client_loss = train_model(global_model, client_train_loader, hp, round_num, epochs=epochs, q=hp["quantize"], lambda_kure=hp["lambda_kure"])
        round_train_loss += client_loss
        num_batches += 1

        # 5. Update and Upload
        client_states.append(client_state)

    # Average loss for the round
    avg_round_training_loss = round_train_loss / num_batches if num_batches > 0 else 0
    writer.add_scalar("Metrics/Training loss", avg_round_training_loss, round_num + 1)
    print(f"Training loss after round {round_num + 1}: {avg_round_training_loss}")

    # 6. Aggregation: Aggregate updates using Federated Averaging
    new_global_state = fl.federated_averaging(global_model, client_states, setting=hp["averaging_setting"])
    global_model.load_state_dict(new_global_state)

    #val_loss = validate_model(global_model, client_val_loader, round_num)
    val_loss = validate_model(global_model, val_loader, round_num)

    writer.add_scalar("Metrics/Validation loss", val_loss, round_num + 1)
    print(f"Validation loss after round {round_num + 1}: {val_loss}")
    print(f"Global model updated for round {round_num + 1}")

if hp["quantize"]:
    torch.quantization.convert(global_model, inplace=True)

# End time measurement
end_time = time.time()

# Print total execution time
print(f"Total time taken: {end_time - start_time:.2f} seconds")

GPU available: True
Round 1
-> Resource requested from 7 clients, 7 clients fulfilled the criteria
Training loss after round 1: 0.7639188800205934
Validation loss after round 1: 2.1827711912790932
Global model updated for round 1
Round 2
-> Resource requested from 8 clients, 8 clients fulfilled the criteria
Training loss after round 2: 0.659407128295625
Validation loss after round 2: 1.9779454612731933
Global model updated for round 2
Round 3
-> Resource requested from 7 clients, 7 clients fulfilled the criteria
Training loss after round 3: 0.5069252590860052
Validation loss after round 3: 1.5539732087453206
Global model updated for round 3
Round 4
-> Resource requested from 9 clients, 9 clients fulfilled the criteria
Training loss after round 4: 0.3705452310387804
Validation loss after round 4: 0.8613528652985891
Global model updated for round 4
Round 5
-> Resource requested from 1 clients, 1 clients fulfilled the criteria
Training loss after round 5: 0.16300005562698994
Validation lo

## Evaluate Model

In [None]:
def evaluate_model(model, dataloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    test_accuracy = correct / total
    print(f"Global model accuracy: {test_accuracy:.2%}")
    return test_accuracy

# Evaluate the model on the test dataset
test_accuracy = evaluate_model(global_model, test_loader)

# End TensorBoard writer
final_metrics = {}
hp_cleaned = {k: v for k, v in hp.items() if isinstance(v, (int, float, str, bool, torch.Tensor))}
writer.add_hparams(hp_cleaned, final_metrics)
writer.close()

Global model accuracy: 73.32%


# Save the model

In [None]:
# Save the final model
model_save_path = f"./saved_models/{hp['run_id']}/global_model_{hp['run_id']}_{run_number}.pth"
torch.save(global_model.state_dict(), model_save_path)
print(f"Global model saved at {model_save_path}")
#model_size_bytes, model_size_mb = calculate_model_size(global_model)
#print(f"Model size: {model_size_bytes} bytes ({model_size_mb:.2f} MB)")

Global model saved at ./saved_models/test/global_model_test_qunatized_mqat_32168_noniid.pth
