In [70]:
import torch, sys, os, random, time
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, TensorDataset, Subset
import numpy as np
from datetime import datetime

# Append the path
src_path = os.path.abspath('../src')
if src_path not in sys.path:
    sys.path.append(src_path)
    
import fl
from models import *
from quantization import *
from client import Client, ClientResources

# Hyperparameters

In [None]:
# Hyperparameters
hp = {
    "run_id": "test",
    "run_number": "noname",
    "learning_rate": 0.01,
    "batch_size": 32,
    "num_clients": 20,
    "num_rounds": 10,
    "num_classes": 10,  # 10 for MNIST
    "classes_per_client": 4,
    "epochs": 1,    # number of epochs to train in each round
    "split": "NONIID",   # ["RANDOM", "NONIID"]
    "quantize": False,
    "setup": "standard",    # ["standard", "scalar", "kure", "mqat"]
    "lambda_kure": 1e-4,
    "delta": 1e-2,
    "shared_fraction": 0.00,
}

# Load the Tensorboard notebook

In [72]:

import tensorflow as tf
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter(log_dir=".")
"""%load_ext tensorboard
%tensorboard --logdir=."""

'%load_ext tensorboard\n%tensorboard --logdir=.'

In [None]:
shared_fraction = hp['shared_fraction']
num_clients = hp["num_clients"]
batch_size = hp["batch_size"]
classes_per_client = hp["classes_per_client"]
num_classes = hp["num_classes"]
lr = hp["learning_rate"]
epochs = hp["epochs"]
split = hp["split"]
lambda_kure = hp["lambda_kure"]
delta = hp["delta"]
setup = hp["setup"]

## Obtain Dataset

In [74]:
from torchvision import datasets, transforms

# Define transformations for the dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalize to range [-1, 1]
])

# Download and load MNIST dataset
mnist_full = datasets.MNIST(root="../data", train=True, transform=transform, download=True)
mnist_test = datasets.MNIST(root="../data", train=False, transform=transform, download=True)

# Split training data into training and validation sets
train_size = int(0.8 * len(mnist_full))  # 80% for training
val_size = len(mnist_full) - train_size  # 20% for validation
train_dataset, val_dataset = random_split(mnist_full, [train_size, val_size])

shared_size = int(len(train_dataset) * shared_fraction)
remaining_size = len(train_dataset) - shared_size
train_shared_dataset, train_dataset = random_split(train_dataset, [shared_size, remaining_size])

# Extract indices from Subset objects
train_indices = train_dataset.indices  # List of training indices

train_shared_indices = train_shared_dataset.indices if shared_fraction > 0.0 else []  # List of training indices
val_indices = val_dataset.indices      # List of validation indices

# Create training and validation MNIST datasets
mnist_train = datasets.MNIST(root="../data", train=True, transform=transform, download=False)
mnist_train_shared = datasets.MNIST(root="../data", train=True, transform=transform, download=False)
mnist_val = datasets.MNIST(root="../data", train=True, transform=transform, download=False)

# Filter datasets by indices
mnist_train.data = mnist_train.data[torch.tensor(train_indices)]
mnist_train.targets = mnist_train.targets[torch.tensor(train_indices)]

if shared_fraction > 0.0:
    mnist_train_shared.data = mnist_train_shared.data[torch.tensor(train_shared_indices)]
    mnist_train_shared.targets = mnist_train_shared.targets[torch.tensor(train_shared_indices)]
else:
    mnist_train_shared = []

mnist_val.data = mnist_val.data[torch.tensor(val_indices)]
mnist_val.targets = mnist_val.targets[torch.tensor(val_indices)]

# Print dataset sizes
print(f"Train dataset size: {len(mnist_train)}, Train Shared dataset size: {len(mnist_train_shared)}, Validation dataset size: {len(mnist_val)}, Test dataset size: {len(mnist_test)}")

