# Federated Learning Algorithm Prototype
## PFL-Research Framework by Apple
#### Alian Haidar ahaidar@apple.com

Dataset Used: FLAIR 

## Data Preparation

In [4]:
!pip install torch

Collecting torch
  Downloading torch-2.3.0-cp312-none-macosx_11_0_arm64.whl.metadata (26 kB)
Collecting filelock (from torch)
  Downloading filelock-3.14.0-py3-none-any.whl.metadata (2.8 kB)
Collecting sympy (from torch)
  Using cached sympy-1.12-py3-none-any.whl.metadata (12 kB)
Collecting fsspec (from torch)
  Using cached fsspec-2024.3.1-py3-none-any.whl.metadata (6.8 kB)
Collecting mpmath>=0.19 (from sympy->torch)
  Using cached mpmath-1.3.0-py3-none-any.whl.metadata (8.6 kB)
Downloading torch-2.3.0-cp312-none-macosx_11_0_arm64.whl (61.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 MB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0mm
[?25hDownloading filelock-3.14.0-py3-none-any.whl (12 kB)
Using cached fsspec-2024.3.1-py3-none-any.whl (171 kB)
Using cached sympy-1.12-py3-none-any.whl (5.7 MB)
Using cached mpmath-1.3.0-py3-none-any.whl (536 kB)
Installing collected packages: mpmath, sympy, fsspec, filelock, torch
Successfully installed fi

In [None]:
import matplotlib.pyplot as plt
import h5py
import numpy as np
import pandas as pd
import sys
import torch
import nest_asyncio


nest_asyncio.apply()
torch.random.manual_seed(1)
np.random.seed(1)

from pfl.model.pytorch import PyTorchModel

In [None]:
from common import (get_multi_hot_targets, get_label_mapping, get_user_num_images)

hdf5_path = 'flair_federated_small.hdf5'
# A dictionary mapping class name to an output index.
classes = get_label_mapping(hdf5_path, use_fine_grained_labels=False)
num_classes = len(classes)

# A dictionary mapping each user id to number of images.
user_num_images = get_user_num_images(hdf5_path, 'train')
user_ids = sorted(list(user_num_images.keys()))

display('Coarse grained classes in FLAIR:')
display((classes))
display('User dataset sizes statistics')
display(pd.Series(user_num_images.values()).describe().apply("{0:.1f}".format))

For testing purposes, define random sampler function uniformly

In [None]:
user_sampler = lambda: user_ids[np.random.randint(0, len(user_ids))]
print('sampled 10 users:', [user_sampler() for _ in range(10)])

In [None]:
from pfl.data.dataset import Dataset
from pfl.data.federated_dataset import FederatedDataset
# `pfl.internal.ops` contains useful helper functions for manipulating tensors.
from pfl.internal.ops import pytorch_ops as ops

def make_dataset_fn(user_id):
    with h5py.File(hdf5_path, 'r') as h5:
        inputs = (np.array(h5[f'/train/{user_id}/images']) - 128) / 255.
        # Get multi-hot labels for user.
        # The zip of `row_indices` and `col_indices` is the sparse matrix of labels for a user.
        row_indices = np.array(h5[f'/train/{user_id}/labels_row'])        
        col_indices = np.array(h5[f'/train/{user_id}/labels_col'])
        # Convert to a dense matrix of labels.
        targets = np.zeros((len(inputs), 17), dtype=np.float32)
        targets[row_indices, col_indices] = 1

    return Dataset((
        ops.to_tensor(inputs), 
        ops.to_tensor(targets)), user_id=user_id)

train_federated_dataset = FederatedDataset(make_dataset_fn, user_sampler)

In [None]:
import itertools

user, seed = next(train_federated_dataset)
print('User: {}\nunique user seed: {}\ndataset length: {}\nfirst 10 images:'.format(user.user_id, seed, len(user)))
fig, axes = plt.subplots(1,min(len(user),10),figsize=(20,12))
for ax, image, label in itertools.islice(zip(axes, *user.raw_data),10):
    ax.set_title('labels={}'.format(torch.nonzero(label).squeeze().tolist()))
    ax.imshow((image.cpu().numpy()*255+128).astype(np.uint8))

In [None]:
from common import make_central_datasets
inputs_all, targets_all = [], []
with h5py.File(hdf5_path, 'r') as h5:
    for user_id in h5[f'/val'].keys():
        inputs = (np.array(h5[f'/val/{user_id}/images']) - 128) / 255.
        # Get multi-hot labels for user.
        row_indices = np.array(h5[f'/val/{user_id}/labels_row'])
        col_indices = np.array(h5[f'/val/{user_id}/labels_col'])
        targets = np.zeros((len(inputs), 17), dtype=np.float32)
        targets[row_indices, col_indices] = 1
        inputs_all.append(inputs)
        targets_all.append(targets)

inputs_all = np.vstack(inputs_all)
targets_all = np.vstack(targets_all)
data_tensors = [inputs_all, targets_all]
central_data = Dataset(raw_data=data_tensors)
print('data shape:', [t.shape for t in central_data.raw_data])
print('fraction of positive labels:', central_data.raw_data[1].sum())

# Initial Model Definition
First define a PyTorch model similarly to standard centralized training

In [None]:
from typing import Dict, Optional
from pfl.metrics import Weighted
import torch
from torchvision.models import resnet18, ResNet18_Weights

# Initialize model with pretrained weights.
weights = ResNet18_Weights.DEFAULT
pytorch_model = resnet18(weights=weights)

# Modify final classification layer.
num_ftrs = pytorch_model.fc.in_features
pytorch_model.fc = torch.nn.Linear(num_ftrs, 17)

# Freeze all layers.
for param in pytorch_model.parameters():
    param.requires_grad = False

# Then unfreeze the last dense layer and final resnet block.
for param in list(pytorch_model.fc.parameters()) + list(pytorch_model.layer4.parameters()):
    param.requires_grad = True

loss_fn = torch.nn.BCEWithLogitsLoss()

def loss(inputs: torch.Tensor, targets: torch.Tensor, eval: bool = False) -> torch.Tensor:
    pytorch_model.eval() if eval else pytorch_model.train()
    return loss_fn(pytorch_model(inputs.permute((0,3,1,2))), targets)


@torch.no_grad()
def metrics(inputs: torch.Tensor,
             targets: torch.Tensor,
             eval: bool = True) -> Dict[str, Weighted]:
    pytorch_model.eval() if eval else pytorch_model.train()
    logits = pytorch_model(inputs.permute((0,3,1,2)))
    num_samples = len(inputs)
    num_predictions = targets.numel()
    correct = torch.sum(torch.eq((logits > 0.0).float(), targets))

    loss = loss_fn(logits, targets).item()
    return {
        "loss": Weighted(loss, num_samples),
        "accuracy": Weighted(correct, num_predictions)
    }

pytorch_model.loss = loss
pytorch_model.metrics = metrics
pytorch_model

In [None]:
params = [p for p in pytorch_model.parameters() if p.requires_grad]

model = PyTorchModel(pytorch_model, 
                     local_optimizer_create=torch.optim.SGD,
                     central_optimizer=torch.optim.SGD(params, 1.0))

# Save initial model
model.save('flair_model')

# Train model with Private Federated Learning
Utilize PFL-Research `Backend` component to collect and aggregate statistics from users

In [None]:
from pfl.aggregate.simulate import SimulatedBackend

cohort_size = 10
central_num_iterations = 5

# Instantiate simulated federated averaging
simulated_backend = SimulatedBackend(
    training_data=train_federated_dataset,
    val_data=None)

In [None]:
from pfl.algorithm import FederatedAveraging, NNAlgorithmParams
from pfl.callback import CentralEvaluationCallback
from pfl.hyperparam import NNTrainHyperParams, NNEvalHyperParams

model_train_params = NNTrainHyperParams(
    local_learning_rate=0.01,
    local_num_epochs=2,
    local_batch_size=16)

model_eval_params = NNEvalHyperParams(local_batch_size=20)

algorithm_params = NNAlgorithmParams(
    central_num_iterations=central_num_iterations,
    evaluation_frequency=4,
    train_cohort_size=cohort_size,
    val_cohort_size=0)

callbacks = [
    CentralEvaluationCallback(
        central_data,
        model_eval_params=model_eval_params,
        frequency=4),
]

model = FederatedAveraging().run(
    algorithm_params=algorithm_params,
    backend=simulated_backend,
    model=model,
    model_train_params=model_train_params,
    model_eval_params=model_eval_params,
    callbacks=callbacks)


# Custom FL Algorithm with Onion-like Routing Protocols

In [None]:
from pfl.algorithm.base import FederatedNNAlgorithm
from pfl.metrics import Metrics

central_opt = torch.optim.Adam([p for p in pytorch_model.parameters() if p.requires_grad], lr=1.0)
        
class MyAlgorithm(FederatedNNAlgorithm):
        
    def process_aggregated_statistics(self, central_context, aggregate_metrics, model, stats):
        stats.average()
        
        # Below is equivalent to 
        # return model.apply_model_update(statistics)
        central_opt.zero_grad()
        for name, var in model.pytorch_model.named_parameters():
            if not var.requires_grad:
                # Frozen variable
                continue
            if var.grad is None:
                var.grad = torch.zeros_like(var)
            var.grad.data.copy_(-1*stats[name])
        central_opt.step()
        
        return model, Metrics([('I updated it', 1.0)])

    
    def train_one_user(self, initial_state, model, user_dataset, central_context):
        opt = torch.optim.SGD(model.pytorch_model.parameters(), lr=0.1)
        opt.zero_grad()
        for x, y in user_dataset.iter(5):
            model.pytorch_model.loss(x, y).backward()
            opt.step()
        return model.get_model_difference(initial_state), Metrics([('I trained it', 1.0)])

In [None]:
# Reset from initial weights
model.load('flair_model')

model = MyAlgorithm().run(
    algorithm_params=algorithm_params,
    backend=simulated_backend,
    model=model,
    model_train_params=model_train_params,
    model_eval_params=model_eval_params,
    callbacks=callbacks)