In [None]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import argparse
import multiprocessing as mp
import os

import os
import copy
import time
import pickle
import numpy as np
from tqdm import tqdm

import torch
from tensorboardX import SummaryWriter

from options import args_parser
from update import LocalUpdate, test_inference, ASRLocalUpdate
from models import MLP, CNNMnist, CNNFashion_Mnist, CNNCifar, Data2VecAudioForCTC, DataCollatorCTCWithPadding
from utils import get_dataset, average_weights, exp_details

from transformers import Data2VecAudioConfig, Wav2Vec2Processor
from multiprocessing import Pool
from collections import OrderedDict
parser = argparse.ArgumentParser()

# federated arguments (Notation for the arguments followed from paper)
parser.add_argument('--epochs', type=int, default=2,
                    help="number of rounds of training")
parser.add_argument('--num_users', type=int, default=2,
                    help="number of users: K")
parser.add_argument('--frac', type=float, default=1.0,
                    help='the fraction of clients: C')
parser.add_argument('--local_ep', type=int, default=1,
                    help="the number of local epochs: E")

parser.add_argument('--model', type=str, default='data2vec', help='model name')


# other arguments
parser.add_argument('--dataset', type=str, default='adress', help="name \
                    of dataset") #cifar
#parser.add_argument('--num_classes', type=int, default=10, help="number \
#                    of classes")
parser.add_argument('--gpu', default=1, help="To use cuda, set \
                    to a specific GPU ID. Default set to use CPU.")

# additional arguments
parser.add_argument('--pretrain_name', type=str, default='facebook/data2vec-audio-large-960h', help="str used to load pretrain model")
parser.add_argument('-lam', '--LAMBDA', type=float, default=0.5, help="Lambda for GRL")
parser.add_argument('-st', '--STAGE', type=int, default=2, help="Current training stage")
parser.add_argument('-GRL', '--GRL', action='store_true', default=False, help="True: GRL")
parser.add_argument('-model_in', '--model_in_path', type=str, default="/mnt/Internal/FedASR/weitung/HuggingFace/Pretrain/saves/data2vec-audio-large-960h_new1_recall/final/", help="Where the model is saved")
parser.add_argument('-model_out', '--model_out_path', type=str, default="./save/data2vec-audio-large-960h_new2_recall_FL", help="Where to save the model")
parser.add_argument('-log', '--log_path', type=str, default="data2vec-audio-large-960h_new2_recall_FL.txt", help="name for the txt file")
# 2023/01/08: loss type
parser.add_argument('-ad_loss', '--AD_loss', type=str, default="recall", help="loss to use for AD classifier")
# 2023/01/18: ckpt
parser.add_argument('-ckpt', '--checkpoint', type=str, default=None, help="path to checkpoint")
# 2023/02/13: TOGGLE_RATIO
parser.add_argument('-toggle_rt', '--TOGGLE_RATIO', type=float, default=0, help="To toggle more or less")
# 2023/02/15: GS_TAU, loss weight
parser.add_argument('-gs_tau', '--GS_TAU', type=float, default=1, help="Tau for gumbel_softmax")
parser.add_argument('-w_loss', '--W_LOSS', type=float, default=None, nargs='+', help="weight for HC and AD")

args = parser.parse_args(args=[]) # for jupyter notebook


# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" # 或者其他你想要使用的 GPU 編號
lock = mp.Lock()
logger = SummaryWriter('../logs')

# 第一步我把logger拿掉了
# final_result = pool.starmap_async(
#     client_train, [(args, train_dataset, logger,
#                     test_dataset, idx, epoch, global_weights)
#                 for idx in idxs_users])
#傳logger會出現RuntimeError: Queue objects should only be shared between processes through inheritance
final_result = pool.starmap_async(
    client_train, [(args, train_dataset, None,
                    test_dataset, idx, epoch, global_weights)
                for idx in idxs_users])

In [None]:
import os
import multiprocessing



def client_train(args, train_dataset, logger,
                 test_dataset, idx, epoch, global_weights=None):
    if args.model == 'data2vec':
        mask_time_prob = 0                                                          # change config to avoid training stopping
        config = Data2VecAudioConfig.from_pretrained(args.pretrain_name, mask_time_prob=mask_time_prob)
        print("load from ", args.model_in_path)
        model = Data2VecAudioForCTC.from_pretrained(args.model_in_path, config=config, args=args)
        print("model loaded")                                                       # load/initialize global model
        model.config.ctc_zero_infinity = True                                       # to avoid inf values

        global_model = copy.deepcopy(model.arbitrator)                              # only has global toggling network
        if global_weights != None:                                                  # if given global_weights
            global_model.load_state_dict(global_weights)                            # load it
        #else:
        #    # copy weights
        #    global_weights = copy.deepcopy(global_model.state_dict())                       # save global weight
        processor = Wav2Vec2Processor.from_pretrained(args.pretrain_name)
        data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
    else:
        exit('Error: unrecognized model') 

    # Set the model to train and send it to device.
    device = 'cuda' if args.gpu else 'cpu'
    global_model.to(device)
    global_model.train()      

    #print(global_model)

    ####################
    # 'use client_id generate sub-dataset' to be done
    ####################
    #print("call ASRLocalUpdate")
    local_model = ASRLocalUpdate(args=args, dataset=train_dataset, logger=logger,
                        data_collator=data_collator, global_test_dataset=test_dataset, 
                        processor=processor, client_id=idx)
                                                                                    # initial dataset of current client
    ####################
    # 'use client_id load local model' to be done
    # 'save model in final round' to be done
    ####################
    #print("perform update_weight")
    w, loss = local_model.update_weights(
        global_arbitrator=copy.deepcopy(global_model), global_round=epoch)          # from global model to train
    
    #send_end.send([w, loss])                                                        # save model weights and average round loss
    #return_dict[str(idx)] = [w, loss]
    print("PID {} Getting ".format(os.getpid()), "Done")
    return w, loss
    
                                                                   
    return 0
        