Train dataset size: 48000, Train Shared dataset size: 0, Validation dataset size: 12000, Test dataset size: 10000


# Split Dataset: non-IID / Random

### Create client train and validation datasets

In [76]:
from collections import Counter

# Number of clients and non-IID split
train_data_size = len(mnist_train) // num_clients
val_data_size = len(mnist_val) // num_clients

if split == "NONIID":
    
    # Create indices for each class
    train_class_indices = {i: np.where(np.array(mnist_train.targets) == i)[0] for i in range(num_classes)}
    val_class_indices = {i: np.where(np.array(mnist_val.targets) == i)[0] for i in range(num_classes)}

    train_indices = []  # Initialize training indices
    val_indices = []    # Initialize validation indices

    # Randomize the order of classes for validation
    shuffled_classes = np.random.permutation(num_classes)

    for client_id in range(num_clients):
        # Choose random classes for training
        chosen_train_classes = np.random.choice(num_classes, classes_per_client, replace=False)
        train_client_idx = []

        # Assign validation classes independently
        chosen_val_classes = np.random.choice(shuffled_classes, classes_per_client, replace=False)
        val_client_idx = []

        for cls in chosen_train_classes:
            train_cls_size = len(train_class_indices[cls]) // (num_clients // classes_per_client)
            train_cls_idx = np.random.choice(train_class_indices[cls], train_cls_size, replace=False)
            train_client_idx.extend(train_cls_idx)
            train_class_indices[cls] = np.setdiff1d(train_class_indices[cls], train_cls_idx)  # Avoid duplication

        for cls in chosen_val_classes:
            val_cls_size = len(val_class_indices[cls]) // (num_clients // classes_per_client)
            val_cls_idx = np.random.choice(val_class_indices[cls], val_cls_size, replace=False)
            val_client_idx.extend(val_cls_idx)
            val_class_indices[cls] = np.setdiff1d(val_class_indices[cls], val_cls_idx)  # Avoid duplication

        train_indices.append(train_client_idx)
        val_indices.append(val_client_idx)

    # Create datasets and DataLoaders for each client
    train_datasets = [Subset(mnist_train, indices) for indices in train_indices]
    train_loaders = [DataLoader(ds, batch_size=batch_size, shuffle=True) for ds in train_datasets]

    val_datasets = [Subset(mnist_val, indices) for indices in val_indices]
    val_loaders = [DataLoader(ds, batch_size=batch_size, shuffle=True) for ds in val_datasets]


else:
    
    # Split the training data into smaller datasets for each client
    train_datasets = random_split(mnist_train, [train_data_size] * num_clients)
    train_loaders = [DataLoader(ds, batch_size=batch_size, shuffle=True) for ds in train_datasets]
    val_datasets = random_split(mnist_train, [train_data_size] * num_clients)
    val_loaders = [DataLoader(ds, batch_size=batch_size, shuffle=True) for ds in train_datasets]

# val_loaders = DataLoader(mnist_val, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=False)


print(f"Simulated {num_clients} clients, each with {train_data_size} training samples, and {len(mnist_val)} validation samples")
# Debugging: Output distribution of classes for all clients
print("Training Set Label Distribution:")
for i, client_dataset in enumerate(train_datasets):
    client_loader = torch.utils.data.DataLoader(client_dataset, batch_size=len(client_dataset))
    client_samples, client_labels = next(iter(client_loader))
    label_counts = Counter(client_labels.tolist())
    print(f"Client {i} label distribution: {dict(label_counts)}")

print("\nValidation Set Label Distribution:")
for i, client_dataset in enumerate(val_datasets):
    client_loader = torch.utils.data.DataLoader(client_dataset, batch_size=len(client_dataset))
    client_samples, client_labels = next(iter(client_loader))
    label_counts = Counter(client_labels.tolist())
    print(f"Client {i} label distribution: {dict(label_counts)}")

