In [None]:
import torch
from torch.utils.data import Subset
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from data_utils import split_non_iid_dataset,read_log
from federated_learning_system import Server, n_clients, ALPHA,fraction_client
from model import Model

# 检查CUDA是否可用
device = 'cuda' if torch.cuda.is_available() else 'cpu'



# EMNIST.class_to_idx是一个存放了映射的字典，可以帮助我们获得labels到字符的映射。
# for i, char in data.class_to_idx.items():
#     print(f"{i}: {char}")


In [None]:
train_dataset = datasets.EMNIST(root='.', split='byclass', download=True, train=True, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomRotation((0, 360))
]))

labels = train_dataset.targets.numpy()
index = np.random.permutation(len(train_dataset))
train_index=index
train_index_by_client = split_non_iid_dataset(labels=labels, index=train_index, alpha=ALPHA, n_clients=n_clients)
train_labels = [labels[idx] for idx in train_index_by_client]
train_datasets = [Subset(train_dataset, idx) for idx in train_index_by_client]

test_dataset = datasets.EMNIST(root='.', split='byclass', download=True, train=False,
                               transform=transforms.Compose([transforms.ToTensor()]))

mapp = np.array(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C',
                 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P',
                 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c',
                 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p',
                 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'], dtype='<U1')

plt.figure(figsize=(20, 3))
# train_labels是有n_clients个list的列表
plt.hist(train_labels,
         bins=np.arange(len(mapp)) - 0.5,
         stacked=True,
         label=["Client {}".format(i + 1) for i in range(n_clients)])
plt.xticks(ticks=np.arange(62), labels=mapp)
plt.legend()
plt.show()

In [None]:
n_rounds=600
server = Server(Model, test_dataset, n_rounds, fraction_client, *train_datasets)

In [None]:
train_loss_list, train_accuracy_list, evaluate_loss_list, evaluate_accuracy_list = [], [], [], []
server.run(train_loss_list, train_accuracy_list, evaluate_loss_list, evaluate_accuracy_list,301)

In [None]:
import matplotlib.pyplot as plt
from  data_utils import  read_log

train_loss_list, train_accuracy_list, evaluate_loss_list, evaluate_accuracy_list=read_log()
length=len(train_loss_list)
r = range(1, length + 1)
rx=range(1,length+1,50)
plt.figure(figsize=(20, 4))

fig, (ax0, ax1) = plt.subplots(2, 1, sharex=True, constrained_layout=True)
ax0.plot(r, train_loss_list, label="train")
ax0.plot(r, evaluate_loss_list, label="evaluate")
ax0.set_title('Learning Curves')
ax0.legend(loc="upper right")
ax0.set_ylabel('Loss')
ax1.plot(r, train_accuracy_list, label="train")
ax1.plot(r, evaluate_accuracy_list, label="evaluate")
ax1.legend(loc="upper right")
ax1.set_ylabel('Accuracy(%)')
ax1.set_xlabel('Round')
ax1.set_xticks(rx)
plt.show()