In [1]:
import copy
import math
import warnings
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from datetime import datetime
from torch.optim import lr_scheduler
from sklearn.metrics import r2_score,mean_squared_error, mean_absolute_error
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
from pylab import mpl
import os
import torch
import pandas as pd
warnings.filterwarnings('ignore')

# 设置matplotlib的配置
# mpl.rcParams['font.sans-serif'] = ['SimHei']  # 指定默认字体为黑体
mpl.rcParams['axes.unicode_minus'] = False  # 解决保存图像是负号'-'显示为方块的问题

In [2]:
USE_MULTI_GPU = True
# 设置默认的CUDA设备
torch.cuda.set_device(0)
# 初始化CUDA环境
torch.cuda.init()
if USE_MULTI_GPU and torch.cuda.device_count() > 1:
    MULTI_GPU = True
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5"  # 设置所有六张显卡的编号
    device_ids = ['0','1','2','3','4','5',] # 设置所有六张显卡的编号
else:
    MULTI_GPU = False
    device_ids = ['0']
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(MULTI_GPU)
deviceCount = torch.cuda.device_count()
torch.cuda.set_device(device)
print(deviceCount)
print(device)

False
1
cuda:0


In [3]:
# 创建一个字典，其中包含所有客户端的训练和测试指标列表
loss_acc_r2_mse_mae_metrics = {
    'client_0': {
        'train': {
            'acc_size': [],'r2_size': [],'mse_size': [],'mae_size': [],'loss_size': []
        },
        'test': {
            'acc_size': [],'r2_size': [],'mse_size': [],'mae_size': [],'loss_size': []
        }
    },
    'client_1': {
        'train': {
            'acc_size': [],'r2_size': [],'mse_size': [],'mae_size': [],'loss_size': []
        },
        'test': {
            'acc_size': [],'r2_size': [],'mse_size': [],'mae_size': [],'loss_size': []
        }
    },
    'client_2': {
        'train': {
            'acc_size': [],'r2_size': [],'mse_size': [],'mae_size': [],'loss_size': []
        },
        'test': {
            'acc_size': [],'r2_size': [],'mse_size': [],'mae_size': [],'loss_size': []
        }
    },
    'server_model': {
        'test_volve': {
            'acc_size': [],'r2_size': [],'mse_size': [],'mae_size': [],'loss_size': []
        },
        'test_xj': {
            'acc_size': [],'r2_size': [],'mse_size': [],'mae_size': [],'loss_size': []
        },
        'test_bh': {
            'acc_size': [],'r2_size': [],'mse_size': [],'mae_size': [],'loss_size': []
        }
    }
}

de_t_p_metrics = {
    'client_0': {
        'train': {'depth': [],'true': [],'pre': []},
        'test': {'depth': [],'r2_size': [],'true': [],'pre': []}
    },
    'client_1': {
        'train': {'depth': [], 'true': [], 'pre': []},
        'test': {'depth': [], 'true': [], 'pre': []}
    },
    'client_2': {
        'train': {'depth': [], 'true': [], 'pre': []},
        'test': {'depth': [], 'true': [], 'pre': []}
    },
    'server_model': {
        'test_volve': {'depth': [], 'true': [], 'pre': []},
        'test_xj': {'depth': [], 'true': [], 'pre': []},
        'test_bh': {'depth': [], 'true': [], 'pre': []},
    },

}

In [4]:
volve_4 = './data/volve/volve_4.csv'
volve_5 = './data/volve/volve_5.csv'
volve_7 = './data/volve/volve_7.csv'
volve_9 = './data/volve/volve_9.csv'
volve_9A = './data/volve/volve_9A.csv'
volve_10 = './data/volve/volve_10.csv'
volve_12 = './data/volve/volve_12.csv'
volve_14 = './data/volve/volve_14.csv'
volve_15A = './data/volve/volve_15A.csv'
volve_4_5_7_9A_10 = './data/volve/volve_4_5_7_9A_10.csv'
volve_5_7_10_12 = './data/volve/volve_5_7_10_12.csv'
xj_3 =  './data/xj/well_3.csv'
xj_2 = './data/xj/well_2.csv'
xj_1 = './data/xj/well_1.csv'

bh_1 = './data/bh/bh_1.csv'
bh_2 = './data/bh/bh_2.csv'
bh_3 = './data/bh/bh_3.csv'
bh_4 = './data/bh/bh_4.csv'
bh_5 = './data/bh/bh_5.csv'
bh_6 = './data/bh/bh_6.csv'
bh_7 = './data/bh/bh_7.csv'
bh_8 = './data/bh/bh_8.csv'
bh_9 = './data/bh/bh_9.csv'
bh_10 = './data/bh/bh_10.csv'
bh_11 = './data/bh/bh_11.csv'
bh_12 = './data/bh/bh_12.csv'
bh_14 = './data/bh/bh_14.csv'
bh_15 = './data/bh/bh_15.csv'
bh_16 = './data/bh/bh_16csv'
bh_7_15 = './data/bh/bh_7_15.csv'

