In [184]:
from torch.utils.data import DataLoader, Dataset
import torch
from torch import nn
from algorithms.TranAD.TranAD import TranAD
from typing import List
from tqdm import tqdm
import numpy as np
import os
import sys
import time
# from logger import logger

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [185]:
def average_weights(state_dicts: List[dict], fed_avg_freqs: torch.Tensor):
    # init
    avg_state_dict = {}
    for key in state_dicts[0].keys():
        avg_state_dict[key] = state_dicts[0][key] * fed_avg_freqs[0]

    state_dicts = state_dicts[1:]
    fed_avg_freqs = fed_avg_freqs[1:]
    for state_dict, freq in zip(state_dicts, fed_avg_freqs):
        for key in state_dict.keys():
            avg_state_dict[key] += state_dict[key] * freq
    return avg_state_dict


def update_global_grad_correct(old_correct: dict, grad_correct_deltas: List[dict], fed_avg_freqs: torch.Tensor, num_chosen_client, num_total_client):
    assert (len(grad_correct_deltas) == num_chosen_client)
    total_delta = average_weights(grad_correct_deltas, [1 / num_chosen_client] * num_chosen_client)
    for key in old_correct.keys():
        if key in total_delta.keys():
            old_correct[key] = old_correct[key] + total_delta[key]
    return old_correct

In [186]:
def get_init_grad_correct(model: nn.Module):
    correct = {}
    for name, _ in model.named_parameters():
        correct[name] = torch.tensor(0, dtype=torch.float, device="cpu")
    return correct

# def load_model(state_dict) -> nn.Module:
#     if args.tsadalg != 'deep_svdd':
#         model = model_fun()
#         model.load_state_dict(state_dict)
#         return model
#     else:
#         if config_svdd["stage"] == "first":
#             model = Model_first_stage()
#         elif config_svdd["stage"] == "second":
#             model = Model_second_stage()
#         else:
#             raise NotImplementedError
#         model.load_state_dict(state_dict)
#         return model

class Client(object):

    def __init__(self, dataset):
        self.net = TranAD(25)
        self.dataset = dataset
        self.state_dict_prev = None
        self.moon_mu = 1
        self.prox_mu = 0.01
        self.local_bs = 64
        self.grad_correct = get_init_grad_correct(TranAD(25))
        self.local_ep = 1
        self.criterion = nn.MSELoss().to(device)
        self.cos_sim = torch.nn.CosineSimilarity(dim=-1).to(device)
        self.temperature = 0.5
        self.trainloader = None
        self.verbose = 1

    def set_local_optimizer(self, model: nn.Module) -> torch.optim.Optimizer:
        return torch.optim.Adam(model.parameters, lr=0.001)

    def local_train(
            self, global_state_dict, global_round, global_grad_correct: dict, global_c: torch.Tensor = None
    ):

        scheduler = None

        # region 准备 model model_prev model_global

        model_global = self.net.load_state_dict(global_state_dict)
        model_global.requires_grad_(False)
        model_global.eval()
        model_global.to(device)
        #
        model_current = self.net.load_state_dict(global_state_dict)
        model_current.requires_grad_(True)
        model_current.train()
        model_current.to(device)
        # endregion

        epoch_loss = []
        train_acc = []  # 返回的loss和acc都是local round上的平均

        # region Set optimizer and dataloader
        optimizer = self.set_local_optimizer(model_current)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 5, 0.9)

        trainloader = DataLoader(
                    self.dataset,
                    batch_size=self.local_bs,
                    shuffle=True,
                    pin_memory=True,
                    num_workers=1,
                    drop_last=False
            )

        l1s = []
        for local_epoch in range(self.local_ep):
            loss1_list = []
            batch_loss = []
            correct = 0
            num_data = 0
            for i, (x, y) in enumerate(trainloader):
                x, y = x.to(self.device), y.to(self.device)

                optimizer.zero_grad()
                local_bs = x.shape[0]
                feats = x.shape[-1]
                window = x.permute(1, 0, 2)
                elem = window[-1, :, :].view(1, local_bs, feats)
                feature, logits, others = model_current(window, elem)
                l = nn.MSELoss(reduction='none')
                n = local_epoch + 1
                z = (others['x1'], others['x2'])
                l1 = l(z, elem) if not isinstance(z, tuple) else (1 / n) * l(z[0], elem) + (1 - 1 / n) * l(
                        z[1],
                        elem
                )
                if isinstance(z, tuple): z = z[1]
                l1s.append(torch.mean(l1).item())
                loss = torch.mean(l1)
                loss.backward()
                optimizer.step()
                state_dict_current = model_current.state_dict()
                lr = optimizer.state_dict()['param_groups'][-1]['lr']
                for key in state_dict_current:
                    # if not state_dict_current[key].requires_grad:
                    #     continue
                    if key == 'pos_encoder.pe':
                        continue
                    c_global = global_grad_correct[key].to(device)
                    c_local = self.grad_correct[key].to(device)
                    state_dict_current[key] -= lr * (c_global - c_local)
                model_current.load_state_dict(state_dict_current)

            batch_loss.append(loss.item())
            epoch_loss.append(sum(batch_loss) / len(batch_loss))
            print(
                            f'| Global Round : {global_round} | Local Epoch : {local_epoch} |' +
                            f' Training Loss: {loss.item():.6f}\t'
                            # + f' Training Accuracy: {train_acc[-1]:.6f}'
                    )
            if scheduler is not None:
                scheduler.step()
        mean_loss = sum(epoch_loss) / len(epoch_loss)
        model_current.cpu()
        self.state_dict_prev = model_current.state_dict()
        c_delta_local = None
        return float(mean_loss), None, c_delta_local