train_loss, test_wer = [], []
val_acc_list, net_list = [], []
cv_loss, cv_acc = [], []
print_every = 2
val_loss_pre, counter = 0, 0
global_weights = None                                                           # initial global_weights

train_dataset, test_dataset, user_groups = get_dataset(args)
if __name__ == "__main__":
    for epoch in range(2):
        m = max(int(args.frac * args.num_users), 1)                                 # num of clients to train, min:1
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)      # select by client_id
        pool = multiprocessing.Pool(processes=m)
        try:
            # final_result = pool.starmap_async(
            #     client_train, [(args, train_dataset, logger,
            #                     test_dataset, idx, epoch, global_weights)
            #                 for idx in idxs_users])
            #傳logger會出現RuntimeError: Queue objects should only be shared between processes through inheritance
            final_result = pool.starmap_async(
                client_train, [(args, train_dataset, None,
                                test_dataset, idx, epoch, global_weights)
                            for idx in idxs_users])
        except Exception as e:
            print(f"An error occurred while running local_model.update_weights(): {str(e)}")
        finally:
            final_result.wait()
            results = final_result.get()
        
        local_weights = []
        local_losses = []
        for idx in range(len(results)):
            w, loss = results[idx]
            local_weights.append(w)
            local_losses.append(loss)

        print("local weights: ", local_weights)
        # get global weights by averaging local weights
        global_weights = average_weights(local_weights)
        print("global wegiths: ", global_weights)

        # update global weights
        #global_model.load_state_dict(global_weights)

        loss_avg = sum(local_losses) / len(local_losses)                # average losses from participated client
        train_loss.append(loss_avg)                                     # save loss for this round
        print("All results done")


In [None]:
# import os
# import multiprocessing

# def client_train(args, train_dataset, logger,
#                  test_dataset, idx, epoch, global_weights=None, result_queue=None):
#     mask_time_prob = 0                                                          
#     config = Data2VecAudioConfig.from_pretrained(args.pretrain_name, mask_time_prob=mask_time_prob)
#     print("load from ", args.model_in_path)
#     # with lock:
#     #     model = Data2VecAudioForCTC.from_pretrained(args.model_in_path, config=config, args=args)
#     model = Data2VecAudioForCTC.from_pretrained(args.model_in_path, config=config, args=args)
#     print("model loaded")                                                       
#     # model.config.ctc_zero_infinity = True 
#     # global_model = copy.deepcopy(model.arbitrator)                              
#     # if global_weights != None:                                                  
#     #     global_model.load_state_dict(global_weights)                            
#     # processor = Wav2Vec2Processor.from_pretrained(args.pretrain_name)
#     # data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
#     return 0
        
# train_loss, test_wer = [], []
# val_acc_list, net_list = [], []
# cv_loss, cv_acc = [], []
# print_every = 2
# val_loss_pre, counter = 0, 0
# global_weights = None                                                           # initial global_weights
# # 創建進程間共享的 Queue 對象
# # manager = mp.Manager()
# # result_queue = manager.Queue()

# train_dataset, test_dataset, user_groups = get_dataset(args)
# if __name__ == "__main__":
#     # 要读取的文件列表
#     files = ["file1.txt", "file2.txt", "file3.txt"]

#     # 创建进程池
#     for epoch in range(2):
        
#         m = max(int(args.frac * args.num_users), 1)                                 # num of clients to train, min:1
#         idxs_users = np.random.choice(range(args.num_users), m, replace=False)      # select by client_id
#         pool = multiprocessing.Pool(processes=m)
        
#         # 使用进程池并行读取文件
#         # results = pool.map(read_file, files)
#         try:
#             final_result = pool.starmap_async(
#                 client_train, [(args, train_dataset, logger,
#                                 test_dataset, idx, epoch, global_weights)
#                             for idx in idxs_users])

#         except Exception as e:
#             print(f"An error occurred while running local_model.update_weights(): {str(e)}")
#         finally:
#             final_result.wait()
#             results = final_result.get()
#         # 输出结果
#         for result in results:
#             print(result)
#     print("All results done")
