In [None]:
import copy
import os
import time
import pickle
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm import tqdm
import matplotlib.pyplot as plt
from tensorboardX import SummaryWriter
import pandas as pd

class Args:
    def __init__(self):
        self.epochs = 50
        self.num_users = 100
        self.frac = 0.1
        self.lr = 0.001
        self.momentum = 0.5
        self.model = 'cnn'
        self.kernel_num = 9
        self.kernel_sizes = '3,4,5'
        self.num_channels = 1
        self.norm = 'batch_norm'
        self.num_filters = 32
        self.max_pool = 'True'
        self.dataset = 'mnist'
        self.num_classes = 10
        self.gpu = None
        self.optimizer = 'sgd'
        self.iid = 1
        self.unequal = 0
        self.stopping_rounds = 10
        self.verbose = 1
        self.seed = 1
        self.gpu_id = None
        self.use_fractional = True  # Using the fractional update rule
        self.alpha = 0.99  # Fractional order parameter
        self.mu0 = 0.1  # Initial learning rate
        self.delta = 0.1  # Delta parameter

args = Args()

def exp_details(args):
    pass

def mnist_iid(dataset, num_users):
    num_items = int(len(dataset) / num_users)
    dict_users = {i: set() for i in range(num_users)}
    all_idxs = np.arange(len(dataset))
    for i, idx in enumerate(all_idxs):
        dict_users[i % num_users].add(idx)
    return dict_users

def mnist_noniid(dataset, num_users):
    num_shards, num_imgs = 200, 300
    idx_shard = [i for i in range(num_shards)]
    dict_users = {i: np.array([]) for i in range(num_users)}
    idxs = np.arange(num_shards * num_imgs)
    labels = dataset.targets.numpy()

    idxs_labels = np.vstack((idxs, labels))
    idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
    idxs = idxs_labels[0, :]

    for i in range(num_users):
        rand_set = set(np.random.choice(idx_shard, 2, replace=False))
        idx_shard = list(set(idx_shard) - rand_set)
        for rand in rand_set:
            dict_users[i] = np.concatenate((dict_users[i], idxs[rand * num_imgs:(rand + 1) * num_imgs]), axis=0)
    return dict_users

def get_dataset(args, iid):
    data_dir = './data/' + args.dataset + '/'
    apply_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

    train_dataset = datasets.MNIST(data_dir, train=True, download=True, transform=apply_transform)
    test_dataset = datasets.MNIST(data_dir, train=False, download=True, transform=apply_transform)

    if iid:
        user_groups = mnist_iid(train_dataset, args.num_users)
    else:
        user_groups = mnist_noniid(train_dataset, args.num_users)

    return train_dataset, test_dataset, user_groups

def average_weights(w):
    w_avg = copy.deepcopy(w[0])
    for key in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[key] += w[i][key]
        w_avg[key] = torch.div(w_avg[key], len(w))
    return w_avg

class DatasetSplit(Dataset):
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = [int(i) for i in idxs]

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return torch.tensor(image), torch.tensor(label)

