In [None]:
import os
import copy
import time
import pickle
import numpy as np
from tqdm import tqdm

import torch
from tensorboardX import SummaryWriter

from options import args_parser
from update import LocalUpdate, test_inference, ASRLocalUpdate
from models import MLP, CNNMnist, CNNFashion_Mnist, CNNCifar, Data2VecAudioForCTC, DataCollatorCTCWithPadding
from utils import get_dataset, average_weights, exp_details

from transformers import Data2VecAudioConfig, Wav2Vec2Processor
from multiprocessing import Pool
from collections import OrderedDict

In [None]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import argparse
import multiprocessing as mp


parser = argparse.ArgumentParser()

# federated arguments (Notation for the arguments followed from paper)
parser.add_argument('--epochs', type=int, default=2,
                    help="number of rounds of training")
parser.add_argument('--num_users', type=int, default=2,
                    help="number of users: K")
parser.add_argument('--frac', type=float, default=1.0,
                    help='the fraction of clients: C')
parser.add_argument('--local_ep', type=int, default=1,
                    help="the number of local epochs: E")
#parser.add_argument('--local_bs', type=int, default=1,
#                    help="local batch size: B")
#parser.add_argument('--lr', type=float, default=0.01,
#                    help='learning rate')
#parser.add_argument('--momentum', type=float, default=0.5,
#                    help='SGD momentum (default: 0.5)')

# model arguments
parser.add_argument('--model', type=str, default='data2vec', help='model name')
#parser.add_argument('--kernel_num', type=int, default=9,
#                    help='number of each kind of kernel')
#parser.add_argument('--kernel_sizes', type=str, default='3,4,5',
#                    help='comma-separated kernel size to \
#                    use for convolution')
#parser.add_argument('--num_channels', type=int, default=1, help="number \
#                    of channels of imgs")
#parser.add_argument('--norm', type=str, default='batch_norm',
#                    help="batch_norm, layer_norm, or None")
#parser.add_argument('--num_filters', type=int, default=32,
#                    help="number of filters for conv nets -- 32 for \
#                    mini-imagenet, 64 for omiglot.")
#parser.add_argument('--max_pool', type=str, default='True',
#                    help="Whether use max pooling rather than \
#                    strided convolutions")

# other arguments
parser.add_argument('--dataset', type=str, default='adress', help="name \
                    of dataset") #cifar
#parser.add_argument('--num_classes', type=int, default=10, help="number \
#                    of classes")
parser.add_argument('--gpu', default=1, help="To use cuda, set \
                    to a specific GPU ID. Default set to use CPU.")
#parser.add_argument('--optimizer', type=str, default='sgd', help="type \
#                    of optimizer")
#parser.add_argument('--iid', type=int, default=1,
#                    help='Default set to IID. Set to 0 for non-IID.')
#parser.add_argument('--unequal', type=int, default=0,
#                    help='whether to use unequal data splits for  \
#                    non-i.i.d setting (use 0 for equal splits)')
#parser.add_argument('--stopping_rounds', type=int, default=10,
#                    help='rounds of early stopping')
#parser.add_argument('--verbose', type=int, default=1, help='verbose')
#parser.add_argument('--seed', type=int, default=1, help='random seed')

