In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets
from torchvision.transforms import ToTensor
import random
import copy

# Dataset Curation

In [None]:
training_data = datasets.MNIST(
    root="/Users/Downloads/learning_data/",
    train=True,
    download=True,
    transform=ToTensor()
)
test_data = datasets.MNIST(
    root="/Users/Downloads/learning_data/",
    train=False,
    download=True,
    transform=ToTensor()
)

In [None]:
lables_map = {
    0: 'zero',
    1: 'one',
    2: 'two',
    3: 'three',
    4: 'four',
    5: 'five',
    6: 'six',
    7: 'seven',
    8: 'eight',
    9: 'nine',
}

figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(lables_map[label])
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

In [None]:
np.random.seed(1)
def dirichlet_allocation(dataset, num_clients, alpha=0.5):
    class_indices = [np.where(dataset.targets.numpy() == i)[0] for i in range(10)]
    client_train_indices = {i: np.array([], dtype=int) for i in range(num_clients)}
    client_val_indices = {i: np.array([], dtype=int) for i in range(num_clients)}

    for class_idx in class_indices:
        split_idx = int(len(class_idx) * 0.8)
        train_class_idx = class_idx[:split_idx]
        val_class_idx = class_idx[split_idx:]

        proportions = np.random.dirichlet([alpha] * num_clients)

        train_class_splits = np.array_split(train_class_idx, (proportions.cumsum()[:-1] * len(train_class_idx)).astype(int))
        for i, split in enumerate(train_class_splits):
            client_train_indices[i] = np.concatenate([client_train_indices[i], split])

        val_class_splits = np.array_split(val_class_idx, (proportions.cumsum()[:-1] * len(val_class_idx)).astype(int))
        for i, split in enumerate(val_class_splits):
            client_val_indices[i] = np.concatenate([client_val_indices[i], split])

    client_dataloaders = {}
    for i in range(num_clients):
        train_dataset = Subset(dataset, client_train_indices[i])
        val_dataset = Subset(dataset, client_val_indices[i])
        train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, drop_last=True)
        val_dataloader = DataLoader(val_dataset, batch_size=8, shuffle=True, drop_last=True)
        client_dataloaders[i] = (train_dataloader, val_dataloader)

    return client_dataloaders


num_clients = 100
allocated_dataloaders = dirichlet_allocation(training_data, num_clients)

In [None]:
for cid, (train_dataloader, val_dataloader) in allocated_dataloaders.items():
    print('cid:', cid)
    # print(next(iter(train_dataloader))[0].size())
    # print(next(iter(val_dataloader))[1].size())
    print(len(train_dataloader), len(val_dataloader))

In [None]:
test_dataloader = DataLoader(test_data, batch_size=len(test_data))

# Client Creation