class LocalUpdateFractional(object):
    def __init__(self, args, dataset, idxs, logger):
        self.args = args
        self.logger = logger
        self.trainloader, self.validloader, self.testloader = self.train_val_test(dataset, list(idxs))
        self.device = 'cuda' if args.gpu else 'cpu'
        self.criterion = nn.CrossEntropyLoss().to(self.device)

    def train_val_test(self, dataset, idxs):
        idxs_train = idxs[:int(0.8*len(idxs))]
        idxs_val = idxs[int(0.8*len(idxs)):int(0.9*len(idxs))]
        idxs_test = idxs[int(0.9*len(idxs)):]
        trainloader = DataLoader(DatasetSplit(dataset, idxs_train), batch_size=self.args.local_bs, shuffle=True)
        validloader = DataLoader(DatasetSplit(dataset, idxs_val), batch_size=int(len(idxs_val)/10), shuffle=False)
        testloader = DataLoader(DatasetSplit(dataset, idxs_test), batch_size=int(len(idxs_test)/10), shuffle=False)
        return trainloader, validloader, testloader

    def update_weights(self, model, global_round):
        model.train()
        epoch_loss = []
        t = torch.tensor(global_round, dtype=torch.float32)  # Global round as tensor
        alpha = torch.tensor(self.args.alpha, dtype=torch.float32)  # Fractional order parameter
        mu0 = torch.tensor(self.args.mu0, dtype=torch.float32)  # Initial learning rate
        delta = torch.tensor(self.args.delta, dtype=torch.float32)  # Delta parameter

        optimizer = torch.optim.SGD(model.parameters(), lr=self.args.lr, momentum=self.args.momentum, weight_decay=5e-4)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)

        for iter in range(self.args.local_ep):
            batch_loss = []
            for batch_idx, (images, labels) in enumerate(self.trainloader):
                images, labels = images.to(self.device), labels.to(self.device)
                model.zero_grad()
                log_probs = model(images)
                loss = self.criterion(log_probs, labels)
                loss.backward()
                optimizer.step()

                # Fractional Order SGD Update
                if global_round == 0:
                    w_next = {name: param - mu0 * param.grad for name, param in model.named_parameters()}
                else:
                    w_prev = {name: param.clone().detach() for name, param in model.named_parameters()}
                    w_next = {name: param - (mu0 / torch.sqrt(t + 1)) * param.grad / torch.lgamma(2 - alpha).exp() *
                                      (torch.abs(param - w_prev[name]) + delta)**(1 - alpha)
                              for name, param in model.named_parameters()}

                for name, param in model.named_parameters():
                    param.data = w_next[name]

                if self.args.verbose and (batch_idx % 10 == 0):
                    self.logger.add_scalar('loss', loss.item())
                batch_loss.append(loss.item())
            scheduler.step()
            epoch_loss.append(sum(batch_loss) / len(batch_loss))
        return model.state_dict(), sum(epoch_loss) / len(epoch_loss)

    def inference(self, model):
        model.eval()
        loss, total, correct = 0.0, 0.0, 0.0

        for batch_idx, (images, labels) in enumerate(self.testloader):
            images, labels = images.to(self.device), labels.to(self.device)
            outputs = model(images)
            batch_loss = self.criterion(outputs, labels)
            loss += batch_loss.item()

            _, pred_labels = torch.max(outputs, 1)
            pred_labels = pred_labels.view(-1)
            correct += torch.sum(torch.eq(pred_labels, labels)).item()
            total += len(labels)

        accuracy = correct / total
        return accuracy, loss

class LocalUpdateFedAvg(object):
    def __init__(self, args, dataset, idxs, logger):
        self.args = args
        self.logger = logger
        self.trainloader, self.validloader, self.testloader = self.train_val_test(dataset, list(idxs))
        self.device = 'cuda' if args.gpu else 'cpu'
        self.criterion = nn.CrossEntropyLoss().to(self.device)

    def train_val_test(self, dataset, idxs):
        idxs_train = idxs[:int(0.8*len(idxs))]
        idxs_val = idxs[int(0.8*len(idxs)):int(0.9*len(idxs))]
        idxs_test = idxs[int(0.9*len(idxs)):]
        trainloader = DataLoader(DatasetSplit(dataset, idxs_train), batch_size=self.args.local_bs, shuffle=True)
        validloader = DataLoader(DatasetSplit(dataset, idxs_val), batch_size=int(len(idxs_val)/10), shuffle=False)
        testloader = DataLoader(DatasetSplit(dataset, idxs_test), batch_size=int(len(idxs_test)/10), shuffle=False)
        return trainloader, validloader, testloader

    def update_weights(self, model, global_round):
        model.train()
        epoch_loss = []
        optimizer = torch.optim.SGD(model.parameters(), lr=self.args.lr, momentum=self.args.momentum)

        for iter in range(self.args.local_ep):
            batch_loss = []
            for batch_idx, (images, labels) in enumerate(self.trainloader):
                images, labels = images.to(self.device), labels.to(self.device)
                model.zero_grad()
                log_probs = model(images)
                loss = self.criterion(log_probs, labels)
                loss.backward()
                optimizer.step()

                batch_loss.append(loss.item())
            epoch_loss.append(sum(batch_loss) / len(batch_loss))
        return model.state_dict(), sum(epoch_loss) / len(epoch_loss)

    def inference(self, model):
        model.eval()
        loss, total, correct = 0.0, 0.0, 0.0

        for batch_idx, (images, labels) in enumerate(self.testloader):
            images, labels = images.to(self.device), labels.to(self.device)
            outputs = model(images)
            batch_loss = self.criterion(outputs, labels)
            loss += batch_loss.item()

            _, pred_labels = torch.max(outputs, 1)
            pred_labels = pred_labels.view(-1)
            correct += torch.sum(torch.eq(pred_labels, labels)).item()
            total += len(labels)

        accuracy = correct / total
        return accuracy, loss

