In [1]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import os
import copy
import time
import numpy as np
from tqdm import tqdm

# from options import args_parser
from utils import get_dataset, average_weights, exp_details

import multiprocessing
from update import update_network_weight, get_model_weight

from training import client_train, centralized_training
from update import ASRLocalUpdate

def FL_training_rounds(args, model_in_path_root, model_out_path, train_dataset, test_dataset):
    train_loss = []                                                                 # list for training loss
    global_weights = None                                                           # initial global_weights

    
    for epoch in tqdm(range(args.epochs)):                                          # train for given global rounds
        print(f'\n | Global Training Round : {epoch+1} |\n')                        # print current round

        m = max(int(args.frac * args.num_users), 1)                                 # num of clients to train, min:1
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)      # select by client_id
        pool = multiprocessing.Pool(processes=m)

        if args.STAGE == 0:                                                         # train ASR
            local_weights_en = []                                                   # weight list for ASR encoder
            local_weights_de = []                                                   # weight list for ASR decoder
        else:                                                                       # train AD classifier or toggling network
            local_weights = []                                                      # only 1 weight list needed
        local_losses = []                                                           # losses of training clients of this round

        try:
            if (epoch == 0) and (args.STAGE == 2):                                  # start from global model to train toggling network
                global_weights = get_model_weight(args=args, source_path=model_out_path + "_global/final/", network="toggling_network")
                                                                                    # local ASR and AD with global toggling network
                                                                                    # get toggling_network weights from model in model_out_path + "_global/final/"
            final_result = pool.starmap_async(client_train, [(args, model_in_path_root, model_out_path, train_dataset, test_dataset, idx,
                                                                  epoch, global_weights) for idx in idxs_users])
                                                                                    # train from model in model_in_path 
                                                                                    #                                 + "_global/final/", when stage=0
                                                                                    #                                 + "_client" + str(idx) + "_round" + str(args.epochs-1) + "/final/", o.w.
                                                                                    # or model in last round
                                                                                    # final result in model_out_path + "_client" + str(client_id) + "_round" + str(global_round)
        except Exception as e:
            print(f"An error occurred while using starmap_sync to run client_train: {str(e)}")
        
        finally:
            final_result.wait()                                                     # wait for all clients end
            results = final_result.get()                                            # get results
        
        for idx in range(len(results)):                                             # for each participated clients
            w, loss = results[idx]                                                  # function client_train returns w & loss
            if args.STAGE == 0:                                                     # train ASR
                local_weights_en.append(copy.deepcopy(w[0]))                        # save encoder weight for this client
                local_weights_de.append(copy.deepcopy(w[1]))                        # save decoder weight for this client
            else:                                                                   # train AD classifier or toggling network
                local_weights.append(copy.deepcopy(w))
            local_losses.append(loss)

        # aggregate weights
        if args.STAGE == 0:                                                         # train ASR
            global_weights = [average_weights(local_weights_en), average_weights(local_weights_de)]
        else:                                                                       # train AD classifier or toggling network
            global_weights = average_weights(local_weights)

        loss_avg = sum(local_losses) / len(local_losses)                            # average losses from participated client
        train_loss.append(loss_avg)                                                 # save loss for this round
    return global_weights

# FL stage 1: ASR & AD Classifier
def stage1_training(args, train_dataset, test_dataset):
    local_epoch = args.local_ep                                                     # save given number of local epoch
    ##########################################################
    # Centralized Training: train global ASR & AD Classifier #
    ##########################################################
    """
    args.local_ep = args.global_ep                                                  # use number of global epoch for global model
    args.STAGE = 0                                                                  # train ASR first
    centralized_training(args=args, model_in_path=args.pretrain_name, model_out_path=args.model_out_path+"_finetune", 
                         train_dataset=train_dataset, test_dataset=test_dataset, epoch=0)
                                                                                    # train from pretrain, final result in args.model_out_path + "_finetune" + "_global/final"
    args.STAGE = 1                                                                  # then train AD classifier
    centralized_training(args=args, model_in_path=args.model_out_path+"_finetune_global/final/", 
                         model_out_path=args.model_out_path, train_dataset=train_dataset, test_dataset=test_dataset, epoch=0)
                                                                                    # train from final result from last line, final result in args.model_out_path + "_global/final"
    """
    ##########################################################
    # FL: train local ASR & AD Classifier federally          #
    ##########################################################
    args.local_ep = local_epoch                                                     # use the given number of local epoch
    args.STAGE = 0                                                                  # train ASR first
    global_weights = FL_training_rounds(args=args, model_in_path_root=args.model_out_path, model_out_path=args.model_out_path+"_finetune",
                                        train_dataset=train_dataset, test_dataset=test_dataset)

    # update global model
    model = update_network_weight(args=args, source_path=args.model_out_path+"_global/final/", target_weight=global_weights, network="ASR") 
                                                                                    # update ASR in source_path with given weights
    model.save_pretrained(args.model_out_path+"_FLASR_global/final")
    
    args.STAGE = 1                                                                  # then train AD classifier
    global_weights = FL_training_rounds(args=args, model_in_path_root=args.model_out_path+"_finetune", model_out_path=args.model_out_path,
                                        train_dataset=train_dataset, test_dataset=test_dataset)

    # update global model
    model = update_network_weight(args=args, source_path=args.model_out_path+"_FLASR_global/final", target_weight=global_weights, network="AD")
                                                                                    # update AD classifier in source_path with given weights
    model.save_pretrained(args.model_out_path+"_FLAD_global/final")
    
    