Simulated 20 clients, each with 2400 training samples, and 12000 validation samples
Training Set Label Distribution:
Client 0 label distribution: {0: 1576, 5: 1441, 7: 1656, 2: 1587, 4: 1551, 9: 1600}
Client 1 label distribution: {8: 1551, 9: 1067, 7: 1104, 6: 1582, 2: 1058, 4: 1034}
Client 2 label distribution: {4: 690, 5: 961, 2: 705, 6: 1055, 0: 1051, 3: 1629}
Client 3 label distribution: {5: 640, 9: 711, 0: 700, 6: 703, 3: 1086, 7: 736}
Client 4 label distribution: {4: 460, 1: 1823, 3: 724, 6: 469, 9: 474, 0: 467}
Client 5 label distribution: {4: 306, 0: 311, 8: 1034, 7: 490, 6: 313, 9: 316}
Client 6 label distribution: {0: 208, 4: 204, 7: 327, 5: 427, 1: 1215, 3: 483}
Client 7 label distribution: {5: 285, 4: 136, 3: 322, 1: 810, 0: 138, 8: 689}
Client 8 label distribution: {7: 218, 1: 540, 6: 208, 9: 211, 0: 92, 4: 91}
Client 9 label distribution: {5: 190, 0: 62, 2: 470, 6: 139, 9: 141, 7: 145}
Client 10 label distribution: {1: 360, 2: 314, 7: 97, 3: 215, 5: 126, 4: 61}
Client 11 

# Optionally distribute the shared data among clients

In [77]:
if shared_fraction > 0.0:
    alpha = 0.05    # Each client gets partition of the shared data

    shared_loader = torch.utils.data.DataLoader(mnist_train_shared, batch_size=len(mnist_train_shared))
    shared_samples, shared_labels = next(iter(shared_loader))
    num_shared_samples = int(alpha * len(shared_samples))

    merged_train_datasets = []
    for client_dataset in train_datasets:
        client_loader = torch.utils.data.DataLoader(client_dataset, batch_size=len(client_dataset))
        client_samples, client_labels = next(iter(client_loader))
        
        # Combine shared and local data
        combined_samples = torch.cat([client_samples, shared_samples[:num_shared_samples]])
        combined_labels = torch.cat([client_labels, shared_labels[:num_shared_samples]])
        merged_train_dataset = torch.utils.data.TensorDataset(combined_samples, combined_labels)
        merged_train_datasets.append(merged_train_dataset)

    train_loaders = [DataLoader(ds, batch_size=batch_size, shuffle=True) for ds in merged_train_datasets]

"""
from collections import Counter

# Display the label distribution for each client
for i, client_dataset in enumerate(merged_train_datasets):
    client_loader = torch.utils.data.DataLoader(client_dataset, batch_size=len(client_dataset))
    client_samples, client_labels = next(iter(client_loader))
    label_counts = Counter(client_labels.tolist())
    print(f"Client {i} label distribution: {dict(label_counts)}")
"""


'\nfrom collections import Counter\n\n# Display the label distribution for each client\nfor i, client_dataset in enumerate(merged_train_datasets):\n    client_loader = torch.utils.data.DataLoader(client_dataset, batch_size=len(client_dataset))\n    client_samples, client_labels = next(iter(client_loader))\n    label_counts = Counter(client_labels.tolist())\n    print(f"Client {i} label distribution: {dict(label_counts)}")\n'

## Set up and initialize the Global Model

In [78]:
# 1. Initialization: Instantiate the global model (server)
quantize = False
global_model = QuantStubModel(q=quantize)
if(quantize):
    global_model.qconfig = torch.quantization.get_default_qat_qconfig("fbgemm")
    torch.quantization.prepare_qat(global_model, inplace=True)

# Clients

### Client Creation
We create clients with various resources, in which we take into account, speed, availability, battery level, bandwidth, etc.

In [79]:
# Create client instances