def test_inference(args, model, test_dataset):
    model.eval()
    loss, total, correct = 0.0, 0.0, 0.0

    device = 'cuda' if args.gpu else 'cpu'
    criterion = nn.NLLLoss().to(device)
    testloader = DataLoader(test_dataset, batch_size=128, shuffle=False)

    for batch_idx, (images, labels) in enumerate(testloader):
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        batch_loss = criterion(outputs, labels)
        loss += batch_loss.item()
        _, pred_labels = torch.max(outputs, 1)
        pred_labels = pred_labels.view(-1)
        correct += torch.sum(torch.eq(pred_labels, labels)).item()
        total += len(labels)

    accuracy = correct / total
    return accuracy, loss

class MLP(nn.Module):
    def __init__(self, dim_in, dim_hidden, dim_out):
        super(MLP, self).__init__()
        self.layer_input = nn.Linear(dim_in, dim_hidden)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout()
        self.layer_hidden = nn.Linear(dim_hidden, dim_out)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = x.view(-1, x.shape[1] * x.shape[-2] * x.shape[-1])
        x = self.layer_input(x)
        x = self.dropout(x)
        x = self.relu(x)
        x = self.layer_hidden(x)
        return self.softmax(x)

class CNNMnist(nn.Module):
    def __init__(self, args):
        super(CNNMnist, self).__init__()
        self.conv1 = nn.Conv2d(args.num_channels, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, args.num_classes)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, x.shape[1] * x.shape[2] * x.shape[3])
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

class CNNFashion_Mnist(nn.Module):
    def __init__(self, args):
        super(CNNFashion_Mnist, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=5, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2))
        self.fc = nn.Linear(7 * 7 * 32, 10)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

class CNNCifar(nn.Module):
    def __init__(self, args):
        super(CNNCifar, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, args.num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)

