# Here we develop heterogeneity in the clients

In [12]:
from torch.utils.data import DataLoader, random_split, TensorDataset
import torch.nn as nn
import torch.optim as optim
import time
from itertools import groupby

import sys
import os

# Get the absolute path of the src directory
src_path = os.path.abspath('../src')
# Add src_path to sys.path
if src_path not in sys.path:
    sys.path.append(src_path)
    
import fl
from client import Client, ClientResources

## Obtain Dataset

In [13]:
# download dataset and preprocess/transform
from torchvision import datasets, transforms

# Define transformations for the dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalize to range [-1, 1]
])

# Download and load MNIST dataset
mnist_train = datasets.MNIST(root="../data", train=True, transform=transform, download=True)
mnist_test = datasets.MNIST(root="../data", train=False, transform=transform, download=True)

# Print dataset sizes
print(f"Train dataset size: {len(mnist_train)}, Test dataset size: {len(mnist_test)}")

Train dataset size: 60000, Test dataset size: 10000


## Split Dataset

In [14]:
from torch.utils.data import random_split

# Define the number of clients and split sizes
num_clients = 5
client_data_size = len(mnist_train) // num_clients

# Split the training data into smaller datasets for each client
client_datasets = random_split(mnist_train, [client_data_size] * num_clients)

# Create DataLoaders for each client
client_loaders = [DataLoader(ds, batch_size=32, shuffle=True) for ds in client_datasets]

# Test DataLoader for evaluation
test_loader = DataLoader(mnist_test, batch_size=32, shuffle=False)

print(f"Simulated {num_clients} clients, each with {client_data_size} training samples.")


Simulated 5 clients, each with 12000 training samples.


### Generate Dummy Clients with different computational speeds

## Set up and initialize the Global Model

In [15]:
# instantiate the global model (server)
model = fl.create_model()
global_model = model
print(global_model)

Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=784, out_features=128, bias=True)
  (2): ReLU()
  (3): Linear(in_features=128, out_features=10, bias=True)
)


