In [16]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import argparse
import multiprocessing as mp
import os

import os
import copy
import time
import pickle
import numpy as np
from tqdm import tqdm

import torch
from tensorboardX import SummaryWriter

from options import args_parser
from update import LocalUpdate, test_inference, ASRLocalUpdate
from models import MLP, CNNMnist, CNNFashion_Mnist, CNNCifar, Data2VecAudioForCTC, DataCollatorCTCWithPadding
from utils import get_dataset, average_weights, exp_details

from transformers import Data2VecAudioConfig, Wav2Vec2Processor
from multiprocessing import Pool
from collections import OrderedDict
parser = argparse.ArgumentParser()

# federated arguments (Notation for the arguments followed from paper)
parser.add_argument('--epochs', type=int, default=2,
                    help="number of rounds of training")
parser.add_argument('--num_users', type=int, default=2,
                    help="number of users: K")
parser.add_argument('--frac', type=float, default=1.0,
                    help='the fraction of clients: C')
parser.add_argument('--local_ep', type=int, default=1,
                    help="the number of local epochs: E")

parser.add_argument('--model', type=str, default='data2vec', help='model name')


# other arguments
parser.add_argument('--dataset', type=str, default='adress', help="name \
                    of dataset") #cifar
#parser.add_argument('--num_classes', type=int, default=10, help="number \
#                    of classes")
parser.add_argument('--gpu', default=1, help="To use cuda, set \
                    to a specific GPU ID. Default set to use CPU.")

# additional arguments
parser.add_argument('--pretrain_name', type=str, default='facebook/data2vec-audio-large-960h', help="str used to load pretrain model")
parser.add_argument('-lam', '--LAMBDA', type=float, default=0.5, help="Lambda for GRL")
parser.add_argument('-st', '--STAGE', type=int, default=2, help="Current training stage")
parser.add_argument('-GRL', '--GRL', action='store_true', default=False, help="True: GRL")
parser.add_argument('-model_in', '--model_in_path', type=str, default="/mnt/Internal/FedASR/weitung/HuggingFace/Pretrain/saves/data2vec-audio-large-960h_new1_recall/final/", help="Where the model is saved")
parser.add_argument('-model_out', '--model_out_path', type=str, default="./save/data2vec-audio-large-960h_new2_recall_FL", help="Where to save the model")
parser.add_argument('-log', '--log_path', type=str, default="data2vec-audio-large-960h_new2_recall_FL.txt", help="name for the txt file")
# 2023/01/08: loss type
parser.add_argument('-ad_loss', '--AD_loss', type=str, default="recall", help="loss to use for AD classifier")
# 2023/01/18: ckpt
parser.add_argument('-ckpt', '--checkpoint', type=str, default=None, help="path to checkpoint")
# 2023/02/13: TOGGLE_RATIO
parser.add_argument('-toggle_rt', '--TOGGLE_RATIO', type=float, default=0, help="To toggle more or less")
# 2023/02/15: GS_TAU, loss weight
parser.add_argument('-gs_tau', '--GS_TAU', type=float, default=1, help="Tau for gumbel_softmax")
parser.add_argument('-w_loss', '--W_LOSS', type=float, default=None, nargs='+', help="weight for HC and AD")

args = parser.parse_args(args=[]) # for jupyter notebook


# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" # 或者其他你想要使用的 GPU 編號
lock = mp.Lock()
logger = SummaryWriter('../logs')

In [17]:

