In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
from torch import nn
from torch.utils.data import DataLoader, Dataset, TensorDataset
from tqdm import tqdm
import copy
import matplotlib.pyplot as plt
from collections import OrderedDict
import random
import torch.nn.functional as F
import torch.nn.functional as func
import collections
from sklearn.model_selection import train_test_split

In [2]:
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'

In [3]:
device = torch.device('cuda' if torch.cuda.is_available else 'cpu')
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Tue Dec 17 10:06:19 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.183.01             Driver Version: 535.183.01   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 4090        Off | 00000000:21:00.0 Off |                  Off |
| 30%   24C    P8              12W / 350W |     11MiB / 24564MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce RTX 4090        Off | 00000000:E1:00.0 Off |  

In [4]:
class LinearBottleNeck(nn.Module):

    def __init__(self, in_channels, out_channels, stride, t=6, class_num=100):
        super().__init__()

        self.residual = nn.Sequential(
            nn.Conv2d(in_channels, in_channels * t, 1),
            nn.BatchNorm2d(in_channels * t),
            nn.ReLU6(inplace=True),

            nn.Conv2d(in_channels * t, in_channels * t, 3, stride=stride, padding=1, groups=in_channels * t),
            nn.BatchNorm2d(in_channels * t),
            nn.ReLU6(inplace=True),

            nn.Conv2d(in_channels * t, out_channels, 1),
            nn.BatchNorm2d(out_channels)
        )

        self.stride = stride
        self.in_channels = in_channels
        self.out_channels = out_channels

    def forward(self, x):

        residual = self.residual(x)

        if self.stride == 1 and self.in_channels == self.out_channels:
            residual += x

        return residual

class MobileNetV2(nn.Module):

    def __init__(self, class_num=20):
        super().__init__()

        self.pre = nn.Sequential(
            nn.Conv2d(3, 32, 1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU6(inplace=True)
        )

        self.stage1 = LinearBottleNeck(32, 16, 1, 1)
        self.stage2 = self._make_stage(2, 16, 24, 2, 6)
        self.stage3 = self._make_stage(3, 24, 32, 2, 6)
        self.stage4 = self._make_stage(4, 32, 64, 2, 6)
        self.stage5 = self._make_stage(3, 64, 96, 1, 6)
        self.stage6 = self._make_stage(3, 96, 160, 1, 6)
        self.stage7 = LinearBottleNeck(160, 320, 1, 6)

        self.conv1 = nn.Sequential(
            nn.Conv2d(320, 1280, 1),
            nn.BatchNorm2d(1280),
            nn.ReLU6(inplace=True)
        )

        self.conv2 = nn.Conv2d(1280, class_num, 1)

    def forward(self, x):
        x = self.pre(x)
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = self.stage5(x)
        x = self.stage6(x)
        x = self.stage7(x)
        x = self.conv1(x)
        x = F.adaptive_avg_pool2d(x, 1)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)

        return x

    def _make_stage(self, repeat, in_channels, out_channels, stride, t):

        layers = []
        layers.append(LinearBottleNeck(in_channels, out_channels, stride, t))

        while repeat - 1:
            layers.append(LinearBottleNeck(out_channels, out_channels, 1, t))
            repeat -= 1

        return nn.Sequential(*layers)

def mobilenetv2():
    return MobileNetV2()

In [5]:
def test_inference(model, test):
    """ Returns the test accuracy and loss.
    """
    tensor_x = torch.Tensor(test[0]).to(device)
    tensor_y = torch.Tensor(test[1]).to(device)
    test_dataset = TensorDataset(tensor_x, tensor_y)

    model.eval()
    loss, total, correct = 0.0, 0.0, 0.0

    criterion = nn.CrossEntropyLoss()
    testloader = DataLoader(test_dataset, batch_size=128,
                            shuffle=True)

    for batch_idx, (images, labels) in enumerate(testloader):
        with torch.no_grad():
        # Inference
            outputs = model(images)
            batch_loss = criterion(outputs, labels.long())
            loss += batch_loss.item() * labels.size(0)

        # Prediction
            _, pred_labels = torch.max(outputs, 1)
            pred_labels = pred_labels.view(-1)
            correct += torch.sum(torch.eq(pred_labels, labels)).item()
            total += len(labels)
    #print(correct,"/",total)
    accuracy = correct/total
    return accuracy, loss

