In [11]:
import time
import torch
import torch.nn as nn

from copy import deepcopy
from torch.distributions.dirichlet import Dirichlet
from torch.utils.data import DataLoader

import os
from utils import *

torch.set_printoptions(precision=3,
                       threshold=1000,
                       edgeitems=5,
                       linewidth=1000,
                       sci_mode=False)
t = time.localtime()
log_path = f'./log/{t.tm_year}-{t.tm_mon}-{t.tm_mday}/'
if not os.path.exists(log_path):
    os.makedirs(log_path)
log_path += f'{t.tm_hour}-{t.tm_min}-{t.tm_sec}.log'
log = get_logger(log_path)


'''1. basic parameters'''
# args = get_args()
n_client = 9
n_train_data = 1000
n_test_data = 200
local_epochs = 30
batch_size = 160
server_epochs = 1
alpha = 1.0
dataset = 'cifar10'
model_structure = 'resnet18'

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

all_client = range(n_client)
acc = {i: [] for i in all_client}
acc_server = []
setup_seed()

'''2. data preparation'''
train_set = CIFAR10(root='./data',
                         train=True,
                         download=True,
                         transform=torchvision.transforms.ToTensor())
test_set = CIFAR10(root='./data',
                        train=False,
                        transform=torchvision.transforms.ToTensor())
n_targets = 10
in_channel = 3
pro = Dirichlet(torch.full(size=(n_targets,),
                           fill_value=float(alpha))).sample([n_client])
train_set = split_non_iid(
    dataset=train_set,
    pro=pro,
    n_data=n_train_data,
    n_client=n_client,
)
test_set = split_non_iid(
    dataset=test_set,
    pro=pro,
    n_data=n_test_data,
    n_client=n_client,
)

Files already downloaded and verified


In [12]:
for key, dataset in train_set.items():
    targets = [data_[1] for data_ in dataset]
    for i in range(n_targets):
        print(len([target for target in targets if target == i]), end=' ')
    print()

375 59 15 7 31 51 24 265 40 133 
99 263 81 226 63 23 81 118 13 33 
39 58 7 29 258 40 126 398 16 29 
45 29 34 206 172 134 69 98 62 151 
11 105 20 40 65 42 67 92 182 376 
126 301 31 7 81 158 73 34 121 68 
61 63 63 189 178 147 132 40 67 60 
45 275 202 314 21 57 16 36 24 10 
269 60 75 40 212 23 31 77 137 76 


In [13]:
for key, dataset in test_set.items():
    targets = [data_[1] for data_ in dataset]
    for i in range(n_targets):
        print(len([target for target in targets if target == i]), end=' ')
    print()

75 12 3 1 6 10 5 53 8 27 
20 51 16 45 13 5 16 24 3 7 
8 12 1 6 52 8 25 79 3 6 
9 6 7 41 34 27 14 20 12 30 
2 21 4 8 13 8 13 18 36 77 
25 60 6 1 16 32 15 7 24 14 
12 13 13 38 36 29 26 8 13 12 
9 55 40 64 4 11 3 7 5 2 
55 12 15 8 42 5 6 15 27 15 


In [14]:
test_loader = {}
for i, dataset_ in test_set.items():
    test_loader[i] = DataLoader(dataset=dataset_,
                                batch_size=10,
                                pin_memory=True,
                                num_workers=8)
if model_structure == 'mlp':
    client_list = model_init(num_client=n_client,
                             model_structure='mlp',
                             num_target=n_targets,
                             in_channel=in_channel)
elif model_structure == 'resnet18':
    client_list = model_init(num_client=n_client,
                             model_structure='resnet18',
                             num_target=n_targets,
                             in_channel=in_channel)
elif model_structure == 'cnn1':
    client_list = model_init(num_client=n_client,
                             model_structure='cnn1',
                             num_target=n_targets,
                             in_channel=in_channel)
elif model_structure == 'cnn2':
    client_list = model_init(num_client=n_client,
                             model_structure='cnn2',
                             num_target=n_targets,
                             in_channel=in_channel)
elif model_structure == 'lenet5':
    client_list = model_init(num_client=n_client,
                             model_structure='lenet5',
                             num_target=n_targets,
                             in_channel=in_channel)
else:
    raise ValueError(f'No such model: {model_structure}')

train_loader = {}
for i, dataset_ in train_set.items():
    train_loader[i] = DataLoader(dataset=dataset_,
                                 batch_size=batch_size,
                                 num_workers=8,
                                 shuffle=True)


'''4. DDP: loss function initialization'''
CE_Loss = nn.CrossEntropyLoss().cuda()


'''5. model training and distillation'''
for server_epoch in range(server_epochs):
    # local train
    msg_local = '[server epoch {}, client {}, local train]'
    msg_test_local = 'local epoch {}, acc: {:.4f}'
    client_list_ = []
    lr = 1e-3 / (server_epoch + 1)
    for i, client in enumerate(client_list):
        print(msg_local.format(server_epoch + 1, i + 1))
        client_ = client.cuda()
        optimizer = torch.optim.Adam(params=client_.parameters(),
                                     lr=lr,
                                     weight_decay=1e-4)
        for local_epoch in range(local_epochs):
            for data_, target_ in train_loader[i]:
                optimizer.zero_grad()
                output_ = client_(data_.cuda())
                loss = CE_Loss(output_, target_.cuda())
                loss.backward()
                optimizer.step()

            # test
            model__ = client_
            acc[i].append(eval_model(model__, test_loader[i]))
            print(msg_test_local.format(local_epoch + 1, acc[i][-1]))
    break

[server epoch 1, client 1, local train]
local epoch 1, acc: 0.3750
local epoch 2, acc: 0.3750
local epoch 3, acc: 0.3000
local epoch 4, acc: 0.3450
local epoch 5, acc: 0.3500
local epoch 6, acc: 0.3800
local epoch 7, acc: 0.3800
local epoch 8, acc: 0.5450
local epoch 9, acc: 0.5500
local epoch 10, acc: 0.5400
local epoch 11, acc: 0.3400
local epoch 12, acc: 0.4700
local epoch 13, acc: 0.4550
local epoch 14, acc: 0.5850
local epoch 15, acc: 0.5050
local epoch 16, acc: 0.5200
local epoch 17, acc: 0.5350
local epoch 18, acc: 0.5100
local epoch 19, acc: 0.4400
local epoch 20, acc: 0.5500
local epoch 21, acc: 0.5100
local epoch 22, acc: 0.5100
local epoch 23, acc: 0.5700
local epoch 24, acc: 0.5500
local epoch 25, acc: 0.5650
local epoch 26, acc: 0.5650
local epoch 27, acc: 0.5550
local epoch 28, acc: 0.5500
local epoch 29, acc: 0.5850
local epoch 30, acc: 0.5900
[server epoch 1, client 2, local train]
local epoch 1, acc: 0.2250
local epoch 2, acc: 0.0800
local epoch 3, acc: 0.0750
local ep

In [3]:
a = [1,]
print(type(a), a)

<class 'list'> [1]