# additional arguments
parser.add_argument('--pretrain_name', type=str, default='facebook/data2vec-audio-large-960h', help="str used to load pretrain model")
parser.add_argument('-lam', '--LAMBDA', type=float, default=0.5, help="Lambda for GRL")
parser.add_argument('-st', '--STAGE', type=int, default=2, help="Current training stage")
parser.add_argument('-GRL', '--GRL', action='store_true', default=False, help="True: GRL")
parser.add_argument('-model_in', '--model_in_path', type=str, default="/mnt/Internal/FedASR/weitung/HuggingFace/Pretrain/saves/data2vec-audio-large-960h_new1_recall/final/", help="Where the model is saved")
parser.add_argument('-model_out', '--model_out_path', type=str, default="./save/data2vec-audio-large-960h_new2_recall_FL", help="Where to save the model")
parser.add_argument('-log', '--log_path', type=str, default="data2vec-audio-large-960h_new2_recall_FL.txt", help="name for the txt file")
# 2023/01/08: loss type
parser.add_argument('-ad_loss', '--AD_loss', type=str, default="recall", help="loss to use for AD classifier")
# 2023/01/18: ckpt
parser.add_argument('-ckpt', '--checkpoint', type=str, default=None, help="path to checkpoint")
# 2023/02/13: TOGGLE_RATIO
parser.add_argument('-toggle_rt', '--TOGGLE_RATIO', type=float, default=0, help="To toggle more or less")
# 2023/02/15: GS_TAU, loss weight
parser.add_argument('-gs_tau', '--GS_TAU', type=float, default=1, help="Tau for gumbel_softmax")
parser.add_argument('-w_loss', '--W_LOSS', type=float, default=None, nargs='+', help="weight for HC and AD")

args = parser.parse_args(args=[]) # for jupyter notebook


# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" # 或者其他你想要使用的 GPU 編號
# 創建一個 Lock 對象
lock = mp.Lock()

In [None]:

def client_train(args, train_dataset, 
                 test_dataset, idx, epoch, global_weights=None):                    # train function for each client
    print(" process PID", os.getpid(), " running")
    
    # BUILD MODEL for every process
    if args.model == 'data2vec':
        mask_time_prob = 0                                                          # change config to avoid training stopping
        config = Data2VecAudioConfig.from_pretrained(args.pretrain_name, mask_time_prob=mask_time_prob)
        print("load from ", args.model_in_path)
        lock.acquire()
        # print(" process PID", os.getpid(),"enter critical section")
        # model = Data2VecAudioForCTC.from_pretrained(args.model_in_path, config=config, args=args)
        # lock.release()
        # print(" process PID", os.getpid(),"exit critical section")
        try:
            # 在這裡執行 critical section 的程式碼
            print(" process PID", os.getpid(),"enter critical section")
            model = Data2VecAudioForCTC.from_pretrained(args.model_in_path, config=config, args=args)
            pass
        finally:
            # 在離開 critical section 後釋放鎖
            lock.release()
            print(" process PID", os.getpid(),"exit critical section")
        print("model loaded")                                                       # load/initialize global model
        model.config.ctc_zero_infinity = True                                       # to avoid inf values

        global_model = copy.deepcopy(model.arbitrator)                              # only has global toggling network
        if global_weights != None:                                                  # if given global_weights
            global_model.load_state_dict(global_weights)                            # load it
        #else:
        #    # copy weights
        #    global_weights = copy.deepcopy(global_model.state_dict())                       # save global weight
        processor = Wav2Vec2Processor.from_pretrained(args.pretrain_name)
        data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
    else:
        exit('Error: unrecognized model')

    # Set the model to train and send it to device.
    device = 'cuda' if args.gpu else 'cpu'
    global_model.to(device)
    global_model.train()
    #print(global_model)

    ####################
    # 'use client_id generate sub-dataset' to be done
    ####################
    #print("call ASRLocalUpdate")
    local_model = ASRLocalUpdate(args=args, dataset=train_dataset, logger=logger,
                        data_collator=data_collator, global_test_dataset=test_dataset, 
                        processor=processor, client_id=idx)
                                                                                    # initial dataset of current client
    ####################
    # 'use client_id load local model' to be done
    # 'save model in final round' to be done
    ####################
    #print("perform update_weight")
    w, loss = local_model.update_weights(
        global_arbitrator=copy.deepcopy(global_model), global_round=epoch)          # from global model to train
    
    #send_end.send([w, loss])                                                        # save model weights and average round loss
    #return_dict[str(idx)] = [w, loss]
    print("PID {} Getting ".format(os.getpid()), "Done")
    return w, loss

