In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as func
import torchvision
class ResNetFc(nn.Module):
    def __init__(self):
        super(ResNetFc, self).__init__()
        model_resnet = torchvision.models.resnet18(pretrained=True)

        self.conv1 = model_resnet.conv1
        self.bn1 = model_resnet.bn1
        self.relu = model_resnet.relu

        self.layer1 = model_resnet.layer1
        self.layer2 = model_resnet.layer2
        self.layer3 = model_resnet.layer3
        self.layer4 = model_resnet.layer4

        self.avgpool = model_resnet.avgpool

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)

        return x


class RSNAModel(nn.Module):
    '''ResNet18'''
    def __init__(self, class_num=10):
        super(RSNAModel, self).__init__()
        self.class_num = class_num

        self.backbone = ResNetFc()

        self.clf = nn.Linear(512, class_num)

    def forward(self, x, Pi, priors_corr, prior_test):
        x = self.backbone(x)
        x = self.clf(x)

        g = torch.softmax(x, dim=1)
        x = self.QfunctionMulticlass(g, Pi, priors_corr)

        return x

    def QfunctionMulticlass(self, g, Pi, priors_corr):
        pi_ita = torch.mm(Pi, g.permute(1, 0))
        rou_pi_ita = torch.matmul(priors_corr, pi_ita)

        pi_corr = pi_ita.permute(1, 0) * priors_corr.unsqueeze(0)
        output = (pi_corr.permute(1, 0) / rou_pi_ita).permute(1, 0)

        return output

    def predict(self, x):
        x = self.backbone(x)
        x = self.clf(x)

        g = torch.softmax(x, dim=1)

        return g

    def server_forward(self, x):
        x = self.backbone(x)
        x = self.clf(x)

        return x

In [2]:
import csv
import numpy as np
import random

def read_csv(data_file_path):
    data = []
    with open(data_file_path, 'r') as f:
        reader = csv.reader(f)
        data = list(reader)
        data = np.asarray(data)
    return data

label_file = '../input/rsna-intracranial-hemorrhage-detection/rsna-intracranial-hemorrhage-detection/stage_2_train.csv'

import glob
import joblib
import numpy as np
import PIL
import pydicom
import tqdm
import torch
import torchvision
from torch.utils.data import Dataset
import torchvision.transforms as transforms


RESIZED_WIDTH, RESIZED_HEIGHT = 128, 128
def get_first_of_dicom_field_as_int(x):
    if type(x) == pydicom.multival.MultiValue:
        return int(x[0])
    return int(x)

def get_id(img_dicom):
    return str(img_dicom.SOPInstanceUID)

def get_metadata_from_dicom(img_dicom):
    metadata = {
        "window_center": img_dicom.WindowCenter,
        "window_width": img_dicom.WindowWidth,
        "intercept": img_dicom.RescaleIntercept,
        "slope": img_dicom.RescaleSlope,
    }
    return {k: get_first_of_dicom_field_as_int(v) for k, v in metadata.items()}

def window_image(img, window_center, window_width, intercept, slope):
    img = img * slope + intercept
    img_min = window_center - window_width // 2
    img_max = window_center + window_width // 2
    img[img < img_min] = img_min
    img[img > img_max] = img_max
    return img 

def resize(img, new_w, new_h):
    img = PIL.Image.fromarray(img.astype(np.int8), mode="L")
    return img.resize((new_w, new_h), resample=PIL.Image.BICUBIC)

def normalize(img):
    mi, ma = img.min(), img.max()
    return (((img - mi) / (ma - mi)) - 0.5) * 2

def prepare_image(img_path):
    img_dicom = pydicom.read_file(img_path)
    img_id = get_id(img_dicom)
    metadata = get_metadata_from_dicom(img_dicom)
    img = window_image(img_dicom.pixel_array, **metadata)
    img = normalize(img)
    img_pil = resize(img, RESIZED_WIDTH, RESIZED_HEIGHT)
    return img_id, img_pil