In [5]:
server_model_epoch = 100
clint_model_epoch = 1

model_pre_len = 50
model_seq_len = 300
model_tf_lr = 0.0002
model_batch = 128
model_feature_size=5
model_d_model=512
model_num_layers=1
model_dropout=0

In [6]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, max_len=512):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)  # 64*512
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)  # 64*1
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))  # 256   model/2
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        # pe.requires_grad = False
        self.register_buffer('pe', pe)  # 64*1*512

    def forward(self, x):  # [seq,batch,d_model]
        return x + self.pe[:x.size(0), :]  # 64*64*512
class TransAm(nn.Module):
    def __init__(self, feature_size=model_feature_size, d_model=model_d_model, num_layers=model_num_layers, dropout=model_dropout):
        super(TransAm, self).__init__()
        self.feature_size = feature_size
        self.model_type = 'Transformer'
        self.embedding = nn.Linear(feature_size, d_model)
        self.dec_input_fc = nn.Linear(feature_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model)  # 50*512
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=8, dropout=dropout, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
        self.decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=8, dropout=dropout, batch_first=True)
        self.transformer_decoder = nn.TransformerDecoder(self.decoder_layer, num_layers=num_layers)
        self.linear = nn.Linear(d_model, 1)
        self.src_mask = None
        self.src_key_padding_mask = None

    def forward(self, src,tgt,tgt_mask):

        # if self.src_key_padding_mask is None:
        #     mask_key = src_padding  # [batch,seq]
        #     self.src_key_padding_mask = mask_key
        src_em = self.embedding(src)  # [seq,batch,d_model]
        src_em_pos = self.pos_encoder(src_em)  # [seq,batch,d_model]
        encoder_output = self.transformer_encoder(src_em_pos)

        tgt_em = self.embedding(tgt)
        tgt_em_pos = self.pos_encoder(tgt_em)

        decoder_output = self.transformer_decoder(tgt_em_pos, encoder_output, tgt_mask=tgt_mask)

        output = self.linear(decoder_output)
        output_squeeze = output.squeeze()

        self.tgt_mask = None
        return output_squeeze

In [7]:
def train(TModel, loader,optimizer):
    epoch_loss = 0
    criterion = nn.MSELoss()  # 占位符 索引为0.9

    for X, y in loader:
        # X--[batch,seq,feature_size]  y--[batch,seq,feature_size]   64 300 13  64 50 13
        X, y = X.to(device), y.to(device)

        optimizer.zero_grad()
        mask = (torch.triu(torch.ones(y.size(1), y.size(1))) == 1).transpose(0, 1)
        tgt_mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)).to(device)

        output = TModel(X, y, tgt_mask)
        loss = criterion(output, y[:, :, -1])
        loss.backward()
        torch.nn.utils.clip_grad_norm_(TModel.parameters(), 0.10)
        optimizer.step()
        epoch_loss += loss.item()

    return epoch_loss

def test(TModel, tf_loader, y_max, y_min, de_max, de_min):
    epoch_loss = 0
    y_pre = []
    y_true = []
    y_depth = []
    criterion = nn.MSELoss()  # 占位符 索引为0.9
    for x, y in tf_loader:
        with torch.no_grad():
            label = y[:, :, -1].detach().view(1, len(y[:, :, -1]) * model_pre_len).squeeze()
            label = label * (y_max - y_min) + y_min
            label = label.numpy().tolist()
            y_true += label

            de = y[:, :, 0].detach().view(1, len(y[:, :, 0]) * model_pre_len).squeeze()
            de = de * (de_max - de_min) + de_min
            de = de.numpy().tolist()
            y_depth += de

            x, y = x.to(device), y.to(device)

            mask = (torch.triu(torch.ones(y.size(1), y.size(1))) == 1).transpose(0, 1)
            tgt_mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)).to(device)

            output = TModel(x, y, tgt_mask)

            loss = criterion(output, y[:, :, -1])
            epoch_loss += loss.item()

            hat = output.cpu().detach().view(1, len(y[:, :, -1]) * model_pre_len).squeeze()
            hat = hat * (y_max - y_min) + y_min
            hat = hat.numpy().tolist()
            y_pre += hat

    label = np.array(y_true)
    predict = np.array(y_pre)
    dep = np.array(y_depth)

    seq_label = label.reshape(int(len(label) / model_pre_len), model_pre_len)
    seq_predict = predict.reshape(int(len(predict) / model_pre_len), model_pre_len)
    seq_depth = dep.reshape(int(len(dep) / model_pre_len), model_pre_len)

    true = np.concatenate((seq_label[:-1, 0], seq_label[-1, :]), axis=0)
    depth = np.concatenate((seq_depth[:-1, 0], seq_depth[-1, :]), axis=0)
    pre = averages(seq_predict)

    r2 = r2_score(true, pre)
    acc = 1 - (np.abs(pre - true) / (true + 1e-8)).mean()
    mse = mean_squared_error(true, pre)

    mae = mean_absolute_error(true, pre)

    return acc, r2, mse, mae, epoch_loss, true, pre, depth