In [6]:
def sparse2coarse(targets):
    """Convert Pytorch CIFAR100 sparse targets to coarse targets.

    Usage:
        trainset = torchvision.datasets.CIFAR100(path)
        trainset.targets = sparse2coarse(trainset.targets)
    """
    coarse_labels = np.array([ 4,  1, 14,  8,  0,  6,  7,  7, 18,  3,
                               3, 14,  9, 18,  7, 11,  3,  9,  7, 11,
                               6, 11,  5, 10,  7,  6, 13, 15,  3, 15,
                               0, 11,  1, 10, 12, 14, 16,  9, 11,  5,
                               5, 19,  8,  8, 15, 13, 14, 17, 18, 10,
                               16, 4, 17,  4,  2,  0, 17,  4, 18, 17,
                               10, 3,  2, 12, 12, 16, 12,  1,  9, 19,
                               2, 10,  0,  1, 16, 12,  9, 13, 15, 13,
                              16, 19,  2,  4,  6, 19,  5,  5,  8, 19,
                              18,  1,  2, 15,  6,  0, 17,  8, 14, 13])
    return coarse_labels[targets]

In [7]:
def CIFAR100():
    '''Return Cifar100
    '''
    train_dataset = torchvision.datasets.CIFAR100(root='data-CIFAR',
                                            train=True,
                                            transform=transforms.ToTensor(),
                                            download=True)
    test_dataset = torchvision.datasets.CIFAR100(root='data-CIFAR',
                                            train=False,
                                            transform=transforms.ToTensor(),
                                            download=True)
    total_img,total_label = [],[]
    for imgs,labels in train_dataset:
        total_img.append(imgs.numpy())
        total_label.append(labels)
    for imgs,labels in test_dataset:
        total_img.append(imgs.numpy())
        total_label.append(labels)
    total_img = np.array(total_img)
    total_label = np.array(sparse2coarse(total_label))

    cifar = [total_img, total_label]
    return cifar


In [8]:
def get_prob(non_iid, client_num, class_num = 20):
    return np.random.dirichlet(np.repeat(non_iid, class_num), client_num)

In [9]:
def create_data(prob, size_per_client, dataset, N=20):
    total_each_class = size_per_client * np.sum(prob, 0)
    data, label = dataset

    all_class_set = []
    for i in range(N):
        size = total_each_class[i]
        sub_data = data[label == i]
        sub_label = label[label == i]

        rand_indx = np.random.choice(len(sub_data), size=int(size), replace=False).astype(int)
        sub2_data, sub2_label = sub_data[rand_indx], sub_label[rand_indx]
        all_class_set.append((sub2_data, sub2_label))

    index = [0] * N
    clients, test = [], []

    for m in range(prob.shape[0]):
        labels, images = [], []
        tlabels, timages = [], []

        for n in range(N):
            start, end = index[n], index[n] + int(prob[m][n] * size_per_client * 0.8)
            test_start, test_end = end, index[n] + int(prob[m][n] * size_per_client)

            image, label = all_class_set[n][0][start:end], all_class_set[n][1][start:end]
            test_image, test_label = all_class_set[n][0][test_start:test_end], all_class_set[n][1][test_start:test_end]

            index[n] += int(prob[m][n] * size_per_client)

            labels.extend(label)
            images.extend(image)

            tlabels.extend(test_label)
            timages.extend(test_image)

        clients.append((np.array(images), np.array(labels)))
        test.append((np.array(timages), np.array(tlabels)))

    return clients, test

