# Bandwidth

In [4]:
import random
import numpy as np
import torch

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

In [5]:
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import flwr
from flwr.client import Client, ClientApp, NumPyClient
from flwr.server import ServerApp, ServerConfig, ServerAppComponents
from flwr.server.strategy import FedAvg, FedAdagrad
from flwr.simulation import run_simulation
from flwr_datasets import FederatedDataset
from flwr.common import ndarrays_to_parameters, NDArrays, Scalar, Context, parameters_to_ndarrays
from flwr.common import Metrics, Context


from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner

from flwr.client.mod import parameters_size_mod

from dmf import *

In [6]:
DEVICE = torch.device("cuda:1")  
NUM_PARTITIONS = 5
BATCH_SIZE = 256

print(f"Training on {DEVICE}")
print(f"Flower {flwr.__version__} / PyTorch {torch.__version__}")

dataset = "ashraq/movielens_ratings" 
partitioner = IidPartitioner(num_partitions=NUM_PARTITIONS)           
fds = FederatedDataset(dataset=dataset,
                    partitioners={"train": partitioner})

Training on cuda:1
Flower 1.10.0 / PyTorch 2.2.1+cu121




In [7]:
########################################
# Process the Federated Dataset & Global Mapping without interaction matrix
########################################

def compute_global_mapping(fds):
    """
    Compute global mapping for user and item IDs from the full dataset.
    """

    global_train_df = fds.load_split("train").to_pandas()[["user_id", "movie_id", "rating"]]    # Full training data across all clients
    global_test_df = fds.load_split("validation").to_pandas()[["user_id", "movie_id", "rating"]]    # Full test set

    train_users = set(global_train_df['user_id'].unique())
    train_movies = set(global_train_df['movie_id'].unique())

    global_test_df = global_test_df[
        global_test_df['user_id'].isin(train_users) &
        global_test_df['movie_id'].isin(train_movies)
    ]

    all_users = set(global_train_df['user_id']).union(global_test_df['user_id'])
    all_movies = set(global_train_df['movie_id']).union(global_test_df['movie_id'])
    user_id_map = {user: idx for idx, user in enumerate(sorted(all_users))}
    movie_id_map = {movie: idx for idx, movie in enumerate(sorted(all_movies))}

    num_users = len(user_id_map)
    num_movies = len(movie_id_map)
    print("Global Number of Users:", num_users)
    print("Global Number of Movies:", num_movies)

    return user_id_map, movie_id_map,

    
# Precompute the global mappings and interaction matrix once.
global_user_id_map, global_movie_id_map = compute_global_mapping(fds)


Global Number of Users: 43584
Global Number of Movies: 15276


In [9]:
########################################
# The model
########################################

class MLPLayers(nn.Module):
    def __init__(self, sizes, dropout=0.3, activation="leaky_relu", bn=False, init_method="norm", last_activation=True):
        super(MLPLayers, self).__init__()
        layers = []
        for i in range(len(sizes) - 1):
            layers.append(nn.Linear(sizes[i], sizes[i+1]))
            if bn:
                layers.append(nn.BatchNorm1d(sizes[i+1]))
            if activation == "leaky_relu":
                layers.append(nn.LeakyReLU())
            else:
                layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout))
        if not last_activation:
            layers = layers[:-2]
        self.mlp = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.mlp(x)