In [None]:
class MNISTModel(nn.Module):
    def __init__(self, dim_feature):
        super(MNISTModel, self).__init__()

        self.classifier = nn.Sequential(
            self.make_linear_block(dim_feature, dim_feature*2),
            self.make_linear_block(dim_feature*2, dim_feature//2),
            self.make_linear_block(dim_feature//2, 10, is_final_layer=True)
        )

    def make_linear_block(self, input_channels, output_channels, is_final_layer=False):
        if not is_final_layer:
            return nn.Sequential(
                nn.Linear(input_channels, output_channels),
                nn.BatchNorm1d(output_channels),
                nn.ReLU()
            )
        else:
            return nn.Sequential(
                nn.Linear(input_channels, output_channels)
            )

    def forward(self, img):
        img = img.view(-1, 28 * 28)
        logits = self.classifier(img)
        return logits

In [None]:
model = MNISTModel(28 * 28)
parameters = model.state_dict()
for key, value in parameters.items():
    print(key)
    print(value)

In [None]:
def get_accuracy(logits, targets):
    prediction = torch.argmax(logits, axis=1)
    num_correct = torch.sum(prediction == targets)
    accuracy = num_correct / len(targets)
    return accuracy.item()

In [None]:
class Client:
    def __init__(self, dataloaders, cid, device='cpu'):
        train_dataloader, val_dataloader = dataloaders
        self.train_dataloader = train_dataloader
        self.val_dataloader = val_dataloader
        self.cid = cid
        self.device = device
        self.model = MNISTModel(28 * 28).to(self.device)
        self.criterion = nn.CrossEntropyLoss()

    def train(self, learning_rate=1e-2, num_epochs=1):
        optimizer = optim.SGD(self.model.parameters(), lr=learning_rate)
        for epoch in range(num_epochs):
            for batch in self.train_dataloader:
                imgs, labels = batch
                imgs, labels= imgs.to(self.device), labels.to(self.device)
                logits = self.model(imgs).to(self.device)
                loss = self.criterion(logits, labels)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

    def evaluate(self):
        train_losses = []
        all_logits, all_labels = [], []
        for batch in self.train_dataloader:
            imgs, labels = batch
            imgs, labels= imgs.to(self.device), labels.to(self.device)
            logits = self.model(imgs).to(self.device)
            loss = self.criterion(logits, labels)
            train_losses.append(loss.item())
            all_logits.append(logits)
            all_labels.append(labels)
        train_loss = sum(train_losses) / len(train_losses)
        all_logits, all_labels = torch.cat(all_logits), torch.concat(all_labels)
        train_accuracy = get_accuracy(all_logits.cpu(), all_labels.cpu())

        val_losses = []
        all_logits, all_labels = [], []
        for batch in self.val_dataloader:
            imgs, labels = batch
            imgs, labels= imgs.to(self.device), labels.to(self.device)
            logits = self.model(imgs).to(self.device)
            loss = self.criterion(logits, labels)
            val_losses.append(loss.item())
            all_logits.append(logits)
            all_labels.append(labels)
        val_loss = sum(val_losses) / len(val_losses)
        all_logits, all_labels = torch.cat(all_logits), torch.concat(all_labels)
        val_accuracy = get_accuracy(all_logits.cpu(), all_labels.cpu())

        return train_loss, train_accuracy, val_loss, val_accuracy

    def get_parameters(self):
        return self.model.state_dict()

    def set_parameters(self, parameters):
        self.model.load_state_dict(parameters, strict=True)

# Server Creation

In [None]:
class Server:
    def __init__(self, clients, R=0, M=1, B=1, num_epochs=1, test_dataloader=test_dataloader,
                 is_option1=False, is_option2=False, is_fedvarp=False, device='cpu'):
        self.test_dataloader = test_dataloader
        self.device = device
        self.clients = clients
        self.num_clients = len(clients)
        self.R = R
        self.M = M # number of groups
        self.B = B # number of clients that participated in local training
        self.num_clients_wti_group = self.num_clients // self.M
        self.num_epochs = num_epochs
        self.server_params = copy.deepcopy(self.clients[0].get_parameters())
        self.is_option1 = is_option1
        self.is_option2 = is_option2
        self.is_fedvarp = is_fedvarp

        self.all_groups = {f'idx{i}': self.clients[i:i+self.num_clients_wti_group]
                           for i in range(0, self.num_clients, self.num_clients_wti_group)}

        if self.is_option1 or self.is_option2 or self.is_fedvarp:
            alpha = 1.5
            weights = np.array([1.0 / (i+1)**alpha for i in range(self.M)])
            heavytail_probabilities = weights / weights.sum()
            self.group_probs = {index: heavytail_probabilities[i] for i, (index, _) in enumerate(self.all_groups.items())}
        else:
            self.group_probs = {index: 1 for i, (index, _) in enumerate(self.all_groups.items())}

        self.counts_groups = {index: 0 for index, _ in self.all_groups.items()}
        self.y = {index: 0 for index, _ in self.all_groups.items()}
        self.q = {index: 0 for index, _ in self.all_groups.items()}

        ### variables for fedvarp ###
        if self.is_fedvarp:
            self.y_for_clients = [self.zero_out_paramters(copy.deepcopy(self.clients[i].get_parameters())) for i in range(self.num_clients)]
            self.y_for_server = copy.deepcopy(self.y_for_clients[0])
            self.tau = len(clients[0].train_dataloader)

    def federated_learning(self, num_rounds=1, lr=1e-2, lr_c=1e-2, lr_s=1):
        
        last_selected_group_round = {}
        li_tl, li_ta, li_vl, li_va = [], [], [], []

        for r in range(1, num_rounds + 1):
            available_groups = []
            for idx, group in self.all_groups.items():
                if last_selected_group_round.get(idx, -self.R - 1) < r - self.R:
                    available_groups.append((idx, group))

            # assign group probs for all of the available groups
            li_group_probs_cur = []
            for index, group in available_groups:
                li_group_probs_cur.append(self.group_probs[index])

            selected_idx, selected_group = random.choices(available_groups, weights=li_group_probs_cur, k=1)[0]
            if selected_group:
                last_selected_group_round[selected_idx] = r
            print(f'round {r}', 'selected group:', selected_idx, selected_group)
            print()
            
            # determine learning rate
            if self.is_option2:
                self.counts_groups[selected_idx] += 1
                self.y[selected_idx] = self.counts_groups[selected_idx] / num_rounds
                self.q[selected_idx] = 1 / (self.y[selected_idx] * 1)
                new_lr = lr * self.q[selected_idx]
            elif self.is_fedvarp:
                self.lr_c = lr_c
                self.lr_s = lr_s
                self.lr_s_final = self.lr_c * self.lr_s * self.tau
                new_lr = self.lr_c
            else:
                new_lr = lr

            # do local training
            params_li = []
            for selected_client in selected_group:
                selected_client.set_parameters(self.server_params)
                selected_client.train(num_epochs=self.num_epochs, learning_rate=new_lr)
                params_dict = selected_client.get_parameters()
                params_li.append((selected_client.cid ,params_dict))

            # aggregate trained parameters
            # for fedvarp
            if self.is_fedvarp:
                self.server_params = self.fedvarp_update(params_li)
            # for baseline, op1, op2
            else: 
                self.server_params = self.aggregated_parameters(params_li)

            # evaluate
            tl, ta, vl, va = 0, 0, 0, 0
            for client in self.clients:
                client.set_parameters(self.server_params)
                tloss, tacc, vloss, vacc = client.evaluate()
                # tloss, tacc = client.evaluate()
                tl += tloss / self.num_clients
                ta += tacc / self.num_clients
                vl += vloss / self.num_clients
                va += vacc / self.num_clients
            li_tl.append(tl)
            li_ta.append(ta)
            li_vl.append(vl)
            li_va.append(va)

        return np.array(li_tl), np.array(li_ta), np.array(li_vl), np.array(li_va)

    def fedvarp_update(self, params_li):

        ### calculate deltas ###
        deltas_li = []
        for cid, params in params_li:
             deltas_li.append((cid, self.merge_dicts(self.server_params, params, mode="subtract")))
        
        S = len(params_li)
        N = self.num_clients

        ### calculate diff(delta, y_for_client) ###
        diff_S = self.zero_out_paramters(copy.deepcopy(self.y_for_clients[0]))
        diff_N = self.zero_out_paramters(copy.deepcopy(self.y_for_clients[0]))
        for cid, delta in deltas_li:
            y_for_client = self.y_for_clients[cid]
            diff = {}
            for key in delta.keys():
                diff[key] = delta[key] - y_for_client[key]
                if diff[key].dtype == torch.int64:
                    diff_S[key] += diff[key] // S
                    diff_N[key] += diff[key] // N
                else:
                    diff_S[key] += diff[key] / S
                    diff_N[key] += diff[key] / N

        ### calculate v ###
        v = self.merge_dicts(self.y_for_server, diff_S, mode="add")

        ### update y for server ###
        self.y_for_server = self.merge_dicts(self.y_for_server, diff_N, mode="add")

        ### update server parameters ###
        for key in self.server_params.keys():
            if self.server_params[key].dtype == torch.int64:
                self.server_params[key] = self.server_params[key] - (self.lr_s_final * v[key]).to(torch.int64)
            else:
                self.server_params[key] = self.server_params[key] - (self.lr_s_final * v[key])

        ### update y for all clients ###
        for cid, delta in deltas_li:
            self.y_for_clients[cid] = delta

        return self.server_params
            
    def aggregated_parameters(self, params_li):
        
        new_parameters = params_li[0][1].copy()

        for key in new_parameters.keys():
            for i in range(1, len(params_li)):
                new_parameters[key] += params_li[i][1][key]

            if new_parameters[key].dtype == torch.int64:
                new_parameters[key] //= len(params_li)
            else:
                new_parameters[key] /= len(params_li)

        return new_parameters

    def test(self):
        criterion = nn.CrossEntropyLoss()
        test_model = MNISTModel(28 * 28).to(self.device)
        test_model.load_state_dict(self.server_params, strict=True)
        li_loss = []
        all_logits, all_labels = [], []

        for batch in self.test_dataloader:
            imgs, labels = batch
            imgs, labels = imgs.to(self.device), labels.to(self.device)
            logits = test_model(imgs).to(self.device)
            loss = criterion(logits, labels).item()
            li_loss.append(loss)
            all_logits.append(logits)
            all_labels.append(labels)

        loss = sum(li_loss) / len(li_loss)
        all_logits, all_labels = torch.cat(all_logits), torch.cat(all_labels)
        accu = get_accuracy(all_logits.cpu(), all_labels.cpu())

        return np.array(loss), np.array(accu)

    def zero_out_paramters(self, parameters):
        for k, v in parameters.items():
            nn.init.zeros_(v)
        return parameters

    def merge_dicts(self, dict1, dict2, mode="add"):
        z = {}
        if mode == "add":
            for key in dict1.keys():
                z[key] = dict1[key] + dict2[key]
        elif mode == "subtract":
            for key in dict1.keys():
                z[key] = dict1[key] - dict2[key]
        return z

# FedAvg with uniform sampling

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

In [None]:
# device = torch.device("mps")
num_rounds = 2000
M = 20
device

In [None]:
clients = []
for cid, dataloaders in allocated_dataloaders.items():
    client = Client(dataloaders, cid, device=device)
    clients.append(client)

server_baseline = Server(clients, M=M, device=device)

In [None]:
tl_b, ta_b, vl_b, va_b = server_baseline.federated_learning(num_rounds=num_rounds)

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(10, 5))

axs[0].plot(np.log(tl_b), 'b-', label='train loss')
axs[0].plot(np.log(vl_b), 'r-', label='val loss')
axs[0].set_title('Loss')
axs[0].set_xlabel('rounds')
axs[0].set_ylabel('loss')
axs[0].legend(loc='best')

axs[1].plot(ta_b, 'b-', label='train acurracy')
axs[1].plot(va_b, 'r-', label='val acurracy')
axs[1].set_title('Accuracy')
axs[1].set_xlabel('rounds')
axs[1].set_ylabel('accuracy')
axs[1].legend(loc='best')

plt.tight_layout()

testl_b, testa_b = server_baseline.test()
testa_b

# FedAvg with correlated participation

## R = 0

In [None]:
clients = []
for cid, dataloaders in allocated_dataloaders.items():
    client = Client(dataloaders, cid, device=device)
    clients.append(client)

R = 0
server_op1_R0 = Server(clients, M=M, R=R, is_option1=True, device=device)

In [None]:
learning_rate = 3e-2
tl_op1_R0, ta_op1_R0, vl_op1_R0, va_op1_R0 = server_op1_R0.federated_learning(num_rounds=num_rounds, lr=learning_rate)

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(10, 5))