In [10]:
def comb_client_test_func(client_test_data):
    comb_client_test_image = []
    comb_client_test_label = []
    for i in range(client_num):
        comb_client_test_image.extend(list(client_test_data[i][0]))
        comb_client_test_label.extend(list(client_test_data[i][1]))
    return [np.array(comb_client_test_image), np.array(comb_client_test_label)]

In [11]:
def select_subset(whole_set, percentage):
    a = whole_set[0]
    b = whole_set[1]
    if len(a) != len(b):
        raise ValueError("Both arrays should have the same length.")

    if not 0 <= percentage <= 1:
        raise ValueError("Percentage must be between 0 and 1.")

    unique_classes = np.unique(b)

    a_prime = []
    b_prime = []

    for cls in unique_classes:
        indices = np.where(b == cls)[0]
        subset_size = int(len(indices) * percentage)

        selected_indices = np.random.choice(indices, subset_size, replace=False)

        a_prime.extend(a[selected_indices])
        b_prime.extend(b[selected_indices])

    a_prime, b_prime = np.array(a_prime), np.array(b_prime)

    # Shuffle arrays to randomize the order of elements
    shuffle_indices = np.random.permutation(len(a_prime))
    a_prime, b_prime = a_prime[shuffle_indices], b_prime[shuffle_indices]

    return [a_prime, b_prime]

In [12]:
def update_weights(model_weight, dataset, learning_rate, local_epoch):
    model = mobilenetv2().to(device)
    model.load_state_dict(model_weight)

    model.train()
    epoch_loss = []
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=0.0001)
    criterion = nn.CrossEntropyLoss()

    Tensor_set = TensorDataset(torch.Tensor(dataset[0]).to(device), torch.Tensor(dataset[1]).to(device))
    data_loader = DataLoader(Tensor_set, batch_size=128, shuffle=True)

    first_iter_gradient = None  # 初始化变量来保存第一个iter的梯度

    for iter in range(local_epoch):
        batch_loss = []
        for batch_idx, (images, labels) in enumerate(data_loader):
            model.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels.long())
            loss.backward()
            optimizer.step()
            batch_loss.append(loss.item()/images.shape[0])

            # 保存第一个iter的梯度
            if iter == 0 and batch_idx == 0:
                first_iter_gradient = {}
                for name, param in model.named_parameters():
                    first_iter_gradient[name] = param.grad.clone()
                # 保存 BatchNorm 层的 running mean 和 running variance
                for name, module in model.named_modules():
                    if isinstance(module, nn.BatchNorm2d):
                        first_iter_gradient[name + '.running_mean'] = module.running_mean.clone()
                        first_iter_gradient[name + '.running_var'] = module.running_var.clone()

        epoch_loss.append(sum(batch_loss)/len(batch_loss))

    return model.state_dict(), sum(epoch_loss) / len(epoch_loss), first_iter_gradient

In [13]:
def weight_differences(n_w, p_w, lr):
    w_diff = copy.deepcopy(n_w)
    for key in w_diff.keys():
        if 'num_batches_tracked' in key:
            continue
        w_diff[key] = (p_w[key] - n_w[key]) * lr
    return w_diff

In [14]:
def update_weights_correction(model_weight, dataset, learning_rate, local_epoch, c_i, c_s):
    model = mobilenetv2().to(device)
    model.load_state_dict(model_weight)

    model.train()
    epoch_loss = []
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=0.0001)
    criterion = nn.CrossEntropyLoss()

    Tensor_set = TensorDataset(torch.Tensor(dataset[0]).to(device), torch.Tensor(dataset[1]).to(device))
    data_loader = DataLoader(Tensor_set, batch_size=200, shuffle=True)

    for iter in range(local_epoch):
        batch_loss = []
        for batch_idx, (images, labels) in enumerate(data_loader):
            model.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels.long())
            loss.backward()
            optimizer.step()
            batch_loss.append(loss.sum().item()/images.shape[0])
        epoch_loss.append(sum(batch_loss)/len(batch_loss))
        corrected_graident = weight_differences(c_i, c_s, learning_rate)
        orginal_model_weight = model.state_dict()
        corrected_model_weight = weight_differences(corrected_graident, orginal_model_weight, 1)
        model.load_state_dict(corrected_model_weight)

    return model.state_dict(),  sum(epoch_loss) / len(epoch_loss)