In [16]:
# Create client instances
num_clients = 5
client_speeds = [1.0, 1.5, 2, 1.5, 1.75]  # Simulated speed factors
client_datasets = random_split(mnist_train, [len(mnist_train) // num_clients] * num_clients)

resources = ClientResources(
    speed_factor=1.5,
    battery_level=80,
    bandwidth=10.0,
    dataset_size=1000,
    CPU_available=True,
    CPU_memory_availability=64.0,
    GPU_available=True,
    GPU_memory_availability=16.0,
)

clients = []
for i in range(num_clients):
    mock_resources = ClientResources.generate_random((len(client_datasets[i]), len(client_datasets[i])))

    new_client = Client(id=i, resources=mock_resources, dataset=client_datasets[i])
    clients.append(new_client)


In [17]:
# from multiprocessing import Pool

# def parallel_train(client):
#     speed, client_loader = client
#     return train_client(global_model, client_loader, speed_factor=speed)

# with Pool(processes=len(client_loaders)) as pool:
#     client_states = pool.map(parallel_train, zip(client_speeds, client_loaders))

In [19]:
from concurrent.futures import ThreadPoolExecutor, as_completed

# Wrapper function for client training
def train_client_parallel(client, global_model, epochs=1):
    print(f"\nTraining client {client.id} with resources {client.resources}")
    return client.id, client.train(global_model, epochs)

# Number of rounds to simulate
num_rounds = 3

# List to store global model states over rounds (optional)
global_model_states = []

for round_num in range(num_rounds):
    print(f"\n=== Federated Learning Round {round_num + 1} ===")

    # Parallelize client training
    client_states = []
    with ThreadPoolExecutor(max_workers=len(clients)) as executor:
        futures = {executor.submit(train_client_parallel, client, global_model, 1): client for client in clients}
        
        for future in as_completed(futures):
            client_id, client_state = future.result()
            print(f"Client {client_id} completed training.")
            client_states.append(client_state)
    
    # Aggregate client updates using Federated Averaging
    new_global_state = fl.federated_averaging(global_model, client_states)
    global_model.load_state_dict(new_global_state)

    # Optionally save global model state for each round
    # global_model_states.append(copy.deepcopy(global_model.state_dict()))

    # Evaluate global model (optional)
    print(f"Global model updated after round {round_num + 1}")



=== Federated Learning Round 1 ===

Training client 0 with resources ClientResources(speed_factor=1.7160021044797742, battery_level=50.704235090809576, bandwidth=51.33546357231246, dataset_size=12000, CPU_available=True, CPU_memory_availability=119.05571570773346, GPU_available=True, GPU_memory_availability=31.669893043095335)

Training client 1 with resources ClientResources(speed_factor=1.6969267879880643, battery_level=94.14238177520939, bandwidth=39.927809165515164, dataset_size=12000, CPU_available=True, CPU_memory_availability=58.40438549768966, GPU_available=False, GPU_memory_availability=0)

Training client 2 with resources ClientResources(speed_factor=1.9723324402252151, battery_level=21.132345241248185, bandwidth=85.70204887906986, dataset_size=12000, CPU_available=False, CPU_memory_availability=53.7062699297276, GPU_available=False, GPU_memory_availability=0)

Training client 3 with resources ClientResources(speed_factor=1.6215943869529335, battery_level=78.85237043272897, 

In [None]:
# Train clients and collect their updates
client_states = []
for client in clients:
    print(f"\nTraining client {client.id} with resources {client.resources}")
    client_state = client.train(global_model, epochs=1)
    client_states.append(client_state)


Training client 0 with resources ClientResources(speed_factor=1.970175601083471, battery_level=44.412161174949844, bandwidth=42.17689437834442, dataset_size=30000, CPU_available=False, CPU_memory_availability=9.070395434254579, GPU_available=False, GPU_memory_availability=0)
Training round complete in 5.67: seconds
Client simulated to take 11.18 seconds for training

Training client 1 with resources ClientResources(speed_factor=1.689800106109189, battery_level=37.42924574833511, bandwidth=56.57171983565547, dataset_size=30000, CPU_available=True, CPU_memory_availability=103.28883643824332, GPU_available=True, GPU_memory_availability=4.624544318161185)


KeyboardInterrupt: 

In [None]:
# Sort clients into batches

batches = []

# sort the clients by speed
sorted_clients = sorted(zip(client_speeds, client_loaders), key=lambda x: x[0])

# Group clients in windows of 0.5 speed factor
for speed, group in groupby(sorted_clients, key=lambda x: x[0] // 0.5):  # Group by speed range
    batches.append(list(group))

batches


[[(1.0, <torch.utils.data.dataloader.DataLoader at 0x7fc7c599b0a0>)],
 [(1.5, <torch.utils.data.dataloader.DataLoader at 0x7fc7c599b0d0>),
  (1.5, <torch.utils.data.dataloader.DataLoader at 0x7fc7a8bb1210>),
  (1.75, <torch.utils.data.dataloader.DataLoader at 0x7fc7a8bb0f70>)],
 [(2, <torch.utils.data.dataloader.DataLoader at 0x7fc7c59988b0>)]]

In [None]:
for batch in batches:
    print(batch)
    # client_states = []
    # for _, client_loader in batch:
    #     client_state = train_client(global_model, client_loader, epochs=1)
    #     client_states.append(client_state)
    
    # # Aggregate updates for the current batch
    # new_global_state = federated_averaging(global_model, client_states)
    # global_model.load_state_dict(new_global_state)


[(1.0, <torch.utils.data.dataloader.DataLoader object at 0x7fc7c599b0a0>)]
[(1.5, <torch.utils.data.dataloader.DataLoader object at 0x7fc7c599b0d0>), (1.5, <torch.utils.data.dataloader.DataLoader object at 0x7fc7a8bb1210>), (1.75, <torch.utils.data.dataloader.DataLoader object at 0x7fc7a8bb0f70>)]
[(2, <torch.utils.data.dataloader.DataLoader object at 0x7fc7c59988b0>)]
