In [1]:
# import from exist library
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.multiprocessing as mp

mp.set_start_method('spawn')

# import from my files
from models.resnet import ResNet
from lib.federated import FederatedServer

In [2]:
# pre-settings

DEVICE_LIST = ['cuda:' + str(i) for i in range(4)]
ROUNDS = 10

WARM_SETTINGS = {
    'warm_up': False,
    'setting':{
        'batch_size': 128
    }
}

CLIENT_SETTINGS = {
    'mode': 'thres',
    'thres': 0.95,
    'max_epoch': 500,
    'batch_size': 256,
    'enable_scheduler': True,
    'scheduler': optim.lr_scheduler.ReduceLROnPlateau,
    'scheduler_settings':{
        'mode': 'min',
        'factor': 0.13,
        'patience': 5,
        'verbose': True,
        'min_lr': 1e-8,
        'cooldown': 3
    }
}

NET_KWARGS = {
    'depth': 44,
    'num_classes': 10
}

SERVER_SETTINGS = {
    'clients_num': 5,
    'split_method': 'imba-size',
    'capacity': [10**i/11111 for i in range(5)],
    'random_response': False,
    'client_settings': CLIENT_SETTINGS,
    'warm_setting': WARM_SETTINGS,
    'net_kwargs': NET_KWARGS,
    'cal_sv': True,
    'cal_loo': True,
    'eval_clients': True,
    'devices': DEVICE_LIST
}

TRANSFORM_CIFAR10_TRAIN = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

TRANSFORM_CIFAR10_TEST = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10('../public_set', train=True, transform=TRANSFORM_CIFAR10_TRAIN, download=True)
testset = torchvision.datasets.CIFAR10('../public_set', train=False, transform=TRANSFORM_CIFAR10_TEST, download=True)

Files already downloaded and verified
Files already downloaded and verified


In [None]:
# run
fl = FederatedServer(ResNet, trainset, testset, **SERVER_SETTINGS)
fl.run(rounds=ROUNDS)
fl.save_valuation('cifa_differ_quan_10_')

Start training clients...
