In [None]:
import sys
import os

from collections import Counter

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from torchvision.datasets import CIFAR10

from fedlab.utils.serialization import SerializationTool
from fedlab.utils.aggregator import Aggregators
from fedlab.utils.functional import evaluate

# 将项目根目录加入环境变量
PROJECT_DIR = os.path.dirname(os.getcwd())
sys.path.append(PROJECT_DIR)

from data.CIFAR10.partition import partition_example

## 前期准备

In [None]:
# 模型结构
class CNN_CIFAR10(nn.Module):
    """from torch tutorial
        https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
    """
    def __init__(self):
        super(CNN_CIFAR10,self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)  # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
# 常规参数
num_client = 50
num_cluster = 10
num_shards = 50
num_classes = 10
batch_size = 64
global_model_file_dir = os.path.join(PROJECT_DIR, "result", "models", "notebook", "grads_difference")
if not os.path.exists(global_model_file_dir):
    os.makedirs(global_model_file_dir)
global_model_file_name = "global_model.pth"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("training device: ", device)

# 数据加载
partitioner = partition_example()

testset = CIFAR10(root=partitioner.root, train=False, download=False, transform=partitioner.transform)
test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4)

# 训练模型和全局模型
client_models = [CNN_CIFAR10() for _ in range(num_client)]
client_optimizers = [torch.optim.SGD(client_model.parameters(), lr=0.001) for client_model in client_models]
client_criteria = [nn.CrossEntropyLoss() for _ in range(num_client)]

if os.path.exists(os.path.join(global_model_file_dir, global_model_file_name)):
    global_model = torch.load(os.path.join(global_model_file_dir, global_model_file_name))
else:
    global_model = CNN_CIFAR10()
global_criteria = nn.CrossEntropyLoss()

## 联邦模型训练

In [None]:
# 本地模型训练
def train(client_id, dataloader, grads_compute=False):
    """训练模型并返回梯度
    该方法开始前, 本地模型(stored in client_models)已经被全局模型更新过一次.
    因此, 直接使用本地模型进行训练即可.
    
    """
    
    # 准备当前训练的模型, 优化器和损失函数
    current_client_model = client_models[client_id]
    current_client_optimizer = client_optimizers[client_id]
    current_client_criterion = client_criteria[client_id]
    
    # train local model
    current_client_model.to(device)
    for _ in range(5):
        current_client_model.train()
        for i, data in enumerate(dataloader, 0):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            current_client_optimizer.zero_grad()
            outputs = current_client_model(inputs)
            loss = current_client_criterion(outputs, labels)
            loss.backward()
            current_client_optimizer.step()
    
    # compute grads
    if grads_compute:
        # 使用 from fedlab.utils.serialization import SerializationTool 计算梯度
        param_before_tensor = SerializationTool.serialize_model(global_model)
        param_after_tensor = SerializationTool.serialize_model(current_client_model)
        grads_tensor = param_after_tensor - param_before_tensor
        
        return grads_tensor
    else:
        return None

In [None]:
# 全局模型训练
def global_train(num_rounds, freq_eval=20, freq_grads=50):
    """全局训练

    Args:
        num_rounds (int): 全局训练轮数
        freq_eval (int, optional): 评估全局模型的频率. Defaults to 20.
        freq_grads (int, optional): 统计client梯度的频率. Defaults to 50.
    
    Returns:
        client_grads (dict): client梯度的字典
    """
    # 准备存储client梯度的字典
    client_grads = {i: [] for i in range(num_client)}
    
    # 全局训练
    for round in range(num_rounds):
        # 每轮开始前, 将本地模型的参数更新为全局模型的参数
        for client_model in client_models:
            client_model.load_state_dict(global_model.state_dict())
        
        # train local model and get grads
        for client_id in range(num_client):
            train_loader = partitioner.get_dataloader(client_id, batch_size)
            
            if round % freq_grads == 0:
                grads = train(client_id, train_loader, grads_compute=True)
                client_grads[client_id].append(grads)
            else:
                grads = train(client_id, train_loader, grads_compute=False)
            

        # update global model
        client_parameters = [
            SerializationTool.serialize_model(client_model) for client_model in client_models
        ]
        SerializationTool.deserialize_model(
            global_model, Aggregators.fedavg_aggregate(client_parameters)
        )
        
        if round % freq_eval == 0:
            loss, acc = evaluate(global_model, global_criteria, test_loader)
            print(f"Round {round} finished, loss: {loss}, acc: {acc}")
    
    return client_grads

In [None]:
# 启动全局训练
global_model.to(device)
client_grads = global_train(400)

In [None]:
# 保存全局模型
torch.save(global_model, os.path.join(global_model_file_dir, global_model_file_name))

## 计算梯度方向相似性

In [None]:
# 计算余弦相似性
def cosine_matrix(vectors):
    
    grads_np = [np.array(grad) for grad in vectors]
    cosine_similarity = np.zeros((len(vectors), len(vectors)))
    for i in range(len(vectors)):
        for j in range(len(vectors)):
            cosine_similarity[i][j] = np.dot(grads_np[i].flatten(), grads_np[j].flatten()) / (np.linalg.norm(grads_np[i]) * np.linalg.norm(grads_np[j]))
            cosine_similarity[j][i] = cosine_similarity[i][j]

    return cosine_similarity

# 绘制余弦相似性矩阵的热点图
def plot_similarity_matrix(similarity_matrix):
    plt.imshow(similarity_matrix, cmap='hot', interpolation='nearest')
    plt.colorbar()
    plt.show()

In [None]:
# 计算 client_grads 的余弦相似性矩阵
cosine_similarity_matrixs = []
for i in range(len(client_grads[0])):
    grads_vectors = [client_grads[j][i] for j in range(num_client)] # 统计相同轮次的梯度
    cosine_similarity = cosine_matrix(grads_vectors)
    cosine_similarity_matrixs.append(cosine_similarity)

# 随机选择一个轮次, 绘制其余弦相似性矩阵
random_round = np.random.randint(0, len(cosine_similarity_matrixs))
print(f"random round: {random_round}")
plot_similarity_matrix(cosine_similarity_matrixs[random_round])

In [None]:
# 绘制所有轮次的余弦相似性矩阵
for i in range(len(cosine_similarity_matrixs)):
    plot_similarity_matrix(cosine_similarity_matrixs[i])

## 比较数据标签分布的相似性

In [None]:
# 统计每个client的标签分布, 即计算每个client中每个标签的比例
client_labels_matrix = [[] for _ in range(num_client)]
for client_id in range(num_client):
    client_dataset = partitioner.get_dataset(client_id)
    client_labels = client_dataset.y
    labels_total = len(client_labels)
    labels_counter = Counter(client_labels)
    for i in range(10):
        if i not in labels_counter:
            labels_counter[i] = 0
    for _, v in sorted(labels_counter.items()):
        client_labels_matrix[client_id].append(v/labels_total)

    print(f"client {client_id} labels distribution: {client_labels_matrix[client_id]}")

In [None]:
# 计算 client_labels_matrix 的余弦相似性矩阵
client_labels_similarity_matrix = cosine_matrix(client_labels_matrix)
print(client_labels_similarity_matrix)

# 绘制 client_labels_similarity_matrix 的热点图
plot_similarity_matrix(client_labels_similarity_matrix)