In [1]:
!pip install ivon-opt



In [2]:
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import torch.nn.functional as F
import torch
from collections import defaultdict
import numpy as np
import ivon
from torch import Tensor
from typing import Tuple
import pickle
from torch.optim.lr_scheduler import CosineAnnealingLR

# Utils

In [3]:
# subclass of IVON to compute sampling using the covariance
class IVON_SAMP(ivon.IVON):
    # A subclass of IVON that rewrites _sample_params
    # To make the sampling with a given covariance instead of the ess and the hess parameters in the original verison used to compute the covariance
    def __init__(self, *args, cov, **kwargs):
        super().__init__(*args, **kwargs)
        for group in self.param_groups:
            group["cov"] = cov
    def _sample_params(self) -> Tuple[Tensor, Tensor]:
            noise_samples = []
            param_avgs = []
            offset = 0
            for group in self.param_groups:
                gnumel = group["numel"]
                #noise_sample = (
                #    torch.randn(gnumel, device=self._device, dtype=self._dtype)
                #    / (
                #        group["ess"] * (group["hess"] + group["weight_decay"])
                #    ).sqrt()
                #)
                cov = group["cov"].astype(np.float32)
                noise_sample = (
                    torch.randn(gnumel, device=self._device, dtype=torch.float32)
                    * torch.from_numpy(np.sqrt(cov)).to(self._device)
                )

                noise_samples.append(noise_sample)

                goffset = 0
                for p in group["params"]:
                    if p is None:
                        continue

                    p_avg = p.data.flatten()
                    numel = p.numel()
                    p_noise = noise_sample[offset : offset + numel]

                    param_avgs.append(p_avg)
                    p.data = (p_avg + p_noise).view(p.shape)
                    goffset += numel
                    offset += numel
                assert goffset == group["numel"]  # sanity check
            assert offset == self._numel  # sanity check

            return torch.cat(param_avgs, 0), torch.cat(noise_samples, 0)

In [5]:
# aggregation functions
def eaa(means, covs, weights=None): #Empirical Arithmetic Aggregation
    mu = np.average(means, weights=weights, axis=0)
    cov = np.average(covs, weights=weights, axis=0)
    return mu, cov

def gaa(means, covs, weights=None): #Gaussian Arithmetic Aggregation
    mu = np.average(means, weights=weights, axis=0)
    weights_squared = [weight**2 for weight in weights]
    cov = np.average(covs, weights=weights_squared, axis=0)
    return mu, cov

def aalv(means, covs, weights=None): #Arithmetic Aggregation with Log Variance
    mu = np.average(means, weights=weights, axis=0)
    cov = np.exp(np.average(np.log(covs), weights=weights, axis=0))
    return mu, cov



def forward_kl_barycenter(means, covs, weights=None):
    mu = np.average(means, weights=weights, axis=0)
    toavg = covs + (means - mu)**2
    cov = np.average(toavg, weights=weights, axis=0)
    return mu, cov

def reverse_kl_barycenter(means, covs, weights=None):  # Kullback Leiber Average
    inverted_covs = [1/cov for cov in covs]
    cov = 1/(np.average(inverted_covs, weights=weights, axis=0))
    inverted_covs_time_means = [(inverted_covs[i] * means[i]).reshape(-1,1) for i in range(len(means))]
    cov = cov.reshape(-1,1)
    mu = cov * np.average(inverted_covs_time_means, weights=weights, axis=0)
    return mu, cov


def wasserstein_barycenter_diag(means, covs, weights = None):
    #check if matrices are diagonal
    assert all([ np.sum(K.cpu() > 1e-10) == K.cpu().shape[0] for K in covs]), \
        "NotDiagonal: One of the covariance matrices is not diagonal."
    #mu = np.average(means, weights=weights, axis=0)
    mu = np.average(means.cpu(), weights=weights, axis=0)
    cov = np.average([np.sqrt(K.cpu()) for K in covs], weights=weights, axis=0)**2  # for ivon
    #cov = np.average([np.sqrt(K) for K in covs], weights=weights, axis=0)**2  #for non ivon
    return mu, cov


In [6]:
#utils for aggregation

def flatten_model_state_dict(state_dict):
    vec = []
    for param_tensor in state_dict:
        vec.append(state_dict[param_tensor].view(-1))
    return torch.cat(vec)

