## This file is a WIP

In [4]:
from torch.utils.data import DataLoader, random_split, TensorDataset
import torch.nn as nn
import torch.optim as optim
import time
from itertools import groupby

import sys
import os

# Get the absolute path of the src directory
src_path = os.path.abspath('../src')
# Add src_path to sys.path
if src_path not in sys.path:
    sys.path.append(src_path)
    
import fl
from client import Client

## Obtain Dataset

In [5]:
# download dataset and preprocess/transform
from torchvision import datasets, transforms

# Define transformations for the dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalize to range [-1, 1]
])

# Download and load MNIST dataset
mnist_train = datasets.MNIST(root="../data", train=True, transform=transform, download=True)
mnist_test = datasets.MNIST(root="../data", train=False, transform=transform, download=True)

# Print dataset sizes
print(f"Train dataset size: {len(mnist_train)}, Test dataset size: {len(mnist_test)}")

Train dataset size: 60000, Test dataset size: 10000


## Split Dataset

In [6]:
from torch.utils.data import random_split

# Define the number of clients and split sizes
num_clients = 5
client_data_size = len(mnist_train) // num_clients

# Split the training data into smaller datasets for each client
client_datasets = random_split(mnist_train, [client_data_size] * num_clients)

# Create DataLoaders for each client
client_loaders = [DataLoader(ds, batch_size=32, shuffle=True) for ds in client_datasets]

# Test DataLoader for evaluation
test_loader = DataLoader(mnist_test, batch_size=32, shuffle=False)

print(f"Simulated {num_clients} clients, each with {client_data_size} training samples.")


Simulated 5 clients, each with 12000 training samples.


### Generate Dummy Clients with different computational speeds

In [7]:
# class Client:
#     def __init__(self, id, speed_factor, dataset, batch_size=32):
#         """
#         Initialize a Client object.

#         Parameters:
#         - id (int): The ID of the client.
#         - speed_factor (float): The speed factor of the client, which determines the training delay.
#         - dataset (torch.utils.data.Dataset): The dataset used for training.
#         - batch_size (int, optional): The batch size for the dataloader. Default is 32.
#         """
#         self.id = id
#         self.speed_factor = speed_factor
#         self.dataset = dataset
#         self.dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

#     def train(self, global_model, epochs=1):
#         """
#         Train the global model on the client's local dataset using Adam optimizer.

#         Parameters:
#         - global_model (torch.nn.Module): The global model to be trained.
#         - epochs (int, optional): The number of training epochs. Default is 1.

#         Returns:
#         - state_dict (dict): The updated model parameters.
#         """
#         # Directly copy the global model
#         local_model = global_model
#         local_model.load_state_dict(global_model.state_dict())
        
#         # Define the loss function and optimizer
#         criterion = nn.CrossEntropyLoss()
#         optimizer = optim.Adam(
#             local_model.parameters(), 
#             lr=0.001, 
#             # learning hyperparameters can be set later
#             # betas=(0.9, 0.99), 
#             # eps=1e-7, 
#             # weight_decay=1e-4
#         )
        
#         # Simulate training delay based on speed_factor
#         local_model.train()
#         for epoch in range(int(epochs * self.speed_factor)):
#             for inputs, labels in self.dataloader:
#                 optimizer.zero_grad()
#                 outputs = local_model(inputs)
#                 loss = criterion(outputs, labels)
#                 loss.backward()
#                 optimizer.step()

#         # Return updated model parameters
#         return local_model.state_dict()
    

## Set up and initialize the Global Model

In [8]:
# instantiate the global model (server)
model = fl.create_model()
global_model = model
print(global_model)

Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=784, out_features=128, bias=True)
  (2): ReLU()
  (3): Linear(in_features=128, out_features=10, bias=True)
)


In [9]:
# Create client instances
num_clients = 5
client_speeds = [1.0, 0.5, 0.25, 1.5, 0.75]  # Simulated speed factors
client_datasets = random_split(mnist_train, [len(mnist_train) // num_clients] * num_clients)

clients = []
for i in range(num_clients):
    new_client = Client(id=i, speed_factor=client_speeds[i], dataset=client_datasets[i])
    clients.append(new_client)


In [10]:
# Train clients and collect their updates
client_states = []
for client in clients:
    print(f"Training client {client.id} with speed {client.speed_factor}")
    client_state = client.train(global_model, epochs=1)
    client_states.append(client_state)

Training client 0 with speed 1.0
Training client 1 with speed 0.5
Training client 2 with speed 0.25
Training client 3 with speed 1.5
Training client 4 with speed 0.75


In [11]:
# This function needs to be made functional

def train_client(model, dataloader, speed_factor, epochs=1):
    start_time = time.time()
    
    # Simulate slower/faster training based on speed_factor
    for _ in range(int(epochs * speed_factor)):
        pass  # Simulated work
    
    end_time = time.time()
    print(f"Client trained in {end_time - start_time:.2f} seconds")
    return model.state_dict()  # Return model state


In [12]:
# Sort clients into batches

batches = []

# sort the clients by speed
sorted_clients = sorted(zip(client_speeds, client_loaders), key=lambda x: x[0])

# Group clients in windows of 0.5 speed factor
for speed, group in groupby(sorted_clients, key=lambda x: x[0] // 0.5):  # Group by speed range
    batches.append(list(group))

batches


[[(0.25, <torch.utils.data.dataloader.DataLoader at 0x7fcd2c4c0400>)],
 [(0.5, <torch.utils.data.dataloader.DataLoader at 0x7fcd4299f970>),
  (0.75, <torch.utils.data.dataloader.DataLoader at 0x7fcd25130430>)],
 [(1.0, <torch.utils.data.dataloader.DataLoader at 0x7fcd3d09a650>)],
 [(1.5, <torch.utils.data.dataloader.DataLoader at 0x7fcd25131b70>)]]

In [13]:
for batch in batches:
    print(batch)
    # client_states = []
    # for _, client_loader in batch:
    #     client_state = train_client(global_model, client_loader, epochs=1)
    #     client_states.append(client_state)
    
    # # Aggregate updates for the current batch
    # new_global_state = federated_averaging(global_model, client_states)
    # global_model.load_state_dict(new_global_state)


[(0.25, <torch.utils.data.dataloader.DataLoader object at 0x7fcd2c4c0400>)]
[(0.5, <torch.utils.data.dataloader.DataLoader object at 0x7fcd4299f970>), (0.75, <torch.utils.data.dataloader.DataLoader object at 0x7fcd25130430>)]
[(1.0, <torch.utils.data.dataloader.DataLoader object at 0x7fcd3d09a650>)]
[(1.5, <torch.utils.data.dataloader.DataLoader object at 0x7fcd25131b70>)]