In [None]:
def client_get_weight(args, idx, epoch):                               # function to get weight for each client
    print(" process PID", os.getpid(), " running")
    # model saved in args.model_out_path + "_" + str(idx) + "/final_" + str(epoch)
    #model_path = args.model_out_path + "_" + str(idx) + "/final_" + str(epoch)
    #model_path = args.model_out_path + "_client" + str(idx) + 
    #"_round" + str(global_round) + "/final"
    model_path = args.model_in_path
    # BUILD MODEL for every process
    if args.model == 'data2vec':
        mask_time_prob = 0                                                          # change config to avoid training stopping
        config = Data2VecAudioConfig.from_pretrained(args.pretrain_name, mask_time_prob=mask_time_prob)
        model = Data2VecAudioForCTC.from_pretrained(model_path, config=config, args=args)
                                                                                    # load local model
        print("model loaded")
        model.config.ctc_zero_infinity = True                                       # to avoid inf values
 
        arbitrator = copy.deepcopy(model.arbitrator)                                # return weight for toggling network only

        return_weights = copy.deepcopy(arbitrator.state_dict())                       # save global weight
    else:
        exit('Error: unrecognized model')

    #return_dict[str(idx)] = return_weights
    print("PID {} Getting ".format(os.getpid()), "Done")
    return return_weights, 0.05


In [None]:
#start_time = time.time()

# define paths
#path_project = os.path.abspath('..')
logger = SummaryWriter('../logs')

#args = args_parser()
exp_details(args) # print out details based on configuration

In [None]:
#if args.gpu_id:
#    torch.cuda.set_device(args.gpu_id)
#device = 'cuda' if args.gpu else 'cpu'

# load dataset and user groups
train_dataset, test_dataset, user_groups = get_dataset(args)

# ID相關code先跳過

In [None]:
user_IDs = train_dataset.map(lambda x: {"user_IDs": x["path"].split("_")[0]})
user_IDs

In [None]:
mylist = user_IDs["user_IDs"]
mylist = list(dict.fromkeys(mylist))
print(mylist)

In [None]:
len(mylist)

In [None]:
# generate sub- training set for given user-ID
start_with_ar = train_dataset.filter(lambda example: (example["path"].startswith("S055")) or (example["path"].startswith("S094")))
start_with_ar["path"]

In [None]:
def cifar_iid(dataset, num_users):
    """
    Sample I.I.D. client data from CIFAR10 dataset
    :param dataset:
    :param num_users:
    :return: dict of image index
    """
    num_items = int(len(dataset)/num_users)
    dict_users, all_idxs = {}, [i for i in range(len(dataset))]
    for i in range(num_users):
        dict_users[i] = set(np.random.choice(all_idxs, num_items,
                                             replace=False))
        all_idxs = list(set(all_idxs) - dict_users[i])
    return dict_users
