In [1]:
import paddle
import numpy as np
import paddle.nn as nn
from paddle.vision.datasets import Cifar10 as CIFAR10
from paddle.vision.datasets import Cifar100 as CIFAR100
from paddle.vision import transforms
from paddle import optimizer as optim
import paddle.nn.functional as F

import os
import logging
from paddle.io import DataLoader

logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)

In [2]:
class Swish(nn.Layer):  # Swish(x) = x∗σ(x)
    def __init__(self):
        super().__init__()

    def forward(self, input):
        return input * F.sigmoid(input)


class ConvNet(nn.Layer):
    def __init__(self, channel=3, num_classes=10, net_width=128, net_depth=3, net_act='relu', net_norm='instancenorm',
                 net_pooling='avgpooling', im_size=(32, 32)):
        super(ConvNet, self).__init__()

        self.features, shape_feat = self._make_layers(channel, net_width, net_depth, net_norm, net_act, net_pooling,
                                                      im_size)
        num_feat = shape_feat[0] * shape_feat[1] * shape_feat[2]
        self.classifier1 = nn.Linear(num_feat, 256)
        self.classifier2 = nn.Linear(256, num_classes)
        self.flatten = paddle.nn.Flatten()

    def forward(self, x):
        h = self.features(x)
        
        _out = self.flatten(h)
        x = self.classifier1(_out)
        y = self.classifier2(x)
        return h, x, y

    def embed(self, x):
        _out = self.features(x)
        _out = self.flatten(_out)
        return _out

    def _get_activation(self, net_act):
        if net_act == 'sigmoid':
            return nn.Sigmoid()
        elif net_act == 'relu':
            return nn.ReLU()
        elif net_act == 'leakyrelu':
            return nn.LeakyReLU(negative_slope=0.01)
        elif net_act == 'swish':
            return Swish()
        else:
            exit('unknown activation function: %s' % net_act)

    def _get_pooling(self, net_pooling):
        if net_pooling == 'maxpooling':
            return nn.MaxPool2D(kernel_size=2, stride=2)
        elif net_pooling == 'avgpooling':
            return nn.AvgPool2D(kernel_size=2, stride=2)
        elif net_pooling == 'none':
            return None
        else:
            exit('unknown net_pooling: %s' % net_pooling)

    def _get_normlayer(self, net_norm, shape_feat):
        # shape_feat = (c*h*w)
        if net_norm == 'batchnorm':
            return nn.BatchNorm2D(shape_feat[0])
        elif net_norm == 'layernorm':
            return nn.LayerNorm(shape_feat)
        elif net_norm == 'instancenorm':
            return nn.GroupNorm(shape_feat[0], shape_feat[0])
        elif net_norm == 'groupnorm':
            return nn.GroupNorm(4, shape_feat[0])
        elif net_norm == 'none':
            return None
        else:
            exit('unknown net_norm: %s' % net_norm)

    def _make_layers(self, channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size):
        layers = []
        in_channels = channel
        if im_size[0] == 28:
            im_size = (32, 32)
        shape_feat = [in_channels, im_size[0], im_size[1]]
        for d in range(net_depth):
            layers += [nn.Conv2D(in_channels=in_channels, out_channels=net_width, kernel_size=3, padding=3 if channel == 1 and d == 0 else 1)]
            shape_feat[0] = net_width
            if net_norm != 'none':
                layers += [self._get_normlayer(net_norm, shape_feat)]
            layers += [self._get_activation(net_act)]
            in_channels = net_width
            if net_pooling != 'none':
                layers += [self._get_pooling(net_pooling)]
                shape_feat[1] //= 2
                shape_feat[2] //= 2

        return nn.Sequential(*layers), shape_feat