def unflatten_model_state_dict(vec, state_dict):
    state_dict = state_dict.copy()
    idx = 0
    for param_tensor in state_dict:
        param = state_dict[param_tensor]
        size = param.numel()
        state_dict[param_tensor] = torch.from_numpy( vec[idx:idx+size].reshape(param.size()) )
        idx += size
    return state_dict


def aggregate_param(means, covs, method, weights=None):
    means = means.cpu()
    covs = covs.cpu()
    if method == 'eaa':
        mu , cov = eaa(means, covs, weights) # works with the diagonals of the covariance matrices
        return mu, cov
    elif method == 'wb_diag':
        mu , cov = wasserstein_barycenter_diag(means, covs, weights)
        return mu, cov
    elif method == 'fkl':
        mu , cov = forward_kl_barycenter(means, covs, weights)
        return mu, cov
    elif method == 'rkl':
        mu , cov = reverse_kl_barycenter(means, covs, weights)
        return mu, cov
    elif method == 'gaa':
        mu , cov = gaa(means, covs, weights)
        return mu, cov
    elif method == 'aalv':
        mu , cov = aalv(means, covs, weights)
        return mu, cov
    else:
        raise ValueError(f"update method {method} non implemented!")


def aggregate_ivon(means, covs, update_method, global_model, weights=None):
    means = [flatten_model_state_dict(m) for m in means]
    means = torch.stack(means)
    covs = torch.stack(covs)
    mu_agg , cov_agg = aggregate_param(means, covs, update_method, weights)
    agg_state_dict = unflatten_model_state_dict(mu_agg, global_model.state_dict())
    return agg_state_dict , cov_agg

In [7]:
# scores of evaluation
def ece(preds, target, device, minibatch=True):
    confidences, predictions = torch.max(preds, 1)
    target = target.float() # to avoid a warning error
    _, target_cls = torch.max(target, 1)
    accuracies = predictions.eq(target_cls)
    n_bins = 100
    bin_boundaries = torch.linspace(0, 1, n_bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]

    ece = torch.zeros(1, device=device)
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
        prop_in_bin = in_bin.float().mean()
        if prop_in_bin.item() > 0:
            accuracy_in_bin = accuracies[in_bin].float().mean()
            avg_confidence_in_bin = confidences[in_bin].mean()
            ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

    return ece.item()

def nll(preds, target, minibatch=True):
    logpred = torch.log(preds + 1e-8)
    if minibatch:
        return -(logpred * target).sum(1).item()
    else:
        return -(logpred * target).sum(1).mean().item()

def acc(preds, target, minibatch=True):
    preds = preds.float() #to avoid a warning error
    preds = preds.argmax(1)
    target = target.float() # to avoid a warning error
    target = target.argmax(1)
    if minibatch:
        return (((preds == target) * 1.0).sum() * 100).item()
    else:
        return (((preds == target) * 1.0).mean() * 100).item()


In [8]:
# Define a CNN
class CNN(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim=10, input_channel=3):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(input_channel, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)

        self.fc1 = nn.Linear(input_dim, hidden_dims[0])
        self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1])
        self.fc3 = nn.Linear(hidden_dims[1], output_dim)
        self.input_dim = input_dim
        self.output_dim = output_dim

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, self.input_dim)

        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


In [9]:
# Parameters
NUM_CLIENTS = 10
EPOCHS = 5
BATCH_SIZE = 128
LR = 0.1
REG = 1e-4
DEVICE = 'mps'
LOGDIR = './logs'
N_SAMPLE = 0
HESS_INIT = 0.1

input_channel = 1
input_dim = (16 * 4 * 4)
hidden_dims=[120, 84]
output_dim = 10

In [10]:
# Training function for local client
def train_ivon(model, dataloader, criterion, optimizer, epochs, client_id, train_samples=3):
    scheduler = CosineAnnealingLR(optimizer, T_max=epochs, eta_min=0.01)
    model = model.to(DEVICE)
    for _ in range(epochs):
        for x, target in dataloader:
            x, target = x.to(DEVICE), target.to(DEVICE)
            for _ in range(train_samples):
                with optimizer.sampled_params(train=True):
                    logit = model(x)
                    loss = criterion(logit, target)
                    loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                scheduler.step()
    torch.save(optimizer.state_dict(), f"{LOGDIR}/clients/client_{client_id}_optimizer.pt")
    return model.state_dict()

In [12]:
# load the data_statistics 
with open('train_data_statistics_0.pkl', 'rb') as f:
    train_data_statistics = pickle.load(f)

# Generate a uniform split dictionary (if needed)
# from collections import defaultdict