class Dataset_Early_Fusion(Dataset):
    def __init__(self, label_file, transform=None):
        self.transform = transform
        self.data = []
        self.targets = []
        files = read_csv(label_file)
        sub_labels = []
        # 106000
        count_sum = 0
        for i in range(len(files)):
            img_labels = files[(i + 1) * 6]
            if img_labels[0].split('_')[-1] == 'any':
                count_labels = 0
                if files[(i + 1) * 6][1] == '0':
                    if random.randint(1, 20) == 1:
                        label = 0
                        count_labels = count_labels + 1
                else:
                    for n in range(1,6):
                        if files[i * 6 + n][1] == '1':
                            count_labels = count_labels + 1
                            label = n
                if count_labels == 1:
                    sub_labels.append([img_labels[0].split('_')[-3] + '_' + img_labels[0].split('_')[-2], label])
                    count_sum = count_sum + 1
            if count_sum == 106000: # !!!!!!! 
                break
        for i in range(len(sub_labels)):
            temp = sub_labels[i]
            full_path = '../input/rsna-intracranial-hemorrhage-detection/rsna-intracranial-hemorrhage-detection/stage_2_train/' + temp[0] + '.dcm'
            _, image = prepare_image(full_path)
            if i % 10000 == 0:
                print(i, '/ 106000')
            self.data.append(np.array(image))
            self.targets.append(temp[1])
        self.data = np.array(self.data)
        self.targets = np.array(self.targets)
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(img.numpy(), mode='L')

        if self.transform is not None:
            img = self.transform(img)
#       img = transforms.ToTensor()(img)

        return img, target


import warnings

warnings.filterwarnings("ignore")

In [3]:
import os
import sys

import numpy as np

np.seterr(divide='ignore', invalid='ignore')
import torch
import torchvision
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from tqdm import tqdm


def get_set_sizes(sets, data_len):
    set_size = data_len // sets
    set_sizes = np.ones(sets) * set_size
    # numpy
    # [n, n, n, n] sets
    return set_sizes


def get_Pi_Multiclass(sets, classnum=6, noniid=False):
    Pi = []

    for i in range(sets):
        # randomly set prior
        this_Pi = np.random.rand(classnum) * 0.9 + 0.1

        this_Pi[i % classnum] *= 10

        this_Pi = this_Pi / np.sum(this_Pi)
        Pi.append(this_Pi)

    Pi = np.array(Pi)
    # numpy
    # [[...],  num0 * 10
    #  [...],  num1 * 10
    #  [...]]  setnum * classnum  Normalization in row

    return Pi


def get_Pi_Multiclass_clientnoniid(sets, classnum=6, noniid=False, clientid=0):
    # def get_Pi_Multiclass(sets, classnum=6, noniid=False):
    Pi = []

    for i in range(sets):
        # randomly set prior
        this_Pi = np.random.rand(classnum) * 0.9 + 0.1

        # dominate step
        if noniid:
            this_Pi = np.random.rand(classnum) * 0.02 + 0.05
            this_Pi[clientid] = np.random.rand(1) * 0.2 + 0.3

            this_Pi = np.random.rand(classnum) * 0.02 + 0.05

        this_Pi = this_Pi / np.sum(this_Pi)
        Pi.append(this_Pi)

    Pi = np.array(Pi)

    return Pi


def get_test_sets_Multiclass(y_test, classnum=6, clientnum=5, clientsize=2000):
    class_idx = []
    for cls in range(classnum):
        this_idx = [i for i, x in enumerate(y_test) if x == cls]
        class_idx.append(this_idx)
    # class_dix: [[...], index of each class
    #             [...]] classnum * ~

    test_clients = ()
    for i in range(clientnum):
        for cls in range(classnum):
            np.random.shuffle(class_idx[cls])

            # uniformly distributed
            n_this = int(clientsize / classnum)

            if cls == 0:
                cur_set = np.array(class_idx[cls][:n_this])
            else:
                cur_set = np.concatenate((cur_set, class_idx[cls][:n_this])).astype(int)

            np.random.shuffle(cur_set)

        test_clients = test_clients + (torch.from_numpy(cur_set),)
        # shuffle balance class index ([...], [...], [...], [...], [...]) clientnum

    return test_clients


