In [1]:
# Jupyter notebook magic to reload modules
%load_ext autoreload
%autoreload 2

In [2]:
from model_utils import load_model, save_model, set_model_weights, visualize_federated_learning_performance
from data_utils import load_client_data, load_dataset, load_dataloader
from federated_utils import federated_averaging, train_on_clients
import copy
import os

In [3]:
train_dataloader, val_dataloader = load_dataset(classes=[
    0, # T-shirt/top
    1, # Trouser
    3, # Dress
], samples_per_class=1000)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./dataset/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:02<00:00, 10753820.19it/s]


Extracting ./dataset/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./dataset/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./dataset/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 1005685.71it/s]

Extracting ./dataset/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./dataset/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./dataset/FashionMNIST/raw/t10k-images-idx3-ubyte.gz



100%|██████████| 4422102/4422102 [00:00<00:00, 8044869.67it/s]


Extracting ./dataset/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./dataset/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./dataset/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 7733623.56it/s]


Extracting ./dataset/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./dataset/FashionMNIST/raw



In [4]:
samples_count = sum([len(data) for data, _ in train_dataloader])
print(f"Loaded {samples_count} training examples")

Loaded 3000 training examples


In [None]:
def run_federated_learning_experiment(num_clients, num_rounds, local_epochs, model_path, save_path_prefix, ml_aggregation_method):
    samples_per_node = samples_count // num_clients
    
    # Load initial model
    global_model = load_model(model_path)
    global_train_dataset, global_val_dataset = load_dataset(classes=[
        0, # T-shirt/top
        1, # Trouser
        3, # Dress
    ], samples_per_class=1000)

    # Load client models and data
    client_models = [copy.deepcopy(global_model) for _ in range(num_clients)]
    client_data = [load_dataloader(global_train_dataset, global_val_dataset, client_id=i, num_samples=samples_per_node) for i in range(num_clients)]

    for round_num in range(num_rounds):
        print(f"Round {round_num + 1}/{num_rounds}")
        train_on_clients(client_models, client_data, local_epochs)

        if ml_aggregation_method == "fedavg":
            averaged_weights = federated_averaging(client_models)
        else:
            raise NotImplementedError(f"Aggregation method {ml_aggregation_method} not implemented")

        set_model_weights(global_model, averaged_weights)

        # evaluate - what KPI´s do we want?

        for client_model in client_models:
            set_model_weights(client_model, averaged_weights)

    # Save the final model
    save_path = f"{save_path_prefix}_clients{num_clients}_rounds{num_rounds}_epochs{local_epochs}.h5"
    save_model(global_model, save_path)
    print(f"Saved model to {save_path}")

# config

In [None]:
node_range = range(2, 11)  # 2 to 10 nodes
local_epoch_range = range(1, 11)  # 1 to 10 local epochs
global_epoch_range = range(1, 6)  # 1 to 5 global epochs

In [6]:

model_folder = "./models"
model_filename = "your_model_file"
model_path = os.path.join(model_folder, model_filename)
save_path_prefix = "./saved_models/model"

# Ensure the save directory exists
os.makedirs(os.path.dirname(save_path_prefix), exist_ok=True)

# run experiments