# def generate_uniform_split(num_clients=10, num_classes=10, total_samples=60000):
#     samples_per_client = total_samples // num_clients
#     samples_per_class_per_client = samples_per_client // num_classes

#     split_dict = defaultdict(dict)
#     for client_id in range(num_clients):
#         for class_id in range(num_classes):
#             split_dict[client_id][class_id] = samples_per_class_per_client

#     return split_dict
# split_dict = generate_uniform_split(num_clients=NUM_CLIENTS, num_classes=10, total_samples=len(full_dataset))


In [13]:
# create the data from the statistics
def create_client_datasets_from_split(full_dataset, split_dict):
    # 1. Organize all indices by class
    class_indices = defaultdict(list)
    for idx, (_, label) in enumerate(full_dataset):
        class_indices[label].append(idx)

    # 2. Shuffle class indices for randomness
    for c in class_indices:
        np.random.shuffle(class_indices[c])

    # 3. Build each client dataset
    client_datasets = []
    class_counters = {c: 0 for c in range(10)}  # track used indices per class

    for client_id in sorted(split_dict.keys()):
        client_indices = []
        for class_label, count in split_dict[client_id].items():
            start = class_counters[class_label]
            end = start + count
            client_indices.extend(class_indices[class_label][start:end])
            class_counters[class_label] = end  # update the counter
        # Append subset for this client
        client_datasets.append(Subset(full_dataset, client_indices))

    return client_datasets


In [14]:
# Load MNIST
transform = transforms.ToTensor()
full_dataset = datasets.MNIST(root=".", train=True, download=True, transform=transform)

client_datasets = create_client_datasets_from_split(full_dataset, train_data_statistics)
# client_datasets = create_client_datasets_from_split(full_dataset, split_dict)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:05<00:00, 1703469.78it/s]


Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 61773.32it/s]


Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:01<00:00, 1333267.16it/s]


Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 2041419.71it/s]


Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw



In [15]:
import os
os.mkdir(LOGDIR)
os.mkdir(f"{LOGDIR}/clients")

# BFL with IVON

In [16]:
# Global model
global_model = CNN(input_dim, hidden_dims, output_dim=output_dim, input_channel=input_channel)
global_model.to(DEVICE)
global_optimizer = IVON_SAMP(global_model.parameters(), lr=LR, ess=1, weight_decay=REG, beta1=0.9, hess_init=HESS_INIT, cov=None)
torch.save(global_optimizer.state_dict(), f"{LOGDIR}/global_optimizer.pt")

for round in range(1, 5):
    local_weights = []
    covs = []
    print(f"\n--- Round {round} ---")

    for i in range(NUM_CLIENTS):
        local_model = CNN(input_dim, hidden_dims, output_dim=output_dim, input_channel=input_channel)

        # Initialize local model with global model
        local_model.load_state_dict(global_model.state_dict())
        train_loader = DataLoader(client_datasets[i], batch_size=BATCH_SIZE, shuffle=True)
        N = len(train_loader.dataset)
        local_model.to(DEVICE)
        optimizer = ivon.IVON(
            local_model.parameters(),
            lr=LR,
            weight_decay=REG,
            ess=N,
            beta1=0.9,
            hess_init=HESS_INIT
            )
        # If not the first round, set the covariance from the global model
        if round > 1:
            global_cov = global_optimizer.param_groups[0]["cov"]
            global_hessian = (1/ (N * global_cov)) - REG
            optimizer.param_groups[0]["hess"] = torch.from_numpy(global_hessian.astype(np.float32)).to(DEVICE)


        criterion = nn.CrossEntropyLoss()
        local_w = train_ivon(local_model, train_loader, criterion, optimizer, EPOCHS, client_id=i)
        optimizer.load_state_dict(torch.load(f"{LOGDIR}/clients/client_{i}_optimizer.pt"))
        # Compute covariance
        cov = 1 / (N * optimizer.state_dict()['param_groups'][0]['hess'] + optimizer.state_dict()['param_groups'][0]['weight_decay'])
        local_weights.append(local_w) # collect the local models (means)
        covs.append(cov) # collect the local covariances


    # Average weights and covariances
    update_method = 'eaa' # choose the update method
    weights = [len(ds) for ds in client_datasets]
    weights = np.array(weights) / sum(weights)
    averaged_weights , average_cov  = aggregate_ivon(local_weights, covs, update_method, global_model, weights=weights)


    # Update global model with averaged weights and the global optimizer with averaged covariance
    global_model.load_state_dict(averaged_weights)
    global_optimizer = IVON_SAMP(global_model.parameters(), lr=LR, ess=1, weight_decay=REG, beta1=0.9, hess_init=HESS_INIT, cov=average_cov) #the ess is set to be 1 to pass the assertion. In all the cases it won't be used
    torch.save(global_optimizer.state_dict(), f"{LOGDIR}/global_optimizer.pt")


    # Update global optimizer with averaged covariance
    global_optimizer.load_state_dict(torch.load(f"{LOGDIR}/global_optimizer.pt", weights_only=False))
    global_optimizer.param_groups[0]["cov"] = average_cov
    torch.save(global_optimizer.state_dict(), f"{LOGDIR}/global_optimizer.pt")

    # Evaluate on small test set
    global_model.eval()
    preds = []
    targets = []
    test_loader = DataLoader(datasets.MNIST('.', train=False, transform=transform), batch_size=1000)


    with torch.no_grad():
        for batch_idx, (x, target) in enumerate(test_loader):
            x, target = x.to(DEVICE), target.to(DEVICE,dtype=torch.int64)
            outs = []
            if N_SAMPLE == 0: #predict with the mean
                out = global_model(x)
                out = F.softmax(out, 1)
                outs.append(out)

            else :
                for _ in range(N_SAMPLE):
                    with global_optimizer.sampled_params():
                        out = global_model(x)
                        out = F.softmax(out, 1)
                        outs.append(out)

            preds.append(torch.stack(outs).mean(0))
            targets.append(F.one_hot(target, 10))
    targets = torch.cat(targets)
    preds = torch.cat(preds)

    _acc = acc(preds, targets, minibatch=False)
    _ece = ece(preds, targets, DEVICE, minibatch=False)
    _nll = nll(preds, targets, minibatch=False)
    print(f"Global Test Accuracy: {_acc:.2f}%")
    print(f"Global Test ECE: {_ece:.2f}")
    print(f"Global Test NLL: {_nll:.2f}")