def get_U_sets_Multiclass(bags, y_train, y_indices, bag_sizes, thetas, classnum=6):
    class_idx = []
    for cls in range(classnum):
        this_idx = [y_indices[i] for i, x in enumerate(y_train) if x == cls]
        class_idx.append(this_idx)
    # indices of each class [[...], [...]] class_num

    U_sets = ()
    size_bag = []
    # for every bag
    for i in range(bags):
        size_cls = []
        # for every class in a bag
        for cls in range(classnum):
            # shuffle data index list
            np.random.shuffle(class_idx[cls])
            # the number of data selected for this class: prior * setsize
            n_this = int(bag_sizes[i] * thetas[i][cls])

            # concatenate index
            if cls == 0:
                cur_set = np.array(class_idx[cls][:n_this])
            else:
                cur_set = np.concatenate((cur_set, class_idx[cls][:n_this])).astype(int)

            size_cls.append(len(class_idx[cls][:n_this]))

            # shuffle current set
            np.random.shuffle(cur_set)

        # concatenate different class data
        U_sets = U_sets + (torch.from_numpy(cur_set),)
        # shuffle indices of each U_set ([...], [...], [...]) bags

        size_bag.append(np.array(size_cls) / sum(size_cls))
        # num of each class in each U_set ([...], [...], [...]) bags

    # calculate priors corr for every U set
    sets_num_count = [len(U_sets[j]) for j in range(len(U_sets))]
    # sample_num in each U_set [...] bags

    priors_corr = torch.from_numpy(
        np.array([sets_num_count[k] / sum(sets_num_count) for k in range(len(sets_num_count))]))
    # tensor
    # Normalized set_sum_count

    bags_pi = np.array(size_bag)
    # num of each class in each U_set [[...], [...], [...]] bags [Normalized]

    return U_sets, priors_corr, bags_pi


def get_iid_Pi(clientnum, setnum_perclient, classnum):
    for _ in range(clientnum):
        this_Pi = torch.from_numpy(get_Pi_Multiclass(setnum_perclient, classnum=classnum))
        if _ == 0:
            iid_Pi = this_Pi
        else:
            iid_Pi += this_Pi
    iid_Pi = iid_Pi / torch.sum(iid_Pi, dim=0)

    return iid_Pi


def get_noniid_class_priority(client_num, classnum=6, dominate_rate=0.5):
    priority = []

    for client in range(client_num):
        this_label_shift = np.random.rand(classnum) * 0.1 + 0.45
        # [] .shape = 10
        this_label_shift[(2 * client) % classnum] *= (4 / (1 - dominate_rate))
        this_label_shift[(2 * client + 1) % classnum] *= (4 / (1 - dominate_rate))

        this_label_shift = this_label_shift / np.sum(this_label_shift)
        priority.append(this_label_shift)
        # [[   ...    ],    [(2 * client) % classnum] * 8 (each row)
        #  [   ...    ],    [(2 * client + 1) % classnum] * 8 (each row)
        #  [   ...    ]]   client_num * classnum   Normalization in row

    return priority


def get_class_index(targets, classnum=6):
    indexs = []

    for cls in range(classnum):
        this_index = [index for (index, value) in enumerate(targets) if value == cls]
        indexs.append(this_index)

    return indexs


def noniid_split_dataset(oridata, lengths, classnum=6, dominate_rate=0.95):
    subsets = []
    priority = get_noniid_class_priority(len(lengths), classnum=classnum, dominate_rate=dominate_rate)

    targets = oridata.targets.tolist()
    class_index = get_class_index(targets, classnum=classnum)
    # index of label 0 [[...],
    # index of label 1  [...]] classnum row
    class_count = [0 for _ in range(classnum)]

    for l in range(len(lengths)):
        this_indices = []
        for cls in range(classnum):
            cls_num = int(priority[l][cls] * lengths[l])

            this_indices.extend(class_index[cls][class_count[cls]: class_count[cls] + cls_num])
            class_count[cls] += cls_num

        this_subset = torch.utils.data.Subset(oridata, this_indices)
        # Dataset distribute with priority
        # [[   ...    ],    [(2 * client) % classnum] * 8 (each row)
        #  [   ...    ],    [(2 * client + 1) % classnum] * 8 (each row)
        #  [   ...    ]]   client_num * classnum   Normalization in row
        subsets.append(this_subset)
        # [Dataset0, ..., Dataset4]

    return subsets