# FL stage 2: Toggling Network
def stage2_training(args, train_dataset, test_dataset):
    local_epoch = args.local_ep                                                     # save given number of local epoch
    ##########################################################
    # Centralized Training: train global Toggling Network    #
    ##########################################################
    """
    args.local_ep = args.global_ep                                                  # use number of global epoch for global model
    centralized_training(args=args, model_in_path=args.model_in_path + "_FLAD_global/final/", model_out_path=args.model_out_path, 
                         train_dataset=train_dataset, test_dataset=test_dataset, epoch=0)
                                                                                    # train from model_in_path + "_FLAD_global/final/" (aggregated ASR & AD)
                                                                                    # final result in args.model_out_path + "_global/final"
    """
    ##########################################################
    # FL: train local Toggling Network federally             #
    ##########################################################
    args.local_ep = local_epoch                                                     # use the given number of local epoch
    global_weights = FL_training_rounds(args=args, model_in_path_root=args.model_in_path, model_out_path=args.model_out_path,
                                        train_dataset=train_dataset, test_dataset=test_dataset)
    # update global model
    model = update_network_weight(args=args, source_path=args.model_out_path+"_global/final", target_weight=global_weights, network="toggling_network")
                                                                                    # update toggling_network in source_path with given weights
    model.save_pretrained(args.model_out_path+"_final_global/final")

def extract_emb(args, train_dataset, test_dataset):
    if args.client_id == "public":
        idx = "public"
    else:
        idx = int(args.client_id)
    local_model = ASRLocalUpdate(args=args, dataset=train_dataset, global_test_dataset=test_dataset, 
                                 client_id=idx, model_in_path=args.model_in_path, model_out_path=None)
                                                                                      # initial dataset of current client
    local_model.extract_embs()
                                                                                      # from model_in_path model, update certain part using given weight
# if __name__ == '__main__':
print("I'm here")
start_time = time.time()

# define paths
path_project = os.path.abspath('..')

# args = args_parser()  
import argparse                                                          # get configuration
parser = argparse.ArgumentParser()
# federated arguments (Notation for the arguments followed from paper)
parser.add_argument('--epochs', type=int, default=2,
                    help="number of rounds of training")
parser.add_argument('--num_users', type=int, default=2,
                    help="number of users: K")
parser.add_argument('--frac', type=float, default=1.0,
                    help='the fraction of clients: C')
parser.add_argument('--local_ep', type=int, default=5,
                    help="the number of local epochs: E")
# model arguments
parser.add_argument('--model', type=str, default='data2vec', help='model name')
parser.add_argument('--dataset', type=str, default='adress', help="name \
                    of dataset")
parser.add_argument('--gpu', default=None, help="To use cuda, set \
                    to a specific GPU ID. Default set to use CPU.")
# additional arguments
parser.add_argument('--pretrain_name', type=str, default='facebook/data2vec-audio-large-960h', help="str used to load pretrain model")

parser.add_argument('-lam', '--LAMBDA', type=float, default=0.5, help="Lambda for GRL")
parser.add_argument('-st', '--STAGE', type=int, default=1, help="Current training stage")
parser.add_argument('-fl_st', '--FL_STAGE', type=int, default=1, help="Current FL training stage")
parser.add_argument('-GRL', '--GRL', action='store_true', default=False, help="True: GRL")
parser.add_argument('-model_in', '--model_in_path', type=str, default="/home/FedASR/dacs/federated/save/data2vec-audio-large-960h_new1_recall_client0_round1/", help="Where the global model is saved")
parser.add_argument('-model_out', '--model_out_path', type=str, default="/home/FedASR/dacs/federated/save/data2vec-audio-large-960h_new1_recall", help="Where to save the model")
parser.add_argument('-log', '--log_path', type=str, default="wav2vec2-base-960h_linear_GRL.txt", help="name for the txt file")
parser.add_argument('-csv', '--csv_path', type=str, default="wav2vec2-base-960h_GRL_0.5", help="name for the csv file")
# 2023/01/08: loss type
parser.add_argument('-ad_loss', '--AD_loss', type=str, default="cel", help="loss to use for AD classifier")
# 2023/01/18: ckpt
parser.add_argument('-ckpt', '--checkpoint', type=str, default=None, help="path to checkpoint")
# 2023/02/13: TOGGLE_RATIO
parser.add_argument('-toggle_rt', '--TOGGLE_RATIO', type=float, default=0, help="To toggle more or less")
# 2023/02/15: GS_TAU, loss weight
parser.add_argument('-gs_tau', '--GS_TAU', type=float, default=1, help="Tau for gumbel_softmax")
parser.add_argument('-w_loss', '--W_LOSS', type=float, default=None, nargs='+', help="weight for HC and AD")
# 2023/04/20
parser.add_argument('-EXTRACT', '--EXTRACT', action='store_true', default=False, help="True: extract embs")
parser.add_argument('-client_id', '--client_id', type=str, default="public", help="client_id: public, 0, or 1")
# 2023/04/24
parser.add_argument('--global_ep', type=int, default=30, help="number for global model")

