In [6]:
import logging
import sys
import time
from typing import Callable, Dict, List, Optional, OrderedDict, Tuple, Union
import torch
import numpy as np
import torch.nn as nn
import torchvision
import yaml
from datasets import Dataset
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import (
    DirichletPartitioner,
    LinearPartitioner,
    SizePartitioner,
    SquarePartitioner,
)
from torch.utils.data import DataLoader
from flwr.common.typing import NDArrays
import logging

logger = logging.getLogger(__name__)
from flwr.client import Client
from flwr.common import (
    Code,
    Context,
    EvaluateIns,
    EvaluateRes,
    FitIns,
    FitRes,
    NDArrays, 
    Status,
    ndarrays_to_parameters,
    Parameters, 
    parameters_to_ndarrays,
    Scalar,
)
import random

import flwr as fl
import pandas as pd
from flwr.server.client_proxy import ClientProxy
from pyarrow import feather

from flwr.server.strategy import FedAvg
import shapley_values
#import timm

ModuleNotFoundError: No module named 'timm'

In [None]:
# For faster computation, we load datasets to the GPU as a dedicated CUDA_VisionDataSet

class TrainingCalls:
    get_model: Callable = None
    train: Callable = None
    test: Callable = None
    get_parameters: Callable = None
    set_parameters: Callable = None
    load_data: Callable = None
    load_global_test_data: Callable = None
    get_initial_parameters: Callable = None


class GlobalArgs:
    save_name: str = str(time.time())
    num_clients: int = 2
    epochs: int = 5
    seed: int = 0
    alpha: float = 1.0
    model_name: str = "mobilenetv3_small_050"
    dsname: str = "mnsit"
    sampler: shapley_values.Sampler = shapley_values.FullSampler

    max_rounds: int = 5
    k: int = 5



def ndarrays_from_model(model: torch.nn.ModuleList) -> List[np.ndarray]:
    """Get model weights as a list of NumPy ndarrays."""
    return [val.cpu().numpy() for _, val in model.state_dict().items()]


def ndarrays_to_model(model: torch.nn.ModuleList, params: List[np.ndarray]):
    """Set model weights from a list of NumPy ndarrays."""
    params_dict = zip(model.state_dict().keys(), params)
    state_dict = OrderedDict({k: torch.from_numpy(np.copy(v)) for k, v in params_dict})
    model.load_state_dict(state_dict, strict=True)


In [None]:
global_args: GlobalArgs = GlobalArgs()
global_args.seed = 1
#global_args.save_name = f"cifar100_{SEED}_{time.time()}"
global_args.num_clients = 5
global_args.max_rounds = 2
global_args.fraction_fit = 1
global_args.alpha = 100
global_args.sampler = shapley_values.LeaveOneOutSampler

In [None]:
# Loading the model (Called when initializing FlowerClient and when testing)
def get_model(model_name: str = 'resnet18') -> torch.nn.Module:
    return torch.hub.load(
        "pytorch/vision:v0.10.0", "densenet121", weights=None, verbose=False
    ).cuda()
    #return timm.create_model(model_name).cuda()

transform_train = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.RandomHorizontalFlip(),
        #torchvision.transforms.Resize((224,224)),
    ]
)

transform_test = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        #torchvision.transforms.Resize((224,224)),
    ]
)

def split_data(dsname, num_clients, alpha):
    partitioner = DirichletPartitioner(
        num_partitions=num_clients,
        partition_by="label",
        alpha=alpha,
        min_partition_size=30,
    )
    return FederatedDataset(
        dataset=dsname,
        partitioners={"train": partitioner},
        trust_remote_code=True,
    )


fds = split_data(global_args.dsname, global_args.num_clients, global_args.alpha)

def load_data(partition_id: int, batch_size: int = 256) -> DataLoader:
    dataset = fds.load_partition(partition_id=partition_id)

    def apply_train_transforms(batch):
        """Apply transforms to the partition from FederatedDataset."""
        batch["image"] = [transform_train(img.convert('RGB')) for img in batch["image"]]
        return batch

    dataset = dataset.with_transform(apply_train_transforms)
    trainloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True,
    )
    return trainloader