def RSNA_UPPER_BOUND(data_path='./data', clientnum=5, setnum_perclient=12, classnum=6, noniid=False):
    all_train_data = Dataset_Early_Fusion(label_file='../input/rsna-intracranial-hemorrhage-detection/rsna-intracranial-hemorrhage-detection/stage_2_train.csv')
#     all_train_data.data = torch.from_numpy(all_train_data.data)
#     all_train_data.targets = torch.from_numpy(np.array(all_train_data.targets))

    validation_data = torch.from_numpy(all_train_data.data[84800:]) # !!!!!!!!!!
    validation_targets = torch.from_numpy(all_train_data.targets[84800:])
    test_data = torch.from_numpy(all_train_data.data[:21200])
    test_targets = torch.from_numpy(all_train_data.targets[:21200])
    all_train_data.data = torch.from_numpy(all_train_data.data[21200:84800])
    all_train_data.targets = torch.from_numpy(all_train_data.targets[21200:84800])

    # split client bags
    client_train_size = len(all_train_data) // clientnum
    client_validation_size = len(validation_data) // clientnum
    client_test_size = len(test_data) // clientnum
    if noniid:
        client_train_sets = noniid_split_dataset(all_train_data, [client_train_size for _ in range(clientnum)])
    else:
        client_train_sets = torch.utils.data.random_split(all_train_data, [client_train_size for _ in range(clientnum)])

    # get uniformly distributed test data index
    validation_client_idxs = get_test_sets_Multiclass(validation_targets, classnum=classnum, clientnum=clientnum,
                                                      clientsize=client_validation_size)
    test_client_idxs = get_test_sets_Multiclass(test_targets, classnum=classnum, clientnum=clientnum,
                                                clientsize=client_test_size)

    client_train_data = []
    client_test_data = []
    client_validation_data = []

    # get Pis, prior test, prior corr
    client_Pi = []
    client_prior_test = []
    client_priors_corr = []

    if not noniid:
        iid_Pi = get_iid_Pi(clientnum, setnum_perclient, classnum)

    print('Spliting U sets for', clientnum, 'clients, each with', setnum_perclient, 'U sets...')
    # for every client
    for n in tqdm(range(clientnum)):
        if noniid:
            this_Pi = torch.from_numpy(get_Pi_Multiclass(setnum_perclient, classnum=classnum))
        if not noniid:
            this_Pi = iid_Pi

        # w/o repeat
        this_set_sizes = get_set_sizes(setnum_perclient, len(client_train_sets[n]))
        this_U_sets, this_priors_corr, this_Pi = get_U_sets_Multiclass(setnum_perclient,
                                                                       client_train_sets[n].dataset.targets[
                                                                           client_train_sets[n].indices],
                                                                       client_train_sets[n].indices,
                                                                       this_set_sizes, this_Pi, classnum=classnum)

        client_Pi.append(torch.from_numpy(this_Pi))
        client_priors_corr.append(this_priors_corr)

        # get prior test
        this_prior_test = None
        client_prior_test.append(this_prior_test)

        # set subsets labels, for every set in every client
        client_set_temp_data = None
        client_set_temp_targets = None
        for i in range(setnum_perclient):
            # w/o repeat
            this_set_temp_data = client_train_sets[n].dataset.data[this_U_sets[i]]

            # surrogate label as set index
            this_set_temp_targets = client_train_sets[n].dataset.targets[this_U_sets[i]]

            # concatenate data and labels
            if i == 0:
                client_set_temp_data = this_set_temp_data
                client_set_temp_targets = this_set_temp_targets
            else:
                client_set_temp_data = torch.cat((client_set_temp_data, this_set_temp_data))
                client_set_temp_targets = torch.cat((client_set_temp_targets, this_set_temp_targets))

        # store different clients' data and labels in a dict, for further load
        client_train_data.append({'images': client_set_temp_data, 'labels': client_set_temp_targets})
        client_test_data.append({'images': test_data[test_client_idxs[n]],
                                 'labels': test_targets[test_client_idxs[n]]})
        client_validation_data.append({'images': validation_data[validation_client_idxs[n]],
                                       'labels': validation_targets[validation_client_idxs[n]]})

    return client_train_data, client_validation_data, client_test_data