In [187]:
def generate_clients(datasets: List[Dataset]) -> List[Client]:
    clients = []
    for dataset in datasets:
        clients.append(Client(dataset))
    return clients

In [188]:
import torch.utils.data as data
import pandas as pd
from torch.distributions import Dirichlet

current_path = os.getcwd()
data_dir = current_path + '/data/datasets/psm/raw'
train_path = current_path + '/data/datasets/psm/raw/train'
test_path = current_path + '/data/datasets/psm/raw/test'
test_labels_path = current_path + '/data/datasets/psm/raw/test_label'

num_clients = 10
beta = 0.5

def generate_data_nums(num_client, num_data, beta=0.5):
    while True:
        data_num_each_client = Dirichlet(torch.tensor([beta] * num_client)).sample()
        data_num_each_client = torch.floor(num_data * data_num_each_client)
        data_num_each_client[-1] = num_data - torch.sum(data_num_each_client[:-1])
        if not (0 in data_num_each_client):
            break
    return data_num_each_client

class PSM_truncated(data.Dataset):

    def __init__(self, root, dataidxs=None, train=True, transform=None, target_transform=None, download=False, window_len=10):

        self.root = root
        self.dataidxs = dataidxs
        self.train = train
        self.transform = transform
        self.target_transform = target_transform
        self.download = download
        self.window_len = window_len

        global scalers
        self.scalers = scalers

        self.data, self.target = self.__build_truncated_dataset__()

    def __build_truncated_dataset__(self):

        # current_path = os.getcwd()
        current_path = 'E:\\pythonProject\\FedTADBench-main'

        if self.train:
            train_path = current_path + '\\data\\datasets\\psm\\raw'
            data = []
            this_data = pd.read_csv(train_path + '/train.csv')
            this_data.drop(columns=[r'timestamp_(min)'], inplace=True)
            data_length = this_data.values.shape[0]
            if beta >= 10000:
                each_length = data_length // num_clients
                this_data_values = this_data.values.astype(np.float32)
                this_data_values = np.nan_to_num(this_data_values)
                this_data_values = self.scalers[0].fit_transform(this_data_values)
                for i in range(num_clients):
                    data.append(this_data_values[i * each_length: (i + 1) * each_length])
            else:
                lengths = generate_data_nums(num_client=num_clients, num_data=data_length, beta=beta)
                lengths = lengths.detach().cpu().numpy()
                lengths = lengths.astype(int)
                lengths = lengths.tolist()
                start = 0
                this_data_values = this_data.values.astype(np.float32)
                this_data_values = np.nan_to_num(this_data_values)
                this_data_values = self.scalers[0].fit_transform(this_data_values)
                for li in range(len(lengths)):
                    l = lengths[li]
                    # if start + l <= this_data_values.shape[0] - 1:
                    if start + l <= this_data_values.shape[0]:
                        data.append(this_data_values[start: start + l])
                    else:
                        data.append(this_data_values[start:])
                    start += l
            target = data.copy()
        else:
            test_path = current_path + '\\data\\datasets\\psm\\raw'
            this_data = pd.read_csv(test_path + '/test.csv')
            this_data.drop(columns=[r'timestamp_(min)'], inplace=True)
            data = this_data.values
            data = data.astype(np.float32)
            data = np.nan_to_num(data)
            data = self.scalers[0].transform(data)
            data_length = this_data.values.shape[0]
            test_target_path = current_path + '\\data\\datasets\\psm\\raw\\test_label.csv'
            # print(test_target_path)
            target_csv = pd.read_csv(test_target_path)
            target_csv.drop(columns=[r'timestamp_(min)'], inplace=True)
            target = target_csv.values
            target = target.astype(np.float32)

        if self.dataidxs is not None:
            data = data[self.dataidxs]
            target = target[self.dataidxs]

        return data, target

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        if index + 1 - self.window_len < 0:
            data = self.data[0: index + 1]
            delta = self.window_len - data.shape[0]
            data0 = np.repeat(data[0][np.newaxis, :], delta, axis=0)
            data = np.concatenate((data0, data), axis=0)
        else:
            data = self.data[index + 1 - self.window_len: index + 1]
        target = self.target[index]

        return data, target

    def __len__(self):
        return len(self.data)