def communication(server_model, models, client_weights):
    with torch.no_grad():
        for key in server_model.state_dict().keys():
            if 'bn' not in key: # 跳过包含批量归一化（Batch Normalization，bn）的参数，因为批量归一化的参数在联邦学习中通常不需要像其他参数一样进行权重聚合
                temp = torch.zeros_like(server_model.state_dict()[key], dtype=torch.float32) # 创建一个与当前参数相同形状和数据类型的零张量，用于存储加权聚合后的参数
                for client_idx in range(client_num):
                    temp += client_weights[client_idx] * models[client_idx].state_dict()[key] # 使用每个客户端的权重和对应模型的参数进行加权累加
                server_model.state_dict()[key].data.copy_(temp) # 将聚合后的参数更新到服务器模型的对应位置
                for client_idx in range(client_num):
                    models[client_idx].state_dict()[key].data.copy_(server_model.state_dict()[key]) # 将聚合后的参数同步回每个客户端的模型
    return server_model, models # 返回值：返回更新后的 server_model（服务器模型）和 models（客户端模型）


In [8]:
def data_load(path):

    data = pd.read_csv(path)
   # data = data.iloc[::interval, :]

    # data = data.clip(lower=0)  # 设置小于0的数都赋0
    # data = data.apply(lambda x: x.mask((x < x.quantile(0.25) - 1.5 * (x.quantile(0.75) - x.quantile(0.25))) |
    #                                      (x > x.quantile(0.75) + 1.5 * (
    #                                                  x.quantile(0.75) - x.quantile(0.25)))).ffill().bfill())
    #
    # data = data.sort_values(by='MD')
    # data=data.reset_index(drop=True)
    data = data.astype('float32')
    data.dropna(inplace=True)
    data = data.values

    data_ =torch.tensor(data[:len(data)])
    maxc, _ = data_.max(dim=0)
    minc, _ = data_.min(dim=0)
    y_max = maxc[-1]
    y_min = minc[-1]
    de_max = maxc[0]
    de_min = minc[0]
    data_ = (data_ - minc) / (maxc - minc)

    data_last_index = data_.shape[0] - model_seq_len

    data_X = []
    data_Y = []
    for i in range(0, data_last_index - model_pre_len+1):
        data_x = np.expand_dims(data_[i:i + model_seq_len], 0)  # [1,seq,feature_size]
        data_y = np.expand_dims(data_[i + model_seq_len:i + model_seq_len + model_pre_len], 0)  # [1,seq,out_size]
        data_X.append(data_x)
        data_Y.append(data_y)

    data_X=np.concatenate(data_X, axis=0)
    data_Y = np.concatenate(data_Y, axis=0)

    process_data = torch.from_numpy(data_X).type(torch.float32)
    process_label = torch.from_numpy(data_Y).type(torch.float32)

    data_feature_size = process_data.shape[-1]

    dataset_train = TensorDataset(process_data, process_label)

    data_dataloader = DataLoader(dataset_train, batch_size=model_batch, shuffle=False)
    return data_dataloader,y_max,y_min, de_max,de_min

def averages(matrix):  # 计算平均值
    matrix = np.array(matrix)
    row_count, col_count = matrix.shape
    max_diagonal = row_count + col_count - 1
    diagonals = np.zeros(max_diagonal)
    counts = np.zeros(max_diagonal, dtype=int)
    for i in range(row_count):
        for j in range(col_count):
            num = matrix[i, j]
            diagonal_index = i + j
            diagonals[diagonal_index] += num
            counts[diagonal_index] += 1
    averages = diagonals / counts
    return averages

def acc_loss_plot_one(train_data, type_, path):
    plt.figure(figsize=(10, 5))
    plt.plot(train_data, label='train_data', color='blue', linewidth=1)
    plt.xlabel('epoch', fontsize=18)
    plt.title(f'train_test_{type_}')
    path_ = f'{path}'

    plt.grid()
    plt.savefig(path_)

    plt.legend()
   # plt.show()