class DMFFederated(nn.Module):
    """
    Modified DMF model that uses learnable embeddings instead of a precomputed global interaction matrix.
    """
    def __init__(self, num_users, num_items,
                 user_embedding_size=32,
                 item_embedding_size=32,
                 user_hidden_sizes=[64, 32],
                 item_hidden_sizes=[64, 32],
                 dropout=0.3,
                 activation="leaky_relu",
                 bn=False,
                 init_method="norm"):
        super(DMFFederated, self).__init__()
        self.num_users = num_users
        self.num_items = num_items
        
        self.user_embedding = nn.Embedding(num_users, user_embedding_size)
        self.item_embedding = nn.Embedding(num_items, item_embedding_size)
        
        self.user_fc_layers = MLPLayers(
            [user_embedding_size] + user_hidden_sizes,
            dropout=dropout,
            activation=activation,
            bn=bn,
            init_method=init_method,
            last_activation=True
        )
        self.item_fc_layers = MLPLayers(
            [item_embedding_size] + item_hidden_sizes,
            dropout=dropout,
            activation=activation,
            bn=bn,
            init_method=init_method,
            last_activation=True
        )
        
        self.loss_fn = nn.HuberLoss(delta=0.5)
        self._init_weights()
    
    def _init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.normal_(module.weight, 0, 0.01)
                if module.bias is not None:
                    module.bias.data.fill_(0.0)
            elif isinstance(module, nn.Embedding):
                nn.init.normal_(module.weight, 0, 0.01)
    
    def forward(self, user_indices, item_indices):
        user_emb = self.user_embedding(user_indices)
        item_emb = self.item_embedding(item_indices)
        user_features = self.user_fc_layers(user_emb)
        item_features = self.item_fc_layers(item_emb)
        prediction = torch.mul(user_features, item_features).sum(dim=1)
        return prediction
    
    def calculate_loss(self, batch):
        user = batch['user_id']
        item = batch['movie_id']
        rating = batch['rating']
        preds = self.forward(user, item)
        loss = self.loss_fn(preds, rating)
        return loss
    
    def predict(self, batch):
        return self.forward(batch['user_id'], batch['movie_id'])


In [10]:
########################################
# Set and get parameters
########################################

from collections import OrderedDict
from typing import Dict, List, Optional, Tuple, Union


def set_parameters(model, parameters: List[np.ndarray]):
    """
    Sets the parameters of the model using a list of NumPy arrays.
    
    Args:
        model (torch.nn.Module): The model.
        parameters (List[np.ndarray]): The model parameters as a list of NumPy arrays.
    """
    params_dict = zip(model.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.from_numpy(v) for k, v in params_dict})
    model.load_state_dict(state_dict, strict=True)


def get_parameters(model) -> List[np.ndarray]:
    """
    Retrieves the model parameters as a list of NumPy arrays.

    Args:
        model (torch.nn.Module): The model.

    Returns:
        List[np.ndarray]: The model parameters as a list of NumPy arrays.
    """
    return [val.cpu().numpy() for _, val in model.state_dict().items()]


In [11]:
########################################
# Model size
########################################

model = DMFFederated(
            num_users=len(global_user_id_map),
            num_items=len(global_movie_id_map),
            user_embedding_size=32,
            item_embedding_size=32,
            user_hidden_sizes=[64, 32],
            item_hidden_sizes=[64, 32],
            dropout=0.3,
            activation="leaky_relu",
            bn=False,
            init_method="norm"
            )

# Calculate model size in bytes
vals = model.state_dict().values()
total_size_bytes = sum(p.element_size() * p.numel() for p in vals)
total_size_mb = int(total_size_bytes / (1024**2))

print("Model size is: {} MB".format(total_size_mb))

Model size is: 7 MB


In [None]:
########################################
# Flower Client and Server Function
########################################
torch.cuda.empty_cache()
device = DEVICE 
num_partitions = NUM_PARTITIONS
batch_size = BATCH_SIZE
num_epochs = 1
lr = 0.0001
weight_decay = 1e-4


class FlowerClient(NumPyClient):
    def __init__(self, model):
        self.model = model
      
   
    def fit(self, parameters, config):
        set_parameters(self.model, parameters)
        return get_parameters(self.model), int(1), {}
    

    def evaluate(self, parameters, config):
        set_parameters(self.model, parameters)
        return float(0), int(1), {"loss": float(0)}
    

def client_fn(context: Context) -> Client:
    return FlowerClient(model).to_client()