# Load the (global) test dataset
def load_global_test_data() -> DataLoader:
    testset = fds.load_split("valid")

    def apply_test_transforms(batch):
        """Apply transforms to the partition from FederatedDataset."""
        batch["image"] = [transform_test(img.convert('RGB')) for img in batch["image"]]
        return batch

    testset = testset.with_transform(apply_test_transforms)
    testloader = DataLoader(
        testset,
        batch_size=64,
        shuffle=True,
        drop_last=True,
    )
    return testloader


# Train and test on a trainloader and testloader
def train(model: nn.Module, trainloader: DataLoader, **kwargs):
    """Train the model on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    current_round = kwargs["ins"].config["current_round"]
    optimizer = torch.optim.Adam(
        params=model.parameters(), lr=1e-3
    )
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    for _ in range(5):
        for i, data in enumerate(trainloader, 0):
            if i > 10:
                pass
            images, labels = data["image"].to(device), data["label"].to(device)
            optimizer.zero_grad()
            loss = criterion(model(images), labels)
            loss.backward()
            optimizer.step()


def test(
    model: nn.Module,
    testloader: DataLoader,
) -> Tuple[float, float]:
    """Validate the model on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    if len(testloader) == 0:
        return np.inf, 0
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    with torch.no_grad():
        for i, data in enumerate(testloader, 0):
            if i > 10:
                pass
            images, labels = data["image"].to(device), data["label"].to(device)
            outputs = model(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = correct / total
    del (testloader, model)
    torch.cuda.empty_cache()
    return loss, accuracy


def ndarrays_from_model(model: torch.nn.ModuleList) -> NDArrays:
    """Get model weights as a list of NumPy ndarrays."""
    return [val.cpu().numpy() for _, val in model.state_dict().items()]


def ndarrays_to_model(model: torch.nn.ModuleList, params: NDArrays):
    """Set model weights from a list of NumPy ndarrays."""
    params_dict = zip(model.state_dict().keys(), params)
    state_dict = OrderedDict({k: torch.from_numpy(np.copy(v)) for k, v in params_dict})
    model.load_state_dict(state_dict, strict=True)


# All training calls, to be sent to the server and clients


from flwr.common import ndarrays_to_parameters


def evaluate_fn(server_round, weights_aggregated, dict, **kwargs):
    model = get_model()
    ndarrays_to_model(model, weights_aggregated)
    loss, accuracy = test(model, load_global_test_data())
    del model
    torch.cuda.empty_cache()
    return -loss, {"accuracy": accuracy}


def get_initial_parameters():
    init_model = get_model()
    initial_parameters = ndarrays_to_parameters(ndarrays_from_model(init_model))
    del init_model
    torch.cuda.empty_cache()
    return initial_parameters


def fit_config(server_round: int):
    """Generate training configuration for each round."""
    # Create the configuration dictionary
    config = {
        "current_round": server_round,
    }
    return config


client_resources: dict = {
    "num_cpus": 1,
    "num_gpus": 0.33333,
}

In [None]:

class FlowerClient(Client):
    def __init__(
        self,
        client_id: int,
        training_calls: TrainingCalls,
    ) -> None:
        self.client_id = client_id
        self.training_calls = training_calls
        self.num_examples = len(self.training_calls.load_data(self.client_id))

        self.round = 0
        self.batchsize = 64

    def fit(self, ins: FitIns) -> FitRes:

        # Deserialize parameters to NumPy ndarray's
        parameters_original = ins.parameters
        ndarrays_original = parameters_to_ndarrays(parameters_original)
        # Update local model, train, get updated parameters
        model = self.training_calls.get_model()
        self.training_calls.set_parameters(model, ndarrays_original)
        trainloader = self.training_calls.load_data(self.client_id, self.batchsize)
        self.training_calls.train(
            model=model,
            trainloader=trainloader,
            ins=ins,
        )
        ndarrays_updated = self.training_calls.get_parameters(model)
        del trainloader, model
        torch.cuda.empty_cache()
        # Serialize ndarray's into a Parameters object
        parameters_updated = ndarrays_to_parameters(ndarrays_updated)
        # Build and return response
        status = Status(code=Code.OK, message="Success")
        logging.info(f"Client {self.client_id} successfully trained.")
        return FitRes(
            status=status,
            parameters=parameters_updated,
            num_examples=self.num_examples,
            metrics={"client_id": self.client_id},
        )

    def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
        status = Status(code=Code.OK, message="Success")
        return EvaluateRes(
            status=status,
            loss=np.inf,
            num_examples=self.num_examples,
            metrics={"accuracy": 0},
        )

In [None]:
import shapley_values
import flwr_contributions


def create_contribution_strategy(
    parent_strategy: fl.server.strategy,
    initial_parameters: Parameters,
    trainingcalls: TrainingCalls,
):

    class FedContribution(parent_strategy):
        def __init__(
            self,
            initial_parameters: Parameters = initial_parameters,
            fraction_evaluate: float = 1.0,
            min_fit_clients: int = 2,
            min_evaluate_clients: int = 2,
            min_available_clients: int = 2,
            evaluate_fn: Optional[
                Callable[
                    [int, NDArrays, Dict[str, Scalar]],
                    Optional[Tuple[float, Dict[str, Scalar]]],
                ]
            ] = None,
            on_fit_config_fn: Callable = None,
        ) -> None:
            super().__init__()

            self.evaluate_fn = evaluate_fn
            self.results: List[Tuple[ClientProxy, FitRes]] = None
            self.on_fit_config_fn = on_fit_config_fn

            self.initial_parameters = initial_parameters
            self.fraction_fit = globalargs.fraction_fit
            self.fraction_evaluate = fraction_evaluate
            self.min_fit_clients = min_fit_clients
            self.min_evaluate_clients = min_evaluate_clients
            self.min_available_clients = min_available_clients

            self.trainingcalls = trainingcalls

            self.contribution_dict = []

            self.times = []

        def aggregate_fit(
            self,
            server_round: int,
            results: List[Tuple[ClientProxy, FitRes]],
            failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
        ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
            """Aggregate fit results using weighted average."""
            if not results:
                return None, {}

            round_contributions = shapley_values.multi_round_reconstruction(server_round, 
                                                                            results, 
                                                                            failures, 
                                                                            self.evaluate_fn, 
                                                                            super().aggregate_fit,
                                                                            global_args.sampler,
                                                                            sample_ratio = .5
                                                                           )
            print(round_contributions)
            self.contribution_dict.append(round_contributions)
            return super().aggregate_fit(
                server_round, results, failures
            )

        def evaluate(
            self, server_round: int, parameters: Parameters
        ) -> Optional[Tuple[float, Dict[str, Scalar]]]:
            """Evaluate model parameters using an evaluation function."""
            if self.evaluate_fn is None:
                # No evaluation function provided
                return None
            parameters_ndarrays = parameters_to_ndarrays(parameters)
            eval_res = self.evaluate_fn(server_round, parameters_ndarrays, {})

            if eval_res is None:
                return None
            loss, metrics = eval_res
            # At the last round, we save
            if server_round == globalargs.max_rounds:
                print(self.times)
                print(self.contribution_dict)
            return loss, metrics

    return FedTimed

In [None]:

training_calls: TrainingCalls = TrainingCalls()
training_calls.get_model = get_model

training_calls.train = train
training_calls.test = test
training_calls.get_parameters = ndarrays_from_model
training_calls.set_parameters = ndarrays_to_model

training_calls.load_data = load_data
training_calls.load_global_test_data = load_global_test_data

training_calls.get_initial_parameters = get_initial_parameters



def client_fn(context: Context) -> FlowerClient:
    partition_id: int = int(context.node_config["partition-id"])
    return FlowerClient(
        client_id=partition_id,
        training_calls=training_calls,
    ).to_client()


def fit_config(server_round: int):
    """Generate training configuration for each round."""
    # Create the configuration dictionary
    config = {
        "current_round": server_round,
        "max_round": global_args.max_rounds,
    }
    return config


def start_simulation():
    timed_strategy = create_timed_strategy(
        FedAvg,
        initial_parameters=training_calls.get_initial_parameters(),
        trainingcalls=training_calls,
    )
    hist = fl.simulation.start_simulation(
        client_fn=client_fn,
        num_clients=global_args.num_clients,
        config=fl.server.ServerConfig(num_rounds=global_args.max_rounds),
        strategy=timed_strategy(
            initial_parameters=training_calls.get_initial_parameters(),
            evaluate_fn=evaluate_fn,
            on_fit_config_fn=fit_config,
        ),
        client_resources=client_resources,
    )

start_simulation()