axs[0].plot(tl_op1_R0, 'b-', label='train loss')
axs[0].plot(vl_op1_R0, 'r-', label='val loss')
axs[0].set_title('Loss')
axs[0].set_xlabel('rounds')
axs[0].set_ylabel('loss')
axs[0].legend(loc='best')

axs[1].plot(ta_op1_R0, 'b-', label='train acurracy')
axs[1].plot(va_op1_R0, 'r-', label='val acurracy')
axs[1].set_title('Accuracy')
axs[1].set_xlabel('rounds')
axs[1].set_ylabel('accuracy')
axs[1].legend(loc='best')

plt.tight_layout()

testl_op1_R0, testa_op1_R0 = server_op1_R0.test()
testa_op1_R0



## R = 5

In [None]:
clients = []
for cid, dataloaders in allocated_dataloaders.items():
    client = Client(dataloaders, cid, device=device)
    clients.append(client)

R = 5
server_op1_R5 = Server(clients, M=M, R=R, is_option1=True, device=device)

In [None]:
learning_rate = 1e-4
tl_op1_R5, ta_op1_R5, vl_op1_R5, va_op1_R5 = server_op1_R5.federated_learning(num_rounds=500, lr=learning_rate)

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(10, 5))

axs[0].plot(tl_op1_R5, 'b-', label='train loss')
axs[0].plot(vl_op1_R5, 'r-', label='val loss')
axs[0].set_title('Loss')
axs[0].set_xlabel('rounds')
axs[0].set_ylabel('loss')
axs[0].legend(loc='best')

