In [None]:
# Conduct federated learning rounds
for round_num in range(num_rounds):
    print(f"Round {round_num + 1}")
    
    # 2. Resource Request
    k = random.randint(1, num_clients)
    resource_requested_clients = random.sample(range(num_clients), k)

    # 3. Client Selection: Collect client updates
    selected_train_clients = client_selection_with_constraints(
        [clients[i] for i in resource_requested_clients], round_deadline
    )
    print(f"-> Resource requested from {len(resource_requested_clients)} clients, {len(selected_train_clients)} clients fulfilled the criteria")

    client_states = []
    total_loss = 0
    num_batches = len(selected_train_clients)
    # Parallelize client training
    with ThreadPoolExecutor(max_workers=len(selected_train_clients)) as executor:
        futures = {
            executor.submit(
                train_client_parallel, 
                client, 
                global_model, 
                epochs=1,
                lr=lr,
                quantize=quantize,
                lambda_kure=lambda_kure,
                delta=delta,
                setup=setup
                ): client
            for client in selected_train_clients
        }
        
        for future in as_completed(futures):
            client_id, (client_state, client_loss) = future.result()
            # print(f"Client {client_id} completed training with loss: {client_loss:.4f}")
            client_states.append(client_state)
            total_loss += client_loss

    # Compute average training loss
    avg_round_training_loss = total_loss / num_batches if num_batches > 0 else 0
    writer.add_scalar("Metrics/Training loss", avg_round_training_loss, round_num + 1)
    print(f"Training loss after round {round_num + 1}: {avg_round_training_loss}")

    # 6. Aggregation: Aggregate updates using Federated Averaging
    new_global_state = fl.federated_averaging(global_model, client_states)  
    global_model.load_state_dict(new_global_state)
    print(f"Global model updated for round {round_num + 1}")

    # 7. Evaluate on Validation Set
    val_loaders = [client.val_loader for client in clients]  # Get all validation loaders
    val_loss, val_accuracy = validate_model(global_model, val_loaders)
    print(f"Validation Loss after round {round_num + 1}: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2%}")
    
    # Optional: Log metrics for visualization
    writer.add_scalar("Metrics/Validation Loss", val_loss, round_num + 1)
    writer.add_scalar("Metrics/Validation Accuracy", val_accuracy, round_num + 1)

if quantize:
    torch.quantization.convert(global_model, inplace=True)


In [None]:
# Wrapper function for client training
def train_client_parallel(client, global_model, epochs=1, lr=0.001, quantize=False, lambda_kure=0.0, delta=0.0, setup='standard'):
    # print(f"Training client {client.id} with resources {client.resources}")
    return client.id, client.train(global_model, epochs, lr, quantize, lambda_kure, delta, setup)

In [None]:
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import time

from dataclasses import dataclass, field
import random
import fl
from models import *



@dataclass
class ClientResources:
    speed_factor: float
    battery_level: float
    bandwidth: float
    dataset_size: int
    CPU_available: bool
    CPU_memory_availability: float
    GPU_available: bool
    GPU_memory_availability: float

    def __post_init__(self):
        # Validation logic
        if not (1 <= self.speed_factor):
            raise ValueError("speed_factor must be greater than or equal to 1")
        if not (1 <= self.battery_level <= 100):
            raise ValueError("battery_level must be between 0 and 100")
        if self.bandwidth < 0:
            raise ValueError("bandwidth must be non-negative")
        if self.dataset_size <= 0:
            raise ValueError("dataset_size must be positive")
        if not (0 <= self.CPU_memory_availability <= 128):  # Example: assuming max 128GB
            raise ValueError("CPU_memory_availability must be between 0 and 128")
        if not (0 <= self.GPU_memory_availability <= 32):  # Example: assuming max 32GB
            raise ValueError("GPU_memory_availability must be between 0 and 32")



    
    @staticmethod
    def generate_random(dataset_size_range=(500, 2000)):
        """
        Generate random valid ClientResources.
        """

        GPU_available = random.choice([True, False])

        return ClientResources(
            speed_factor=random.uniform(1.0, 2.0),  # Speed factor in range [0.1, 2.0]
            battery_level=random.uniform(1, 100),   # Battery level in range [0, 100]
            bandwidth=random.uniform(1, 100),       # Bandwidth in range [1, 100] Mbps
            dataset_size=random.randint(*dataset_size_range),  # Dataset size
            CPU_available=random.choice([True, False]),        # Random CPU availability
            CPU_memory_availability=random.uniform(0, 128),    # CPU memory in GB
            GPU_available=GPU_available,        # Random GPU availability
            GPU_memory_availability=random.uniform(0, 32) if GPU_available else 0,  # GPU memory if available
        )

