In [2]:
import logging
import torch
import sys
import os
from tqdm import tqdm
import wandb

# 添加环境
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../MyExpr")))
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../FedML")))
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../")))

print(sys.path)

# 查看GPU
print(torch.cuda.is_available())
for i in range(torch.cuda.device_count()):
    print("GPU[{:d}]: {:s}".format(i, torch.cuda.get_device_name(i)))

# 选择GPU
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
# print(torch.cuda.device_count())

['/home/guest/Fed_Expr', '/home/guest/Fed_Expr/FedML', '/home/guest/Fed_Expr/MyExpr', '/home/guest/Fed_Expr', '/home/guest/Fed_Expr/FedML', '/home/guest/Fed_Expr/MyExpr', '/home/guest/Fed_Expr/MyExpr/dfl', '/home/guest/miniconda/envs/fedml/lib/python37.zip', '/home/guest/miniconda/envs/fedml/lib/python3.7', '/home/guest/miniconda/envs/fedml/lib/python3.7/lib-dynload', '', '/home/guest/miniconda/envs/fedml/lib/python3.7/site-packages', '/home/guest/miniconda/envs/fedml/lib/python3.7/site-packages/IPython/extensions', '/home/guest/.ipython']
True
GPU[0]: GeForce RTX 2080 Ti
GPU[1]: GeForce RTX 2080 Ti
GPU[2]: GeForce RTX 2080 Ti
GPU[3]: GeForce RTX 2080 Ti


In [3]:
from fedml_api.standalone.decentralized.topology_manager import TopologyManager

from MyExpr.dfl.model.resnet import resnet18
from MyExpr.dfl.Args import add_args
# from MyExpr.dfl.component.recorder import Recorder
from MyExpr.dfl.component.client import Client
from MyExpr.dfl.component.trainer import Trainer
from MyExpr.dfl.component.broadcaster import Broadcaster
from MyExpr.dfl.component.top_k import TopKSelector

from MyExpr.data import Data

parser = add_args()
# args = parser.parse_args()
args = parser.parse_known_args()[0]

In [5]:
import wandb


# snapshot of a single iteration
class Snapshot(object):
    def __init__(self, recorder):
        self.recorder = recorder

        self.local_train_loss_per_client = {}
        self.mutual_train_loss_per_client = {}
        self.local_train_correct_per_client = {}

        # 聚合统计数据
        self.mutual_train_regret = 0
        self.local_train_regret = 0
        self.total_local_train_correct = 0

        # 均值
        self.avg_local_train_loss_per_client = 0
        self.avg_mutual_train_loss_per_client = 0
        self.avg_local_train_acc = 0

    def __repr__(self):
        description = "===============================epoch:{:d} iteration:{:d}===============================\n" \
                      "mutual_train_regret:{:f}, local_train_regret:{:f}, avg_local_train_loss_per_client:{:f}\n" \
                      "avg_mutual_train_loss_per_client:{:f}, avg_local_train_acc:{:f}\n" \
                      "=====================================================================================\n"

        return description.format(self.recorder.epoch, self.recorder.iteration, self.mutual_train_regret,
                                  self.local_train_regret, self.avg_local_train_loss_per_client,
                                  self.avg_mutual_train_loss_per_client, self.avg_local_train_acc)

    # 聚合单个iteration的所有记录
    def aggregate(self):
        for c_id in self.recorder.client_dic.keys():
            self.mutual_train_regret += self.mutual_train_loss_per_client[c_id]
            self.local_train_regret += self.local_train_loss_per_client[c_id]
            self.total_local_train_correct += self.local_train_correct_per_client[c_id]

        self.avg_mutual_train_loss_per_client = self.mutual_train_regret / len(self.recorder.client_dic.keys())
        self.avg_local_train_loss_per_client = self.local_train_regret / len(self.recorder.client_dic.keys())
        self.avg_local_train_acc = self.total_local_train_correct / (
                    self.recorder.args.batch_size * len(self.local_train_correct_per_client))