# Example of a client's resources
# resources = ClientResources(
#     speed_factor=1.5,
#     battery_level=80,
#     bandwidth=10.0,
#     dataset_size=1000,
#     CPU_available=True,
#     CPU_memory_availability=64.0,
#     GPU_available=True,
#     GPU_memory_availability=16.0,
# )

clients = []
for i in range(num_clients):
    mock_resources = ClientResources.generate_random((len(train_datasets[i]), len(train_datasets[i])))

    new_client = Client(id=i, resources=mock_resources, dataset=train_datasets[i], dataloader=train_loaders[i], val_loader=val_loaders[i])
    clients.append(new_client)


In [80]:
def client_selection_with_constraints(clients, deadline):

    # Calculate update times for all clients
    client_times = [
        (client, client.resources.dataset_size / client.resources.speed_factor)  # (Client object, update_time)
        for client in clients
    ]
    
    # Sort clients by their update time
    client_times.sort(key=lambda x: x[1])

    selected_clients = []
    total_time = 0

    # Select clients within the deadline
    for client, update_time in client_times:
        if total_time + update_time <= deadline:
            selected_clients.append(client)
            total_time += update_time
        else:
            break  # Stop once adding a client exceeds the deadline

    return selected_clients


def select_indices(n, k):
    return random.sample(range(n), k)

# Validation

In [81]:
def validate_model(global_model, val_loaders):

    global_model.eval()  # Set the model to evaluation mode
    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    criterion = torch.nn.CrossEntropyLoss()

    with torch.no_grad():
        for val_loader in val_loaders:
            client_loss = 0.0
            client_correct = 0
            client_samples = 0

            for inputs, labels in val_loader:
                outputs = global_model(inputs)
                loss = criterion(outputs, labels)
                client_loss += loss.item() * len(labels)  # Accumulate loss weighted by batch size
                _, preds = torch.max(outputs, 1)
                client_correct += (preds == labels).sum().item()  # Count correct predictions
                client_samples += len(labels)  # Track number of samples

            # Accumulate client metrics into global metrics
            total_loss += client_loss
            total_correct += client_correct
            total_samples += client_samples

    # Compute global averages
    avg_loss = total_loss / total_samples if total_samples > 0 else 0.0
    avg_accuracy = total_correct / total_samples if total_samples > 0 else 0.0
    return avg_loss, avg_accuracy


## Run Federated Learning Round

In [82]:
# Wrapper function for client training
def train_client_parallel(client, global_model, epochs=1, lr=0.001, quantize=False, lambda_kure=0.0, delta=0.0, setup='standard'):
    #print(f"\nTraining client {client.id} with resources {client.resources}")
    return client.id, client.train(global_model, epochs, lr, quantize, lambda_kure, delta, setup)

In [83]:

from concurrent.futures import ThreadPoolExecutor, as_completed

# Generate a unique log directory based on the current time
run_number = hp["run_number"]
log_dir = f"./logs/{hp['run_id']}/run_{run_number}-{datetime.now().strftime('%m-%d-%H-%M')}"
writer = SummaryWriter(log_dir=log_dir)
hyperparams_text = "\n".join([f"{key}: {value}" for key, value in hp.items()])
start_time = time.time()

# Federated Learning with FedCS Client Selection
num_rounds = hp["num_rounds"]
epochs = hp["epochs"]
round_deadline = np.inf  # Example round deadline (in ARBITRARY time units)
#assert num_clients == len(train_loaders)
print("GPU available:", torch.cuda.is_available())

GPU available: True