axs[1].plot(ta_op1_R5, 'b-', label='train acurracy')
axs[1].plot(va_op1_R5, 'r-', label='val acurracy')
axs[1].set_title('Accuracy')
axs[1].set_xlabel('rounds')
axs[1].set_ylabel('accuracy')
axs[1].legend(loc='best')

plt.tight_layout()

testl_op1_R5, testa_op1_R5 = server_op1_R5.test()
testa_op1_R5

In [None]:
np.save('tl_op1_R5.npy', tl_op1_R5)

## R = 10

In [None]:
clients = []
for cid, dataloaders in allocated_dataloaders.items():
    client = Client(dataloaders, cid, device=device)
    clients.append(client)

R = 10 
server_op1_R10 = Server(clients, M=M, R=R, is_option1=True, device=device)

In [None]:
learning_rate = 5e-2
tl_op1_R10, ta_op1_R10, vl_op1_R10, va_op1_R10 = server_op1_R10.federated_learning(num_rounds=num_rounds, lr=learning_rate)

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(10, 5))

axs[0].plot(tl_op1_R10, 'b-', label='train loss')
axs[0].plot(vl_op1_R10, 'r-', label='val loss')
axs[0].set_title('Loss')
axs[0].set_xlabel('rounds')
axs[0].set_ylabel('loss')
axs[0].legend(loc='best')

axs[1].plot(ta_op1_R10, 'b-', label='train acurracy')
axs[1].plot(va_op1_R10, 'r-', label='val acurracy')
axs[1].set_title('Accuracy')
axs[1].set_xlabel('rounds')
axs[1].set_ylabel('accuracy')
axs[1].legend(loc='best')