def client_train_old(args, train_dataset, logger,
                 test_dataset, idx, epoch, global_weights=None):                    # train function for each client
    print(" process PID", os.getpid(), " running")
    
    # BUILD MODEL for every process
    if args.model == 'data2vec':
        mask_time_prob = 0                                                          # change config to avoid training stopping
        config = Data2VecAudioConfig.from_pretrained(args.pretrain_name, mask_time_prob=mask_time_prob)
        print("load from ", args.model_in_path)
        lock.acquire()
        # print(" process PID", os.getpid(),"enter critical section")
        # model = Data2VecAudioForCTC.from_pretrained(args.model_in_path, config=config, args=args)
        # lock.release()
        # print(" process PID", os.getpid(),"exit critical section")
        try:
            # 在這裡執行 critical section 的程式碼
            print(" process PID", os.getpid(),"enter critical section")
            model = Data2VecAudioForCTC.from_pretrained(args.model_in_path, config=config, args=args)
            pass
        finally:
            # 在離開 critical section 後釋放鎖
            lock.release()
            print(" process PID", os.getpid(),"exit critical section")
        print("model loaded")                                                       # load/initialize global model
        model.config.ctc_zero_infinity = True                                       # to avoid inf values

        global_model = copy.deepcopy(model.arbitrator)                              # only has global toggling network
        if global_weights != None:                                                  # if given global_weights
            global_model.load_state_dict(global_weights)                            # load it
        #else:
        #    # copy weights
        #    global_weights = copy.deepcopy(global_model.state_dict())                       # save global weight
        processor = Wav2Vec2Processor.from_pretrained(args.pretrain_name)
        data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
    else:
        exit('Error: unrecognized model')

    # Set the model to train and send it to device.
    device = 'cuda' if args.gpu else 'cpu'
    global_model.to(device)
    global_model.train()
    #print(global_model)

    ####################
    # 'use client_id generate sub-dataset' to be done
    ####################
    #print("call ASRLocalUpdate")
    local_model = ASRLocalUpdate(args=args, dataset=train_dataset, logger=logger,
                        data_collator=data_collator, global_test_dataset=test_dataset, 
                        processor=processor, client_id=idx)
                                                                                    # initial dataset of current client
    ####################
    # 'use client_id load local model' to be done
    # 'save model in final round' to be done
    ####################
    #print("perform update_weight")
    w, loss = local_model.update_weights(
        global_arbitrator=copy.deepcopy(global_model), global_round=epoch)          # from global model to train
    
    #send_end.send([w, loss])                                                        # save model weights and average round loss
    #return_dict[str(idx)] = [w, loss]
    print("PID {} Getting ".format(os.getpid()), "Done")
    return w, loss

In [18]:

def client_train_old2(args, train_dataset, logger,
                 test_dataset, idx, epoch, global_weights=None, result_queue=None):                    # train function for each client
    print(" process PID", os.getpid(), " running")
    
    # BUILD MODEL for every process
    if args.model == 'data2vec':
        mask_time_prob = 0                                                          # change config to avoid training stopping
        config = Data2VecAudioConfig.from_pretrained(args.pretrain_name, mask_time_prob=mask_time_prob)
        print("load from ", args.model_in_path)
        lock.acquire()
        # print(" process PID", os.getpid(),"enter critical section")
        # model = Data2VecAudioForCTC.from_pretrained(args.model_in_path, config=config, args=args)
        # lock.release()
        # print(" process PID", os.getpid(),"exit critical section")
        try:
            # 在這裡執行 critical section 的程式碼
            print(" process PID", os.getpid(),"enter critical section")
            model = Data2VecAudioForCTC.from_pretrained(args.model_in_path, config=config, args=args)
            pass
        finally:
            # 在離開 critical section 後釋放鎖
            lock.release()
            print(" process PID", os.getpid(),"exit critical section")
        print("model loaded")                                                       # load/initialize global model
        model.config.ctc_zero_infinity = True                                       # to avoid inf values

        global_model = copy.deepcopy(model.arbitrator)                              # only has global toggling network
        if global_weights != None:                                                  # if given global_weights
            global_model.load_state_dict(global_weights)                            # load it
        #else:
        #    # copy weights
        #    global_weights = copy.deepcopy(global_model.state_dict())                       # save global weight
        processor = Wav2Vec2Processor.from_pretrained(args.pretrain_name)
        data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
    else:
        exit('Error: unrecognized model')

    # Set the model to train and send it to device.
    device = 'cuda' if args.gpu else 'cpu'
    global_model.to(device)
    global_model.train()
    #print(global_model)

    ####################
    # 'use client_id generate sub-dataset' to be done
    ####################
    #print("call ASRLocalUpdate")
    local_model = ASRLocalUpdate(args=args, dataset=train_dataset, logger=logger,
                        data_collator=data_collator, global_test_dataset=test_dataset, 
                        processor=processor, client_id=idx)
                                                                                    # initial dataset of current client
    ####################
    # 'use client_id load local model' to be done
    # 'save model in final round' to be done
    ####################
    #print("perform update_weight")
    w, loss = local_model.update_weights(
        global_arbitrator=copy.deepcopy(global_model), global_round=epoch)          # from global model to train
    
    # 將進程的結果存放在 result_queue 中
    if result_queue is not None:
        result_queue.put((w, loss))
    #send_end.send([w, loss])                                                        # save model weights and average round loss
    #return_dict[str(idx)] = [w, loss]
    print("PID {} Getting ".format(os.getpid()), "Done")
    return w, loss