def acc_loss_plot_two(train_data, test_data, type_, path):
    plt.figure(figsize=(10, 5))
    plt.plot(train_data, label='train_data', color='blue', linewidth=1)
    plt.plot(test_data, label='test_data', color='red', linewidth=1)
    plt.xlabel('epoch', fontsize=18)
    plt.title(f'train_test_{type_}')
    path_ = f'{path}'

    plt.grid()
    plt.savefig(path_)

    plt.legend()
  #  plt.show()

def true_test_plot(depth, true_data, predicted_data, type_, path):
    plt.figure(figsize=(20, 6))
    plt.plot(depth, true_data, label='true_data', color='blue', linewidth=1)
    plt.plot(depth, predicted_data, label='test_data', color='green', linewidth=1)
    plt.ylabel("GRA", fontsize=18)
    plt.xlabel('depth', fontsize=18)
    path_ = f'{path}'
    plt.grid()
    plt.savefig(path_)
#    plt.show()

In [None]:
volve_train,volve_train_y_max,volve_train_y_min, volve_train_de_max,volve_train_de_min = data_load(volve_5_7_10_12)
volve_test,volve_test_y_max,volve_test_y_min, volve_test_de_max,volve_test_de_min =  data_load(volve_9A)

xj_train,xj_train_y_max,xj_train_y_min, xj_train_de_max,xj_train_de_min =  data_load(xj_3)
xj_test,xj_test_y_max,xj_test_y_min, xj_test_de_max,xj_test_de_min =  data_load(xj_1)

bh_train,bh_train_y_max,bh_train_y_min, bh_train_de_max,bh_train_de_min =  data_load(bh_7_15)
bh_test,bh_test_y_max,bh_test_y_min, bh_test_de_max,bh_test_de_min =  data_load(bh_10)

train_loaders = [volve_train, xj_train, bh_train]
test_loaders = [volve_test,xj_test,bh_test]

train_y_maxs = [volve_train_y_max,xj_train_y_max,bh_train_y_max]
train_y_mins = [volve_train_y_min,xj_train_y_min,bh_train_y_min]
train_de_maxs = [volve_train_de_max,xj_train_de_max,bh_train_de_max]
train_de_mins = [volve_train_de_min,xj_train_de_min,bh_train_de_min]

test_y_max = [volve_test_y_max,xj_test_y_max,bh_test_y_max]
test_y_min = [volve_test_y_min,xj_test_y_min,bh_test_y_min]
test_de_max = [volve_test_de_max,xj_test_de_max,bh_test_de_max]
test_de_min = [volve_test_de_min,xj_test_de_min,bh_test_de_min]

datasets= ['volve', 'xj', 'bh']
client_num = len(datasets) # 客户端的数量
client_weights = [1 / client_num for i in range(client_num)]

server_model=TransAm().to(device)
models = [copy.deepcopy(server_model).to(device) for idx in range(client_num)]
optimizers = [optim.Adam(params=models[idx].parameters(), lr=model_tf_lr, weight_decay=0.001) for idx in range(client_num)]