def cifar_noniid(dataset, num_users):
    """
    Sample non-I.I.D client data from CIFAR10 dataset
    :param dataset:
    :param num_users:
    :return:
    """
    num_shards, num_imgs = 200, 250
    idx_shard = [i for i in range(num_shards)]
    dict_users = {i: np.array([]) for i in range(num_users)}
    idxs = np.arange(num_shards*num_imgs)
    # labels = dataset.train_labels.numpy()
    labels = np.array(dataset.train_labels)

    # sort labels
    idxs_labels = np.vstack((idxs, labels))
    idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
    idxs = idxs_labels[0, :]

    # divide and assign
    for i in range(num_users):
        rand_set = set(np.random.choice(idx_shard, 2, replace=False))
        idx_shard = list(set(idx_shard) - rand_set)
        for rand in rand_set:
            dict_users[i] = np.concatenate(
                (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
    return dict_users

# 下面開始繼續跑

### [參考](https://stackoverflow.com/questions/10415028/how-to-get-the-return-value-of-a-function-passed-to-multiprocessing-process): 不同multi-process寫法

In [None]:
# multi-process ver1
from multiprocessing import Pool
"""
def Get_phonationdictbag_map(parameters):
    args, train_dataset, logger, test_dataset, idx, epoch, global_weights = parameters
    print(" process PID", os.getpid(), " running")
    for file in files:
        ///
        your code
        ///
    print("PID {} Getting ".format(os.getpid()), "Done")
client_train(args, train_dataset, logger, 
                 test_dataset, idx, epoch, global_weights=None)
"""
#interval=20
#parameters_lst = []
#for i in range(args.num_users):                       # for each client
    
#    keys.append(files[i:i+interval])
#flat_keys=[item for sublist in keys for item in sublist]
#assert len(flat_keys) == len(files)

#final_result = pool.starmap(Get_phonationdictbag_map, [([file_block])    for file_block in tqdm(keys)])
#client_get_weight(args, idx, epoch)
pool = Pool(int(os.cpu_count()))
epoch = 0
final_result = pool.starmap(client_get_weight, [(args, idx, epoch) for idx in range(args.num_users)])

In [None]:
final_result[1]
aaa=ccc

### 只跑一個round

In [None]:
# Training
train_loss, test_wer = [], []
val_acc_list, net_list = [], []
cv_loss, cv_acc = [], []
print_every = 2
val_loss_pre, counter = 0, 0
global_weights = None                                                           # initial global_weights
epoch = 0
#for epoch in tqdm(range(args.epochs)):                                          # train for given global rounds
#local_weights, local_losses = [], []                                        # weights and losses of training clients of this round
print(f'\n | Global Training Round : {epoch+1} |\n')                        # print current round

m = max(int(args.frac * args.num_users), 1)                                 # num of clients to train, min:1
idxs_users = np.random.choice(range(args.num_users), m, replace=False)      # select by client_id

pool = Pool(4)
# pool = Pool(int(os.cpu_count()))
import concurrent.futures

with concurrent.futures.ProcessPoolExecutor() as executor:
    try:
        final_result = pool.starmap(client_train, [(args, train_dataset, 
                        test_dataset, idx, epoch, global_weights) for idx in idxs_users])
    except Exception as e:
            # 处理子进程中的异常
            print(f"Exception in subprocess: {e}")



In [None]:
final_result

In [None]:
local_weights = []
local_losses = []
for idx in range(len(final_result)):
    w, loss = final_result[idx]
    local_weights.append(w)
    local_losses.append(loss)

In [None]:
global_weights = average_weights(local_weights)
global_weights

In [None]:
# pool.close()
# pool.terminate()

epoch = 1
#for epoch in tqdm(range(args.epochs)):                                          # train for given global rounds
#local_weights, local_losses = [], []                                        # weights and losses of training clients of this round
print(f'\n | Global Training Round : {epoch+1} |\n')                        # print current round

m = max(int(args.frac * args.num_users), 1)                                 # num of clients to train, min:1
idxs_users = np.random.choice(range(args.num_users), m, replace=False)      # select by client_id

pool = Pool(4)
final_result = pool.starmap(client_train, [(args, train_dataset, 
                 test_dataset, idx, epoch, global_weights) for idx in idxs_users])

In [None]:
def client_train_mod(args, train_dataset, 
                        test_dataset, idx, epoch, global_weights=None):
    # ...

    # 客戶端訓練的代碼
    local_weights, local_loss = client_train(args, train_dataset, 
                        test_dataset, idx, epoch, global_weights)

    # 返回本地權重和損失
    return local_weights, local_loss

### 跑多個round

In [17]:
train_loss, test_wer = [], []
val_acc_list, net_list = [], []
cv_loss, cv_acc = [], []
print_every = 2
val_loss_pre, counter = 0, 0
global_weights = None                                                           # initial global_weights
# 創建進程間共享的 Queue 對象
import multiprocessing as mp
# manager = mp.Manager()
# result_queue = manager.Queue()
for epoch in tqdm(range(args.epochs)):                                          # train for given global rounds
    #local_weights, local_losses = [], []                                        # weights and losses of training clients of this round
    print(f'\n | Global Training Round : {epoch+1} |\n')                        # print current round

    m = max(int(args.frac * args.num_users), 1)                                 # num of clients to train, min:1
    idxs_users = np.random.choice(range(args.num_users), m, replace=False)      # select by client_id

    pool = Pool(m)
    final_result = pool.starmap(client_train, [(args, train_dataset, 
                        test_dataset, idx, epoch, global_weights) for idx in idxs_users])

    local_weights = []
    local_losses = []
    # while not result_queue.empty():
    #     w, loss = result_queue.get()
    #     local_weights.append(w)
    #     local_losses.append(loss)
    # 舊的
    for idx in range(len(final_result)):
        w, loss = final_result[idx]
        local_weights.append(w)
        local_losses.append(loss)

    print("local weights: ", local_weights)
    # get global weights by averaging local weights
    global_weights = average_weights(local_weights)
    print("global wegiths: ", global_weights)

    # update global weights
    #global_model.load_state_dict(global_weights)

    loss_avg = sum(local_losses) / len(local_losses)                # average losses from participated client
    train_loss.append(loss_avg)                                     # save loss for this round

  0%|          | 0/2 [00:00<?, ?it/s]


 | Global Training Round : 1 |

 process PID process PID  268233268232   running running

load from  load from /mnt/Internal/FedASR/weitung/HuggingFace/Pretrain/saves/data2vec-audio-large-960h_new1_recall/final/ 
/mnt/Internal/FedASR/weitung/HuggingFace/Pretrain/saves/data2vec-audio-large-960h_new1_recall/final/ process PID
 268233 enter critical section
lambda =  tensor(0.5000)
Current stage: 2


  0%|          | 0/2 [07:52<?, ?it/s]


KeyboardInterrupt: 

In [None]:
for idx in range(len(final_result)):
    w, loss = final_result[idx]
    local_weights.append(w)
    local_losses.append(loss)

print("local weights: ", local_weights)
# get global weights by averaging local weights
global_weights = average_weights(local_weights)
print("global wegiths: ", global_weights)

# update global weights
#global_model.load_state_dict(global_weights)

loss_avg = sum(local_losses) / len(local_losses)                # average losses from participated client
train_loss.append(loss_avg)                                     # save loss for this round

## 以下其他

In [None]:
# Training
train_loss, test_wer = [], []
val_acc_list, net_list = [], []
cv_loss, cv_acc = [], []
print_every = 2
val_loss_pre, counter = 0, 0
global_weights = None    
for epoch in range(2):
    #for epoch in tqdm(range(args.epochs)):                                          # train for given global rounds
    #local_weights, local_losses = [], []                                        # weights and losses of training clients of this round
    print(f'\n | Global Training Round : {epoch+1} |\n')                        # print current round

    m = max(int(args.frac * args.num_users), 1)                                 # num of clients to train, min:1
    idxs_users = np.random.choice(range(args.num_users), m, replace=False)      # select by client_id

    manager = multiprocessing.Manager()
    return_dict = manager.dict()
    jobs = []
    #pipe_list = []
    for idx in idxs_users:                                                      # for each training client

        print("start of client #", idx)
        #recv_end, send_end = multiprocessing.Pipe(False)
        p = multiprocessing.Process(target=client_train, args=(return_dict,
                args, train_dataset, logger, test_dataset, idx, epoch, global_weights))

        #p = multiprocessing.Process(target=client_get_weight, args=(return_dict, args, idx, epoch))
        jobs.append(p)
        #pipe_list.append(recv_end)
        p.start()

        #local_weights.append(copy.deepcopy(w))                      # save weight for this client
        #local_losses.append(copy.deepcopy(loss))                    # save loss for this client
    for proc in jobs:
        proc.join()
        #proc.close()

In [None]:
print(return_dict.keys())

In [None]:
# 
for idx in idxs_users:                                                      # for each training client
    # get model weights from saved model

In [None]:
local_weights = []
local_losses = []
for key in return_dict.keys():
    w, loss = return_dict[key]
    local_weights.append(w)
    local_losses.append(loss)

In [None]:
global_weights = average_weights(local_weights)
global_weights

In [None]:
if args.model == 'data2vec':
    mask_time_prob = 0                                                          # change config to avoid training stopping
    config = Data2VecAudioConfig.from_pretrained(args.pretrain_name, mask_time_prob=mask_time_prob)
    model = Data2VecAudioForCTC.from_pretrained(args.model_in_path, config=config, args=args)
                                                                                # load/initialize global model
    model.config.ctc_zero_infinity = True                                       # to avoid inf values

    global_model = copy.deepcopy(model.arbitrator)                              # only has global toggling network
    if global_weights != None:                                                  # if given global_weights
        global_model.load_state_dict(global_weights)                            # load it
    #else:
    #    # copy weights
    #    global_weights = copy.deepcopy(global_model.state_dict())                       # save global weight
    processor = Wav2Vec2Processor.from_pretrained(args.pretrain_name)
    data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
else:
    exit('Error: unrecognized model')

# Set the model to train and send it to device.
global_model.to(device)
global_model.train()

In [None]:
w = global_model.state_dict()
print(w)

In [None]:
from collections import OrderedDict
w_ret = OrderedDict()
w_ret['weight'] = w['weight'].numpy()
w_ret['bias'] = w['bias'].numpy()
print(w_ret)

In [None]:
print(w.cpu())
#copy.deepcopy(w.detach().cpu())

In [None]:
global_model.state_dict()

In [None]:
manager = multiprocessing.Manager()
return_dict = manager.dict()

In [None]:
return_dict[2] = [global_model.state_dict(), 0.88]

In [None]:
return_dict.values()

# 以下尚未確認

In [None]:
local_weights

In [None]:
# get global weights by averaging local weights
global_weights = average_weights(local_weights)
global_weights

In [None]:
# update global weights
global_model.load_state_dict(global_weights)

loss_avg = sum(local_losses) / len(local_losses)                # average losses from participated client
train_loss.append(loss_avg)     
train_loss

In [None]:
list_wer = []
global_model.eval()
for c in range(args.num_users):                                 # for ALL users
    local_model = ASRLocalUpdate(args=args, dataset=train_dataset, logger=logger,
                        data_collator=data_collator, global_test_dataset=test_dataset, 
                        processor=processor)
                                                                # initial dataset of current client
    wer = local_model.inference(global_arbitrator=global_model)       # get acc. & total loss on clients' test set
    list_wer.append(wer)                                        # save acc.
    #list_loss.append(loss)                                      # save loss
train_accuracy.append(sum(list_wer)/len(list_wer))              # acc average over all clients

# 以下尚未確認

In [None]:


# print global training loss after every 'i' rounds
if (1+1) % print_every == 0:
    print(f' \nAvg Training Stats after {1+1} global rounds:')
    print(f'Training Loss : {np.mean(np.array(train_loss))}')
    print('Train Accuracy: {:.2f}% \n'.format(100*train_accuracy[-1])) # on testing set of clients though

In [None]:
from transformers.training_args import TrainingArguments
from transformers import Trainer
from typing import Any, Dict, List, Optional, Union
import json

LOG_DIR = './'#log/'
from datasets import load_metric
wer_metric = load_metric("wer")
def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

class CustomTrainer(Trainer):    
    def compute_loss(self, model, inputs, return_outputs=False):
            """
            How the loss is computed by Trainer. By default, all models return the loss in the first element.
            Subclass and override for custom behavior.
            """
            #dementia_labels = inputs.pop("dementia_labels") # pop 出來就會不見?
            
            if self.label_smoother is not None and "labels" in inputs:
                labels = inputs.pop("labels")
            else:
                labels = None
            
            outputs = model(**inputs)
            # Save past state if it exists
            # TODO: this needs to be fixed and made cleaner later.
            if self.args.past_index >= 0:
                self._past = outputs[self.args.past_index]

            if labels is not None:
                loss = self.label_smoother(outputs, labels)
            else:
                # We don't use .loss here since the model may return tuples instead of ModelOutput.
                loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]

            return (loss, outputs) if return_outputs else loss
    def log(self, logs: Dict[str, float]) -> None:
        """
        Log `logs` on the various objects watching training.
        Subclass and override this method to inject custom behavior.
        Args:
            logs (`Dict[str, float]`):
                The values to log.
        """
        if self.state.epoch is not None:
            logs["epoch"] = round(self.state.epoch, 2)

        output = {**logs, **{"step": self.state.global_step}}
        self.state.log_history.append(output)
        
        # write to txt file
        file_object = open(LOG_DIR + args.log_path, 'a')
        # Append at the end of file
        file_object.write(json.dumps(output) + '\n')
        # Close the file
        file_object.close()

        self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)