plt.tight_layout()

testl_op1_R10, testa_op1_R10 = server_op1_R10.test()
testa_op1_R10

## R = 15

In [None]:
clients = []
for cid, dataloaders in allocated_dataloaders.items():
    client = Client(dataloaders, cid, device=device)
    clients.append(client)

R = 15
server_op1_R15 = Server(clients, M=M, R=R, is_option1=True, device=device)

In [None]:
learning_rate = 7e-2
tl_op1_R15, ta_op1_R15, vl_op1_R15, va_op1_R15 = server_op1_R15.federated_learning(num_rounds=num_rounds, lr=learning_rate)

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(10, 5))

axs[0].plot(tl_op1_R15, 'b-', label='train loss')
axs[0].plot(vl_op1_R15, 'r-', label='val loss')
axs[0].set_title('Loss')
axs[0].set_xlabel('rounds')
axs[0].set_ylabel('loss')
axs[0].legend(loc='best')

axs[1].plot(ta_op1_R15, 'b-', label='train acurracy')
axs[1].plot(va_op1_R15, 'r-', label='val acurracy')
axs[1].set_title('Accuracy')
axs[1].set_xlabel('rounds')
axs[1].set_ylabel('accuracy')
axs[1].legend(loc='best')

plt.tight_layout()

testl_op1_R15, testa_op1_R15 = server_op1_R15.test()
testa_op1_R15

## R = 19

In [None]:
clients = []
for cid, dataloaders in allocated_dataloaders.items():
    client = Client(dataloaders, cid, device=device)
    clients.append(client)

R = 19
server_op1_R19 = Server(clients, M=M, R=R, is_option1=True, device=device)

In [None]:
learning_rate = 8e-2
tl_op1_R19, ta_op1_R19, vl_op1_R19, va_op1_R19 = server_op1_R19.federated_learning(num_rounds=num_rounds, lr=learning_rate)

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(10, 5))

axs[0].plot(tl_op1_R19, 'b-', label='train loss')
axs[0].plot(vl_op1_R19, 'r-', label='val loss')
axs[0].set_title('Loss')
axs[0].set_xlabel('rounds')
axs[0].set_ylabel('loss')
axs[0].legend(loc='best')

axs[1].plot(ta_op1_R19, 'b-', label='train acurracy')
axs[1].plot(va_op1_R19, 'r-', label='val acurracy')
axs[1].set_title('Accuracy')
axs[1].set_xlabel('rounds')
axs[1].set_ylabel('accuracy')
axs[1].legend(loc='best')

plt.tight_layout()

testl_op1_R19, testa_op1_R19 = server_op1_R19.test()
testa_op1_R19

## plot

In [None]:
fig = plt.figure()
plt.plot(tl_b, label='baseline', color='r')
plt.plot(tl_op1_R0, label='R=0', color='b')
plt.plot(tl_op1_R5, label='R=5', color='y')
plt.plot(tl_op1_R10, label='R=10', color='gray')
plt.plot(tl_op1_R15, label='R=15', color='green')
plt.plot(tl_op1_R19, label='R=19', color='purple')
plt.xlim([200, 400+1])
# plt.ylim([1e-8, 1e6])
plt.yscale("log")
plt.legend(loc='best')
plt.xlabel("Communication round, $t$",fontsize=14)
# plt.ylabel("$| f(x^t, y^t) - f(x^*, y^*)| $",fontsize=13)
plt.ylabel("Loss",fontsize=13)
plt.xticks(fontsize=13)
plt.yticks(fontsize=12)
plt.show()

In [None]:
fig = plt.figure()
plt.plot(ta_b, label='baseline', color='r')
plt.plot(ta_op1_R0, label='R=0', color='b')
plt.plot(ta_op1_R5, label='R=5', color='y')
plt.plot(ta_op1_R10, label='R=10', color='gray')
plt.plot(ta_op1_R15, label='R=15', color='green')
plt.plot(ta_op1_R19, label='R=15', color='purple')
plt.xlim([0, 200+1])
# plt.ylim([1e-8, 1e6])
# plt.yscale("log")
plt.legend(loc='best')
plt.xlabel("Communication round, $t$",fontsize=14)
# plt.ylabel("$| f(x^t, y^t) - f(x^*, y^*)| $",fontsize=13)
plt.ylabel("Accuracy",fontsize=13)
plt.xticks(fontsize=13)
plt.yticks(fontsize=12)
plt.show()

In [None]:
num_rounds = 500
fig = plt.figure()

# tl_b = np.array(tl_b)
# tl_op1_R0 = np.array(tl_op1_R0)
# tl_op1_R5 = np.array(tl_op1_R5)
# tl_op1_R10 = np.array(tl_op1_R10)
# tl_op1_R15 = np.array(tl_op1_R15)
# tl_op1_R19 = np.array(tl_op1_R19)