def get_dataset(dataset, datadir, dataidxs=None, noise_level=0):
    train_ds = PSM_truncated(datadir, dataidxs=dataidxs, train=True, transform=None, download=True)
    test_ds = PSM_truncated(datadir, train=False, transform=None, download=True)

    return train_ds, test_ds

def psm_noniid():
    train_ds_locals, test_ds_locals = [None] * num_clients, [None] * num_clients
    chosen_idxes = [i for i in range(num_clients)]
    for i in range(len(chosen_idxes)):
        dataidxs = chosen_idxes[i]
        train_ds_locals[i], test_ds_locals[i] = get_dataset(
            "psm", data_dir, dataidxs
        )
    return train_ds_locals

client_datasets_non_iid = psm_noniid()

In [189]:
epochs = 3
client_rate = 0.2

clients = generate_clients(client_datasets_non_iid)

model = TranAD(25).cpu()
global_state_dict = model.state_dict()
global_correct = get_init_grad_correct(TranAD(25).cpu())

for global_round in tqdm(range(epochs), file=sys.stdout):
    print(f'\n | Global Training Round : {global_round + 1} |\n')

    num_active_client = int((len(clients) * client_rate))
    print(num_active_client)
    ind_active_clients = np.random.choice(range(len(clients)), num_active_client, replace=False)
    active_clients = [clients[i] for i in ind_active_clients]

    # endregion

    active_state_dict = []
    data_nums = []
    train_accuracies = []
    train_losses = []
    grad_correct_deltas = []
    client_times = []
    for client in active_clients:
        client_start = time.time()
        data_nums.append(len(client.dataset))
        loss, accuracy, grad_correct_delta = client.local_train(
            global_state_dict,
            global_round,
            global_correct,
            )
        client_times.append(time.time() - client_start)
        grad_correct_deltas.append(grad_correct_delta)

        train_losses.append(loss)
        active_state_dict.append(client.state_dict_prev)
#
#     # end region
#
#     this_time = max(client_times)
#     time_start = time.time()
#     fed_freq = torch.tensor(data_nums, dtype=torch.float) / sum(data_nums)
#     global_state_dict = average_weights(active_state_dict, fed_freq)

  0%|          | 0/3 [00:00<?, ?it/s]
 | Global Training Round : 1 |

2
  0%|          | 0/3 [00:00<?, ?it/s]


AttributeError: '_IncompatibleKeys' object has no attribute 'requires_grad_'

In [7]:
import numpy as np

from datasets.MOON_util import partition_data, get_dataset
import os

usage: ipykernel_launcher.py [-h] [--device DEVICE]
                             [--num_workers NUM_WORKERS]
                             [--save_every SAVE_EVERY] [--verbose VERBOSE]
                             [--seed SEED] [--iid IID] --alg
                             {fedavg,fedprox,moon,scaffold,Elastic,Hyper}
                             --dataset {smd,smap,psm,swat,skab} --tsadalg
                             {gdn,deep_svdd,usad,tran_ad,lstm_ae,transformer,itransformer}
                             [--num_clients NUM_CLIENTS]
                             [--slide_win SLIDE_WIN]
                             [--client_rate CLIENT_RATE] [--beta BETA]
                             [--mu MU] [--tau TAU]
ipykernel_launcher.py: error: the following arguments are required: --alg, --dataset, --tsadalg


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [10]:
from datasets.MOON_util import partition_data, get_dataset
import os

usage: ipykernel_launcher.py [-h] [--device DEVICE]
                             [--num_workers NUM_WORKERS]
                             [--save_every SAVE_EVERY] [--verbose VERBOSE]
                             [--seed SEED] [--iid IID] --alg
                             {fedavg,fedprox,moon,scaffold,Elastic,Hyper}
                             --dataset {smd,smap,psm,swat,skab} --tsadalg
                             {gdn,deep_svdd,usad,tran_ad,lstm_ae,transformer,itransformer}
                             [--num_clients NUM_CLIENTS]
                             [--slide_win SLIDE_WIN]
                             [--client_rate CLIENT_RATE] [--beta BETA]
                             [--mu MU] [--tau TAU]
ipykernel_launcher.py: error: the following arguments are required: --alg, --dataset, --tsadalg


SystemExit: 2

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