class Client:
    def __init__(self, id, resources: ClientResources, dataset, dataloader, val_loader):
        """
        Initialize a Client object.

        Parameters:
        - id (int): The ID of the client.
        - speed_factor (float): The speed factor of the client, which determines the training delay. It must be greater than 1
        - dataset (torch.utils.data.Dataset): The dataset used for training.
        - batch_size (int, optional): The batch size for the dataloader. Default is 32.
        """

        self.id = id
        self.resources = resources

        self.dataset = dataset
        self.dataloader = dataloader
        self.val_loader = val_loader

    def train(self, global_model, epochs=1, lr=0.001, quantize=False, lambda_kure=0.0, delta=0.0, setup='standard'):
        """
        Train the global model on the client's local dataset using Adam optimizer.

        Parameters:
        - global_model (torch.nn.Module): The global model to be trained.
        - epochs (int, optional): The number of training epochs. Default is 1.

        Returns:
        - state_dict (dict): The updated model parameters.
        """
        total_loss = 0
        num_batches = 0

        # track start time
        start_time = time.time()

        # Directly copy the global model
        local_model = QuantStubModel(q=quantize)
        if(quantize):
            local_model.qconfig = torch.quantization.get_default_qat_qconfig("fbgemm")
            torch.quantization.prepare_qat(local_model, inplace=True)
        local_model.load_state_dict(global_model.state_dict())
        
        # Define the loss function and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(
            local_model.parameters(), 
            lr=lr, 
            # learning hyperparameters can be set later
            # betas=(0.9, 0.99), 
            # eps=1e-7, 
            # weight_decay=1e-4
        )
        
        # Simulate training delay based on speed_factor
        local_model.train()
        for epoch in range(epochs):
            for inputs, labels in self.dataloader:
                optimizer.zero_grad()

                if setup == 'mqat':
                    # Apply Pseudo-Quantization Noise (APQN)
                    if delta is not None:
                        for param in model.parameters():
                            param.data = add_pseudo_quantization_noise(param, delta)

                    # Apply Multi-Bit Quantization (MQAT)
                    if bit_widths is not None:
                        bit_width = random.choice(bit_widths)
                        for param in model.parameters():
                            param.data = quantize_multi_bit(param, bit_width)

                outputs = local_model(inputs)
                loss = criterion(outputs, labels)

                # Kurtosis Regularization
                if setup == 'kure':
                    for param in local_model.parameters():
                        loss += lambda_kure * kurtosis_regularization(param)

                loss.backward()
                optimizer.step()
                total_loss += loss.item()
                num_batches += 1


        end_time = time.time()
        time_elapsed = end_time - start_time
        #print(f"Training round complete in {time_elapsed:.2f}: seconds")

        time.sleep((self.resources.speed_factor - 1) * time_elapsed)
        end_time = time.time()
        time_elapsed = end_time - start_time
        #print(f"Client simulated to take {time_elapsed:.2f} seconds for training")

        # Return updated model parameters
        avg_loss = total_loss / num_batches if num_batches > 0 else 0
        return local_model.state_dict(), loss
    

In [None]:
from collections import Counter

# Number of clients and non-IID split
train_data_size = len(mnist_train) // num_clients
val_data_size = len(mnist_val) // num_clients

if split == "NONIID":
    
    # Create indices for each class
    train_class_indices = {i: np.where(np.array(mnist_train.targets) == i)[0] for i in range(num_classes)}
    val_class_indices = {i: np.where(np.array(mnist_val.targets) == i)[0] for i in range(num_classes)}

    train_indices = []
    for client_id in range(num_clients):
        chosen_classes = np.random.choice(num_classes, classes_per_client, replace=False)
        train_client_idx = []
        val_client_idx = []
        for cls in chosen_classes:
            train_cls_size = len(train_class_indices[cls]) // (num_clients // classes_per_client)
            train_cls_idx = np.random.choice(train_class_indices[cls], train_cls_size, replace=False)
            train_client_idx.extend(train_cls_idx)

            val_cls_size = len(val_class_indices[cls]) // (num_clients // classes_per_client)
            val_cls_idx = np.random.choice(val_class_indices[cls], val_cls_size, replace=False)
            val_client_idx.extend(val_cls_idx)
                        
            # Remove assigned indices to avoid duplication
            train_class_indices[cls] = np.setdiff1d(train_class_indices[cls], train_cls_idx)
            val_class_indices[cls] = np.setdiff1d(val_class_indices[cls], val_cls_idx)

        train_indices.append(train_client_idx)
        val_indices.append(val_client_idx)

    # Create datasets and DataLoaders for each client
    train_datasets = [Subset(mnist_train, indices) for indices in train_indices]
    train_loaders = [DataLoader(ds, batch_size=batch_size, shuffle=True) for ds in train_datasets]
    val_datasets = [Subset(mnist_val, indices) for indices in val_indices]
    val_loaders = [DataLoader(ds, batch_size=batch_size, shuffle=True) for ds in val_datasets]

else:
    
    # Split the training data into smaller datasets for each client
    train_datasets = random_split(mnist_train, [train_data_size] * num_clients)
    train_loaders = [DataLoader(ds, batch_size=batch_size, shuffle=True) for ds in train_datasets]
    val_datasets = random_split(mnist_train, [train_data_size] * num_clients)
    val_loaders = [DataLoader(ds, batch_size=batch_size, shuffle=True) for ds in train_datasets]

# val_loaders = DataLoader(mnist_val, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=False)

"""
print(f"Simulated {num_clients} clients, each with {train_data_size} training samples, and {len(mnist_val)} validation samples")
# Debugging: Output distribution of classes for all clients
for i, client_dataset in enumerate(train_datasets):
    client_loader = torch.utils.data.DataLoader(client_dataset, batch_size=len(client_dataset))
    client_samples, client_labels = next(iter(client_loader))
    label_counts = Counter(client_labels.tolist())
    print(f"Client {i} label distribution: {dict(label_counts)}")
"""