class Recorder(object):

    def __init__(self, client_dic, topology_manager, args):
        self.client_dic = client_dic
        self.topology_manager = topology_manager
        self.args = args

        self.epoch = 0
        self.iteration = 0
        self.cur_snapshot = Snapshot(self)
        self.train_history_per_iteration = []
        self.train_history_per_epoch = []
        self.print_log = False
        self.wandb_log = True

    def next_iteration(self):
        self.cur_snapshot.aggregate()
        if self.print_log:
            self.print_round_log()
        self.iteration += 1
        self.train_history_per_iteration.append(self.cur_snapshot)
        self.cur_snapshot = Snapshot(self)

    def next_epoch(self):
        self.epoch += 1
        self.iteration = 0

        # todo 优化epoch总结输出
        epoch_mutual_train_regret = 0
        epoch_local_train_correct = 0
        for snapshot in self.train_history_per_iteration:
            epoch_mutual_train_regret += snapshot.mutual_train_regret
            epoch_local_train_correct += snapshot.total_local_train_correct
        # 记录本轮数据
        self.train_history_per_epoch.append(self.train_history_per_iteration)

        epoch_local_train_acc = epoch_local_train_correct / (
                    self.args.batch_size * len(self.client_dic) * len(self.train_history_per_iteration))
        if self.print_log:
            print("epoch_mutual_train_regret:{:f} epoch_mutual_train_acc:{:f}\n ".format(epoch_mutual_train_regret, epoch_local_train_acc))

        # 重置轮记录
        self.train_history_per_iteration = []

        if self.wandb_log:
            wandb.log({"epoch_mutual_train_regret":epoch_mutual_train_regret, "epoch_mutual_train_acc":epoch_local_train_acc})

    def record_local_train_loss(self, c_id, loss):
        self.cur_snapshot.local_train_loss_per_client[c_id] = loss

    def record_mutual_train_loss(self, c_id, loss):
        self.cur_snapshot.mutual_train_loss_per_client[c_id] = loss

    def record_local_train_correct(self, c_id, correct):
        self.cur_snapshot.local_train_correct_per_client[c_id] = correct
        # print("correct:", correct)

    def print_round_log(self):
        print(self.cur_snapshot)


In [6]:
# 1、设置trainer策略
trainer = Trainer()
trainer.use(args.mutual_trainer_strategy)
# 2、设置broadcaster策略
broadcaster = Broadcaster()
broadcaster.use(args.broadcaster_strategy)
# 3、设置Top_K策略
topK_selector = TopKSelector()
topK_selector.use(args.topK_strategy)
# 4、初始化拓扑结构
client_num_in_total = args.client_num_in_total
topology_manager = TopologyManager(client_num_in_total, True,
                                           undirected_neighbor_num=args.topology_neighbors_num_undirected)
topology_manager.generate_topology()
print("finished topology generation")

Trainer use strategy:local_and_mutual
Broadcaster use strategy:flood
TopKSelector use strategy:loss
finished topology generation


In [7]:
# 5、加载数据集，划分
data = Data(args)
train_loader, test_loader, test_all = data.train_loader, data.test_loader, data.test_all
train_data_size_per_client = len(train_loader[0])
test_data_size_per_client = len(test_loader[0])
epochs = args.epochs
batch_size = args.batch_size
train_iteration = train_data_size_per_client

In [8]:
client_dic = {}
# 6、注册recorder
recorder = Recorder(client_dic, topology_manager, args)
trainer.register_recorder(recorder)
broadcaster.register_recorder(recorder)
topK_selector.register_recoder(recorder)

# 7、初始化client, 选择搭载模型等
for c_id in range(client_num_in_total):
    # "ResNet18_GN"
    model = resnet18(num_classes=10)
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr,
                                     weight_decay=args.wd, amsgrad=True)
    c = Client(model, c_id, args, train_loader[c_id], test_loader[c_id])
    # 方便更换策略
    c.register(topK_selector=topK_selector, recorder=recorder, broadcaster=broadcaster)
    client_dic[c_id] = c


In [None]:
# 8、开始训练

wandb.init(project="dfl-topK-dml",
                entity="kyriegyj",
               config=args)

for epoch in range(epochs):
    # train
    # todo 全体设置model.train()
    for iteration in range(train_iteration):
        # print("============开始训练(第:d轮)============".format(iteration))
        trainer.train()
        # print("============结束训练(第:d轮)============".format(iteration))
    trainer.next_epoch()
    recorder.next_epoch()
    # todo 全体设置mode.eval()
    # todo test
wandb.finish()


In [None]:
# print(topology_manager.get_symmetric_neighbor_list(0))
# print(topology_manager.get_symmetric_neighbor_list(1))
# print(topology_manager.get_symmetric_neighbor_list(2))
# wandb.finish()

In [None]:
print(train_iteration)
print(args.batch_size)
print(len(train_loader[0]))