In [3]:
def load_cifar10_data(datadir):

    transform_train = transforms.Compose([
        transforms.ToTensor(),
        transforms.RandomCrop(32),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])
    # data prep for test set
    transform_test = transforms.Compose([transforms.ToTensor()])

    cifar10_train_ds = CIFAR10(data_file=os.path.join(datadir, 'cifar-10-python.tar.gz'), mode='train', transform=transform_train)
    cifar10_test_ds = CIFAR10(data_file=os.path.join(datadir, 'cifar-10-python.tar.gz'), mode='test', transform=transform_test)

    Train_data = cifar10_train_ds.data
    Test_data = cifar10_test_ds.data
    X_train,y_train = np.array([x[0] for x in Train_data]), np.array([x[1] for x in Train_data])
    X_test,y_test = np.array([x[0] for x in Test_data]), np.array([x[1] for x in Test_data])
    
    return (cifar10_train_ds,cifar10_test_ds,X_train,y_train,X_test,y_test)

def load_cifar100_data(datadir):
    transform_train = transforms.Compose([
        transforms.RandomCrop(32),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
    ])
    # data prep for test set
    transform_test = transforms.Compose([transforms.ToTensor()])

    cifar100_train_ds = CIFAR100(data_file=os.path.join(datadir, 'cifar-100-python.tar.gz'), mode='train', transform=transform_train)
    cifar100_test_ds = CIFAR100(data_file=os.path.join(datadir, 'cifar-100-python.tar.gz'), mode='test', transform=transform_test)

    Train_data = cifar100_train_ds.data
    Test_data = cifar100_test_ds.data

    X_train,y_train = np.array([x[0] for x in Train_data]), np.array([x[1] for x in Train_data])
    X_test,y_test = np.array([x[0] for x in Test_data]), np.array([x[1] for x in Test_data])

    return (cifar100_train_ds,cifar100_test_ds,X_train,y_train,X_test,y_test)

def record_net_data_stats(y_train, net_dataidx_map, logdir):

    net_cls_counts = {}

    for net_i, dataidx in net_dataidx_map.items():
        unq, unq_cnt = np.unique(y_train[dataidx], return_counts=True)
        tmp = {unq[i]: unq_cnt[i] for i in range(len(unq))}
        net_cls_counts[net_i] = tmp

    logger.info('Data statistics: %s' % str(net_cls_counts))

    return net_cls_counts
def partition_data(dataset, datadir, logdir, partition, n_parties, beta=0.4):
    if dataset == 'cifar10':
        _,_,X_train,y_train,X_test,y_test = load_cifar10_data(datadir)
        class_total_num = 10
    elif dataset == 'cifar100':
        _,_,X_train,y_train,X_test,y_test = load_cifar100_data(datadir)
        class_total_num = 100

    n_train = y_train.shape[0]

    if partition == "homo":
        idxs = np.random.permutation(n_train)
        batch_idxs = np.array_split(idxs, n_parties)
        net_dataidx_map = {i: batch_idxs[i] for i in range(n_parties)}

    elif partition == "noniid-labeldir":
        min_size = 0
        min_require_size = 10
        N = len(y_train)
        #np.random.seed(2020)
        net_dataidx_map = {}

        while min_size < min_require_size:
            idx_batch = [[] for _ in range(n_parties)]
            for k in range(class_total_num):
                idx_k = np.where(y_train == k)[0]
                np.random.shuffle(idx_k)
                proportions = np.random.dirichlet(np.repeat(beta, n_parties))
                # logger.info("proportions1: ", proportions)
                # logger.info("sum pro1:", np.sum(proportions))
                ## Balance
                proportions = np.array([p * (len(idx_j) < N / n_parties) for p, idx_j in zip(proportions, idx_batch)])
                # logger.info("proportions2: ", proportions)
                proportions = proportions / proportions.sum()
                # logger.info("proportions3: ", proportions)
                proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1]
                # logger.info("proportions4: ", proportions)
                idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))]
                min_size = min([len(idx_j) for idx_j in idx_batch])
                # if K == 2 and n_parties <= 10:
                #     if np.min(proportions) < 200:
                #         min_size = 0
                #         break
        for j in range(n_parties):
            np.random.shuffle(idx_batch[j])
            net_dataidx_map[j] = idx_batch[j]

    traindata_cls_counts = record_net_data_stats(y_train, net_dataidx_map, logdir)
    net_dataidx_map_ = {i:[[] for j in range(class_total_num)] for i in range(n_parties)}
    for i in range(n_parties):
        for j in net_dataidx_map[i]:
            idx = y_train[j]
            net_dataidx_map_[i][idx].append(j)
    return (X_train, y_train, X_test, y_test, net_dataidx_map, net_dataidx_map_, traindata_cls_counts)