args = parser.parse_args(args=[])
exp_details(args)                                                               # print out details based on configuration

train_dataset, test_dataset = get_dataset(args)                                 # get dataset
# _, test_dataset = get_dataset(args)
# train_dataset=test_dataset #先這樣Debug
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

multiprocessing.set_start_method('spawn', force=True)
if args.EXTRACT != True:                                                        # Training
    if args.FL_STAGE == 1:
        print("| Start FL Training Stage 1|")
        stage1_training(args, train_dataset, test_dataset)                      # Train ASR & AD Classifier
        print("| FL Training Stage 1 Done|")

    elif args.FL_STAGE == 2:
        print("| Start FL Training Stage 2|")
        args.STAGE = 2
        stage2_training(args, train_dataset, test_dataset)                      # Train Toggling Network
        print("| FL Training Stage 2 Done|")
else:
    extract_emb(args, train_dataset, test_dataset)

print('\n Total Run Time: {0:0.4f}'.format(time.time()-start_time))


2023-04-26 17:26:48.270880: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
  wer_metric = load_metric("wer")


I'm here

Experimental details:
    Model     : data2vec
    Global Rounds   : 2

    Current Stage   : 1

    Loss Type       : cel

    Federated parameters:
    Number of users    : 2
    Fraction of users  : 1.0


Loading cached processed dataset at /home/FedASR/dacs/federated/src/dataset/train/cache-7cd3d56ce65492d2_*_of_00010.arrow
Loading cached processed dataset at /home/FedASR/dacs/federated/src/dataset/test/cache-c5c36142e1357d21_*_of_00010.arrow


Load data from local...
Load data from local...
| Start FL Training Stage 1|


  0%|          | 0/2 [00:00<?, ?it/s]


 | Global Training Round : 1 |



2023-04-26 17:26:52.458644: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-04-26 17:26:52.683445: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
  wer_metric = load_metric("wer")
  wer_metric = load_metric("wer")


Generating client training set for client  0 ...
Generating client testing set for client  0 ...
Generating client training set for client  1 ...
Generating client testing set for client  1 ...


Loading cached processed dataset at /home/FedASR/dacs/federated/src/dataset/train/cache-61d2c01fd50e0c4f.arrow
Loading cached processed dataset at /home/FedASR/dacs/federated/src/dataset/test/cache-3266f2669213b7f9.arrow
Loading cached processed dataset at /home/FedASR/dacs/federated/src/dataset/train/cache-e8df470bb0b50834.arrow
Loading cached processed dataset at /home/FedASR/dacs/federated/src/dataset/test/cache-985eb06188176881.arrow


lambda =  tensor(0.5000)
Current stage: 0
lambda =  tensor(0.5000)
Current stage: 0


Using cuda_amp half precision backend
The following columns in the training set don't have a corresponding argument in `Data2VecAudioForCTC.forward` and have been ignored: array, path, text. If array, path, text are not expected by `Data2VecAudioForCTC.forward`,  you can safely ignore this message.


 | Client  1  ready to train! |


Using cuda_amp half precision backend
The following columns in the training set don't have a corresponding argument in `Data2VecAudioForCTC.forward` and have been ignored: text, path, array. If text, path, array are not expected by `Data2VecAudioForCTC.forward`,  you can safely ignore this message.


 | Client  0  ready to train! |


***** Running training *****
  Num examples = 23
  Num Epochs = 5
  Instantaneous batch size per device = 1
  Total train batch size (w. parallel, distributed & accumulation) = 1
  Gradient Accumulation steps = 1
  Total optimization steps = 115
  Number of trainable parameters = 309,101,600
  "`as_target_processor` is deprecated and will be removed in v5 of Transformers. You can process your "
  0%|          | 0/115 [00:00<?, ?it/s]
***** Running training *****
  Num examples = 28
  Num Epochs = 5
  Instantaneous batch size per device = 1
  Total train batch size (w. parallel, distributed & accumulation) = 1
  Gradient Accumulation steps = 1
  Total optimization steps = 140
  Number of trainable parameters = 309,101,600
  "`as_target_processor` is deprecated and will be removed in v5 of Transformers. You can process your "
  0%|          | 0/140 [00:00<?, ?it/s]



TypeError: div() got an unexpected keyword argument 'rounding_mode'