In [None]:
import sys
import os

from fluke.data.datasets import Datasets
import torch
import rich
import random
import torchvision
from fluke import DDict
from torchvision.transforms import v2
transforms_train = v2.Compose([
    v2.RandomResizedCrop(size=(24, 24), antialias=True),
    v2.RandomHorizontalFlip(p=random.uniform(0, 1)),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.49139968, 0.48215827 ,0.44653124], std=[0.24703233, 0.24348505, 0.26158768]),
])
transforms_test = v2.Compose([
    v2.CenterCrop(size=(24, 24)),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.49139968, 0.48215827 ,0.44653124], std=[0.24703233, 0.24348505, 0.26158768]),
])
dataset = Datasets.get("cifar10", path="../data", transforms=transforms_train)

In [2]:
from fluke.data import DataSplitter
data_splitter = DataSplitter(dataset=dataset,
                        distribution="dir", dist_args=DDict(beta=0.1, balanced=False))

In [3]:
from fluke import GlobalSettings  # NOQA
from fluke.data import DataSplitter, FastDataLoader  # NOQA
from fluke.data.datasets import Datasets  # NOQA
from fluke.evaluation import ClassificationEval  # NOQA
from fluke.utils import (Configuration, OptimizerConfigurator,  # NOQA
                    get_class_from_qualified_name, get_loss, get_model)
from fluke.utils.log import get_logger  # NOQA

evaluator = ClassificationEval(eval_every=1, n_classes=dataset.num_classes)
GlobalSettings().set_evaluator(evaluator)

In [None]:
from fluke.algorithms.fedopt import FedOpt
from fluke.utils import Configuration
from rich.panel import Panel
from rich.pretty import Pretty
from rich.progress import track

cfg = Configuration("../configs/exp_adam.yaml", "../configs/fedadam.yaml")

seed = cfg.exp.seed
GlobalSettings().set_seed(cfg.exp.seed)
GlobalSettings().set_device(cfg.exp.device)
GlobalSettings().set_eval_cfg(cfg.eval)

fl_algo_class = get_class_from_qualified_name(cfg.method.name)
fl_algo = fl_algo_class(cfg.protocol.n_clients,
                            data_splitter,
                            cfg.method.hyperparameters)

log = get_logger(cfg.logger.name, name=str(cfg), **cfg.logger.exclude('name'))
log.init(**cfg)
fl_algo.set_callbacks(log)
rich.print(Panel(Pretty(fl_algo), title="FL algorithm"))

fl_algo.run(cfg.protocol.n_rounds, cfg.protocol.eligible_perc)