In [19]:
def client_train(args, train_dataset, logger,
                 test_dataset, idx, epoch, global_weights=None, result_queue=None):                    
    print("process PID", os.getpid(), "running")
    # create lock
    lock = mp.Lock()
    # BUILD MODEL for every process
    if args.model == 'data2vec':
        mask_time_prob = 0                                                          
        config = Data2VecAudioConfig.from_pretrained(args.pretrain_name, mask_time_prob=mask_time_prob)
        print("load from ", args.model_in_path)
        with lock:
            model = Data2VecAudioForCTC.from_pretrained(args.model_in_path, config=config, args=args)
        print("model loaded")                                                       
        model.config.ctc_zero_infinity = True                                       

        global_model = copy.deepcopy(model.arbitrator)                              
        if global_weights != None:                                                  
            global_model.load_state_dict(global_weights)                            
        processor = Wav2Vec2Processor.from_pretrained(args.pretrain_name)
        data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
    else:
        exit('Error: unrecognized model')

    # Set the model to train and send it to device.
    device = 'cuda' if args.gpu else 'cpu'
    global_model.to(device)
    global_model.train()

    # generate sub-dataset
    local_dataset = generate_local_dataset(train_dataset, idx, args.num_users)
    local_model = ASRLocalUpdate(args=args, dataset=local_dataset, logger=logger,
                        data_collator=data_collator, global_test_dataset=test_dataset, 
                        processor=processor, client_id=idx)

    try:
        w, loss = local_model.update_weights(
            global_arbitrator=copy.deepcopy(global_model), global_round=epoch)
    except Exception as e:
        print(f"An error occurred while running local_model.update_weights(): {str(e)}")
    else:
        # save model weights and average round loss
        if result_queue is not None:
            result_queue.put((w, loss))
        print("process PID {} done".format(os.getpid()))
    return w, loss


# 拆解上面那個client train

In [20]:
def client_train_toy(args, train_dataset, logger,
                 test_dataset, idx, epoch, global_weights=None, result_queue=None):                    
    print("process PID", os.getpid(), "running")
    sys.stdout.flush()
    # create lock
    lock = mp.Lock()
    # BUILD MODEL for every process
    if args.model == 'data2vec':
        mask_time_prob = 0                                                          
        config = Data2VecAudioConfig.from_pretrained(args.pretrain_name, mask_time_prob=mask_time_prob)
        print("load from ", args.model_in_path)
        with lock:
            model = Data2VecAudioForCTC.from_pretrained(args.model_in_path, config=config, args=args)
        print("model loaded")                                                       
        model.config.ctc_zero_infinity = True                                       

        global_model = copy.deepcopy(model.arbitrator)                              
        if global_weights != None:                                                  
            global_model.load_state_dict(global_weights)                            
        processor = Wav2Vec2Processor.from_pretrained(args.pretrain_name)
        data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
    else:
        exit('Error: unrecognized model')
    
    print("process PID {} done".format(os.getpid()))
    sys.stdout.flush()
    return 1,2
    # # Set the model to train and send it to device.
    # device = 'cuda' if args.gpu else 'cpu'
    # global_model.to(device)
    # global_model.train()

    # # generate sub-dataset
    # local_dataset = generate_local_dataset(train_dataset, idx, args.num_users)
    # local_model = ASRLocalUpdate(args=args, dataset=local_dataset, logger=logger,
    #                     data_collator=data_collator, global_test_dataset=test_dataset, 
    #                     processor=processor, client_id=idx)

    # try:
    #     w, loss = local_model.update_weights(
    #         global_arbitrator=copy.deepcopy(global_model), global_round=epoch)
    # except Exception as e:
    #     print(f"An error occurred while running local_model.update_weights(): {str(e)}")
    # else:
    #     # save model weights and average round loss
    #     if result_queue is not None:
    #         result_queue.put((w, loss))
    #     print("process PID {} done".format(os.getpid()))
    # return w, loss