In [15]:
def average_weights(w):
    """
    Returns the average of the weights.
    """
    w_avg = copy.deepcopy(w[0])
    for key in w_avg.keys():
        if 'num_batches_tracked' in key:
            continue
        for i in range(1, len(w)):
            w_avg[key] += w[i][key]
        w_avg[key] = torch.div(w_avg[key], len(w))
    return w_avg

In [16]:
def server_only(initial_w, global_round, gamma, E):
    test_model = mobilenetv2().to(device)
    train_w = copy.deepcopy(initial_w)
    test_acc = []
    train_loss = []
    for round in tqdm(range(global_round)):
        # if gamma > 0.001:
        #     gamma = gamma * 0.99
        # Server side local training
        server_data = select_subset(comb_client_data, server_percentage)
        update_server_w, round_loss, _ = update_weights(train_w, server_data, gamma, E)
        train_w = update_server_w
        test_model.load_state_dict(train_w)
        train_loss.append(round_loss)
        # Test Accuracy
        test_a = 0
        for i in client_test_data:
            ac = test_inference(test_model,i)[0]
            test_a = test_a + ac
        test_a = test_a/len(client_test_data)
        test_acc.append(test_a)
        # print(test_a)
    return test_acc, train_loss

In [17]:
def fedavg(initial_w, global_round, eta, K, M):
    test_model = mobilenetv2().to(device)
    train_w = copy.deepcopy(initial_w)
    test_acc = []
    train_loss = []
    for round in tqdm(range(global_round)):
        local_weights, local_loss = [], []
        # Client side local training
        # if eta > 0.001:
        #     eta = eta * 0.99
        sampled_client = random.sample(range(client_num), M)
        for i in sampled_client:
            update_client_w, client_round_loss, _ = update_weights(train_w, client_data[i], eta, K)
            local_weights.append(update_client_w)
            local_loss.append(client_round_loss)

        train_w = average_weights(local_weights)

        # Test Accuracy
        test_model.load_state_dict(train_w)
        loss_avg = sum(local_loss)/ len(local_loss)
        train_loss.append(loss_avg)
        test_a = 0
        for i in client_test_data:
            ac = test_inference(test_model,i)[0]
            test_a = test_a + ac
        test_a = test_a/len(client_test_data)
        test_acc.append(test_a)
#         print(test_a)
    return test_acc, train_loss


In [18]:
def CLG_SGD(initial_w, global_round, eta, gamma, K, E, M):
    test_model = mobilenetv2().to(device)
    train_w = copy.deepcopy(initial_w)
    test_acc = []
    train_loss = []
    for round in tqdm(range(global_round)):
        # if eta > 0.001:
        #     eta = eta * 0.99
        # if gamma > 0.001:
        #     gamma = gamma * 0.99
        local_weights, local_loss = [], []
        # Client side local training
        sampled_client = random.sample(range(client_num), M)
        for i in sampled_client:
            update_client_w, client_round_loss, _ = update_weights(train_w, client_data[i], eta, K)
            local_weights.append(update_client_w)
            local_loss.append(client_round_loss)
        train_w = average_weights(local_weights)
        # Server side local training
        server_data = select_subset(comb_client_data, server_percentage)
        update_server_w, round_loss, _ = update_weights(train_w, server_data, gamma, E)
        train_w = update_server_w
        local_loss.append(round_loss)

        # Test Accuracy
        test_model.load_state_dict(train_w)
        loss_avg = sum(local_loss)/ len(local_loss)
        train_loss.append(loss_avg)

        test_a = 0
        for i in client_test_data:
            ac = test_inference(test_model,i)[0]
            test_a = test_a + ac
        test_a = test_a/len(client_test_data)
        test_acc.append(test_a)