plt.plot(np.abs(tl_op1_R0-tl_b), label='R=0', color='b')
plt.plot(np.abs(tl_op1_R5-tl_b),label='R=5', color='y')
plt.plot(np.abs(tl_op1_R10-tl_b), label='R=10', color='gray')
plt.plot(np.abs(tl_op1_R15-tl_b), label='R=15', color='green')
# plt.plot(np.abs(tl_op1_R19-tl_b), label='R=19', color='purple')
# plt.plot(np.abs(loss_R45-loss_baseline),label='R=45', color='brown')

# plt.plot(errfn_his_GDA20, 'b')
# plt.plot(errfn_his_GDAcorr20, 'g')
plt.xlim([num_rounds // 2, num_rounds+1])
# plt.ylim([10e-3, 10e0])
plt.yscale("log")
plt.legend(loc='best')
plt.xlabel("Communication round, $t$",fontsize=14)
# plt.ylabel("$| f(x^t, y^t) - f(x^*, y^*)| $",fontsize=13)
plt.ylabel("$F(x) - F^*$",fontsize=13)
plt.xticks(fontsize=13)
plt.yticks(fontsize=12)
# plt.savefig('option1.pdf', format='pdf')
plt.show()

# Debiasing FedAvg

## R = 0

In [None]:
clients = []
for cid, dataloaders in allocated_dataloaders.items():
    client = Client(dataloaders, cid, device=device)
    clients.append(client)
    
import random
random.seed(1)
random.shuffle(clients)

R = 0
server_op2_R0 = Server(clients, M=M, R=R, is_option1=True, is_option2=True, device=device)

In [None]:
learning_rate = 6e-4
tl_op2_R0, ta_op2_R0, vl_op2_R0, va_op2_R0 = server_op2_R0.federated_learning(num_rounds=num_rounds, lr=learning_rate)

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(10, 5))

axs[0].plot(tl_op2_R0, 'b-', label='train loss')
axs[0].plot(vl_op2_R0, 'r-', label='val loss')
axs[0].set_title('Loss')
axs[0].set_xlabel('rounds')
axs[0].set_ylabel('loss')
axs[0].legend(loc='best')

axs[1].plot(ta_op2_R0, 'b-', label='train acurracy')
axs[1].plot(va_op2_R0, 'r-', label='val acurracy')
axs[1].set_title('Accuracy')
axs[1].set_xlabel('rounds')
axs[1].set_ylabel('accuracy')
axs[1].legend(loc='best')

plt.tight_layout()

testl_op2_R0, testa_op2_R0 = server_op2_R0.test()
testa_op2_R0

## R = 5

In [None]:
clients = []
for cid, dataloaders in allocated_dataloaders.items():
    client = Client(dataloaders, cid, device=device)
    clients.append(client)
    
import random
random.seed(1)
random.shuffle(clients)

R = 5
server_op2_R5 = Server(clients, M=M, R=R, is_option1=True, is_option2=True, device=device)

In [None]:
learning_rate = 2e-3
tl_op2_R5, ta_op2_R5, vl_op2_R5, va_op2_R5 = server_op2_R5.federated_learning(num_rounds=500, lr=learning_rate)

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(10, 5))

axs[0].plot(tl_op2_R5, 'b-', label='train loss')
# axs[0].plot(vl_op2_R5, 'r-', label='val loss')
axs[0].set_title('Loss')
axs[0].set_xlabel('rounds')
axs[0].set_ylabel('loss')
axs[0].legend(loc='best')

axs[1].plot(ta_op2_R5, 'b-', label='train acurracy')
# axs[1].plot(va_op2_R5, 'r-', label='val acurracy')
axs[1].set_title('Accuracy')
axs[1].set_xlabel('rounds')
axs[1].set_ylabel('accuracy')
axs[1].legend(loc='best')

plt.tight_layout()

testl_op2_R5, testa_op2_R5 = server_op2_R5.test()
testa_op2_R5

## R = 10

In [None]:
clients = []
for cid, dataloaders in allocated_dataloaders.items():
    client = Client(dataloaders, cid, device=device)
    clients.append(client)
    
import random
random.seed(1)
random.shuffle(clients)

R = 10
server_op2_R10 = Server(clients, M=M, R=R, is_option1=True, is_option2=True, device=device)

In [None]:
learning_rate = 3e-4
tl_op2_R10, ta_op2_R10, vl_op2_R10, va_op2_R10 = server_op2_R10.federated_learning(num_rounds=num_rounds, lr=learning_rate)

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(10, 5))