--- Round 1 ---
Global Test Accuracy: 77.02%
Global Test ECE: 0.42
Global Test NLL: 1.21

--- Round 2 ---
Global Test Accuracy: 93.03%
Global Test ECE: 0.03
Global Test NLL: 0.22

--- Round 3 ---
Global Test Accuracy: 95.75%
Global Test ECE: 0.01
Global Test NLL: 0.13

--- Round 4 ---
Global Test Accuracy: 96.77%
Global Test ECE: 0.01
Global Test NLL: 0.10


# FedAVG with SGD

In [19]:
# FedAVG with SGD

# Parameters
LR = 0.01


# FedAvg: average model parameters
def average_weights(w_list):
    avg_w = {}
    for key in w_list[0].keys():
        avg_w[key] = sum(w[key] for w in w_list) / len(w_list)
    return avg_w

# Training function for local client
def train(model, dataloader, criterion, optimizer, epochs):
    model.train()
    for _ in range(epochs):
        for data, target in dataloader:
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
    return model.state_dict()

# Global model
global_model = CNN(input_dim, hidden_dims, output_dim=output_dim, input_channel=input_channel)

for round in range(1, 5):
    local_weights = []
    print(f"\n--- Round {round} ---")

    for i in range(NUM_CLIENTS):
        local_model = CNN(input_channel=input_channel, input_dim=input_dim, hidden_dims=hidden_dims, output_dim=output_dim)
        local_model.load_state_dict(global_model.state_dict())
        train_loader = DataLoader(client_datasets[i], batch_size=BATCH_SIZE, shuffle=True)

        optimizer = optim.SGD(local_model.parameters(), lr=LR)
        criterion = nn.CrossEntropyLoss()
        local_w = train(local_model, train_loader, criterion, optimizer, EPOCHS)
        local_weights.append(local_w)

    # Average weights and update global model
    averaged_weights = average_weights(local_weights)
    global_model.load_state_dict(averaged_weights)

    # Evaluate on small test set
    test_loader = DataLoader(datasets.MNIST('.', train=False, transform=transform), batch_size=1000)
    global_model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in test_loader:
            pred = global_model(x).argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)
    print(f"Global Test Accuracy: {100 * correct / total:.2f}%")



--- Round 1 ---
Global Test Accuracy: 9.74%

--- Round 2 ---
Global Test Accuracy: 29.91%

--- Round 3 ---
Global Test Accuracy: 68.76%

--- Round 4 ---
Global Test Accuracy: 77.25%
