Implementation Of fedGKt based on fedml library
# https://github.com/FedML-AI/FedML/tree/master/fedml_experiments/distributed/fedgkt

In [None]:



import torch
from torchsummary import summary

import numpy as np
from models_server import ResNet50
from models_client import ResNet8
from server import GKTServerTrainer
from client import GKTClientTrainer

In [None]:
#chaged yml file into notebook for simplicity of tuning parameters
###########################################################
iid = 0 
unbalanced = 0 
num_users = 100 
frac = 0.1 
server_epochs = 10
gpu = 0
optimizer = "sgd" 
local_batch_size = 128 
lr = 1e-2 
client_epochs = 1
loss_function = "CrossEntropyLoss"
partition_alpha = 0.5
client_number = num_users  
temperature = 3.0
communication_rounds = 10 
num_groups = 0  #put 2 for GN
if num_groups == 0:
    normalization_type = "BatchNorm"
else:
    normalization_type = "GroupNorm"
if iid:
    from iid import get_dataset, average_weights, exp_details
else:
    from niid import get_dataset, average_weights, exp_details

In [None]:
train_dataset, test_dataset, user_groups = get_dataset(iid=iid, unbalanced=unbalanced, num_users=num_users)

Files already downloaded and verified
Files already downloaded and verified


In [None]:
def create_client_model():
    # client_model = RenNet8(normalization_type)
    client_model = ResNet8()
    return client_model

def create_server_model():
    server_model = ResNet50(n_type=normalization_type)
    return server_model
server_model = create_server_model()
client_model = create_client_model()

In [None]:
train_on_gpu = torch.cuda.is_available()

if not train_on_gpu:
    print('CUDA is not available.  Training on CPU ...')
    device = torch.device("cpu")
    gpu = 0
else:
    print('CUDA is available!  Training on GPU ...')
    device = torch.device("cuda")
    gpu = 1

server_model.to(device)
client_model.to(device)

# set the models to train
server_model.train()
client_model.train()

CUDA is available!  Training on GPU ...


ResNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1),

In [None]:
# 23,503,626 parameters (ResNet49)
# 23,520,842 parameters (ResNet50)
summary(server_model, (16, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 32, 32]           1,024
       BatchNorm2d-2           [-1, 64, 32, 32]             128
            Conv2d-3           [-1, 64, 32, 32]          36,864
       BatchNorm2d-4           [-1, 64, 32, 32]             128
            Conv2d-5          [-1, 256, 32, 32]          16,384
       BatchNorm2d-6          [-1, 256, 32, 32]             512
            Conv2d-7          [-1, 256, 32, 32]           4,096
       BatchNorm2d-8          [-1, 256, 32, 32]             512
        Bottleneck-9          [-1, 256, 32, 32]               0
           Conv2d-10           [-1, 64, 32, 32]          16,384
      BatchNorm2d-11           [-1, 64, 32, 32]             128
           Conv2d-12           [-1, 64, 32, 32]          36,864
      BatchNorm2d-13           [-1, 64, 32, 32]             128
           Conv2d-14          [-1, 256,

In [None]:
# 10,586 parameters (ResNet8)
summary(client_model, (3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 32, 32]             432
       BatchNorm2d-2           [-1, 16, 32, 32]              32
              ReLU-3           [-1, 16, 32, 32]               0
            Conv2d-4           [-1, 16, 32, 32]             256
       BatchNorm2d-5           [-1, 16, 32, 32]              32
              ReLU-6           [-1, 16, 32, 32]               0
            Conv2d-7           [-1, 16, 32, 32]           2,304
       BatchNorm2d-8           [-1, 16, 32, 32]              32
              ReLU-9           [-1, 16, 32, 32]               0
           Conv2d-10           [-1, 64, 32, 32]           1,024
      BatchNorm2d-11           [-1, 64, 32, 32]             128
           Conv2d-12           [-1, 64, 32, 32]           1,024
      BatchNorm2d-13           [-1, 64, 32, 32]             128
             ReLU-14           [-1, 64,

In [1]:
# init server
server_trainer = GKTServerTrainer(server_model, num_users, lr, server_epochs, device,
                                  optimizer, temperature)
clients_trainer = []  # list of client_trainer

# different clients at each epoch

idxs_users = range(num_users)

# init all clients
for idx in idxs_users:
    client_trainer = GKTClientTrainer(client_model, train_dataset, test_dataset,
                                      user_groups[idx], idx, gpu, optimizer, local_batch_size,
                                      lr, client_epochs, temperature, partition_alpha)
    clients_trainer.append(client_trainer)
    # print(f"client \t{idx}/{num_users} initialized")

for communication_round in range(1, communication_rounds+1):
    print(f'\nCommunication Round: {communication_round} \n')

    m = max(int(frac * num_users), 1) # number of users to be used for federated updates, at least 1
    idxs_chosen_users = np.random.choice(range(num_users), m, replace=False) # choose randomly m users

    print(idxs_chosen_users)
    for idx in idxs_chosen_users:
        # the server broadcast k-th Z_c to the client
        extracted_feature_dict, logits_dict, labels_dict, extracted_feature_dict_test,\
        labels_dict_test = clients_trainer[idx].train()

        # send client result to server
        server_trainer.add_local_trained_result(idx, extracted_feature_dict, logits_dict, labels_dict,
                                                extracted_feature_dict_test, labels_dict_test)

    # # check if all updates are received
    # b_all_received = server_trainer.check_whether_all_receive()
    # print("b_all received" + str(b_all_received))
    #
    # if b_all_received:
    #     server_trainer.train(communication_round)

    server_trainer.train(communication_round, idxs_chosen_users)

    for idx in idxs_chosen_users:
        # get global logits
        global_logits = server_trainer.get_global_logits(idx)

        # print(type(global_logits))
        # print(len(global_logits))
        # print(global_logits)

        # send global logits to client
        clients_trainer[idx].update_large_model_logits(global_logits)

# get lists of train loss and accuracy
train_loss, train_accuracy = server_trainer.get_loss_acc_list()

NameError: ignored