training_args = TrainingArguments(
    output_dir=args.model_out_path + "just_for_testing",
    group_by_length=True,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    evaluation_strategy="steps",
    num_train_epochs=2,
    fp16=True,
    gradient_checkpointing=True, 
    save_steps=500,
    eval_steps=500,
    logging_steps=500,
    learning_rate=1e-4,
    weight_decay=0.005,
    warmup_steps=1000,
    save_total_limit=2,
    log_level='debug',
    logging_strategy="steps",
    #adafactor=True,            # default:false. Whether or not to use transformers.Adafactor optimizer instead of transformers.AdamW
    #fp16_full_eval=True,      # to save memory
    #max_grad_norm=0.5
)
trainer = CustomTrainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=processor.feature_extractor,
)
trainer.train()

In [None]:
trainer.model.arbitrator.state_dict()

In [None]:
trainer.state.log_history

In [None]:
trainer.state.log_history[-1]["train_loss"]

In [None]:
for epoch in tqdm(range(args.epochs)):
    local_weights, local_losses = [], []
    print(f'\n | Global Training Round : {epoch+1} |\n') # epoch = round

    global_model.train()
    m = max(int(args.frac * args.num_users), 1) # 取部分client來train，至少一個client
    idxs_users = np.random.choice(range(args.num_users), m, replace=False) # 選client_id

    for idx in idxs_users: # 每一位選上的client
        # 做training?
        local_model = LocalUpdate(args=args, dataset=train_dataset,
                                    idxs=user_groups[idx], logger=logger)
        # 得到這次的weights + loss
        w, loss = local_model.update_weights(
            model=copy.deepcopy(global_model), global_round=epoch)
        local_weights.append(copy.deepcopy(w))
        local_losses.append(copy.deepcopy(loss))

    # compute global weights
    global_weights = average_weights(local_weights)

    # update global weights
    global_model.load_state_dict(global_weights)

    loss_avg = sum(local_losses) / len(local_losses)
    train_loss.append(loss_avg) # loss average over all clients

    # Calculate avg training accuracy over all users at every epoch
    list_acc, list_loss = [], []
    global_model.eval()
    for c in range(args.num_users): # 所有clients
        # 做training? 應為update weight
        local_model = LocalUpdate(args=args, dataset=train_dataset,
                                    idxs=user_groups[idx], logger=logger)
        # 用global model去測試，得到acc & loss
        acc, loss = local_model.inference(model=global_model)
        list_acc.append(acc)
        list_loss.append(loss)
    train_accuracy.append(sum(list_acc)/len(list_acc)) # acc_avg

    # print global training loss after every 'i' rounds
    if (epoch+1) % print_every == 0:
        print(f' \nAvg Training Stats after {epoch+1} global rounds:')
        print(f'Training Loss : {np.mean(np.array(train_loss))}')
        print('Train Accuracy: {:.2f}% \n'.format(100*train_accuracy[-1]))

# Test inference after completion of training
test_acc, test_loss = test_inference(args, global_model, test_dataset)

print(f' \n Results after {args.epochs} global rounds of training:')
print("|---- Avg Train Accuracy: {:.2f}%".format(100*train_accuracy[-1]))
print("|---- Test Accuracy: {:.2f}%".format(100*test_acc))

# Saving the objects train_loss and train_accuracy:
file_name = '../save/objects/{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}].pkl'.\
    format(args.dataset, args.model, args.epochs, args.frac, args.iid,
            args.local_ep, args.local_bs)

# save 失敗
with open(file_name, 'wb') as f:
    pickle.dump([train_loss, train_accuracy], f)

print('\n Total Run Time: {0:0.4f}'.format(time.time()-start_time))