In [4]:
seed = 0
np.random.seed(seed)
beta=0.5
dataset = 'cifar100'

if dataset == 'cifar100':
    data_path = '../data/data152750'
    real_client_num = 10
    vir_clients_num = 10
    class_num=100
    num_per_class=500
elif dataset == 'cifar10':
    data_path = '../data/data152754'
    real_client_num = 10
    vir_clients_num = 10
    class_num=10
    num_per_class=5000

X_train, y_train, X_test, y_test, net_dataidx_map, net_dataidx_map_, traindata_cls_counts = partition_data(
    dataset, data_path, 'logdir_test', 'noniid-labeldir', real_client_num, beta=beta)

INFO:root:Data statistics: {0: {0: 112, 3: 8, 4: 79, 5: 148, 6: 100, 8: 10, 9: 88, 10: 109, 11: 5, 12: 24, 13: 5, 14: 45, 15: 75, 16: 158, 17: 1, 18: 262, 19: 20, 20: 6, 21: 51, 22: 1, 23: 10, 24: 79, 25: 25, 27: 99, 28: 109, 29: 30, 30: 72, 31: 40, 32: 129, 33: 8, 34: 4, 35: 112, 36: 3, 37: 15, 38: 73, 39: 2, 40: 30, 41: 53, 42: 3, 43: 17, 44: 1, 45: 9, 46: 44, 47: 25, 48: 50, 49: 40, 50: 14, 51: 3, 52: 45, 53: 20, 54: 39, 55: 17, 56: 30, 57: 8, 58: 185, 59: 199, 60: 7, 61: 2, 62: 56, 63: 11, 64: 101, 66: 96, 67: 48, 68: 133, 69: 16, 70: 38, 71: 22, 72: 14, 73: 172, 74: 142, 75: 43, 76: 1, 77: 36, 78: 81, 79: 18, 80: 4, 81: 264, 82: 2, 83: 75, 84: 9, 85: 243, 86: 15, 87: 24, 88: 5, 89: 171, 91: 6, 92: 94, 94: 10, 95: 85, 96: 125}, 1: {0: 34, 1: 18, 2: 164, 3: 68, 4: 104, 5: 22, 6: 85, 7: 11, 8: 221, 9: 9, 10: 71, 11: 62, 12: 2, 13: 151, 14: 3, 15: 98, 16: 2, 17: 65, 18: 31, 19: 109, 20: 16, 21: 19, 22: 64, 23: 80, 24: 74, 25: 8, 27: 10, 28: 19, 29: 5, 30: 15, 31: 4, 32: 1, 33: 2, 34: 

In [5]:
dataloader_virtual_matrix = [[[] for j in range(real_client_num)] for i in range(vir_clients_num)]
dataloader_virtual_matrix_dl = [[] for i in range(vir_clients_num)]
# dataloader_virtual_matrix

In [6]:
level = int(num_per_class / vir_clients_num)
for virtual_th in range(vir_clients_num):
    current_level = [level for i in range(class_num)]
    for client_th in range(real_client_num):
        for i in range(class_num):
            temp = current_level[i]
            dataloader_virtual_matrix[virtual_th][client_th].extend(net_dataidx_map_[client_th][i][:current_level[i]])
            current_level[i] -= len(net_dataidx_map_[client_th][i][:current_level[i]])
            net_dataidx_map_[client_th][i][:temp] = []
            # debug
            # print("client:{0} class:{1} has {2}".format(client_th, i, len(net_dataidx_map_[client_th][i])))
# dataloader_virtual_matrix


In [7]:
net_dataidx_map_
# all empty is right

{0: [[],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  []],
 1: [[],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [],
  [

In [8]:
def get_dataloader(dataset, datadir, train_bs, test_bs, dataidxs=None, is_test=False):
    if dataset in ('mnist', 'femnist', 'fmnist', 'cifar10', 'svhn', 'generated', 'covtype', 'a9a', 'rcv1', 'SUSY','cifar100'):
        if dataset == 'cifar10':
            train_ds,_,_,_,_,_ = load_cifar10_data(datadir)
            if dataidxs is not None:
                train_ds = [train_ds[i] for i in dataidxs]
            if is_test:
                _,test_ds,_,_,_,_ = load_cifar10_data(datadir)
        elif dataset == 'cifar100':
            train_ds,_,_,_,_,_ = load_cifar100_data(datadir)
            if dataidxs is not None:
                train_ds = [train_ds[i] for i in dataidxs]
            if is_test:
                _,test_ds,_,_,_,_ = load_cifar100_data(datadir)
        else:
            train_ds, test_ds = None, None

        train_dl = DataLoader(dataset=train_ds, batch_size=train_bs, shuffle=True, drop_last=False)
        if is_test:
            test_dl = DataLoader(dataset=test_ds, batch_size=test_bs, shuffle=False, drop_last=False)
            return train_dl,test_dl
        else:
            return train_dl

In [9]:
def compute_accuracy(model, dataloader):
    model.eval()
    accuracies = []
    losses = []
    for batch_id, (x_data,y_data) in enumerate(dataloader()):

        y_data = paddle.to_tensor(y_data)
        y_data = paddle.unsqueeze(y_data, 1)

        _,_,logits = model(x_data)
        loss = F.cross_entropy(logits, y_data)
        acc = paddle.metric.accuracy(logits, y_data)
        accuracies.append(acc.numpy())
        losses.append(loss.numpy())

    avg_acc, avg_loss = np.mean(accuracies), np.mean(losses)

    return avg_acc

In [10]:
sum_ = 0
for net in traindata_cls_counts.values():
    sum_ += net[0]
print(f'each class total num:[{sum_}], set virtual_client_num:[10], per:[{int(sum_/10)}]')

each class total num:[500], set virtual_client_num:[10], per:[50]


In [11]:
for virtual_client_dl_idx in range(len(dataloader_virtual_matrix)):
    for real_client_dl_idx in dataloader_virtual_matrix[virtual_client_dl_idx]:
        if real_client_dl_idx == []:
            dataloader_virtual_matrix_dl[virtual_client_dl_idx].append([])
            continue
        train_dl_local= get_dataloader(dataset, data_path, 64, 32, real_client_dl_idx, is_test=False)
        dataloader_virtual_matrix_dl[virtual_client_dl_idx].append(train_dl_local)

W1004 20:35:03.060935 40542 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 11.2
W1004 20:35:03.064807 40542 gpu_resources.cc:91] device: 0, cuDNN Version: 8.2.


KeyboardInterrupt: 

In [None]:
dataloader_virtual_matrix_dl

In [None]:
import time

args_optimizer = 'sgd'
lr = 0.01
reg = 5e-4
acc_list = []
rounds = 20
epochs = 10

net = ConvNet(channel=3, num_classes=class_num, net_width=128, net_depth=3, net_act='relu', net_norm='instancenorm', net_pooling='avgpooling', im_size=(32, 32))

_,test_dl = get_dataloader(dataset, data_path, 64, 32, is_test=True)


if args_optimizer == 'adam':
    optimizer = optim.Adam(parameters=net.parameters(), learning_rate=lr, weight_decay=reg)
else:
    optimizer = optim.SGD(parameters=net.parameters(), learning_rate=lr, weight_decay=reg)




In [None]:
for rd in range(rounds):
    rd_time = time.time()
    virtual_id = 0
    for virtual_client_dl in dataloader_virtual_matrix_dl:
        for epoch in range(epochs):
            ep_time = time.time()
            for train_dl_local in virtual_client_dl:
                if train_dl_local == []:
                    continue
                for batch_idx, (x, target) in enumerate(train_dl_local()):
                    _, _, out = net(x)
                    loss = F.cross_entropy(out, target)

                    loss.backward()
                    optimizer.step()
                    optimizer.clear_grad()
            print("round::{0}, virtual_client::{1}, epoch::{2}, loss::{3}, elps_time::{4}.".format(rd,virtual_id,epoch,loss.item(),time.time()-ep_time))
        virtual_id += 1
    acc = compute_accuracy(net, test_dl)
    acc_list.append(acc)
    print("round::{0} finished, test_acc::{1}, elpsed time::{2}".format(rd,acc,time.time()-rd_time))

In [None]:
acc_list

In [None]:
import matplotlib.pyplot as plt
plt.plot(range(len(acc_list)),acc_list)
plt.xlabel("rounds")
plt.ylabel("test_acc")
plt.show()

In [None]:
# fedavg_seed0_beta_05 = [0.346146166324615,0.46675318479538,0.536541521549225,0.573682129383087,0.58316695690155,0.622603833675385,0.633386552333832,0.64367014169693,0.655051946640015,0.662739634513855,0.671625375747681,0.68140971660614,0.684205293655396,0.682208478450775,0.686900973320007,0.691293954849243,0.693390548229218,0.697683691978455,0.696984827518463]
# # fedprox_seed0_beta_05 = [0.35553115606308,0.466253995895386,0.536242008209229,0.566094219684601,0.57917332649231,0.606130182743073,0.626497626304626,0.631389796733856,0.646964848041534,0.655950486660004,0.667132616043091,0.677216470241547,0.680910527706146,0.68260782957077,0.686501622200012,0.687699675559998,0.689596652984619,0.697484016418457,0.695986449718475]
# moon_seed0_beta_05 = [0.360123813152313,0.475838661193848,0.539936125278473,0.576477646827698,0.584964036941528,0.61401754617691,0.622503995895386,0.638079047203064,0.65425318479538,0.658646166324616,0.665435314178467,0.676218032836914,0.683406531810761,0.67861419916153,0.686601459980011,0.687000811100006,0.692492008209229,0.697783529758453,0.698382616043091]
# fedvc_seed0_beta_05 = [0.5548123, 0.5952476, 0.63977635, 0.6609425, 0.677516, 0.6963858, 0.7071685, 0.7157548, 0.71914935, 0.72853434, 0.73003197, 0.7345248, 0.7346246, 0.73722047, 0.73901755, 0.74021566, 0.74101436, 0.7433107, 0.74211264, 0.74271166]
# plt.plot(range(19),fedavg_seed0_beta_05,label="FedAvg")
# plt.plot(range(19),fedprox_seed0_beta_05,label="FedProx")
# plt.plot(range(19),moon_seed0_beta_05,label="MOON")
# plt.plot(range(20),fedvc_seed0_beta_05,label="fedvc(Ours)")
# plt.xlabel("rounds")
# plt.ylabel("test_acc")
# plt.legend()
# # plt.savefig("fedDyn_beta05.png",dpi=330)
# plt.show()

In [None]:
# fedprox_seed0_beta_05=[0.162539929151535,0.211361825466156,0.267372190952301,0.296425729990005,0.26767173409462,0.343051105737686,0.351337850093842,0.348342657089233,0.352236419916153
# ,0.377396166324615,0.395866602659225,0.416833072900772,0.435902565717697,0.421126186847687,0.438897758722305,0.442891359329224,0.44928115606308,0.458965659141541
# ,0.453474432229996]

In [None]:
# fedvc_seed0_beta_05_cifar100 = [0.319988,
#  0.36491615,
#  0.38578275,
#  0.39996007,
#  0.41134185,
#  0.4154353,
#  0.42272365,
#  0.4224241,
#  0.42501998,
#  0.42601836,
#  0.43011183,
#  0.4308107,
#  0.43190894,
#  0.4335064,
#  0.43370607,
#  0.43789935,
#  0.44029552,
#  0.4365016,
#  0.43869808,
#  0.4406949]