In [11]:
import torch
from fedgkt_trainers import GKTClientTrainer, GKTServerTrainer
from resnet_gkt_server import ResNet49
from resnet_gkt_client import ResNet8
from fedavg_utils import get_datasets, get_user_groups
import numpy as np
from reproducibility import make_it_reproducible
import pandas as pd

In [2]:
make_it_reproducible(seed=0)

In [3]:
client_number = 100
norm_type = 'Batch Norm'
participation_frac = 0.1
iid = True
unbalanced = False

args_server = {
    'temperature': 3,
    'epochs_server': 10,
    'alpha': 0.5
}

args_client={
    'temperature': 3,
    'epochs_client': 1,
    'alpha': 0.5
}

communication_rounds = 10

In [4]:
device = 'cuda' if torch.cuda.is_available else 'cpu'

In [5]:
#SERVER
server_trainer = GKTServerTrainer(client_number, device, ResNet49(norm_type), args_server)

In [6]:
#CLIENT

trainset, testset = get_datasets()
user_groups = get_user_groups(trainset, iid=iid, unbalanced=unbalanced, tot_users=client_number)

clients = []
for client_idx in range(client_number):
  clients.append(GKTClientTrainer(client_idx, trainset, testset,
                                  user_groups[client_idx], device, ResNet8(norm_type), args_client))

Files already downloaded and verified
Files already downloaded and verified


In [7]:
for round in range(communication_rounds):
  print("Communication round: ", round+1)
  m = max(int(participation_frac*client_number), 1)
  chosen_users = np.random.choice(range(client_number), m, replace=False)
  print(f"Chosen users: {chosen_users}")
  for idx in chosen_users:
    extracted_features_dict, extracted_logits_dict, labels_dict,\
    extracted_features_dict_test, labels_dict_test = clients[idx].train()

    server_trainer.add_local_trained_result(idx, extracted_features_dict, extracted_logits_dict, labels_dict,\
    extracted_features_dict_test, labels_dict_test)

  server_trainer.train(round)

  for idx in chosen_users:
    global_logits = server_trainer.get_global_logits(idx)
    clients[idx].update_large_model_logits(global_logits)


train_metrics, test_metrics = server_trainer.get_metrics_lists()

Communication round:  1
Chosen users: [18]
{'train/loss': 1.0299100577831268, 'train/accuracy': 67.21443939208984, 'epoch': 1}
{'test/loss': 2.298644741879234, 'test/accuracy': 11.530854430379748, 'epoch': 1}


In [13]:
df = pd.DataFrame(train_metrics)
df.to_csv(f"train_{norm_type}_{'iid' if iid else 'noniid'}_{'unbalanced' if unbalanced else 'balanced'}.csv", index=False)

In [None]:
df = pd.DataFrame(test_metrics)
df.to_csv(f"test_{norm_type}_{'iid' if iid else 'noniid'}_{'unbalanced' if unbalanced else 'balanced'}.csv", index=False)