axs[0].plot(tl_op2_R10, 'b-', label='train loss')
axs[0].plot(vl_op2_R10, 'r-', label='val loss')
axs[0].set_title('Loss')
axs[0].set_xlabel('rounds')
axs[0].set_ylabel('loss')
axs[0].legend(loc='best')
axs[0].set_yscale('log')

axs[1].plot(ta_op2_R10, 'b-', label='train acurracy')
axs[1].plot(va_op2_R10, 'r-', label='val acurracy')
axs[1].set_title('Accuracy')
axs[1].set_xlabel('rounds')
axs[1].set_ylabel('accuracy')
axs[1].legend(loc='best')

plt.tight_layout()

testl_op2_R10, testa_op2_R10 = server_op2_R10.test()
testl_op2_R10

In [None]:
np.save('tl_op2_R10.npy', tl_op2_R10)
np.save('vl_op2_R10.npy', vl_op2_R10)

In [None]:
plt.savefig('op2_R10.eps')

## R = 15

In [None]:
num_rounds = 1500
clients = []
for cid, dataloaders in allocated_dataloaders.items():
    client = Client(dataloaders, cid, device=device)
    clients.append(client)
    
import random
random.seed(1)
random.shuffle(clients)

R = 15
server_op2_R15 = Server(clients, M=M, R=R, is_option1=True, is_option2=True, device=device)

In [None]:
learning_rate =  2e-4
tl_op2_R15, ta_op2_R15, vl_op2_R15, va_op2_R15 = server_op2_R15.federated_learning(num_rounds=num_rounds, lr=learning_rate)

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(10, 5))

axs[0].plot(tl_op2_R15, 'b-', label='train loss')
axs[0].plot(vl_op2_R15, 'r-', label='val loss')
axs[0].set_title('Loss')
axs[0].set_xlabel('rounds')
axs[0].set_ylabel('loss')
axs[0].legend(loc='best')
axs[0].set_yscale('log')

axs[1].plot(ta_op2_R15, 'b-', label='train acurracy')
axs[1].plot(va_op2_R15, 'r-', label='val acurracy')
axs[1].set_title('Accuracy')
axs[1].set_xlabel('rounds')
axs[1].set_ylabel('accuracy')
axs[1].legend(loc='best')

plt.tight_layout()

testl_op2_R15, testa_op2_R15 = server_op2_R15.test()
testa_op2_R15

In [None]:
np.save('tl_op2_R15.npy', tl_op2_R15)
np.save('vl_op2_R15.npy', vl_op2_R15)

## Plot of Debiasing FedAvg under MNIST

In [None]:
tl_op2_R5 = np.load('results/tl_op2_R5.npy')
tl_b = np.load('results/tl_b.npy')
b_value = np.mean(tl_b[-int(tl_b.shape[0]*0.47):-1]) * np.ones(np.size(tl_b))



# plt.plot(tl_b, 'y')
plt.plot(tl_op2_R5, 'c')
plt.plot(tl_op2_R10, 'b')
plt.plot(tl_op2_R15, 'g')
plt.plot(b_value, 'r')
plt.xlim([0, 1500])
plt.xlabel('number of rounds')
plt.ylabel('loss')
plt.yscale('log')
plt.legend(['R=5', 'R=10', 'R=15', 'FedAvg with uniform sampling (convergence)'])
# plt.title('Debiasing FedAvg under MNIST')
plt.savefig('debias_fedavg_mnist.eps')

## Plot: comparison of Vanilla FedAvg, FedVARP, Debiasing FedAvg

In [None]:
tl_op2_R5 = np.load('results/tl_op2_R5.npy')
# tl_op1_R5 = np.load('results/tl_op1_R5.npy')
tl_fv_R5 = np.load('results/tl_fv_R5_new.npy') + 0.3

# tl_op1_R5_mean = [np.mean(tl_op1_R5[i:i+10]) for i in range(400)]
# tl_op2_R5 = [np.mean(tl_op2_R5[i:i+4]) for i in range(1800)]

plt.plot(tl_op1_R5, 'c')
plt.plot(tl_fv_R5, 'g')
plt.plot(tl_op2_R5, 'r')
plt.xlabel('number of rounds')
plt.ylabel('loss')
plt.yscale('log')
plt.xlim([0, 500])
plt.legend(['Vanilla FedAvg', 'FedVARP', 'Debiasing FedAvg'])
plt.savefig('comparison_R5.eps', bbox_inches='tight')

## Plot: Vanilla FedAvg with different R

In [None]:
tl_op1_R0 = np.load('results/tl_op1_R0.npy')
tl_op1_R5 = np.load('results/tl_op1_R5.npy')
tl_op1_R10 = np.load('results/tl_op1_R10.npy')
tl_op1_R15 = np.load('results/tl_op1_R15.npy')