In [21]:
train_loss, test_wer = [], []
val_acc_list, net_list = [], []
cv_loss, cv_acc = [], []
print_every = 2
val_loss_pre, counter = 0, 0
global_weights = None                                                           # initial global_weights
# 創建進程間共享的 Queue 對象
manager = mp.Manager()
result_queue = manager.Queue()

train_dataset, test_dataset, user_groups = get_dataset(args)
for epoch in range(args.epochs):                                          # train for given global rounds
    #local_weights, local_losses = [], []                                        # weights and losses of training clients of this round
    print(f'\n | Global Training Round : {epoch+1} |\n')                        # print current round

    m = max(int(args.frac * args.num_users), 1)                                 # num of clients to train, min:1
    idxs_users = np.random.choice(range(args.num_users), m, replace=False)      # select by client_id

    pool = Pool(m)
    # final_result = pool.starmap(client_train, [(args, train_dataset, logger,
    #                     test_dataset, idx, epoch, global_weights) for idx in idxs_users])
    print("Training clients:", idxs_users)
    # final_result = pool.starmap_async(
    #     client_train, [(args, train_dataset, logger,
    #                     test_dataset, idx, epoch, global_weights, result_queue)
    #                    for idx in idxs_users])
    try:
        final_result = pool.starmap_async(
            client_train, [(args, train_dataset, logger,
                            test_dataset, idx, epoch, global_weights, result_queue)
                        for idx in idxs_users])
    except Exception as e:
        print(f"An error occurred while running local_model.update_weights(): {str(e)}")



    pool.close()
    pool.join()

    local_weights = []
    local_losses = []
    while not result_queue.empty():
        w, loss = result_queue.get()
        local_weights.append(w)
        local_losses.append(loss)
    # 舊的
    # for idx in range(len(final_result)):
    #     w, loss = final_result[idx]
    #     local_weights.append(w)
    #     local_losses.append(loss)

    print("Final result:", final_result)
    print("Local weights:", local_weights)


    # 這邊等修好了再解除comment
    # print("local weights: ", local_weights)
    # # get global weights by averaging local weights
    # global_weights = average_weights(local_weights)
    # print("global wegiths: ", global_weights)

    # # update global weights
    # #global_model.load_state_dict(global_weights)

    # loss_avg = sum(local_losses) / len(local_losses)                # average losses from participated client
    # train_loss.append(loss_avg)                                     # save loss for this round

Loading cached processed dataset at /home/FedASR/dacs/federated/dataset/train/cache-b6f4c0d2143105d5_*_of_00010.arrow
Loading cached processed dataset at /home/FedASR/dacs/federated/dataset/test/cache-f07dc6d726972452_*_of_00010.arrow


Load data from local...
Load data from local...

 | Global Training Round : 1 |

Training clients: [1 0]
Final result: <multiprocessing.pool.MapResult object at 0x7fd8b543c590>
Local weights: []

 | Global Training Round : 2 |

Training clients: [0 1]
Final result: <multiprocessing.pool.MapResult object at 0x7fd8b5446e90>
Local weights: []


# 拆解上面那個coding block

# 替換成pool.map寫法

In [27]:
import concurrent.futures

train_loss, test_wer = [], []
val_acc_list, net_list = [], []
cv_loss, cv_acc = [], []
print_every = 2
val_loss_pre, counter = 0, 0
global_weights = None                                                           # initial global_weights
# 創建進程間共享的 Queue 對象
# manager = mp.Manager()
# result_queue = manager.Queue()

train_dataset, test_dataset, user_groups = get_dataset(args)

def client_train_map(args_train):                    
    print("process PID", os.getpid(), "running")
    # args, train_dataset, logger, test_dataset, idx, epoch, global_weights, result_queue = args_train
    args, train_dataset, logger, test_dataset, idx, epoch, global_weights = args_train
    sys.stdout.flush()
    # create lock
    lock = mp.Lock()
    # BUILD MODEL for every process
    if args.model == 'data2vec':
        mask_time_prob = 0                                                          
        config = Data2VecAudioConfig.from_pretrained(args.pretrain_name, mask_time_prob=mask_time_prob)
        print("load from ", args.model_in_path)
        with lock:
            model = Data2VecAudioForCTC.from_pretrained(args.model_in_path, config=config, args=args)
        print("model loaded")                                                       
        model.config.ctc_zero_infinity = True                                       

        global_model = copy.deepcopy(model.arbitrator)                              
        if global_weights != None:                                                  
            global_model.load_state_dict(global_weights)                            
        processor = Wav2Vec2Processor.from_pretrained(args.pretrain_name)
        data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
    else:
        exit('Error: unrecognized model')
    
    print("process PID {} done".format(os.getpid()))
    sys.stdout.flush()
    # result_queue.put((1, 2))
    return 1,2