In [84]:
# Conduct federated learning rounds
for round_num in range(num_rounds):
    print(f"Round {round_num + 1}")
    
    # 2. Resource Request
    k = random.randint(1, num_clients)
    resource_requested_clients = random.sample(range(num_clients), k)

    # 3. Client Selection: Collect client updates
    selected_train_clients = client_selection_with_constraints(
        [clients[i] for i in resource_requested_clients], round_deadline
    )
    print(f"-> Resource requested from {len(resource_requested_clients)} clients, {len(selected_train_clients)} clients fulfilled the criteria")

    client_states = []
    total_loss = 0
    num_batches = len(selected_train_clients)
    # Parallelize client training
    with ThreadPoolExecutor(max_workers=len(selected_train_clients)) as executor:
        futures = {
            executor.submit(
                train_client_parallel, 
                client, 
                global_model, 
                epochs=1,
                lr=lr,
                quantize=quantize,
                lambda_kure=lambda_kure,
                delta=delta,
                setup=setup
                ): client
            for client in selected_train_clients
        }
        
        for future in as_completed(futures):
            client_id, (client_state, client_loss) = future.result()
            # print(f"Client {client_id} completed training.")
            client_states.append(client_state)
            total_loss += client_loss

    # Compute average training loss
    avg_round_training_loss = total_loss / num_batches if num_batches > 0 else 0
    writer.add_scalar("Metrics/Training loss", avg_round_training_loss, round_num + 1)
    print(f"Training loss after round {round_num + 1}: {avg_round_training_loss}")

    # 6. Aggregation: Aggregate updates using Federated Averaging
    new_global_state = fl.federated_averaging(global_model, client_states)  
    global_model.load_state_dict(new_global_state)
    print(f"Global model updated for round {round_num + 1}")

    # 7. Evaluate on Validation Set
    val_loaders = [client.val_loader for client in clients]  # Get all validation loaders
    val_loss, val_accuracy = validate_model(global_model, val_loaders)
    print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2%}")
    
    # Optional: Log metrics for visualization
    writer.add_scalar("Metrics/Validation Loss", val_loss, round_num + 1)
    writer.add_scalar("Metrics/Validation Accuracy", val_accuracy, round_num + 1)

if quantize:
    torch.quantization.convert(global_model, inplace=True)


Round 1
-> Resource requested from 8 clients, 8 clients fulfilled the criteria
Training loss after round 1: 1.298714860438536
Global model updated for round 1
Validation Loss: 2.1863, Validation Accuracy: 33.46%
Round 2
-> Resource requested from 14 clients, 14 clients fulfilled the criteria
Training loss after round 2: 1.0192292607046303
Global model updated for round 2
Validation Loss: 1.5342, Validation Accuracy: 41.06%
Round 3
-> Resource requested from 6 clients, 6 clients fulfilled the criteria
Training loss after round 3: 0.5667829816862755
Global model updated for round 3
Validation Loss: 1.6853, Validation Accuracy: 41.90%
Round 4
-> Resource requested from 13 clients, 13 clients fulfilled the criteria
Training loss after round 4: 0.5777835205918533
Global model updated for round 4
Validation Loss: 1.0263, Validation Accuracy: 63.77%
Round 5
-> Resource requested from 6 clients, 6 clients fulfilled the criteria
Training loss after round 5: 0.5670084149127609
Global model updat

## Evaluate Model

In [85]:
def evaluate_model(model, dataloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    test_accuracy = correct / total
    print(f"Global model accuracy: {test_accuracy:.2%}")
    return test_accuracy
    
# Evaluate the model on the test dataset
test_accuracy = evaluate_model(global_model, test_loader)

# End TensorBoard writer
final_metrics = {}
hp_cleaned = {k: v for k, v in hp.items() if isinstance(v, (int, float, str, bool, torch.Tensor))}
writer.add_hparams(hp_cleaned, final_metrics)
writer.close()

Global model accuracy: 76.72%


In [86]:
# Save the final model
model_save_path = f"./saved_models/global_model_{hp['run_id']}_{run_number}.pth"
torch.save(global_model.state_dict(), model_save_path)
print(f"Global model saved at {model_save_path}")

Global model saved at ./saved_models/global_model_test_noname.pth