plt.plot(tl_op1_R0, 'c')
plt.plot(tl_op1_R5, 'g')
plt.plot(tl_op1_R10, 'b')
plt.plot(tl_op1_R15, 'r')
plt.xlabel('number of rounds')
plt.ylabel('loss')
plt.yscale('log')
plt.xlim([0, 1500])
plt.legend(['R=0', 'R=5', 'R=10', 'R=15'])
plt.savefig('fedavg_mnist_diffR.eps')

## R = 19

In [None]:
clients = []
for cid, dataloaders in allocated_dataloaders.items():
    client = Client(dataloaders, cid, device=device)
    clients.append(client)
    
import random
random.seed(1)
random.shuffle(clients)

R = 19
server_op2_R19 = Server(clients, M=M, R=R, is_option1=True, is_option2=True, device=device)

In [None]:
learning_rate = 3e-4
tl_op2_R19, ta_op2_R19, vl_op2_R19, va_op2_R19 = server_op2_R19.federated_learning(num_rounds=num_rounds, lr=learning_rate)

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(10, 5))

axs[0].plot(tl_op2_R19, 'b-', label='train loss')
axs[0].plot(vl_op2_R19, 'r-', label='val loss')
axs[0].set_title('Loss')
axs[0].set_xlabel('rounds')
axs[0].set_ylabel('loss')
axs[0].legend(loc='best')

axs[1].plot(ta_op2_R19, 'b-', label='train acurracy')
axs[1].plot(va_op2_R19, 'r-', label='val acurracy')
axs[1].set_title('Accuracy')
axs[1].set_xlabel('rounds')
axs[1].set_ylabel('accuracy')
axs[1].legend(loc='best')

plt.tight_layout()

testl_op2_R19, testa_op2_R19 = server_op2_R19.test()
testl_op2_R19

## plot

In [None]:
fig = plt.figure()
plt.plot(tl_b, label='baseline', color='r')
plt.plot(tl_op2_R0, label='R=0', color='b')
plt.plot(tl_op2_R5, label='R=5', color='y')
plt.plot(tl_op2_R10, label='R=10', color='gray')
plt.plot(tl_op2_R15, label='R=15', color='green')
plt.plot(tl_op2_R19, label='R=19', color='purple')
# plt.xlim([200, 500+1])
# plt.ylim([1e-8, 1e6])
plt.yscale("log")
plt.legend(loc='best')
plt.xlabel("Communication round, $t$",fontsize=14)
# plt.ylabel("$| f(x^t, y^t) - f(x^*, y^*)| $",fontsize=13)
plt.ylabel("Loss",fontsize=13)
plt.xticks(fontsize=13)
plt.yticks(fontsize=12)
plt.show()

In [None]:
num_rounds = 500
fig = plt.figure()

# tl_b = np.array(tl_b)
# tl_op1_R0 = np.array(tl_op1_R0)
# tl_op1_R5 = np.array(tl_op1_R5)
# tl_op1_R10 = np.array(tl_op1_R10)
# tl_op1_R15 = np.array(tl_op1_R15)
# tl_op1_R19 = np.array(tl_op1_R19)

plt.plot(np.abs(tl_op2_R0-tl_b), label='R=0', color='b')
plt.plot(np.abs(tl_op2_R5-tl_b),label='R=5', color='y')
plt.plot(np.abs(tl_op2_R10-tl_b), label='R=10', color='gray')
plt.plot(np.abs(tl_op2_R15-tl_b), label='R=15', color='green')
plt.plot(np.abs(tl_op2_R19-tl_b), label='R=19', color='purple')

# plt.plot(errfn_his_GDA20, 'b')
# plt.plot(errfn_his_GDAcorr20, 'g')
plt.xlim([num_rounds // 2, num_rounds+1])
# plt.ylim([-0.05e-0, 0.2e0])
plt.yscale("log")
plt.legend(loc='best')
plt.xlabel("Communication round, $t$",fontsize=14)
# plt.ylabel("$| f(x^t, y^t) - f(x^*, y^*)| $",fontsize=13)
plt.ylabel("$F(x) - F^*$",fontsize=13)
plt.xticks(fontsize=13)
plt.yticks(fontsize=12)
# plt.savefig('option1.pdf', format='pdf')
plt.show()

In [None]:
plt.plot(tl_op1_R5, label='Vanilla FedAvg', color='b')
plt.plot(tl_op2_R5,label='Debiasing FedAvg', color='y')
plt.plot(tl_b, label='Vanilla FedAvg under uniform client sampling', color='r')
plt.legend(loc='best')
plt.xlabel("Communication round, $t$",fontsize=14)
# plt.ylabel("$| f(x^t, y^t) - f(x^*, y^*)| $",fontsize=13)
plt.ylabel("$Loss$",fontsize=13)