client_app = ClientApp(
    client_fn=client_fn,
    mods=[parameters_size_mod]
    )



bandwidth_sizes = []


class BandwidthTrackingFedAvg(FedAvg):
    def aggregate_fit(self, server_round, results, failures):
        if not results:
            return None, {}

        # Track sizes of models received
        for _, res in results:
            ndas = parameters_to_ndarrays(res.parameters)
            size = int(sum(n.nbytes for n in ndas) / (1024**2))
            print(f"Server receiving model size: {size} MB")
            bandwidth_sizes.append(size)

        return super().aggregate_fit(server_round, results, failures)

    def configure_fit(self, server_round, parameters, client_manager):
        # Call FedAvg for actual configuration
        instructions = super().configure_fit(
            server_round, parameters, client_manager
        )

        # Track sizes of models to be sent
        for _, ins in instructions:
            ndas = parameters_to_ndarrays(ins.parameters)
            size = int(sum(n.nbytes for n in ndas) / (1024**2))
            print(f"Server sending model size: {size} MB")
            bandwidth_sizes.append(size)

        return instructions
    

params = ndarrays_to_parameters(get_parameters(model))

def server_fn(context: Context):
    strategy = BandwidthTrackingFedAvg(
        fraction_evaluate=0.0,
        initial_parameters=params,
    )
    config = ServerConfig(num_rounds=1)
    return ServerAppComponents(
        strategy=strategy,
        config=config,
    )

backend_config = {"client_resources": None}
if DEVICE.type == "cuda":
    backend_config = {"client_resources": {"num_gpus": 2, "num_cpus": 1}}

server_app = ServerApp(server_fn=server_fn)

In [15]:
########################################
# Simulation
########################################
run_simulation(server_app=server_app,
               client_app=client_app,
               num_supernodes=NUM_PARTITIONS,
               backend_config=backend_config
               )

[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=1, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Using initial global parameters provided by strategy
[92mINFO [0m:      Evaluating initial global parameters
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 5 clients (out of 5)


Server sending model size: 7 MB
Server sending model size: 7 MB
Server sending model size: 7 MB
Server sending model size: 7 MB
Server sending model size: 7 MB


[2m[36m(ClientAppActor pid=1093841)[0m [92mINFO [0m:      {'fitins.parameters': {'parameters': 1891904, 'bytes': 7568906}}
[2m[36m(ClientAppActor pid=1093841)[0m [92mINFO [0m:      Total parameters transmitted: 7568906 bytes
[2m[36m(ClientAppActor pid=1093841)[0m [92mINFO [0m:      {'fitins.parameters': {'parameters': 1891904, 'bytes': 7568906}}
[2m[36m(ClientAppActor pid=1093841)[0m [92mINFO [0m:      Total parameters transmitted: 7568906 bytes
[2m[36m(ClientAppActor pid=1093841)[0m [92mINFO [0m:      {'fitins.parameters': {'parameters': 1891904, 'bytes': 7568906}}
[2m[36m(ClientAppActor pid=1093841)[0m [92mINFO [0m:      Total parameters transmitted: 7568906 bytes
[2m[36m(ClientAppActor pid=1093841)[0m [92mINFO [0m:      {'fitins.parameters': {'parameters': 1891904, 'bytes': 7568906}}
[2m[36m(ClientAppActor pid=1093841)[0m [92mINFO [0m:      Total parameters transmitted: 7568906 bytes
[2m[36m(ClientAppActor pid=1093841)[0m [92mINFO [0m:   

Server receiving model size: 7 MB
Server receiving model size: 7 MB
Server receiving model size: 7 MB
Server receiving model size: 7 MB
Server receiving model size: 7 MB


In [16]:
print("Total bandwidth used for one round: {} MB".format(sum(bandwidth_sizes)))

Total bandwidth used: 70 MB


For whole process: 70 MB * 5 (clients) * 10 (rounds) * 1 (fraction of selected clients) = 700 MB