for epoch in range(server_model_epoch):
    print('-' * 40, 'server model epoch', epoch, '-' * 40)
    for wi in range(clint_model_epoch):
        print('    ', 'clint model epoch', wi, '    ')
        for client_idx in range(client_num):
            models[client_idx].train()
            train(models[client_idx], train_loaders[client_idx],optimizers[client_idx])

            models[client_idx].eval()
            train_acc,train_r2, train_mse,train_mae,train_loss, true_train, pre_train, train_depth = test(models[client_idx],train_loaders[client_idx],
                                                                                                           train_y_maxs[client_idx],train_y_mins[client_idx],
                                                                                                           train_de_maxs[client_idx],train_de_mins[client_idx])

            test_acc, test_r2, test_mse, test_mae, test_loss, true_test, pre_test, test_depth = test(models[client_idx],test_loaders[client_idx],
                                                                                                     test_y_max[client_idx],test_y_min[client_idx],
                                                                                                     test_de_max[client_idx],test_de_min[client_idx])
            # 获取当前时间
            now = datetime.now()
            print(datasets[client_idx])
            print('  train:loss =','{:.4f}'.format(train_loss), ' acc =', '{:.4f}'.format(train_acc), ' r2 =', '{:.4f}'.format(train_r2), 'time = ', now.strftime("%Y-%m-%d %H:%M:%S"))
            print('  test:loss =','{:.4f}'.format(test_loss), ' acc =', '{:.4f}'.format(test_acc), ' r2 =', '{:.4f}'.format(test_r2), 'time = ', now.strftime("%Y-%m-%d %H:%M:%S"))

            client_key = f'client_{client_idx}'  # 动态生成键名
            loss_acc_r2_mse_mae_metrics[client_key]['train']['acc_size'].append(train_acc)
            loss_acc_r2_mse_mae_metrics[client_key]['train']['r2_size'].append(train_r2)
            loss_acc_r2_mse_mae_metrics[client_key]['train']['mse_size'].append(train_mse)
            loss_acc_r2_mse_mae_metrics[client_key]['train']['mae_size'].append(train_mae)
            loss_acc_r2_mse_mae_metrics[client_key]['train']['loss_size'].append(train_loss)
            loss_acc_r2_mse_mae_metrics[client_key]['test']['acc_size'].append(test_acc)
            loss_acc_r2_mse_mae_metrics[client_key]['test']['r2_size'].append(test_r2)
            loss_acc_r2_mse_mae_metrics[client_key]['test']['mse_size'].append(test_mse)
            loss_acc_r2_mse_mae_metrics[client_key]['test']['mae_size'].append(test_mae)
            loss_acc_r2_mse_mae_metrics[client_key]['test']['loss_size'].append(test_loss)

            de_t_p_metrics[client_key]['train']['depth'] = train_depth
            de_t_p_metrics[client_key]['train']['true'] = true_train
            de_t_p_metrics[client_key]['train']['pre'] = pre_train
            de_t_p_metrics[client_key]['test']['depth'] = test_depth
            de_t_p_metrics[client_key]['test']['true'] = true_test
            de_t_p_metrics[client_key]['test']['pre'] = pre_test

    server_model, models = communication(server_model, models, client_weights)

    for client_idx in range(client_num):
        server_model.eval()
        test_acc, test_r2, test_mse, test_mae, test_loss, true_test, pre_test, test_depth = test(server_model,test_loaders[client_idx],
                                                                                                 test_y_max[client_idx],test_y_min[client_idx],
                                                                                                 test_de_max[client_idx],test_de_min[client_idx])
        client_key = f'test_{datasets[client_idx]}'  # 动态生成键名
        loss_acc_r2_mse_mae_metrics['server_model'][client_key]['acc_size'].append(test_acc)
        loss_acc_r2_mse_mae_metrics['server_model'][client_key]['r2_size'].append(test_r2)
        loss_acc_r2_mse_mae_metrics['server_model'][client_key]['mse_size'].append(test_mse)
        loss_acc_r2_mse_mae_metrics['server_model'][client_key]['mae_size'].append(test_mae)
        loss_acc_r2_mse_mae_metrics['server_model'][client_key]['loss_size'].append(test_loss)

        de_t_p_metrics['server_model'][client_key]['depth'] = test_depth
        de_t_p_metrics['server_model'][client_key]['true'] = true_test
        de_t_p_metrics['server_model'][client_key]['pre'] = pre_test

        # 获取当前时间
        now = datetime.now()

        print('    ', 'server model', '    ')
        print()
        print(datasets[client_idx],'test:loss =', '{:.4f}'.format(test_loss), ' acc =', '{:.4f}'.format(test_acc), ' r2 =','{:.4f}'.format(test_r2),
              ' mse =', '{:.4f}'.format(test_mse), ' mae =', '{:.4f}'.format(test_mae),'time = ', now.strftime("%Y-%m-%d %H:%M:%S"))

    server_volve_loss_acc_mse_mae_dict = {'test_loss': loss_acc_r2_mse_mae_metrics['server_model']['test_volve']['loss_size'],
                                    'test_acc': loss_acc_r2_mse_mae_metrics['server_model']['test_volve']['acc_size'],
                                    'test_r2': loss_acc_r2_mse_mae_metrics['server_model']['test_volve']['r2_size'],
                                    'test_mse': loss_acc_r2_mse_mae_metrics['server_model']['test_volve']['mse_size'],
                                    'test_mae': loss_acc_r2_mse_mae_metrics['server_model']['test_volve']['mae_size'] }
    server_volve_pre_ture_test_dict = {'depth': de_t_p_metrics['server_model']['test_volve']['depth'],
                                 'true': de_t_p_metrics['server_model']['test_volve']['true'],
                                 'pre': de_t_p_metrics['server_model']['test_volve']['pre']}
    server_xj_loss_acc_mse_mae_dict = {'test_loss': loss_acc_r2_mse_mae_metrics['server_model']['test_xj']['loss_size'],
                                    'test_acc': loss_acc_r2_mse_mae_metrics['server_model']['test_xj']['acc_size'],
                                    'test_r2': loss_acc_r2_mse_mae_metrics['server_model']['test_xj']['r2_size'],
                                    'test_mse': loss_acc_r2_mse_mae_metrics['server_model']['test_xj']['mse_size'],
                                    'test_mae': loss_acc_r2_mse_mae_metrics['server_model']['test_xj']['mae_size'] }
    server_xj_pre_ture_test_dict = {'depth': de_t_p_metrics['server_model']['test_xj']['depth'],
                                       'true': de_t_p_metrics['server_model']['test_xj']['true'],
                                       'pre': de_t_p_metrics['server_model']['test_xj']['pre']}
    server_bh_loss_acc_mse_mae_dict = {'test_loss': loss_acc_r2_mse_mae_metrics['server_model']['test_bh']['loss_size'],
                                    'test_acc': loss_acc_r2_mse_mae_metrics['server_model']['test_bh']['acc_size'],
                                    'test_r2': loss_acc_r2_mse_mae_metrics['server_model']['test_bh']['r2_size'],
                                    'test_mse': loss_acc_r2_mse_mae_metrics['server_model']['test_bh']['mse_size'],
                                    'test_mae': loss_acc_r2_mse_mae_metrics['server_model']['test_bh']['mae_size'] }
    server_bh_pre_ture_test_dict = {'depth': de_t_p_metrics['server_model']['test_bh']['depth'],
                                       'true': de_t_p_metrics['server_model']['test_bh']['true'],
                                       'pre': de_t_p_metrics['server_model']['test_bh']['pre']}

    clint0_volve__loss_acc_mse_mae_dict = {'test_loss': loss_acc_r2_mse_mae_metrics['client_0']['test']['loss_size'],
                                'test_acc': loss_acc_r2_mse_mae_metrics['client_0']['test']['acc_size'],
                                'test_r2': loss_acc_r2_mse_mae_metrics['client_0']['test']['r2_size'],
                                'test_mse': loss_acc_r2_mse_mae_metrics['client_0']['test']['mse_size'],
                                'test_mae': loss_acc_r2_mse_mae_metrics['client_0']['test']['mae_size'],
                                'train_loss': loss_acc_r2_mse_mae_metrics['client_0']['train']['loss_size'],
                                'train_acc': loss_acc_r2_mse_mae_metrics['client_0']['train']['acc_size'],
                                'train_r2': loss_acc_r2_mse_mae_metrics['client_0']['train']['r2_size'],
                                'train_mse': loss_acc_r2_mse_mae_metrics['client_0']['train']['mse_size'],
                                'train_mae': loss_acc_r2_mse_mae_metrics['client_0']['train']['mae_size'],
                                }
    clint0_volve_pre_ture_test_dict = {'depth': de_t_p_metrics['client_0']['test']['depth'],
                                    'true': de_t_p_metrics['client_0']['test']['true'],
                                    'pre': de_t_p_metrics['client_0']['test']['pre'] }
    clint1_xj_loss_acc_mse_mae_dict = {'test_loss': loss_acc_r2_mse_mae_metrics['client_1']['test']['loss_size'],
                                'test_acc': loss_acc_r2_mse_mae_metrics['client_1']['test']['acc_size'],
                                'test_r2': loss_acc_r2_mse_mae_metrics['client_1']['test']['r2_size'],
                                'test_mse': loss_acc_r2_mse_mae_metrics['client_1']['test']['mse_size'],
                                'test_mae': loss_acc_r2_mse_mae_metrics['client_1']['test']['mae_size'],
                                'train_loss': loss_acc_r2_mse_mae_metrics['client_1']['train']['loss_size'],
                                'train_acc': loss_acc_r2_mse_mae_metrics['client_1']['train']['acc_size'],
                                'train_r2': loss_acc_r2_mse_mae_metrics['client_1']['train']['r2_size'],
                                'train_mse': loss_acc_r2_mse_mae_metrics['client_1']['train']['mse_size'],
                                'train_mae': loss_acc_r2_mse_mae_metrics['client_1']['train']['mae_size'],
                                }
    clint1_xj_pre_ture_test_dict = {'depth': de_t_p_metrics['client_1']['test']['depth'],
                                    'true': de_t_p_metrics['client_1']['test']['true'],
                                    'pre': de_t_p_metrics['client_1']['test']['pre'] }
    clint2_bh_loss_acc_mse_mae_dict = {'test_loss': loss_acc_r2_mse_mae_metrics['client_2']['test']['loss_size'],
                                'test_acc': loss_acc_r2_mse_mae_metrics['client_2']['test']['acc_size'],
                                'test_r2': loss_acc_r2_mse_mae_metrics['client_2']['test']['r2_size'],
                                'test_mse': loss_acc_r2_mse_mae_metrics['client_2']['test']['mse_size'],
                                'test_mae': loss_acc_r2_mse_mae_metrics['client_2']['test']['mae_size'],
                                'train_loss': loss_acc_r2_mse_mae_metrics['client_2']['train']['loss_size'],
                                'train_acc': loss_acc_r2_mse_mae_metrics['client_2']['train']['acc_size'],
                                'train_r2': loss_acc_r2_mse_mae_metrics['client_2']['train']['r2_size'],
                                'train_mse': loss_acc_r2_mse_mae_metrics['client_2']['train']['mse_size'],
                                'train_mae': loss_acc_r2_mse_mae_metrics['client_2']['train']['mae_size'],
                                }
    clint2_bh_pre_ture_test_dict = {'depth': de_t_p_metrics['client_2']['test']['depth'],
                                    'true': de_t_p_metrics['client_2']['test']['true'],
                                    'pre': de_t_p_metrics['client_2']['test']['pre'] }

    server_volve_loss_acc_mse_mae = pd.DataFrame(server_volve_loss_acc_mse_mae_dict)
    server_volve_pre_ture_test = pd.DataFrame(server_volve_pre_ture_test_dict)
    server_xj_loss_acc_mse_mae = pd.DataFrame(server_xj_loss_acc_mse_mae_dict)
    server_xj_pre_ture_test = pd.DataFrame(server_xj_pre_ture_test_dict)
    server_bh_loss_acc_mse_mae = pd.DataFrame(server_bh_loss_acc_mse_mae_dict)
    server_bh_pre_ture_test = pd.DataFrame(server_bh_pre_ture_test_dict)

    clint0_volve_loss_acc_mse_mae = pd.DataFrame(clint0_volve__loss_acc_mse_mae_dict)
    clint0_volve_pre_ture_test = pd.DataFrame(clint0_volve_pre_ture_test_dict)
    clint1_xj_loss_acc_mse_mae = pd.DataFrame(clint1_xj_loss_acc_mse_mae_dict)
    clint1_xj_pre_ture_test = pd.DataFrame(clint1_xj_pre_ture_test_dict)
    clint2_bh_loss_acc_mse_mae = pd.DataFrame(clint2_bh_loss_acc_mse_mae_dict)
    clint2_bh_pre_ture_test = pd.DataFrame(clint2_bh_pre_ture_test_dict)

    server_volve_loss_acc_mse_mae.to_csv('./out1/server/volve/server_volve_loss_acc_mse_mae.csv', sep=",", index=True)
    server_volve_pre_ture_test.to_csv('./out1/server/volve/server_volve_pre_ture_test.csv', sep=",", index=True)
    server_xj_loss_acc_mse_mae.to_csv('./out1/server/xj/server_xj_loss_acc_mse_mae.csv', sep=",", index=True)
    server_xj_pre_ture_test.to_csv('./out1/server/xj/server_xj_pre_ture_test.csv', sep=",", index=True)
    server_bh_loss_acc_mse_mae.to_csv('./out1/server/bh/server_bh_loss_acc_mse_mae.csv', sep=",", index=True)
    server_bh_pre_ture_test.to_csv('./out1/server/bh/server_bh_pre_ture_test.csv', sep=",", index=True)

    clint0_volve_loss_acc_mse_mae.to_csv('./out1/client/volve/clint0_volve_loss_acc_mse_mae.csv', sep=",", index=True)
    clint0_volve_pre_ture_test.to_csv('./out1/client/volve/clint0_volve_pre_ture_test.csv', sep=",", index=True)
    clint1_xj_loss_acc_mse_mae.to_csv('./out1/client/xj/clint1_xj_loss_acc_mse_mae.csv', sep=",", index=True)
    clint1_xj_pre_ture_test.to_csv('./out1/client/xj/clint1_xj_pre_ture_test.csv', sep=",", index=True)
    clint2_bh_loss_acc_mse_mae.to_csv('./out1/client/bh/clint2_bh_loss_acc_mse_mae.csv', sep=",", index=True)
    clint2_bh_pre_ture_test.to_csv('./out1/client/bh/clint2_bh_pre_ture_test.csv', sep=",", index=True)

    acc_loss_plot_one(server_volve_loss_acc_mse_mae['test_loss'], 'loss','./out1/server/volve/server_test_loss.png')
    acc_loss_plot_one(server_volve_loss_acc_mse_mae['test_r2'], 'r2', './out1/server/volve/server_test_r2.png')
    acc_loss_plot_one(server_xj_loss_acc_mse_mae['test_loss'], 'loss','./out1/server/xj/server_test_loss.png')
    acc_loss_plot_one(server_xj_loss_acc_mse_mae['test_r2'], 'r2', './out1/server/xj/server_test_r2.png')
    acc_loss_plot_one(server_bh_loss_acc_mse_mae['test_loss'], 'loss','./out1/server/bh/server_test_loss.png')
    acc_loss_plot_one(server_bh_loss_acc_mse_mae['test_r2'], 'r2', './out1/server/bh/server_test_r2.png')

    acc_loss_plot_two(clint0_volve_loss_acc_mse_mae['train_r2'], clint0_volve_loss_acc_mse_mae['test_r2'], 'r2','./out1/client/volve/clint0_volve_r2.png')
    acc_loss_plot_two(clint0_volve_loss_acc_mse_mae['train_loss'], clint0_volve_loss_acc_mse_mae['test_loss'], 'r2','./out1/client/volve/clint0_volve_loss.png')
    acc_loss_plot_two(clint1_xj_loss_acc_mse_mae['train_r2'], clint1_xj_loss_acc_mse_mae['test_r2'], 'r2','./out1/client/xj/clint1_xj_r2.png')
    acc_loss_plot_two(clint1_xj_loss_acc_mse_mae['train_loss'], clint1_xj_loss_acc_mse_mae['test_loss'], 'r2','./out1/client/xj/clint1_xj_loss.png')
    acc_loss_plot_two(clint2_bh_loss_acc_mse_mae['train_r2'], clint2_bh_loss_acc_mse_mae['test_r2'], 'r2','./out1/client/bh/clint2_bh_r2.png')
    acc_loss_plot_two(clint2_bh_loss_acc_mse_mae['train_loss'], clint2_bh_loss_acc_mse_mae['test_loss'], 'r2','./out1/client/bh/clint2_bh_loss.png')

    true_test_plot(server_volve_pre_ture_test['depth'], server_volve_pre_ture_test['true'], server_volve_pre_ture_test['pre'], 'test',
                   './out1/server/volve/server_volve_pre_ture_test.png')
    true_test_plot(server_xj_pre_ture_test['depth'], server_xj_pre_ture_test['true'], server_xj_pre_ture_test['pre'], 'test',
                   './out1/server/xj/server_xj_pre_ture_test.png')
    true_test_plot(server_bh_pre_ture_test['depth'], server_bh_pre_ture_test['true'], server_bh_pre_ture_test['pre'], 'test',
                   './out1/server/bh/server_bh_pre_ture_test.png')
    true_test_plot(clint0_volve_pre_ture_test['depth'], clint0_volve_pre_ture_test['true'], clint0_volve_pre_ture_test['pre'], 'test',
                   './out1/client/volve/clint0_volve_pre_ture_test.png')
    true_test_plot(clint1_xj_pre_ture_test['depth'], clint1_xj_pre_ture_test['true'], clint1_xj_pre_ture_test['pre'], 'test',
                   './out1/client/xj/clint1_xj_pre_ture_test.png')
    true_test_plot(clint2_bh_pre_ture_test['depth'], clint2_bh_pre_ture_test['true'], clint2_bh_pre_ture_test['pre'], 'test',
                   './out1/client/bh/clint2_bh_pre_ture_test.png')

    torch.save(models[0].state_dict(), './out1/model/model_0_volve.pkl')
    torch.save(models[1].state_dict(), './out1/model/model_1_xj.pkl')
    torch.save(models[2].state_dict(), './out1/model/model_2_bh.pkl')
    print('save clint model')
    torch.save(server_model.state_dict(), './out1/model/server_model.pkl')
    print('save server_model')