for epoch in range(args.epochs):                                          # train for given global rounds
    print(f'\n | Global Training Round : {epoch+1} |\n')                        # print current round
    m = max(int(args.frac * args.num_users), 1)                                 # num of clients to train, min:1
    idxs_users = np.random.choice(range(args.num_users), m, replace=False)      # select by client_id
    
    
    # pool = Pool(m)
    # print("Training clients:", idxs_users)
    # try:
    #     # final_result = pool.starmap_async(
    #     #     client_train, [(args, train_dataset, logger,
    #     #                     test_dataset, idx, epoch, global_weights, result_queue)
    #     #                 for idx in idxs_users])
    #     print("stepping in client_train_map")
    #     # final_result = pool.starmap_async(
    #     #     client_train_toy, [(args, train_dataset, logger,
    #     #                     test_dataset, idx, epoch, global_weights, result_queue)
    #     #                 for idx in idxs_users])
    #     final_result = pool.map(client_train_map, [(args, train_dataset, logger,
    #                                     test_dataset, idx, epoch, global_weights, result_queue)
    #                                   for idx in idxs_users])
    #     # while not final_result.ready():
    #     #     time.sleep(0.1)
    #     # final_result_content = final_result.get()
    #     # print("final_result_content: ",final_result_content)
    # except Exception as e:
    #     print(f"An error occurred while running local_model.update_weights(): {str(e)}")
    # pool.close()
    # pool.join()

    # 換另一種call法
    # with concurrent.futures.ProcessPoolExecutor() as executor:
    #     try:
    #         print("stepping in client_train_map")
    #         # final_result = list(executor.map(client_train_map, [(args, train_dataset, logger,
    #         #                                 test_dataset, idx, epoch, global_weights, result_queue)
    #         #                               for idx in idxs_users]))
    #         final_result = list(executor.map(client_train_map, [(args, train_dataset, logger,
    #                                         test_dataset, idx, epoch, global_weights)  # removed result_queue
    #                                       for idx in idxs_users]))
    #     except Exception as e:
    #         print(f"An error occurred while running local_model.update_weights(): {str(e)}")

    # 換另2種call法
    # with mp.Pool(m) as pool:
    #     try:
    #         print("stepping in client_train_map")
    #         final_result = pool.map(client_train_map, [(args, train_dataset, logger,
    #                                         test_dataset, idx, epoch, global_weights)
    #                                       for idx in idxs_users])
    #     except Exception as e:
    #         print(f"An error occurred while running local_model.update_weights(): {str(e)}")

    # 換另3種call法
    with mp.Pool(m) as pool:
        try:
            print("stepping in client_train_map")
            final_result = pool.imap_unordered(client_train_map, [(args, train_dataset, logger,
                                            test_dataset, idx, epoch, global_weights)
                                          for idx in idxs_users])
        except Exception as e:
            print(f"An error occurred while running local_model.update_weights(): {str(e)}")



    local_weights = []
    local_losses = []
    for w, loss in final_result:
        local_weights.append(w)
        local_losses.append(loss)
    # while not result_queue.empty():
    #     w, loss = result_queue.get()
    #     local_weights.append(w)
    #     local_losses.append(loss)
    print("Final result:", final_result)
    print("Local weights:", local_weights)

Loading cached processed dataset at /home/FedASR/dacs/federated/dataset/train/cache-b6f4c0d2143105d5_*_of_00010.arrow
Loading cached processed dataset at /home/FedASR/dacs/federated/dataset/test/cache-f07dc6d726972452_*_of_00010.arrow


Load data from local...
Load data from local...

 | Global Training Round : 1 |

stepping in client_train_map


RuntimeError: Queue objects should only be shared between processes through inheritance

In [28]:
final_result

<multiprocessing.pool.IMapUnorderedIterator at 0x7fd8b540c990>