class BaiscDataset(Dataset):
    def __init__(self, data, transform=None):
        self.transform = transform
        self.images = data['images']
        self.labels = data['labels']

    def __len__(self):
        return self.images.shape[0]

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]

        if self.transform is not None:
            image = self.transform(transforms.ToPILImage()(image))

        return image, label


In [4]:
import os
import sys
import torch
from torch import nn, optim
import copy
import argparse
import numpy as np
import torchvision.transforms as transforms


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    
def prepare_data(args):
    # Data Augmentation
    rotate_degree = 20
    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.Grayscale(num_output_channels=3),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    test_transform = transforms.Compose([
        transforms.Grayscale(num_output_channels=3),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # Get splited data
    client_train_data, client_validation_data, client_test_data = \
            RSNA_UPPER_BOUND(
                data_path="./data",
                clientnum=args.clientnum,
                setnum_perclient=args.setnum,
                noniid=args.noniid)

    # get dataloaders
    train_loaders = []
    test_loaders = []
    validation_loaders = []

    for i, this_client_data in enumerate(client_train_data):
        this_train_set = BaiscDataset(client_train_data[i], transform=train_transform)
        this_validation_set = BaiscDataset(client_validation_data[i], transform=test_transform)
        this_test_set = BaiscDataset(client_test_data[i], transform=test_transform)

        train_loaders.append(torch.utils.data.DataLoader(this_train_set, batch_size=args.batch, shuffle=True))
        validation_loaders.append(
            torch.utils.data.DataLoader(this_validation_set, batch_size=args.batch * 5, shuffle=False))
        test_loaders.append(torch.utils.data.DataLoader(this_test_set, batch_size=args.batch * 5, shuffle=False))

    return train_loaders, validation_loaders, test_loaders

def L1_Regularization(model):
    L1_reg = 0
    for param in model.parameters():
        L1_reg += torch.sum(torch.abs(param))

    return L1_reg


def train(args, model, train_loader, optimizer, loss_fun, client_num, device):
    model.train()
    num_data = 0
    correct = 0
    loss_all = 0
    train_iter = iter(train_loader)
    for step in range(len(train_iter)):
        optimizer.zero_grad()
        x, y = next(train_iter)
        num_data += y.size(0)
        x = x.to(device).float()
        y = y.to(device).long()

        output = model.predict(x)

        loss = loss_fun(output, y) + L1_Regularization(model) * args.wdecay
        loss.backward()
        loss_all += loss.item()

        optimizer.step()

    return loss_all / len(train_iter)


def test(model, test_loader, loss_fun, device, classnum=10):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for data, target in test_loader:
            data = data.to(device).float()
            target = target.to(device).long()

            output = model.predict(data)

            pred = output.data.max(1)[1]
            correct += pred.eq(target.view(-1)).sum().item()
            total += target.size(0)

    test_error = (total - correct) / total

    return test_error


def communication(args, server_model, models, client_weights):
    with torch.no_grad():
        # aggregate params
        for key in server_model.state_dict().keys():
            # num_batches_tracked is a non trainable LongTensor and
            # num_batches_tracked are the same for all clients for the given datasets
            if 'num_batches_tracked' in key:
                server_model.state_dict()[key].data.copy_(models[0].state_dict()[key])
            else:
                temp = torch.zeros_like(server_model.state_dict()[key])
                for client_idx in range(len(client_weights)):
                    temp += client_weights[client_idx] * models[client_idx].state_dict()[key]
                server_model.state_dict()[key].data.copy_(temp)
                for client_idx in range(len(client_weights)):
                    models[client_idx].state_dict()[key].data.copy_(server_model.state_dict()[key])
    return server_model, models


In [5]:
if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print('Device:', device, '\n')
    parser = argparse.ArgumentParser()
    parser.add_argument('--test', action='store_true', help='test the pretrained model')
    parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
    parser.add_argument('--wdecay', type=float, default=5e-4, help='learning rate')
    parser.add_argument('--batch', type=int, default=128, help='batch size')
    parser.add_argument('--iters', type=int, default=100, help='iterations for communication')
    parser.add_argument('--wk_iters', type=int, default=1,
                        help='optimization iters in local worker between communication')
    parser.add_argument('--mode', type=str, default='fedavg', help='fedavg')
    parser.add_argument('--save_path', type=str, default='./checkpoint/rsna', help='path to save the checkpoint')

    parser.add_argument('--clientnum', type=int, default=5, help='client number')
    parser.add_argument('--setnum', type=int, default=10, help='set number per client has')
    parser.add_argument('--classnum', type=int, default=10, help='class num')
    parser.add_argument('--seed', type=int, default=0, help='random seed')

    parser.add_argument('--noniid', action='store_true', help='noniid sampling')

    args = parser.parse_known_args()[0] # args = parser.parse_args()
    print(args)

    setup_seed(args.seed)

    exp_folder = 'rsna_fedavg'

    args.save_path = os.path.join(args.save_path, exp_folder)

    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)
    SAVE_PATH = os.path.join(args.save_path,
                             args.mode + 'client' + str(args.clientnum) + 'sets' + str(args.setnum) + 'seed' + str(
                                 args.seed) + str(args.noniid))

    # server model and ce loss
    server_model = RSNAModel(class_num=args.classnum).to(device)
    loss_fun = nn.CrossEntropyLoss()

    # prepare the data
    train_loaders, validation_loaders, test_loaders = prepare_data(args)

    print('\nData prepared, start training...\n')

    # federated setting
    client_num = args.clientnum
    clients = ['client' + str(_) for _ in range(1, client_num + 1)]
    client_weights = [1 / client_num for i in range(client_num)]
    models = [copy.deepcopy(server_model).to(device) for idx in range(client_num)]

    if args.test:
        checkpoint = torch.load(SAVE_PATH)
        server_model.load_state_dict(checkpoint['server_model'])
        this_test_error = []
        for test_idx, test_loader in enumerate(test_loaders):
            test_loss = test(server_model, test_loader, loss_fun, device, classnum=args.classnum)
            this_test_error.append(test_loss)
            print(' {:<8s}| Error Rate: {:.2f} %'.format(clients[test_idx], test_loss * 100.))
        print('Best Test Error: {:.2f} %'.format(100. * sum(this_test_error) / len(this_test_error)))

        exit(0)

    best_test_error = 1.
    training_loss_log = []
    error_rate_log = []

    # start training
    for a_iter in range(args.iters):
        # record training loss and test error rate
        this_test_error = []
        this_train_loss = []

        optimizers = [optim.Adam(params=models[idx].parameters(), lr=args.lr)
                      for idx in range(client_num)]

        for wi in range(args.wk_iters):
            print("============ Train epoch {} ============".format(wi + 1 + a_iter * args.wk_iters))

            for client_idx in range(client_num):
                model, train_loader, optimizer = models[client_idx], train_loaders[client_idx], optimizers[client_idx]
                train_loss = train(args, model, train_loader, optimizer, loss_fun, client_num, device)
                print(' {:<8s}| Train Loss: {:.4f}'.format(clients[client_idx], train_loss))

                this_train_loss.append(train_loss)

        # aggregation
        server_model, models = communication(args, server_model, models, client_weights)

        # start testing
        for test_idx, test_loader in enumerate(validation_loaders):
            test_loss = test(models[test_idx], test_loader, loss_fun, device, classnum=args.classnum)
            this_test_error.append(test_loss)
            print(' {:<8s}| Error Rate: {:.2f} %'.format(clients[test_idx], test_loss * 100.))

        print()

        # error rate after this communication
        this_test_error = sum(this_test_error) / len(this_test_error)
        if this_test_error < best_test_error:
            best_test_error = this_test_error

            # Save checkpoint
            print(' Saving checkpoints to {}'.format(SAVE_PATH))

            torch.save({
                'server_model': server_model.state_dict(),
                'a_iter': a_iter,
            }, SAVE_PATH)

        # Best Test Error Rate
        print(' Best Validation Error Rate: {:.2f} %, Current Validation Error Rate: {:.2f} %\n'.format(
            best_test_error * 100.,
            this_test_error * 100.
        ))

        training_loss_log.append(sum(this_train_loss) / len(this_train_loss))
        error_rate_log.append(this_test_error)

        if not os.path.exists(os.path.join('./logs/rsna_fedavg', args.mode)):
            os.makedirs(os.path.join('./logs/rsna_fedavg', args.mode))

    print(' Start final testing\n')
    checkpoint = torch.load(SAVE_PATH)
    server_model.load_state_dict(checkpoint['server_model'])
    this_test_error = []
    for test_idx, test_loader in enumerate(test_loaders):
        test_loss = test(server_model, test_loader, loss_fun, device, classnum=args.classnum)
        this_test_error.append(test_loss)
        print(' {:<8s}| Error Rate: {:.2f} %'.format(clients[test_idx], test_loss * 100.))
    print(' Best Test Error: {:.2f} %'.format(100. * sum(this_test_error) / len(this_test_error)))

    error_rate_log.append(sum(this_test_error) / len(this_test_error))
    # save record
    np.savetxt(os.path.join('./logs/rsna_fedavg', args.mode, 'client' + str(args.clientnum) +
                            'sets' + str(args.setnum) + 'seed' + str(args.seed) + str(args.noniid) + 'train_loss.txt'),
               training_loss_log, newline="\r\n")
    np.savetxt(os.path.join('./logs/rsna_fedavg', args.mode, 'client' + str(args.clientnum) +
                            'sets' + str(args.setnum) + 'seed' + str(args.seed) + str(args.noniid) + 'error_rate.txt'),
               error_rate_log, newline="\r\n")


Device: cuda 

Namespace(batch=128, classnum=10, clientnum=5, iters=100, lr=0.0001, mode='fedavg', noniid=False, save_path='./checkpoint/rsna', seed=0, setnum=10, test=False, wdecay=0.0005, wk_iters=1)


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

0 / 106000
10000 / 106000
20000 / 106000
30000 / 106000
40000 / 106000
50000 / 106000
60000 / 106000
70000 / 106000
80000 / 106000
90000 / 106000
100000 / 106000
Spliting U sets for 5 clients, each with 10 U sets...


100%|██████████| 5/5 [00:07<00:00,  1.59s/it]



Data prepared, start training...

 client1 | Train Loss: 72.6263
 client2 | Train Loss: 72.6412
 client3 | Train Loss: 72.8490
 client4 | Train Loss: 72.6322
 client5 | Train Loss: 72.6301
 client1 | Error Rate: 68.99 %
 client2 | Error Rate: 68.39 %
 client3 | Error Rate: 69.25 %
 client4 | Error Rate: 69.30 %
 client5 | Error Rate: 68.99 %

 Saving checkpoints to ./checkpoint/rsna/rsna_fedavg/fedavgclient5sets10seed0False
 Best Validation Error Rate: 68.98 %, Current Validation Error Rate: 68.98 %

 client1 | Train Loss: 50.5326
 client2 | Train Loss: 50.5469
 client3 | Train Loss: 50.7107
 client4 | Train Loss: 50.5559
 client5 | Train Loss: 50.5445
 client1 | Error Rate: 65.04 %
 client2 | Error Rate: 65.71 %
 client3 | Error Rate: 66.03 %
 client4 | Error Rate: 65.92 %
 client5 | Error Rate: 65.90 %

 Saving checkpoints to ./checkpoint/rsna/rsna_fedavg/fedavgclient5sets10seed0False
 Best Validation Error Rate: 65.72 %, Current Validation Error Rate: 65.72 %

 client1 | Train Loss