In [1]:
import os
import random
from tqdm import tqdm
import numpy as np
import time
import logging
import sys
import copy
import gc
# draw metrics
from sklearn.metrics import roc_curve, auc, roc_auc_score
import matplotlib.pyplot as plt

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.datasets as datasets
from torch.utils.data.dataset import Dataset
import torchvision.models as models
from torchvision.models import AlexNet_Weights
import torchmetrics

import torchvision.transforms as transforms
from torchvision.transforms import Compose 
from torch.utils.tensorboard import SummaryWriter
torch.backends.cudnn.benchmark=True
print(sys.executable)
print(sys.version)
print(torch.cuda.is_available())
print(f'Torchvision version: {torchvision.__version__}')

c:\Users\jizhi\.conda\envs\flv1\python.exe
3.8.18 (default, Sep 11 2023, 13:39:12) [MSC v.1916 64 bit (AMD64)]
True
Torchvision version: 0.16.0


In [2]:
def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

seed_value = 42
set_seed(seed_value)

In [3]:
# set logger
# create logger
logger = logging.getLogger(__name__)
# set log level for all handlers to debug
logger.setLevel(logging.DEBUG)

# create console handler and set level to debug
# best for development or debugging
consoleHandler = logging.StreamHandler()
consoleHandler.setLevel(logging.DEBUG)

# create formatter
# formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
formatter = logging.Formatter('%(asctime)s-%(levelname)s: %(message)s')

# add formatter to ch
consoleHandler.setFormatter(formatter)

# add ch to logger
logger.addHandler(consoleHandler)

In [4]:
##### Hyperparameters for federated learning #########
epochs_ = 50
batch_size_ = 16

writer_ = SummaryWriter('runs\\sim-fusion-alexnet\\')

# CUDA_LAUNCH_BLOCKING=1

In [5]:
# prepare image
train_names_list = ['hospital','sim-room','warehouse']

val_names_list = ['hospital','sim-room','warehouse','overall']

train_data_folders_list = [ 'C:\\Users\\jizhi\\Desktop\\study\\FL\\federated_learning\\fl_data\\mine\\sim_train\\hospital\\train',
                            'C:\\Users\\jizhi\\Desktop\\study\\FL\\federated_learning\\fl_data\\mine\\sim_train\\sim-room\\train',
                            'C:\\Users\\jizhi\\Desktop\\study\\FL\\federated_learning\\fl_data\\mine\\sim_train\\warehouse\\train',
                    ]

val_data_folders_list = [ 'C:\\Users\\jizhi\\Desktop\\study\\FL\\federated_learning\\fl_data\\mine\\sim_train\\hospital\\test',
                            'C:\\Users\\jizhi\\Desktop\\study\\FL\\federated_learning\\fl_data\\mine\\sim_train\\sim-room\\test',
                            'C:\\Users\\jizhi\\Desktop\\study\\FL\\federated_learning\\fl_data\\mine\\sim_train\\warehouse\\test',
                            'C:\\Users\\jizhi\\Desktop\\study\\FL\\federated_learning\\fl_data\\mine\\sim_train\\overall\\test',
                    ]

transformed_train_datasets, transformed_val_datasets = [], []