In [None]:
# e = enumerate(train_loader[0])
# _, (x, y) = next(e)
# print(x.shape, y.shape)

total: 4 selected: 3
total: 4 selected: 3
total: 4 selected: 3
total: 4 selected: 3
total: 4 selected: 3
total: 4 selected: 3
total: 4 selected: 3
total: 4 selected: 3
total: 4 selected: 3
total: 4 selected: 3
total: 4 selected: 3
total: 4 selected: 3
total: 4 selected: 3
total: 4 selected: 3
total: 4 selected: 3
total: 4 selected: 3
total: 4 selected: 3
total: 4 selected: 3
total: 4 selected: 3
total: 4 selected: 3
total: 4 selected: 3
total: 4 selected: 3
total: 4 selected: 3
total: 4 selected: 3
total: 4 selected: 3
total: 4 selected: 3
total: 4 selected: 3
total: 4 selected: 3
total: 4 selected: 3
total: 4 selected: 3


ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/home/guest/miniconda/envs/fedml/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3524, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_302443/1701665895.py", line 12, in <module>
    trainer.train()
  File "/home/guest/Fed_Expr/MyExpr/dfl/component/trainer.py", line 49, in local_and_mutual_learning_collaborate_update
    client_dic[c_id].train(client_data_dic[c_id][0], client_data_dic[c_id][1])
  File "/home/guest/Fed_Expr/MyExpr/dfl/component/client.py", line 59, in train
    loss.backward()
  File "/home/guest/miniconda/envs/fedml/lib/python3.7/site-packages/torch/_tensor.py", line 307, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/home/guest/miniconda/envs/fedml/lib/python3.7/site-packages/torch/autograd/__init__.py", line 156, in backward
    allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag
Keybo

TypeError: object of type 'NoneType' has no len()

In [None]:
# print(topology_manager.get_symmetric_neighbor_list(0))
# print(topology_manager.get_symmetric_neighbor_list(1))
# print(topology_manager.get_symmetric_neighbor_list(2))
# wandb.finish()

In [None]:
print(train_iteration)
print(args.batch_size)
print(len(train_loader[0]))

In [None]:
# f1计算
c = client_dic[0]
it = c.train_it
_, (x, y) = next(it)
x, y = x.to(c.device), y.to(c.device)
output = c.model(x)
pred = output.argmax(dim=1)
# print(pred.cpu().shape)
pred_mask = torch.zeros(output.size()).scatter_(1, pred.cpu().view(-1, 1), 1.)
# print(pred_mask)
pred_num = pred_mask.sum(0) # 数据中每类的预测量
targ_mask = torch.zeros(output.size()).scatter_(1, y.data.cpu().view(-1, 1), 1.)
targ_num = targ_mask.sum(0) # 标签中每类的数量
acc_mask = pred_mask * targ_mask
acc_num = acc_mask.sum(0) # 每类中预测正确的数量
print(acc_num, targ_num, pred_num)

# 防止0/0产生nan
epsilon = 1e-7

recall = acc_num / (targ_num + epsilon)
precision = acc_num / (pred_num + epsilon)
print(recall, precision)
#
f1 = 2 * recall * precision / (recall + precision + epsilon)
f1_marco = f1.sum(0) / f1.size(0)
print("F1-marco:", f1_marco.data.item())

# torch.clamp()区间限制
# fp 为预测为positive，实际不是。pred去掉tp，剩下即为错误的positive预测，即为fp。
fp = torch.clamp((pred_mask - targ_mask), 0.0).sum(0)
# fn 为预测为negative，实际不是。targ去掉tp，剩下的即为没预测到的positive样本，即为fn。
fn = torch.clamp((targ_mask - pred_mask), 0.0).sum(0)
# print("pred_num:", pred_num)
# print("targ_num:", targ_num)
# print("acc_num:", acc_num)
# # tp + fp = pred_p
# print("fp:", fp)
# # tp + fn = targ_p
# print("fn:", fn)

acc_sum = acc_num.sum(0)
fp_sum = fp.sum(0)
fn_sum = fn.sum(0)

precision_sum = acc_sum / (acc_sum + fp_sum + epsilon)
recall_sum = acc_sum / (acc_sum + fn_sum + epsilon)
f1_micro = 2 * recall_sum * precision_sum / (recall_sum + precision_sum + epsilon)
print("f1_micro:", f1_micro.data.item())