#         print(test_a)
    return test_acc, train_loss

In [19]:
def Fed_C(initial_w, global_round, eta, gamma, K, E, M):
    test_model = mobilenetv2().to(device)
    train_w = copy.deepcopy(initial_w)
    test_acc = []
    train_loss = []
    for round in tqdm(range(global_round)):
        # if eta > 0.001:
        #     eta = eta * 0.99
        # if gamma > 0.001:
        #     gamma = gamma * 0.99
        local_weights, local_loss = [], []
        g_i_list = []
        server_data = select_subset(comb_client_data, server_percentage)
        # Server gradient
        _, _, g_s = update_weights(train_w, server_data, gamma, 1)

        # Client gradient
        sampled_client = random.sample(range(client_num), M)
        for i in sampled_client:
            _, _, g_i = update_weights(train_w, client_data[i], eta, 1)
            g_i_list.append(g_i)


        # Client side local training
        for i in range(len(sampled_client)):
            update_client_w, client_round_loss = update_weights_correction(train_w, client_data[sampled_client[i]], eta, K, g_i_list[i], g_s)
            local_weights.append(update_client_w)
            local_loss.append(client_round_loss)
        train_w = average_weights(local_weights)
        # Server side local training
        update_server_w, round_loss, _ = update_weights(train_w, server_data, gamma, E)
        train_w = update_server_w
        local_loss.append(round_loss)

        # Test Accuracy
        test_model.load_state_dict(train_w)
        loss_avg = sum(local_loss)/ len(local_loss)
        train_loss.append(loss_avg)

        test_a = 0
        for i in client_test_data:
            ac = test_inference(test_model,i)[0]
            test_a = test_a + ac
        test_a = test_a/len(client_test_data)
        test_acc.append(test_a)
    return test_acc, train_loss


In [20]:
def Fed_S(initial_w, global_round, eta, gamma, K, E, M):
    test_model = mobilenetv2().to(device)
    train_w = copy.deepcopy(initial_w)
    test_acc = []
    train_loss = []
    for round in tqdm(range(global_round)):
        # if eta > 0.001:
        #     eta = eta * 0.99
        # if gamma > 0.001:
        #     gamma = gamma * 0.99
        local_weights, local_loss = [], []
        g_i_list = []
        # Server gradient
        server_data = select_subset(comb_client_data, server_percentage)
        _, _, g_s = update_weights(train_w, server_data, gamma, 1)

        # Client gradient
        sampled_client = random.sample(range(client_num), M)
        for i in sampled_client:
            _, _, g_i = update_weights(train_w, client_data[i], eta, 1)
            g_i_list.append(g_i)


        # Client side local training
        for i in range(len(sampled_client)):
            update_client_w, client_round_loss, _ = update_weights(train_w, client_data[sampled_client[i]], eta, K)
            local_weights.append(update_client_w)
            local_loss.append(client_round_loss)
        train_w = average_weights(local_weights)

        # Server aggregation correction
        g_i_average = average_weights(g_i_list)
        correction_g = weight_differences(g_i_average, g_s, K*eta)
        train_w = weight_differences(correction_g, copy.deepcopy(train_w), 1)


        # Server side local training
        update_server_w, round_loss, _ = update_weights(train_w, server_data, gamma, E)
        train_w = update_server_w
        local_loss.append(round_loss)

        # Test Accuracy
        test_model.load_state_dict(train_w)
        loss_avg = sum(local_loss)/ len(local_loss)
        train_loss.append(loss_avg)

        test_a = 0
        for i in client_test_data:
            ac = test_inference(test_model,i)[0]
            test_a = test_a + ac
        test_a = test_a/len(client_test_data)
        test_acc.append(test_a)
    return test_acc, train_loss