---------------------------------------- server model epoch 0 ----------------------------------------
     clint model epoch 0     
volve
  train:loss = 4.5829  acc = 0.3199  r2 = 0.7698 time =  2025-02-14 10:26:38
  test:loss = 0.3174  acc = 0.8772  r2 = 0.8271 time =  2025-02-14 10:26:38
xj
  train:loss = 29.3121  acc = -5.8839  r2 = -3.8383 time =  2025-02-14 10:27:09
  test:loss = 4.7856  acc = -0.7509  r2 = -5.2584 time =  2025-02-14 10:27:09
bh
  train:loss = 20.8725  acc = 0.3927  r2 = -0.0743 time =  2025-02-14 10:29:23
  test:loss = 6.1581  acc = 0.3815  r2 = 0.0797 time =  2025-02-14 10:29:23
     server model     

volve test:loss = 0.5032  acc = 0.8711  r2 = 0.7593  mse = 27.7668  mae = 4.4549 time =  2025-02-14 10:29:28
     server model     

xj test:loss = 0.1503  acc = 0.7196  r2 = 0.8320  mse = 4.8418  mae = 1.8097 time =  2025-02-14 10:29:30
     server model     

bh test:loss = 2.6929  acc = 0.3118  r2 = 0.6265  mse = 9.1393  mae = 2.6373 time =  2025-02-14 10:29:4