################################################  Data Loaders  #####################################
# 在这个部分，为训练和验证的图像数据定义了一些预处理步骤，包括颜色抖动、重新调整大小、转换为张量和归一化。
# 这些预处理步骤通过 PyTorch 的 transforms 模块实现，并且它们将在训练过程中应用到每个图像上。
for train_dataset in train_data_folders_list:
    transformed_train_datasets.append(
        datasets.ImageFolder(
            train_dataset,
            transforms.Compose([
                transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
        )
    ) 

for val_dataset in val_data_folders_list:
    transformed_val_datasets.append(
        datasets.ImageFolder(
            val_dataset,
            transforms.Compose([
                transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
        )
    ) 

#########  Data Loaders  ##########

train_loader_list, val_loader_list = [], []

for train_d in transformed_train_datasets:
    train_loader_list.append(
        torch.utils.data.DataLoader(
                train_d,
                batch_size=batch_size_,
                shuffle=True,
                num_workers=4,
                drop_last = False,
            )
    )

for val_d in transformed_val_datasets:
    val_loader_list.append(
        torch.utils.data.DataLoader(
                val_d,
                batch_size=1,
                shuffle=True,
                num_workers=4,
                drop_last = False,
            )
    )

#########   Model Fusion Setup  #########

fusion_model_list = [(0,1), (0,2),(1,2),(0,1,2)]

fusion_model_pathes =[]
fusion_names_ = []
logger.info("The number of different fusion models: {}".format(len(fusion_model_list)))
for data in fusion_model_list:
    fusion_model_path = "C:\\Users\\jizhi\\Desktop\\study\\FL\\federated_learning\\fl_results\\saved_models\\sim\\fusion_models\\my_fusion_model_on"
    fusion_name = ''
    for d in data:
        fusion_model_path = fusion_model_path  + "_" + train_names_list[d]
        fusion_name = fusion_name + "_" + train_names_list[d]
    #fusion_model_path =  fusion_model_path + ".pth"

    logger.info("Fusion model names: {}".format(fusion_name))
    logger.info("Fusion model save path: {}".format(fusion_model_path ))

    fusion_model_pathes.append(fusion_model_path)
    fusion_names_.append(fusion_name)

2023-10-29 01:37:57,189-INFO: The number of different fusion models: 4
2023-10-29 01:37:57,190-INFO: Fusion model names: _hospital_sim-room
2023-10-29 01:37:57,191-INFO: Fusion model save path: C:\Users\jizhi\Desktop\study\FL\federated_learning\fl_results\saved_models\sim\fusion_models\my_fusion_model_on_hospital_sim-room
2023-10-29 01:37:57,191-INFO: Fusion model names: _hospital_warehouse
2023-10-29 01:37:57,192-INFO: Fusion model save path: C:\Users\jizhi\Desktop\study\FL\federated_learning\fl_results\saved_models\sim\fusion_models\my_fusion_model_on_hospital_warehouse
2023-10-29 01:37:57,192-INFO: Fusion model names: _sim-room_warehouse
2023-10-29 01:37:57,193-INFO: Fusion model save path: C:\Users\jizhi\Desktop\study\FL\federated_learning\fl_results\saved_models\sim\fusion_models\my_fusion_model_on_sim-room_warehouse
2023-10-29 01:37:57,194-INFO: Fusion model names: _hospital_sim-room_warehouse
2023-10-29 01:37:57,195-INFO: Fusion model save path: C:\Users\jizhi\Desktop\study\FL\f

In [6]:
def server_aggregate_weighted(global_model, client_models, client_lens):
    """
    This function has aggregation method 'wmean'
    wmean takes the weighted mean of the weights of models
    """
    total = sum(client_lens)
    n = len(client_models)
    global_dict = global_model.state_dict()

    #momentum = {kk: (print(f"Key: {kk}, Value: {vv}"), torch.zeros_like(vv))[1] for kk, vv in global_dict.items()}
    for k in global_dict.keys():
        #global_dict[k] = torch.stack([client_models[i].state_dict()[k].float()*(n*client_lens[i]/total) for i in range(n)], 0).mean(0)
        global_dict[k] = torch.sum(torch.stack([client_models[i].state_dict()[k].float() * (client_lens[i]/total) for i in range(n)], 0), dim=0)
    global_model.load_state_dict(global_dict)

def server_aggregate_simple(global_model, client_models):
    global_dict = global_model.state_dict()
    n = len(client_models)
    for k in global_dict.keys():
        global_dict[k] = torch.stack([client_models[i].state_dict()[k].float() for i in range(n)], 0).mean(0)
    global_model.load_state_dict(global_dict)


def server_aggregate_median(global_model, client_models):
    global_dict = global_model.state_dict()
    n = len(client_models)
    for k in global_dict.keys():
        global_dict[k] = torch.median(torch.stack([client_models[i].state_dict()[k].float() for i in range(n)], 0), 0).values
    global_model.load_state_dict(global_dict)


def server_aggregate_momentum(global_model, client_models, alpha=0.9):
    global_dict = global_model.state_dict()
    momentum = {k: torch.zeros_like(v) for k, v in global_dict.items()}
    for k in global_dict.keys():
        new_value = torch.stack([client_models[i].state_dict()[k].float() for i in range(len(client_models))], 0).mean(0)
        momentum[k] = alpha * momentum[k] + (1 - alpha) * new_value
        global_dict[k] += momentum[k]
    global_model.load_state_dict(global_dict)


def server_aggregate_personalized(global_model, client_models, alpha=0.1):
    global_dict = global_model.state_dict()
    for k, v in global_dict.items():
        client_updates = torch.stack([m.state_dict()[k] for m in client_models])
        global_dict[k] = alpha * v + (1 - alpha) * torch.mean(client_updates, dim=0)
    global_model.load_state_dict(global_dict)


def server_aggregate_dp(global_model, client_models, client_lens, epsilon=1.0):
    """
    This function has aggregation method 'Differential Privacy average'
    It adds Laplace noise for differential privacy.
    """
    total = sum(client_lens)
    n = len(client_models)
    global_dict = global_model.state_dict()
    
    for k in global_dict.keys():
        # Calculate the weighted mean as before
        aggregated_weight = torch.stack(
            [client_models[i].state_dict()[k].float() * (client_lens[i]/total) for i in range(len(client_models))],
            0
        ).sum(0)
        
        # Add Laplace noise for differential privacy
        laplace_noise = torch.from_numpy(
            np.random.laplace(0, 1.0 / epsilon, aggregated_weight.shape)
        ).float()
        
        # Combine true answer and noise
        global_dict[k] = aggregated_weight + laplace_noise.to(aggregated_weight.device)

    global_model.load_state_dict(global_dict)


def server_aggregate_quantization(global_model, client_models, client_lens, quantization_level=8):
    """
    This function has aggregation method 'q-FedAvg' which quantizes the updates before averaging
    """
    total = sum(client_lens)
    global_dict = global_model.state_dict()
    n = len(client_models)

    for k in global_dict.keys():
        # Quantize each client's updates
        quantized_client_weights = [
            torch.round(client_models[i].state_dict()[k].float() * (2 ** quantization_level - 1)) / (2 ** quantization_level - 1)
            for i in range(len(client_models))
        ]

        # Perform weighted averaging on quantized updates
        global_dict[k] = torch.stack([quantized_client_weights[i] * (n*client_lens[i] / total) for i in range(len(client_models))], 0).mean(0)


    global_model.load_state_dict(global_dict)

In [7]:
def test(model_name, writer, current_epoch, model, test_loader):
    """
    This function test the global model on test 
    data and returns test loss and test accuracy 
    """
    model.eval()

    #test_loss = 0
    #correct = 0
    test_error_count1 = 0
    with torch.no_grad():
        for image, label in test_loader:#以batch为单位取值
            image, label = image.cuda(), label.cuda()
            outputs = model(image)

            test_error_count1 += float(torch.sum(torch.abs(label - outputs.argmax(1))))
        # 用于计算测试阶段的平均损失
        #test_loss /= len(test_loader.dataset)


        acc = 1.0 - float(test_error_count1) / float(len(test_loader.dataset))
        writer.add_scalar("Val-Acc/"+model_name, acc, current_epoch)

    return acc

def client_syn(client_model, global_model):
  '''
  This function synchronizes the client model with global model
  '''
  client_model.load_state_dict(global_model.state_dict())

In [8]:
aggregation_methods = [server_aggregate_weighted, server_aggregate_simple, server_aggregate_median, server_aggregate_momentum, server_aggregate_personalized, server_aggregate_dp, server_aggregate_quantization]

model = models.alexnet(weights=AlexNet_Weights.IMAGENET1K_V1)

for param in model.parameters():
    param.requires_grad = False

# 这行代码修改了AlexNet模型分类器的最后一层。它将最后一层的输出特征数从原始的1000（AlexNet在ImageNet上的分类数）改为了2。
model.classifier[6] = torch.nn.Linear(model.classifier[6].in_features, 2)

device = torch.device('cuda')

############################## client models ##############################


for agg_method in aggregation_methods:  # 添加这个循环,尝试不同的聚合方法
    global_models = [copy.deepcopy(model).to(device) for _ in range(len(fusion_model_list))]######复制了4个model
    for idx, fusion_pair in enumerate(fusion_model_list):

        client_models = [copy.deepcopy(model).to(device) for _ in range(len(fusion_pair))]

        optims_list = [optim.SGD(client_model.parameters(), lr=0.001, momentum=0.9) for client_model in client_models]

        client_lens = [len(train_loader_list[idx]) for idx in fusion_pair]
        '''
        It gives the number of batches in the training dataset for the corresponding client.
        In Federated Learning, the size of each client's dataset is often used as a weight when aggregating the models from the clients, 
        since clients with larger datasets generally contribute more to the learning process. 
        The client_lens list could be used for this purpose in the server-side model aggregation function.
        '''

        logger.info("-----------------------Fusion model names:{}--------------------------".format(fusion_names_[idx]))

        #这行代码创建了一个梯度缩放器，它通常用于半精度（float16）训练，这可以加速训练并减少GPU内存使用。如果你不打算使用半精度训练，这个部分可以忽略。
        scalers_list = [torch.cuda.amp.GradScaler() for _ in range(len(fusion_pair))]

        best_acc = 0.0
        for i in range(epochs_):
            
            for it in range(len(fusion_pair)):
                client_syn(client_models[it], global_models[idx])  
                #print(f'------ Train Epoch: {i}, Client: {train_names_list[fusion_pair[it]]},  ------')
                for images, labels in iter(train_loader_list[fusion_pair[it]]):

                    images = images.to(device)
                    labels = labels.to(device)
                    optims_list[it].zero_grad()
                    with torch.cuda.amp.autocast_mode.autocast():
                        outputs = client_models[it](images)
                        loss = F.cross_entropy(outputs, labels)
                    scaler = scalers_list[it]
                    scaler.scale(loss).backward()
                    scaler.step(optims_list[it])
                    scaler.update()

            if agg_method == server_aggregate_weighted or agg_method == server_aggregate_quantization or agg_method == server_aggregate_dp:
                agg_method(global_models[idx], client_models, client_lens)
            else:
                agg_method(global_models[idx], client_models)

            val_inx  = 3 #用overall数据集来决定最后存储的模型
            val_data = val_loader_list[val_inx]

            acc = test("Fusion_Models_:{}-On-Val-Data-{}".format(fusion_names_[idx], val_names_list[val_inx]), writer_, i, global_models[idx],  val_data)
            logger.info("Fusion Model:{}, Agg_method:{}, Epoch:{}, Val_Data:{}, Accuracy:{}.".format(fusion_names_[idx], agg_method.__name__, val_names_list[val_inx], i, acc))

            if acc > best_acc and acc < 1.0:
                best_acc = acc
                save_path = f"{fusion_model_pathes[idx]}_{agg_method.__name__}.pth"
                torch.save(global_models[idx].state_dict(), save_path)
        logger.info("Final Fusion Model:{}, Agg_method:{}, Best Accuracy:{}.".format(fusion_names_[idx], agg_method.__name__,best_acc))
        
        # After you're done with the client models and the optimizer, you can free up the memory
        for client_model, optim_ in zip(client_models, optims_list):
            del client_model
            del optim_
        torch.cuda.empty_cache()  # Clear GPU cache
    
    logger.info("---------------Agg_method:{}:DONE!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!---------------.".format(agg_method.__name__))
    # After you're done with the global models, you can free up the memory
    for global_model in global_models:
        del global_model
    torch.cuda.empty_cache()  # Clear GPU cache
    # Manually run Python's garbage collection
    gc.collect()

logger.info("DONE!!!!!!!!")

Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to C:\Users\jizhi/.cache\torch\hub\checkpoints\alexnet-owt-7be5be79.pth
100%|██████████| 233M/233M [00:21<00:00, 11.4MB/s] 
2023-10-29 01:38:24,565-INFO: -----------------------Fusion model names:_hospital_sim-room--------------------------
2023-10-29 01:39:10,607-INFO: Fusion Model:_hospital_sim-room, Agg_method:server_aggregate_weighted, Epoch:overall, Val_Data:0, Accuracy:0.9230769230769231.
2023-10-29 01:39:55,871-INFO: Fusion Model:_hospital_sim-room, Agg_method:server_aggregate_weighted, Epoch:overall, Val_Data:1, Accuracy:0.9615384615384616.
2023-10-29 01:40:41,647-INFO: Fusion Model:_hospital_sim-room, Agg_method:server_aggregate_weighted, Epoch:overall, Val_Data:2, Accuracy:0.9615384615384616.
2023-10-29 01:41:25,699-INFO: Fusion Model:_hospital_sim-room, Agg_method:server_aggregate_weighted, Epoch:overall, Val_Data:3, Accuracy:0.9735576923076923.
2023-10-29 01:42:12,187-INFO: Fusion Model:_hospital_si