if __name__ == '__main__':
    start_time = time.time()

    path_project = os.path.abspath('.')
    logger = SummaryWriter('./logs')

    exp_details(args)

    if args.gpu_id:
        torch.cuda.set_device(args.gpu_id)
    device = 'cuda' if args.gpu else 'cpu'

    configs = [(10, 5)]

    results_iid = {'fractional': {}, 'fedavg': {}}
    results_noniid = {'fractional': {}, 'fedavg': {}}

    for local_bs, local_ep in configs:
        args.local_bs = local_bs
        args.local_ep = local_ep

        train_dataset, test_dataset, user_groups_iid = get_dataset(args, iid=True)
        train_dataset, test_dataset, user_groups_noniid = get_dataset(args, iid=False)

        for is_iid, user_groups, results, update_class in [
            (True, user_groups_iid, results_iid, LocalUpdateFractional),
            (False, user_groups_noniid, results_noniid, LocalUpdateFractional)
        ]:
            if args.model == 'cnn':
                if args.dataset == 'mnist':
                    global_model = CNNMnist(args=args)
                elif args.dataset == 'fmnist':
                    global_model = CNNFashion_Mnist(args=args)
                elif args.dataset == 'cifar':
                    global_model = CNNCifar(args=args)
            elif args.model == 'mlp':
                img_size = train_dataset[0][0].shape
                len_in = 1
                for x in img_size:
                    len_in *= x
                global_model = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes)
            else:
                exit('Error: unrecognized model')

            global_model.to(device)
            global_model.train()

            global_weights = global_model.state_dict()

            train_loss, train_accuracy, test_accuracies = [], [], []

            for epoch in tqdm(range(args.epochs)):
                local_weights, local_losses = [], []

                global_model.train()
                m = max(int(args.frac * args.num_users), 1)
                idxs_users = np.arange(args.num_users)[:m]

                for idx in idxs_users:
                    local_model = update_class(args=args, dataset=train_dataset, idxs=user_groups[idx], logger=logger)
                    w, loss = local_model.update_weights(model=copy.deepcopy(global_model), global_round=epoch)
                    local_weights.append(copy.deepcopy(w))
                    local_losses.append(copy.deepcopy(loss))

                global_weights = average_weights(local_weights)
                global_model.load_state_dict(global_weights)

                loss_avg = sum(local_losses) / len(local_losses)
                train_loss.append(loss_avg)

                list_acc, list_loss = [], []
                global_model.eval()
                for c in range(args.num_users):
                    local_model = update_class(args=args, dataset=train_dataset, idxs=user_groups[c], logger=logger)
                    acc, loss = local_model.inference(model=global_model)
                    list_acc.append(acc)
                    list_loss.append(loss)
                train_accuracy.append(sum(list_acc) / len(list_acc))

                # Calculate and store test accuracy at each round
                test_acc, test_loss = test_inference(args, global_model, test_dataset)
                test_accuracies.append(test_acc)

            config_name = f'B={local_bs}_E={local_ep}'
            results['fractional'][config_name] = {
                'train_loss': train_loss,
                'train_accuracy': train_accuracy,
                'test_accuracies': test_accuracies
            }

        for is_iid, user_groups, results, update_class in [
            (True, user_groups_iid, results_iid, LocalUpdateFedAvg),
            (False, user_groups_noniid, results_noniid, LocalUpdateFedAvg)
        ]:
            if args.model == 'cnn':
                if args.dataset == 'mnist':
                    global_model = CNNMnist(args=args)
                elif args.dataset == 'fmnist':
                    global_model = CNNFashion_Mnist(args=args)
                elif args.dataset == 'cifar':
                    global_model = CNNCifar(args=args)
            elif args.model == 'mlp':
                img_size = train_dataset[0][0].shape
                len_in = 1
                for x in img_size:
                    len_in *= x
                global_model = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes)
            else:
                exit('Error: unrecognized model')

            global_model.to(device)
            global_model.train()

            global_weights = global_model.state_dict()

            train_loss, train_accuracy, test_accuracies = [], [], []

            for epoch in tqdm(range(args.epochs)):
                local_weights, local_losses = [], []

                global_model.train()
                m = max(int(args.frac * args.num_users), 1)
                idxs_users = np.arange(args.num_users)[:m]

                for idx in idxs_users:
                    local_model = update_class(args=args, dataset=train_dataset, idxs=user_groups[idx], logger=logger)
                    w, loss = local_model.update_weights(model=copy.deepcopy(global_model), global_round=epoch)
                    local_weights.append(copy.deepcopy(w))
                    local_losses.append(copy.deepcopy(loss))

                global_weights = average_weights(local_weights)
                global_model.load_state_dict(global_weights)

                loss_avg = sum(local_losses) / len(local_losses)
                train_loss.append(loss_avg)

                list_acc, list_loss = [], []
                global_model.eval()
                for c in range(args.num_users):
                    local_model = update_class(args=args, dataset=train_dataset, idxs=user_groups[c], logger=logger)
                    acc, loss = local_model.inference(model=global_model)
                    list_acc.append(acc)
                    list_loss.append(loss)
                train_accuracy.append(sum(list_acc) / len(list_acc))

                # Calculate and store test accuracy at each round
                test_acc, test_loss = test_inference(args, global_model, test_dataset)
                test_accuracies.append(test_acc)

            config_name = f'B={local_bs}_E={local_ep}'
            results['fedavg'][config_name] = {
                'train_loss': train_loss,
                'train_accuracy': train_accuracy,
                'test_accuracies': test_accuracies
            }

    # Convert results to DataFrames
    def convert_to_df(results, metric):
        df = pd.DataFrame()
        for method, data in results.items():
            df[method] = data[f'B={configs[0][0]}_E={configs[0][1]}'][metric]
        return df

    # Test Accuracy for IID
    df_test_acc_iid = convert_to_df(results_iid, 'test_accuracies')
    print("\nTable 1: Test Accuracy for IID setting")
    print(df_test_acc_iid)

    # Test Accuracy for Non-IID
    df_test_acc_noniid = convert_to_df(results_noniid, 'test_accuracies')
    print("\nTable 2: Test Accuracy for Non-IID setting")
    print(df_test_acc_noniid)

    # Training Loss for IID
    df_train_loss_iid = convert_to_df(results_iid, 'train_loss')
    print("\nTable 3: Training Loss for IID setting")
    print(df_train_loss_iid)

    # Training Loss for Non-IID
    df_train_loss_noniid = convert_to_df(results_noniid, 'train_loss')
    print("\nTable 4: Training Loss for Non-IID setting")
    print(df_train_loss_noniid)

    # Save these tables to CSV files
    df_test_acc_iid.to_csv('test_accuracy_iid.csv', index=False)
    df_test_acc_noniid.to_csv('test_accuracy_noniid.csv', index=False)
    df_train_loss_iid.to_csv('training_loss_iid.csv', index=False)
    df_train_loss_noniid.to_csv('training_loss_noniid.csv', index=False)

    # Define the directory where you want to save the figures
    save_dir = './results'

    # Make sure the directory exists
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    # Plotting for IID - Test Accuracy
    plt.figure()
    for method in results_iid:
        plt.plot(range(1, args.epochs + 1), results_iid[method][f'B={configs[0][0]}_E={configs[0][1]}']['test_accuracies'], label=method)
    plt.xlabel('Communication Rounds')
    plt.ylabel('Test Accuracy')
    plt.legend()
    plt.title('Test Accuracy vs. Communication Rounds (IID)')
    plt.savefig(os.path.join(save_dir, 'test_accuracy_vs_comm_rounds_iid.png'))
    plt.show()

    # Plotting for Non-IID - Test Accuracy
    plt.figure()
    for method in results_noniid:
        plt.plot(range(1, args.epochs + 1), results_noniid[method][f'B={configs[0][0]}_E={configs[0][1]}']['test_accuracies'], label=method)
    plt.xlabel('Communication Rounds')
    plt.ylabel('Test Accuracy')
    plt.legend()
    plt.title('Test Accuracy vs. Communication Rounds (Non-IID)')
    plt.savefig(os.path.join(save_dir, 'test_accuracy_vs_comm_rounds_noniid.png'))
    plt.show()

    # Plotting for IID - Training Loss
    plt.figure()
    for method in results_iid:
        plt.plot(range(1, args.epochs + 1), results_iid[method][f'B={configs[0][0]}_E={configs[0][1]}']['train_loss'], label=method)
    plt.xlabel('Communication Rounds')
    plt.ylabel('Training Loss')
    plt.legend()
    plt.title('Training Loss vs. Communication Rounds (IID)')
    plt.savefig(os.path.join(save_dir, 'train_loss_vs_comm_rounds_iid.png'))
    plt.show()

    # Plotting for Non-IID - Training Loss
    plt.figure()
    for method in results_noniid:
        plt.plot(range(1, args.epochs + 1), results_noniid[method][f'B={configs[0][0]}_E={configs[0][1]}']['train_loss'], label=method)
    plt.xlabel('Communication Rounds')
    plt.ylabel('Training Loss')
    plt.legend()
    plt.title('Training Loss vs. Communication Rounds (Non-IID)')
    plt.savefig(os.path.join(save_dir, 'train_loss_vs_comm_rounds_noniid.png'))
    plt.show()
