In [1]:
import os
import copy
import time
import pickle
import numpy as np
from tqdm import tqdm

import torch
from tensorboardX import SummaryWriter

from options import args_parser
from update import LocalUpdate, test_inference, ASRLocalUpdate
from models import MLP, CNNMnist, CNNFashion_Mnist, CNNCifar, Data2VecAudioForCTC, DataCollatorCTCWithPadding
from utils import get_dataset, average_weights, exp_details

from transformers import Data2VecAudioConfig, Wav2Vec2Processor
from multiprocessing import Pool
from collections import OrderedDict

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import argparse



parser = argparse.ArgumentParser()

# federated arguments (Notation for the arguments followed from paper)
parser.add_argument('--epochs', type=int, default=2,
                    help="number of rounds of training")
parser.add_argument('--num_users', type=int, default=2,
                    help="number of users: K")
parser.add_argument('--frac', type=float, default=1.0,
                    help='the fraction of clients: C')
parser.add_argument('--local_ep', type=int, default=1,
                    help="the number of local epochs: E")
#parser.add_argument('--local_bs', type=int, default=1,
#                    help="local batch size: B")
#parser.add_argument('--lr', type=float, default=0.01,
#                    help='learning rate')
#parser.add_argument('--momentum', type=float, default=0.5,
#                    help='SGD momentum (default: 0.5)')

# model arguments
parser.add_argument('--model', type=str, default='data2vec', help='model name')
#parser.add_argument('--kernel_num', type=int, default=9,
#                    help='number of each kind of kernel')
#parser.add_argument('--kernel_sizes', type=str, default='3,4,5',
#                    help='comma-separated kernel size to \
#                    use for convolution')
#parser.add_argument('--num_channels', type=int, default=1, help="number \
#                    of channels of imgs")
#parser.add_argument('--norm', type=str, default='batch_norm',
#                    help="batch_norm, layer_norm, or None")
#parser.add_argument('--num_filters', type=int, default=32,
#                    help="number of filters for conv nets -- 32 for \
#                    mini-imagenet, 64 for omiglot.")
#parser.add_argument('--max_pool', type=str, default='True',
#                    help="Whether use max pooling rather than \
#                    strided convolutions")

# other arguments
parser.add_argument('--dataset', type=str, default='adress', help="name \
                    of dataset") #cifar
#parser.add_argument('--num_classes', type=int, default=10, help="number \
#                    of classes")
parser.add_argument('--gpu', default=1, help="To use cuda, set \
                    to a specific GPU ID. Default set to use CPU.")
#parser.add_argument('--optimizer', type=str, default='sgd', help="type \
#                    of optimizer")
#parser.add_argument('--iid', type=int, default=1,
#                    help='Default set to IID. Set to 0 for non-IID.')
#parser.add_argument('--unequal', type=int, default=0,
#                    help='whether to use unequal data splits for  \
#                    non-i.i.d setting (use 0 for equal splits)')
#parser.add_argument('--stopping_rounds', type=int, default=10,
#                    help='rounds of early stopping')
#parser.add_argument('--verbose', type=int, default=1, help='verbose')
#parser.add_argument('--seed', type=int, default=1, help='random seed')

# additional arguments
parser.add_argument('--pretrain_name', type=str, default='facebook/data2vec-audio-large-960h', help="str used to load pretrain model")
parser.add_argument('-lam', '--LAMBDA', type=float, default=0.5, help="Lambda for GRL")
parser.add_argument('-st', '--STAGE', type=int, default=2, help="Current training stage")
parser.add_argument('-GRL', '--GRL', action='store_true', default=False, help="True: GRL")
parser.add_argument('-model_in', '--model_in_path', type=str, default="/mnt/Internal/FedASR/weitung/HuggingFace/Pretrain/saves/data2vec-audio-large-960h_new1_recall/final/", help="Where the model is saved")
parser.add_argument('-model_out', '--model_out_path', type=str, default="./save/data2vec-audio-large-960h_new2_recall_FL", help="Where to save the model")
parser.add_argument('-log', '--log_path', type=str, default="data2vec-audio-large-960h_new2_recall_FL.txt", help="name for the txt file")
# 2023/01/08: loss type
parser.add_argument('-ad_loss', '--AD_loss', type=str, default="recall", help="loss to use for AD classifier")
# 2023/01/18: ckpt
parser.add_argument('-ckpt', '--checkpoint', type=str, default=None, help="path to checkpoint")
# 2023/02/13: TOGGLE_RATIO
parser.add_argument('-toggle_rt', '--TOGGLE_RATIO', type=float, default=0, help="To toggle more or less")
# 2023/02/15: GS_TAU, loss weight
parser.add_argument('-gs_tau', '--GS_TAU', type=float, default=1, help="Tau for gumbel_softmax")
parser.add_argument('-w_loss', '--W_LOSS', type=float, default=None, nargs='+', help="weight for HC and AD")

args = parser.parse_args(args=[]) # for jupyter notebook


In [3]:
def client_train(args, train_dataset, 
                 test_dataset, idx, epoch, global_weights=None):                    # train function for each client
    print(" process PID", os.getpid(), " running")
    
    # BUILD MODEL for every process
    if args.model == 'data2vec':
        mask_time_prob = 0                                                          # change config to avoid training stopping
        config = Data2VecAudioConfig.from_pretrained(args.pretrain_name, mask_time_prob=mask_time_prob)
        print("load from ", args.model_in_path)
        model = Data2VecAudioForCTC.from_pretrained(args.model_in_path, config=config, args=args)
        print("model loaded")                                                       # load/initialize global model
        model.config.ctc_zero_infinity = True                                       # to avoid inf values

        global_model = copy.deepcopy(model.arbitrator)                              # only has global toggling network
        if global_weights != None:                                                  # if given global_weights
            global_model.load_state_dict(global_weights)                            # load it
        #else:
        #    # copy weights
        #    global_weights = copy.deepcopy(global_model.state_dict())                       # save global weight
        processor = Wav2Vec2Processor.from_pretrained(args.pretrain_name)
        data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
    else:
        exit('Error: unrecognized model')

    # Set the model to train and send it to device.
    device = 'cuda' if args.gpu else 'cpu'
    global_model.to(device)
    global_model.train()
    #print(global_model)

    ####################
    # 'use client_id generate sub-dataset' to be done
    ####################
    #print("call ASRLocalUpdate")
    local_model = ASRLocalUpdate(args=args, dataset=train_dataset, logger=logger,
                        data_collator=data_collator, global_test_dataset=test_dataset, 
                        processor=processor, client_id=idx)
                                                                                    # initial dataset of current client
    ####################
    # 'use client_id load local model' to be done
    # 'save model in final round' to be done
    ####################
    #print("perform update_weight")
    w, loss = local_model.update_weights(
        global_arbitrator=copy.deepcopy(global_model), global_round=epoch)          # from global model to train
    
    #send_end.send([w, loss])                                                        # save model weights and average round loss
    #return_dict[str(idx)] = [w, loss]
    print("PID {} Getting ".format(os.getpid()), "Done")
    return w, loss

In [4]:
def client_get_weight(args, idx, epoch):                               # function to get weight for each client
    print(" process PID", os.getpid(), " running")
    # model saved in args.model_out_path + "_" + str(idx) + "/final_" + str(epoch)
    #model_path = args.model_out_path + "_" + str(idx) + "/final_" + str(epoch)
    #model_path = args.model_out_path + "_client" + str(idx) + 
    #"_round" + str(global_round) + "/final"
    model_path = args.model_in_path
    # BUILD MODEL for every process
    if args.model == 'data2vec':
        mask_time_prob = 0                                                          # change config to avoid training stopping
        config = Data2VecAudioConfig.from_pretrained(args.pretrain_name, mask_time_prob=mask_time_prob)
        model = Data2VecAudioForCTC.from_pretrained(model_path, config=config, args=args)
                                                                                    # load local model
        print("model loaded")
        model.config.ctc_zero_infinity = True                                       # to avoid inf values
 
        arbitrator = copy.deepcopy(model.arbitrator)                                # return weight for toggling network only

        return_weights = copy.deepcopy(arbitrator.state_dict())                       # save global weight
    else:
        exit('Error: unrecognized model')

    #return_dict[str(idx)] = return_weights
    print("PID {} Getting ".format(os.getpid()), "Done")
    return return_weights, 0.05


In [5]:
#start_time = time.time()

# define paths
#path_project = os.path.abspath('..')
logger = SummaryWriter('../logs')

#args = args_parser()
exp_details(args) # print out details based on configuration


Experimental details:
    Model     : data2vec
    Global Rounds   : 2

    Current Stage   : 2

    Loss Type       : recall

    Federated parameters:
    Number of users    : 2
    Fraction of users  : 1.0


In [6]:
#if args.gpu_id:
#    torch.cuda.set_device(args.gpu_id)
#device = 'cuda' if args.gpu else 'cpu'

# load dataset and user groups
train_dataset, test_dataset, user_groups = get_dataset(args)

Load data from local...
Load data from local...


Loading cached processed dataset at dataset/train/cache-eef2ff61fe552b21.arrow
Loading cached processed dataset at dataset/train/cache-cd21b7340f9b8047.arrow
Loading cached processed dataset at dataset/train/cache-0dd52273b5907da5.arrow
Loading cached processed dataset at dataset/train/cache-ba7a9e979f233d88.arrow
Loading cached processed dataset at dataset/train/cache-0d40fc390de55f7c.arrow
Loading cached processed dataset at dataset/train/cache-f8c6198f1e1fbec6.arrow
Loading cached processed dataset at dataset/train/cache-9a77b4114be7ed46.arrow
Loading cached processed dataset at dataset/train/cache-2ab0be3c68e13e94.arrow
Loading cached processed dataset at dataset/train/cache-2b8608c1545fc612.arrow
Loading cached processed dataset at dataset/train/cache-499d91cde6971409.arrow
Loading cached processed dataset at dataset/test/cache-04dcfd9c5cab1c0d.arrow
Loading cached processed dataset at dataset/test/cache-6a045a1fa580e9fb.arrow
Loading cached processed dataset at dataset/test/cache

# ID相關code先跳過

In [16]:
user_IDs = train_dataset.map(lambda x: {"user_IDs": x["path"].split("_")[0]})
user_IDs

100%|██████████| 206/206 [00:10<00:00, 20.58ex/s]


Dataset({
    features: ['path', 'array', 'text', 'dementia_labels', 'input_values', 'labels', 'user_IDs'],
    num_rows: 206
})

In [18]:
mylist = user_IDs["user_IDs"]
mylist = list(dict.fromkeys(mylist))
print(mylist)

['S082', 'S070', 'S081', 'S073', 'S024', 'S062', 'S100', 'S132', 'S111', 'S001', 'S012', 'S009', 'S028', 'S079', 'S108', 'S151', 'S094', 'S048', 'S003', 'S055', 'S077', 'S043', 'S002', 'S138', 'S126', 'S029', 'S107', 'S092', 'S089', 'S011', 'S148', 'S135', 'S139', 'S129', 'S015', 'S116', 'S090', 'S101', 'S128', 'S005', 'S004', 'S041', 'S149', 'S040', 'S150', 'S039', 'S156', 'S061', 'S141', 'S137', 'S080', 'S056', 'S038', 'S076', 'S064', 'S007', 'S063', 'S140', 'S019', 'S130', 'S114', 'S124', 'S033', 'S118', 'S021', 'S051', 'S084', 'S032', 'S153', 'S142', 'S016', 'S103', 'S145', 'S013', 'S095', 'S034', 'S006', 'S083', 'S027', 'S025', 'S017', 'S097', 'S086', 'S052', 'S068', 'S096']


In [19]:
len(mylist)

86

In [13]:
# generate sub- training set for given user-ID
start_with_ar = train_dataset.filter(lambda example: (example["path"].startswith("S055")) or (example["path"].startswith("S094")))
start_with_ar["path"]

100%|██████████| 1/1 [00:11<00:00, 11.16s/ba]


['S094_PAR_12_75580_76503.wav',
 'S055_PAR_15_41550_42647.wav',
 'S094_PAR_8_48996_55998.wav',
 'S055_PAR_9_21400_22280.wav',
 'S055_PAR_14_39953_41550.wav',
 'S094_PAR_4_25903_28918.wav',
 'S055_PAR_4_12000_14377.wav',
 'S094_INV_3_74801_75580.wav']

In [6]:
def cifar_iid(dataset, num_users):
    """
    Sample I.I.D. client data from CIFAR10 dataset
    :param dataset:
    :param num_users:
    :return: dict of image index
    """
    num_items = int(len(dataset)/num_users)
    dict_users, all_idxs = {}, [i for i in range(len(dataset))]
    for i in range(num_users):
        dict_users[i] = set(np.random.choice(all_idxs, num_items,
                                             replace=False))
        all_idxs = list(set(all_idxs) - dict_users[i])
    return dict_users
def cifar_noniid(dataset, num_users):
    """
    Sample non-I.I.D client data from CIFAR10 dataset
    :param dataset:
    :param num_users:
    :return:
    """
    num_shards, num_imgs = 200, 250
    idx_shard = [i for i in range(num_shards)]
    dict_users = {i: np.array([]) for i in range(num_users)}
    idxs = np.arange(num_shards*num_imgs)
    # labels = dataset.train_labels.numpy()
    labels = np.array(dataset.train_labels)

    # sort labels
    idxs_labels = np.vstack((idxs, labels))
    idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
    idxs = idxs_labels[0, :]

    # divide and assign
    for i in range(num_users):
        rand_set = set(np.random.choice(idx_shard, 2, replace=False))
        idx_shard = list(set(idx_shard) - rand_set)
        for rand in rand_set:
            dict_users[i] = np.concatenate(
                (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
    return dict_users

# 下面開始繼續跑

### [參考](https://stackoverflow.com/questions/10415028/how-to-get-the-return-value-of-a-function-passed-to-multiprocessing-process): 不同multi-process寫法

In [11]:
# multi-process ver1
from multiprocessing import Pool
"""
def Get_phonationdictbag_map(parameters):
    args, train_dataset, logger, test_dataset, idx, epoch, global_weights = parameters
    print(" process PID", os.getpid(), " running")
    for file in files:
        ///
        your code
        ///
    print("PID {} Getting ".format(os.getpid()), "Done")
client_train(args, train_dataset, logger, 
                 test_dataset, idx, epoch, global_weights=None)
"""
#interval=20
#parameters_lst = []
#for i in range(args.num_users):                       # for each client
    
#    keys.append(files[i:i+interval])
#flat_keys=[item for sublist in keys for item in sublist]
#assert len(flat_keys) == len(files)

#final_result = pool.starmap(Get_phonationdictbag_map, [([file_block])    for file_block in tqdm(keys)])
#client_get_weight(args, idx, epoch)
pool = Pool(int(os.cpu_count()))
epoch = 0
final_result = pool.starmap(client_get_weight, [(args, idx, epoch) for idx in range(args.num_users)])

 process PID  process PID 29867342986733   running
 running
lambda =  tensor(0.5000)
lambda =  tensor(0.5000)
Current stage: 2
Current stage: 2
model loaded
PID 2986733 Getting  Done
model loaded
PID 2986734 Getting  Done


In [14]:
final_result[1]

(OrderedDict([('weight',
               tensor([[-0.0123,  0.0295,  0.0211,  ..., -0.0300,  0.0236,  0.0029],
                       [ 0.0174,  0.0111,  0.0379,  ...,  0.0432, -0.0009,  0.0588],
                       [-0.0086, -0.0458, -0.0061,  ...,  0.0018, -0.0037, -0.0036],
                       ...,
                       [ 0.0082, -0.0028, -0.0058,  ...,  0.0179, -0.0070, -0.0085],
                       [-0.0223, -0.0279,  0.0233,  ...,  0.0101,  0.0206,  0.0180],
                       [-0.0271,  0.0189,  0.0437,  ...,  0.0140, -0.0013,  0.0029]])),
              ('bias', tensor([0., 0., 0.,  ..., 0., 0., 0.]))]),
 0.05)

### 只跑一個round

In [7]:
# Training
train_loss, test_wer = [], []
val_acc_list, net_list = [], []
cv_loss, cv_acc = [], []
print_every = 2
val_loss_pre, counter = 0, 0
global_weights = None                                                           # initial global_weights
epoch = 0
#for epoch in tqdm(range(args.epochs)):                                          # train for given global rounds
#local_weights, local_losses = [], []                                        # weights and losses of training clients of this round
print(f'\n | Global Training Round : {epoch+1} |\n')                        # print current round

m = max(int(args.frac * args.num_users), 1)                                 # num of clients to train, min:1
idxs_users = np.random.choice(range(args.num_users), m, replace=False)      # select by client_id

pool = Pool(int(os.cpu_count()))
final_result = pool.starmap(client_train, [(args, train_dataset, 
                 test_dataset, idx, epoch, global_weights) for idx in idxs_users])




 | Global Training Round : 1 |

 process PID process PID 4098957  4098960 running 
 running
load from  /mnt/Internal/FedASR/weitung/HuggingFace/Pretrain/saves/data2vec-audio-large-960h_new1_recall/final/load from 
 /mnt/Internal/FedASR/weitung/HuggingFace/Pretrain/saves/data2vec-audio-large-960h_new1_recall/final/
lambda =  tensor(0.5000)
lambda =  tensor(0.5000)
Current stage: 2
Current stage: 2
model loaded
model loaded
initialize ASRLocalUpdate
Generating client training set for client  1 ...


Loading cached processed dataset at dataset/train/cache-fb08d3bc1bef9380.arrow


load model
initialize ASRLocalUpdate
Generating client training set for client  0 ...


Loading cached processed dataset at dataset/train/cache-e6128932aeaa194b.arrow


load model
lambda =  tensor(0.5000)
Current stage: 2
lambda =  tensor(0.5000)
Current stage: 2


Using amp half precision backend


1  ready to train!


The following columns in the training set  don't have a corresponding argument in `Data2VecAudioForCTC.forward` and have been ignored: text, path, array. If text, path, array are not expected by `Data2VecAudioForCTC.forward`,  you can safely ignore this message.
Using amp half precision backend


0  ready to train!


The following columns in the training set  don't have a corresponding argument in `Data2VecAudioForCTC.forward` and have been ignored: text, path, array. If text, path, array are not expected by `Data2VecAudioForCTC.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 419
  Num Epochs = 1
  Instantaneous batch size per device = 1
  Total train batch size (w. parallel, distributed & accumulation) = 1
  Gradient Accumulation steps = 1
  Total optimization steps = 419


Step,Training Loss,Validation Loss


***** Running training *****
  Num examples = 543
  Num Epochs = 1
  Instantaneous batch size per device = 1
  Total train batch size (w. parallel, distributed & accumulation) = 1
  Gradient Accumulation steps = 1
  Total optimization steps = 543


Step,Training Loss,Validation Loss,Wer
500,538.8379,1651.486816,0.255931


The following columns in the evaluation set  don't have a corresponding argument in `Data2VecAudioForCTC.forward` and have been ignored: text, path, array. If text, path, array are not expected by `Data2VecAudioForCTC.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 800
  Batch size = 1


Training completed. Do not forget to share your model on huggingface.co/models =)


Saving model checkpoint to ./save/data2vec-audio-large-960h_new2_recall_FL_client1_round0/final
Configuration saved in ./save/data2vec-audio-large-960h_new2_recall_FL_client1_round0/final/config.json
Model weights saved in ./save/data2vec-audio-large-960h_new2_recall_FL_client1_round0/final/pytorch_model.bin
Feature extractor saved in ./save/data2vec-audio-large-960h_new2_recall_FL_client1_round0/final/preprocessor_config.json
loading configuration file https://huggingface.co/facebook/data2vec-audio-large-960h/resolve/main/config.json from cache at /home/weitung/.cache/hugg

lambda =  tensor(0.5000)
Current stage: 2


All model checkpoint weights were used when initializing Data2VecAudioForCTC.

All the weights of Data2VecAudioForCTC were initialized from the model checkpoint at ./save/data2vec-audio-large-960h_new2_recall_FL_client1_round0/final.
If your task is similar to the task the model of the checkpoint was trained on, you can already use Data2VecAudioForCTC for predictions without further training.


PID 4098960 Getting  Done


Saving model checkpoint to ./save/data2vec-audio-large-960h_new2_recall_FL_0/checkpoint-500
Configuration saved in ./save/data2vec-audio-large-960h_new2_recall_FL_0/checkpoint-500/config.json
Model weights saved in ./save/data2vec-audio-large-960h_new2_recall_FL_0/checkpoint-500/pytorch_model.bin
Feature extractor saved in ./save/data2vec-audio-large-960h_new2_recall_FL_0/checkpoint-500/preprocessor_config.json


Training completed. Do not forget to share your model on huggingface.co/models =)


Saving model checkpoint to ./save/data2vec-audio-large-960h_new2_recall_FL_client0_round0/final
Configuration saved in ./save/data2vec-audio-large-960h_new2_recall_FL_client0_round0/final/config.json
Model weights saved in ./save/data2vec-audio-large-960h_new2_recall_FL_client0_round0/final/pytorch_model.bin
Feature extractor saved in ./save/data2vec-audio-large-960h_new2_recall_FL_client0_round0/final/preprocessor_config.json
loading configuration file https://huggingface.co/facebook/data2vec-

lambda =  tensor(0.5000)
Current stage: 2


All model checkpoint weights were used when initializing Data2VecAudioForCTC.

All the weights of Data2VecAudioForCTC were initialized from the model checkpoint at ./save/data2vec-audio-large-960h_new2_recall_FL_client0_round0/final.
If your task is similar to the task the model of the checkpoint was trained on, you can already use Data2VecAudioForCTC for predictions without further training.


PID 4098957 Getting  Done


In [8]:
final_result

[(OrderedDict([('weight',
                tensor([[ 0.0102,  0.0241,  0.0207,  ..., -0.0309,  0.0183,  0.0087],
                        [ 0.0156,  0.0400,  0.0283,  ...,  0.0336, -0.0089,  0.0473],
                        [-0.0214, -0.0527, -0.0180,  ...,  0.0068, -0.0065, -0.0028],
                        ...,
                        [-0.0171,  0.0124, -0.0117,  ...,  0.1232,  0.0053,  0.0296],
                        [-0.0945, -0.0258,  0.0358,  ...,  0.0473,  0.1927,  0.0342],
                        [ 0.0357, -0.0304,  0.0311,  ...,  0.0296, -0.0344,  0.1237]])),
               ('bias',
                tensor([ 0.0103, -0.0126,  0.0101,  ...,  0.0506,  0.0309,  0.0172]))]),
  580.0991165976059),
 (OrderedDict([('weight',
                tensor([[ 0.0037,  0.0308,  0.0278,  ..., -0.0259,  0.0077,  0.0016],
                        [ 0.0200,  0.0299,  0.0313,  ...,  0.0398, -0.0063,  0.0532],
                        [-0.0211, -0.0469, -0.0152,  ...,  0.0076, -0.0046, -0.0008],
       

In [9]:
local_weights = []
local_losses = []
for idx in range(len(final_result)):
    w, loss = final_result[idx]
    local_weights.append(w)
    local_losses.append(loss)

In [10]:
global_weights = average_weights(local_weights)
global_weights

OrderedDict([('weight',
              tensor([[ 0.0069,  0.0274,  0.0242,  ..., -0.0284,  0.0130,  0.0051],
                      [ 0.0178,  0.0350,  0.0298,  ...,  0.0367, -0.0076,  0.0502],
                      [-0.0213, -0.0498, -0.0166,  ...,  0.0072, -0.0055, -0.0018],
                      ...,
                      [-0.0214,  0.0140, -0.0163,  ...,  0.1007, -0.0020,  0.0179],
                      [-0.0713, -0.0242,  0.0404,  ...,  0.0330,  0.1507,  0.0183],
                      [ 0.0168, -0.0169,  0.0271,  ...,  0.0346, -0.0279,  0.0942]])),
             ('bias',
              tensor([ 0.0065, -0.0089,  0.0098,  ...,  0.0433,  0.0203,  0.0206]))])

In [11]:
pool.close()
pool.terminate()

epoch = 1
#for epoch in tqdm(range(args.epochs)):                                          # train for given global rounds
#local_weights, local_losses = [], []                                        # weights and losses of training clients of this round
print(f'\n | Global Training Round : {epoch+1} |\n')                        # print current round

m = max(int(args.frac * args.num_users), 1)                                 # num of clients to train, min:1
idxs_users = np.random.choice(range(args.num_users), m, replace=False)      # select by client_id

pool = Pool(int(os.cpu_count()))
final_result = pool.starmap(client_train, [(args, train_dataset, 
                 test_dataset, idx, epoch, global_weights) for idx in idxs_users])


 | Global Training Round : 2 |

 process PID process PID  41147454114746   running running

load from  /mnt/Internal/FedASR/weitung/HuggingFace/Pretrain/saves/data2vec-audio-large-960h_new1_recall/final/
load from  /mnt/Internal/FedASR/weitung/HuggingFace/Pretrain/saves/data2vec-audio-large-960h_new1_recall/final/
lambda =  tensor(0.5000)
Current stage: 2
lambda =  tensor(0.5000)
Current stage: 2


Process ForkPoolWorker-218:
Process ForkPoolWorker-242:
Process ForkPoolWorker-258:
Process ForkPoolWorker-231:
Process ForkPoolWorker-137:
Process ForkPoolWorker-171:
Process ForkPoolWorker-182:
Process ForkPoolWorker-203:
Process ForkPoolWorker-189:
Process ForkPoolWorker-158:
Process ForkPoolWorker-211:
Process ForkPoolWorker-255:
Process ForkPoolWorker-167:
Process ForkPoolWorker-180:
Process ForkPoolWorker-165:
Process ForkPoolWorker-247:
Process ForkPoolWorker-253:
Process ForkPoolWorker-235:
Process ForkPoolWorker-205:
Process ForkPoolWorker-250:
Process ForkPoolWorker-199:
Process ForkPoolWorker-184:
Process ForkPoolWorker-142:
Process ForkPoolWorker-221:
Process ForkPoolWorker-179:
Process ForkPoolWorker-174:
Process ForkPoolWorker-196:
Process ForkPoolWorker-168:
Process ForkPoolWorker-197:
Process ForkPoolWorker-236:
Process ForkPoolWorker-140:
Process ForkPoolWorker-181:
Process ForkPoolWorker-188:
Process ForkPoolWorker-183:
Process ForkPoolWorker-254:
Process ForkPoolWork

KeyboardInterrupt: 

Process ForkPoolWorker-131:
Process ForkPoolWorker-166:
Process ForkPoolWorker-170:
Process ForkPoolWorker-145:
Process ForkPoolWorker-237:
Process ForkPoolWorker-141:
Process ForkPoolWorker-223:
Process ForkPoolWorker-150:
Process ForkPoolWorker-243:
Process ForkPoolWorker-138:
Process ForkPoolWorker-163:
Process ForkPoolWorker-202:
Process ForkPoolWorker-192:
Process ForkPoolWorker-177:
Process ForkPoolWorker-146:
Process ForkPoolWorker-227:
Process ForkPoolWorker-173:
Process ForkPoolWorker-210:
Process ForkPoolWorker-134:
Process ForkPoolWorker-194:
Process ForkPoolWorker-241:
Process ForkPoolWorker-248:
Process ForkPoolWorker-234:
Process ForkPoolWorker-198:
Process ForkPoolWorker-228:
Process ForkPoolWorker-208:
Process ForkPoolWorker-161:
Process ForkPoolWorker-217:
Process ForkPoolWorker-144:
Process ForkPoolWorker-246:
Process ForkPoolWorker-164:
Process ForkPoolWorker-245:
Process ForkPoolWorker-136:
Process ForkPoolWorker-229:
Process ForkPoolWorker-143:
Process ForkPoolWork

  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/mult

  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/mult

  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/pyt

  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/pyt

  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/pool.py", line 114, in worker
    task = get()
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/pool.py", line 114, in worker
    task = get()
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/pool.py", line 114, in worker
    task = get()
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/pool.py", line 114, in worker
    task = get()
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/pool.py", line 114, in worker
    task = get()
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/pool.py", line 114, in worker
    task = get()
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/pool.py", line 114, in worker
    task = get()
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/pool.py", line 114, in 

  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/pool.py", line 114, in worker
    task = get()
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/queues.py", line 355, in get
    with self._rlock:
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/pool.py", line 114, in worker
    task = get()
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/pool.py", line 114, in worker
    task = get()
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/pool.py", line 114, in worker
    task = get()
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/pool.py", line 114, in worker
    task = get()
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/pool.py", line 114, in worker
    task = get()
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/pool.py", line 114,

  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/queues.py", line 355, in get
    with self._rlock:
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/queues.py", line 355, in get
    with self._rlock:
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/queues.py", line 355, in get
    with self._rlock:
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/queues.py", line 355, in get
    with self._rlock:
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/queues.py", line 355, in get
    with self._rlock:
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/queues.py", line 355, in get
    with self._rlock:
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/queues.py", line 355, in get
    with self._rlock:
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiproces

  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/queues.py", line 355, in get
    with self._rlock:
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/queues.py", line 355, in get
    with self._rlock:
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/queues.py", line 355, in get
    with self._rlock:
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/queues.py", line 355, in get
    with self._rlock:
 

  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/pyt

KeyboardInterrupt
KeyboardInterrupt
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
KeyboardInterrupt
KeyboardInterrupt
KeyboardInterrupt
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
KeyboardInterrupt
KeyboardInterrupt
KeyboardInterrupt
KeyboardInterrupt
KeyboardInterrupt
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
KeyboardInterrupt
KeyboardInterrupt

  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/queues.py", line 355, in get
    with self._rlock:
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/queues.py", line 355, in get
    with self._rlock:
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/queues.py", line 355, in get
    with self._rlock:
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/pool.py", line 114, in worker
    task = get()
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/pool.py", line 114, in worker
    task = get()
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/pool.py", line 114, in worker
    task = get()
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/pool.py", line 114, in worker
    task = get()
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/process.py"

  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/queues.py", line 355, in get
    with self._rlock:
  File "/home/weitung/.conda/envs/flwr-huggingface/lib/python3.8/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
KeyboardInterrupt


### 跑多個round

In [11]:
train_loss, test_wer = [], []
val_acc_list, net_list = [], []
cv_loss, cv_acc = [], []
print_every = 2
val_loss_pre, counter = 0, 0
global_weights = None                                                           # initial global_weights
for epoch in tqdm(range(args.epochs)):                                          # train for given global rounds
    #local_weights, local_losses = [], []                                        # weights and losses of training clients of this round
    print(f'\n | Global Training Round : {epoch+1} |\n')                        # print current round

    m = max(int(args.frac * args.num_users), 1)                                 # num of clients to train, min:1
    idxs_users = np.random.choice(range(args.num_users), m, replace=False)      # select by client_id

    pool = Pool(int(os.cpu_count()))
    final_result = pool.starmap(client_train, [(args, train_dataset, logger, test_dataset, idx, epoch, global_weights) for idx in idxs_users])
    
    local_weights = []
    local_losses = []
    for idx in len(final_result):
        w, loss = final_result[idx]
        local_weights.append(w)
        local_losses.append(loss)

    print("local weights: ", local_weights)
    # get global weights by averaging local weights
    global_weights = average_weights(local_weights)
    print("global wegiths: ", global_weights)

    # update global weights
    #global_model.load_state_dict(global_weights)

    loss_avg = sum(local_losses) / len(local_losses)                # average losses from participated client
    train_loss.append(loss_avg)                                     # save loss for this round

  0%|                                                                                                                | 0/2 [00:00<?, ?it/s]


 | Global Training Round : 1 |



  0%|                                                                                                                | 0/2 [00:02<?, ?it/s]


RuntimeError: Queue objects should only be shared between processes through inheritance

## 以下其他

In [11]:
# Training
train_loss, test_wer = [], []
val_acc_list, net_list = [], []
cv_loss, cv_acc = [], []
print_every = 2
val_loss_pre, counter = 0, 0
global_weights = None    
for epoch in range(2):
    #for epoch in tqdm(range(args.epochs)):                                          # train for given global rounds
    #local_weights, local_losses = [], []                                        # weights and losses of training clients of this round
    print(f'\n | Global Training Round : {epoch+1} |\n')                        # print current round

    m = max(int(args.frac * args.num_users), 1)                                 # num of clients to train, min:1
    idxs_users = np.random.choice(range(args.num_users), m, replace=False)      # select by client_id

    manager = multiprocessing.Manager()
    return_dict = manager.dict()
    jobs = []
    #pipe_list = []
    for idx in idxs_users:                                                      # for each training client

        print("start of client #", idx)
        #recv_end, send_end = multiprocessing.Pipe(False)
        p = multiprocessing.Process(target=client_train, args=(return_dict,
                args, train_dataset, logger, test_dataset, idx, epoch, global_weights))

        #p = multiprocessing.Process(target=client_get_weight, args=(return_dict, args, idx, epoch))
        jobs.append(p)
        #pipe_list.append(recv_end)
        p.start()

        #local_weights.append(copy.deepcopy(w))                      # save weight for this client
        #local_losses.append(copy.deepcopy(loss))                    # save loss for this client
    for proc in jobs:
        proc.join()
        #proc.close()


 | Global Training Round : 1 |

start of client # 1
start of client # 0
lambda =  tensor(0.5000)
Current stage: 2
model loaded
lambda =  tensor(0.5000)
Current stage: 2
model loaded
initialize ASRLocalUpdate
Generating client training set for client  0 ...


Loading cached processed dataset at dataset/train/cache-e6128932aeaa194b.arrow


load model
initialize ASRLocalUpdate
Generating client training set for client  1 ...


Loading cached processed dataset at dataset/train/cache-fb08d3bc1bef9380.arrow


load model
lambda =  tensor(0.5000)
Current stage: 2
lambda =  tensor(0.5000)
Current stage: 2


Using amp half precision backend


0  ready to train!


The following columns in the training set  don't have a corresponding argument in `Data2VecAudioForCTC.forward` and have been ignored: array, text, path. If array, text, path are not expected by `Data2VecAudioForCTC.forward`,  you can safely ignore this message.
Using amp half precision backend


1  ready to train!


The following columns in the training set  don't have a corresponding argument in `Data2VecAudioForCTC.forward` and have been ignored: array, text, path. If array, text, path are not expected by `Data2VecAudioForCTC.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 419
  Num Epochs = 1
  Instantaneous batch size per device = 1
  Total train batch size (w. parallel, distributed & accumulation) = 1
  Gradient Accumulation steps = 1
  Total optimization steps = 419
***** Running training *****
  Num examples = 543
  Num Epochs = 1
  Instantaneous batch size per device = 1
  Total train batch size (w. parallel, distributed & accumulation) = 1
  Gradient Accumulation steps = 1
  Total optimization steps = 543


Step,Training Loss,Validation Loss


Step,Training Loss,Validation Loss,Wer
500,538.8379,1651.486816,0.255931




Training completed. Do not forget to share your model on huggingface.co/models =)


Saving model checkpoint to ./save/data2vec-audio-large-960h_new2_recall_FL_client1_round0/final
Configuration saved in ./save/data2vec-audio-large-960h_new2_recall_FL_client1_round0/final/config.json
Model weights saved in ./save/data2vec-audio-large-960h_new2_recall_FL_client1_round0/final/pytorch_model.bin
Feature extractor saved in ./save/data2vec-audio-large-960h_new2_recall_FL_client1_round0/final/preprocessor_config.json
loading configuration file https://huggingface.co/facebook/data2vec-audio-large-960h/resolve/main/config.json from cache at /home/weitung/.cache/huggingface/transformers/a5e291023d6dd7ec0034390cee6d97f07e340fb24c68c7b5f3ec8d017a6fd29d.ed9b9e83fb80348aa91a073138fc7a0f44e669fc412c9c4bc98857f45bfd4330
Model config Data2VecAudioConfig {
  "activation_dropout": 0.1,
  "adapter_kernel_size": 3,
  "adapter_stride": 2,
  "add_adapter": false,
  "apply_spec_augment": true,
  "architectur

lambda =  tensor(0.5000)
Current stage: 2


All model checkpoint weights were used when initializing Data2VecAudioForCTC.

All the weights of Data2VecAudioForCTC were initialized from the model checkpoint at ./save/data2vec-audio-large-960h_new2_recall_FL_client1_round0/final.
If your task is similar to the task the model of the checkpoint was trained on, you can already use Data2VecAudioForCTC for predictions without further training.
The following columns in the evaluation set  don't have a corresponding argument in `Data2VecAudioForCTC.forward` and have been ignored: array, text, path. If array, text, path are not expected by `Data2VecAudioForCTC.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 800
  Batch size = 1
Saving model checkpoint to ./save/data2vec-audio-large-960h_new2_recall_FL_0/checkpoint-500
Configuration saved in ./save/data2vec-audio-large-960h_new2_recall_FL_0/checkpoint-500/config.json
Model weights saved in ./save/data2vec-audio-large-960h_new2_recall_FL_0/check

lambda =  tensor(0.5000)
Current stage: 2


All model checkpoint weights were used when initializing Data2VecAudioForCTC.

All the weights of Data2VecAudioForCTC were initialized from the model checkpoint at ./save/data2vec-audio-large-960h_new2_recall_FL_client0_round0/final.
If your task is similar to the task the model of the checkpoint was trained on, you can already use Data2VecAudioForCTC for predictions without further training.



 | Global Training Round : 2 |

start of client # 1
start of client # 0
lambda =  tensor(0.5000)
Current stage: 2
lambda =  tensor(0.5000)
Current stage: 2
model loaded
model loaded
initialize ASRLocalUpdate
Generating client training set for client  0 ...


Loading cached processed dataset at dataset/train/cache-e6128932aeaa194b.arrow


load model
initialize ASRLocalUpdate
Generating client training set for client  1 ...


Loading cached processed dataset at dataset/train/cache-fb08d3bc1bef9380.arrow


load model
lambda =  tensor(0.5000)
Current stage: 2
lambda =  tensor(0.5000)
Current stage: 2


Using amp half precision backend


0 ready to train! 


The following columns in the training set  don't have a corresponding argument in `Data2VecAudioForCTC.forward` and have been ignored: array, text, path. If array, text, path are not expected by `Data2VecAudioForCTC.forward`,  you can safely ignore this message.
Using amp half precision backend


1  ready to train!


The following columns in the training set  don't have a corresponding argument in `Data2VecAudioForCTC.forward` and have been ignored: array, text, path. If array, text, path are not expected by `Data2VecAudioForCTC.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 419
  Num Epochs = 1
  Instantaneous batch size per device = 1
  Total train batch size (w. parallel, distributed & accumulation) = 1
  Gradient Accumulation steps = 1
  Total optimization steps = 419
***** Running training *****
  Num examples = 543
  Num Epochs = 1
  Instantaneous batch size per device = 1
  Total train batch size (w. parallel, distributed & accumulation) = 1
  Gradient Accumulation steps = 1
  Total optimization steps = 543


Step,Training Loss,Validation Loss,Wer
500,538.8379,1651.486816,0.255931


Step,Training Loss,Validation Loss




Training completed. Do not forget to share your model on huggingface.co/models =)


Saving model checkpoint to ./save/data2vec-audio-large-960h_new2_recall_FL_client1_round1/final
Configuration saved in ./save/data2vec-audio-large-960h_new2_recall_FL_client1_round1/final/config.json
The following columns in the evaluation set  don't have a corresponding argument in `Data2VecAudioForCTC.forward` and have been ignored: array, text, path. If array, text, path are not expected by `Data2VecAudioForCTC.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 800
  Batch size = 1
Model weights saved in ./save/data2vec-audio-large-960h_new2_recall_FL_client1_round1/final/pytorch_model.bin
Feature extractor saved in ./save/data2vec-audio-large-960h_new2_recall_FL_client1_round1/final/preprocessor_config.json
loading configuration file https://huggingface.co/facebook/data2vec-audio-large-960h/resolve/main/config.json from cache at /home/weitung/.cache/hugg

lambda =  tensor(0.5000)
Current stage: 2


All model checkpoint weights were used when initializing Data2VecAudioForCTC.

All the weights of Data2VecAudioForCTC were initialized from the model checkpoint at ./save/data2vec-audio-large-960h_new2_recall_FL_client1_round1/final.
If your task is similar to the task the model of the checkpoint was trained on, you can already use Data2VecAudioForCTC for predictions without further training.
Saving model checkpoint to ./save/data2vec-audio-large-960h_new2_recall_FL_0/checkpoint-500
Configuration saved in ./save/data2vec-audio-large-960h_new2_recall_FL_0/checkpoint-500/config.json
Model weights saved in ./save/data2vec-audio-large-960h_new2_recall_FL_0/checkpoint-500/pytorch_model.bin
Feature extractor saved in ./save/data2vec-audio-large-960h_new2_recall_FL_0/checkpoint-500/preprocessor_config.json


Training completed. Do not forget to share your model on huggingface.co/models =)


Saving model checkpoint to ./save/data2vec-audio-large-960h_new2_recall_FL_client0_round1/final
Configu

lambda =  tensor(0.5000)
Current stage: 2


All model checkpoint weights were used when initializing Data2VecAudioForCTC.

All the weights of Data2VecAudioForCTC were initialized from the model checkpoint at ./save/data2vec-audio-large-960h_new2_recall_FL_client0_round1/final.
If your task is similar to the task the model of the checkpoint was trained on, you can already use Data2VecAudioForCTC for predictions without further training.


In [12]:
print(return_dict.keys())

['1', '0']


In [None]:
# 
for idx in idxs_users:                                                      # for each training client
    # get model weights from saved model

In [13]:
local_weights = []
local_losses = []
for key in return_dict.keys():
    w, loss = return_dict[key]
    local_weights.append(w)
    local_losses.append(loss)

In [14]:
global_weights = average_weights(local_weights)
global_weights

OrderedDict([('weight',
              tensor([[ 0.0069,  0.0274,  0.0242,  ..., -0.0284,  0.0130,  0.0051],
                      [ 0.0178,  0.0350,  0.0298,  ...,  0.0367, -0.0076,  0.0502],
                      [-0.0213, -0.0498, -0.0166,  ...,  0.0072, -0.0055, -0.0018],
                      ...,
                      [-0.0214,  0.0140, -0.0163,  ...,  0.1007, -0.0020,  0.0179],
                      [-0.0713, -0.0242,  0.0404,  ...,  0.0330,  0.1507,  0.0183],
                      [ 0.0168, -0.0169,  0.0271,  ...,  0.0346, -0.0279,  0.0942]])),
             ('bias',
              tensor([ 0.0065, -0.0089,  0.0098,  ...,  0.0433,  0.0203,  0.0206]))])

In [17]:
if args.model == 'data2vec':
    mask_time_prob = 0                                                          # change config to avoid training stopping
    config = Data2VecAudioConfig.from_pretrained(args.pretrain_name, mask_time_prob=mask_time_prob)
    model = Data2VecAudioForCTC.from_pretrained(args.model_in_path, config=config, args=args)
                                                                                # load/initialize global model
    model.config.ctc_zero_infinity = True                                       # to avoid inf values

    global_model = copy.deepcopy(model.arbitrator)                              # only has global toggling network
    if global_weights != None:                                                  # if given global_weights
        global_model.load_state_dict(global_weights)                            # load it
    #else:
    #    # copy weights
    #    global_weights = copy.deepcopy(global_model.state_dict())                       # save global weight
    processor = Wav2Vec2Processor.from_pretrained(args.pretrain_name)
    data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
else:
    exit('Error: unrecognized model')

# Set the model to train and send it to device.
global_model.to(device)
global_model.train()

lambda =  tensor(0.5000)
Current stage: 2


Linear(in_features=1024, out_features=4096, bias=True)

In [23]:
w = global_model.state_dict()
print(w)

OrderedDict([('weight', tensor([[-0.0123,  0.0295,  0.0211,  ..., -0.0300,  0.0236,  0.0029],
        [ 0.0174,  0.0111,  0.0379,  ...,  0.0432, -0.0009,  0.0588],
        [-0.0086, -0.0458, -0.0061,  ...,  0.0018, -0.0037, -0.0036],
        ...,
        [ 0.0082, -0.0028, -0.0058,  ...,  0.0179, -0.0070, -0.0085],
        [-0.0223, -0.0279,  0.0233,  ...,  0.0101,  0.0206,  0.0180],
        [-0.0271,  0.0189,  0.0437,  ...,  0.0140, -0.0013,  0.0029]])), ('bias', tensor([0., 0., 0.,  ..., 0., 0., 0.]))])


In [25]:
from collections import OrderedDict
w_ret = OrderedDict()
w_ret['weight'] = w['weight'].numpy()
w_ret['bias'] = w['bias'].numpy()
print(w_ret)

OrderedDict([('weight', array([[-0.0123494 ,  0.02952368,  0.02109445, ..., -0.02995213,
         0.02364982,  0.00286387],
       [ 0.01744892,  0.01111791,  0.03789755, ...,  0.04320697,
        -0.00092741,  0.05882485],
       [-0.00859752, -0.04579691, -0.00612044, ...,  0.00179596,
        -0.00365438, -0.00361218],
       ...,
       [ 0.00822815, -0.00279331, -0.00580185, ...,  0.01794216,
        -0.00702918, -0.00849224],
       [-0.02230225, -0.02785204,  0.02326177, ...,  0.01009476,
         0.02058739,  0.01800553],
       [-0.02707957,  0.01885535,  0.04367625, ...,  0.01396277,
        -0.00133464,  0.00289201]], dtype=float32)), ('bias', array([0., 0., 0., ..., 0., 0., 0.], dtype=float32))])


In [11]:
print(w.cpu())
#copy.deepcopy(w.detach().cpu())

AttributeError: 'collections.OrderedDict' object has no attribute 'cpu'

In [12]:
global_model.state_dict()

OrderedDict([('weight',
              tensor([[-0.0123,  0.0295,  0.0211,  ..., -0.0300,  0.0236,  0.0029],
                      [ 0.0174,  0.0111,  0.0379,  ...,  0.0432, -0.0009,  0.0588],
                      [-0.0086, -0.0458, -0.0061,  ...,  0.0018, -0.0037, -0.0036],
                      ...,
                      [ 0.0082, -0.0028, -0.0058,  ...,  0.0179, -0.0070, -0.0085],
                      [-0.0223, -0.0279,  0.0233,  ...,  0.0101,  0.0206,  0.0180],
                      [-0.0271,  0.0189,  0.0437,  ...,  0.0140, -0.0013,  0.0029]])),
             ('bias', tensor([0., 0., 0.,  ..., 0., 0., 0.]))])

In [13]:
manager = multiprocessing.Manager()
return_dict = manager.dict()

In [17]:
return_dict[2] = [global_model.state_dict(), 0.88]

In [18]:
return_dict.values()

[[OrderedDict([('weight',
                tensor([[-0.0123,  0.0295,  0.0211,  ..., -0.0300,  0.0236,  0.0029],
                        [ 0.0174,  0.0111,  0.0379,  ...,  0.0432, -0.0009,  0.0588],
                        [-0.0086, -0.0458, -0.0061,  ...,  0.0018, -0.0037, -0.0036],
                        ...,
                        [ 0.0082, -0.0028, -0.0058,  ...,  0.0179, -0.0070, -0.0085],
                        [-0.0223, -0.0279,  0.0233,  ...,  0.0101,  0.0206,  0.0180],
                        [-0.0271,  0.0189,  0.0437,  ...,  0.0140, -0.0013,  0.0029]])),
               ('bias', tensor([0., 0., 0.,  ..., 0., 0., 0.]))]),
  0.88],
 [OrderedDict([('weight',
                tensor([[-0.0123,  0.0295,  0.0211,  ..., -0.0300,  0.0236,  0.0029],
                        [ 0.0174,  0.0111,  0.0379,  ...,  0.0432, -0.0009,  0.0588],
                        [-0.0086, -0.0458, -0.0061,  ...,  0.0018, -0.0037, -0.0036],
                        ...,
                        [ 0.0082, -0.

# 以下尚未確認

In [11]:
local_weights

[OrderedDict([('weight',
               tensor([[-0.0130,  0.0346,  0.0234,  ..., -0.0382,  0.0211,  0.0051],
                       [ 0.0241,  0.0255,  0.0383,  ...,  0.0432,  0.0020,  0.0541],
                       [-0.0088, -0.0451, -0.0152,  ...,  0.0060, -0.0079, -0.0027],
                       ...,
                       [-0.0066, -0.0090, -0.0239,  ...,  0.0724, -0.0043,  0.0117],
                       [-0.0451, -0.0136,  0.0419,  ...,  0.0208,  0.0978,  0.0114],
                       [-0.0099,  0.0023,  0.0220,  ...,  0.0394, -0.0133,  0.0624]],
                      device='cuda:0')),
              ('bias',
               tensor([-0.0086, -0.0131,  0.0066,  ...,  0.0325, -0.0189,  0.0216],
                      device='cuda:0'))]),
 OrderedDict([('weight',
               tensor([[-0.0130,  0.0346,  0.0234,  ..., -0.0382,  0.0211,  0.0051],
                       [ 0.0241,  0.0255,  0.0383,  ...,  0.0432,  0.0020,  0.0541],
                       [-0.0088, -0.0451, -0.0152,

In [12]:
# get global weights by averaging local weights
global_weights = average_weights(local_weights)
global_weights

OrderedDict([('weight',
              tensor([[-0.0130,  0.0346,  0.0234,  ..., -0.0382,  0.0211,  0.0051],
                      [ 0.0241,  0.0255,  0.0383,  ...,  0.0432,  0.0020,  0.0541],
                      [-0.0088, -0.0451, -0.0152,  ...,  0.0060, -0.0079, -0.0027],
                      ...,
                      [-0.0066, -0.0090, -0.0239,  ...,  0.0724, -0.0043,  0.0117],
                      [-0.0451, -0.0136,  0.0419,  ...,  0.0208,  0.0978,  0.0114],
                      [-0.0099,  0.0023,  0.0220,  ...,  0.0394, -0.0133,  0.0624]],
                     device='cuda:0')),
             ('bias',
              tensor([-0.0086, -0.0131,  0.0066,  ...,  0.0325, -0.0189,  0.0216],
                     device='cuda:0'))])

In [13]:
# update global weights
global_model.load_state_dict(global_weights)

loss_avg = sum(local_losses) / len(local_losses)                # average losses from participated client
train_loss.append(loss_avg)     
train_loss

[538.7666110436893]

In [14]:
list_wer = []
global_model.eval()
for c in range(args.num_users):                                 # for ALL users
    local_model = ASRLocalUpdate(args=args, dataset=train_dataset, logger=logger,
                        data_collator=data_collator, global_test_dataset=test_dataset, 
                        processor=processor)
                                                                # initial dataset of current client
    wer = local_model.inference(global_arbitrator=global_model)       # get acc. & total loss on clients' test set
    list_wer.append(wer)                                        # save acc.
    #list_loss.append(loss)                                      # save loss
train_accuracy.append(sum(list_wer)/len(list_wer))              # acc average over all clients

loading configuration file https://huggingface.co/facebook/data2vec-audio-large-960h/resolve/main/config.json from cache at /home/weitung/.cache/huggingface/transformers/a5e291023d6dd7ec0034390cee6d97f07e340fb24c68c7b5f3ec8d017a6fd29d.ed9b9e83fb80348aa91a073138fc7a0f44e669fc412c9c4bc98857f45bfd4330
Model config Data2VecAudioConfig {
  "activation_dropout": 0.1,
  "adapter_kernel_size": 3,
  "adapter_stride": 2,
  "add_adapter": false,
  "apply_spec_augment": true,
  "architectures": [
    "Data2VecAudioForCTC"
  ],
  "attention_dropout": 0.1,
  "bos_token_id": 1,
  "classifier_proj_size": 256,
  "codevector_dim": 768,
  "contrastive_logits_temperature": 0.1,
  "conv_bias": false,
  "conv_dim": [
    512,
    512,
    512,
    512,
    512,
    512,
    512
  ],
  "conv_kernel": [
    10,
    3,
    3,
    3,
    3,
    2,
    2
  ],
  "conv_pos_kernel_size": 19,
  "conv_stride": [
    5,
    2,
    2,
    2,
    2,
    2,
    2
  ],
  "ctc_loss_reduction": "sum",
  "ctc_zero_infinity":

load global!!!!!!!!!!!!! should change to client model


All model checkpoint weights were used when initializing Data2VecAudioForCTC_eval.

All the weights of Data2VecAudioForCTC_eval were initialized from the model checkpoint at /mnt/Internal/FedASR/weitung/HuggingFace/Pretrain/saves/data2vec-audio-large-960h_new1_recall/final/.
If your task is similar to the task the model of the checkpoint was trained on, you can already use Data2VecAudioForCTC_eval for predictions without further training.


lambda =  tensor(0.5000)
Current stage: 2
before:  OrderedDict([('weight', tensor([[-0.0123,  0.0295,  0.0211,  ..., -0.0300,  0.0236,  0.0029],
        [ 0.0174,  0.0111,  0.0379,  ...,  0.0432, -0.0009,  0.0588],
        [-0.0086, -0.0458, -0.0061,  ...,  0.0018, -0.0037, -0.0036],
        ...,
        [ 0.0082, -0.0028, -0.0058,  ...,  0.0179, -0.0070, -0.0085],
        [-0.0223, -0.0279,  0.0233,  ...,  0.0101,  0.0206,  0.0180],
        [-0.0271,  0.0189,  0.0437,  ...,  0.0140, -0.0013,  0.0029]])), ('bias', tensor([0., 0., 0.,  ..., 0., 0., 0.]))])
after:  OrderedDict([('weight', tensor([[-0.0130,  0.0346,  0.0234,  ..., -0.0382,  0.0211,  0.0051],
        [ 0.0241,  0.0255,  0.0383,  ...,  0.0432,  0.0020,  0.0541],
        [-0.0088, -0.0451, -0.0152,  ...,  0.0060, -0.0079, -0.0027],
        ...,
        [-0.0066, -0.0090, -0.0239,  ...,  0.0724, -0.0043,  0.0117],
        [-0.0451, -0.0136,  0.0419,  ...,  0.0208,  0.0978,  0.0114],
        [-0.0099,  0.0023,  0.0220,  ...,  

loading configuration file https://huggingface.co/facebook/data2vec-audio-large-960h/resolve/main/config.json from cache at /home/weitung/.cache/huggingface/transformers/a5e291023d6dd7ec0034390cee6d97f07e340fb24c68c7b5f3ec8d017a6fd29d.ed9b9e83fb80348aa91a073138fc7a0f44e669fc412c9c4bc98857f45bfd4330
Model config Data2VecAudioConfig {
  "activation_dropout": 0.1,
  "adapter_kernel_size": 3,
  "adapter_stride": 2,
  "add_adapter": false,
  "apply_spec_augment": true,
  "architectures": [
    "Data2VecAudioForCTC"
  ],
  "attention_dropout": 0.1,
  "bos_token_id": 1,
  "classifier_proj_size": 256,
  "codevector_dim": 768,
  "contrastive_logits_temperature": 0.1,
  "conv_bias": false,
  "conv_dim": [
    512,
    512,
    512,
    512,
    512,
    512,
    512
  ],
  "conv_kernel": [
    10,
    3,
    3,
    3,
    3,
    2,
    2
  ],
  "conv_pos_kernel_size": 19,
  "conv_stride": [
    5,
    2,
    2,
    2,
    2,
    2,
    2
  ],
  "ctc_loss_reduction": "sum",
  "ctc_zero_infinity":

load global!!!!!!!!!!!!! should change to client model


All model checkpoint weights were used when initializing Data2VecAudioForCTC_eval.

All the weights of Data2VecAudioForCTC_eval were initialized from the model checkpoint at /mnt/Internal/FedASR/weitung/HuggingFace/Pretrain/saves/data2vec-audio-large-960h_new1_recall/final/.
If your task is similar to the task the model of the checkpoint was trained on, you can already use Data2VecAudioForCTC_eval for predictions without further training.


lambda =  tensor(0.5000)
Current stage: 2
before:  OrderedDict([('weight', tensor([[-0.0123,  0.0295,  0.0211,  ..., -0.0300,  0.0236,  0.0029],
        [ 0.0174,  0.0111,  0.0379,  ...,  0.0432, -0.0009,  0.0588],
        [-0.0086, -0.0458, -0.0061,  ...,  0.0018, -0.0037, -0.0036],
        ...,
        [ 0.0082, -0.0028, -0.0058,  ...,  0.0179, -0.0070, -0.0085],
        [-0.0223, -0.0279,  0.0233,  ...,  0.0101,  0.0206,  0.0180],
        [-0.0271,  0.0189,  0.0437,  ...,  0.0140, -0.0013,  0.0029]])), ('bias', tensor([0., 0., 0.,  ..., 0., 0., 0.]))])
after:  OrderedDict([('weight', tensor([[-0.0130,  0.0346,  0.0234,  ..., -0.0382,  0.0211,  0.0051],
        [ 0.0241,  0.0255,  0.0383,  ...,  0.0432,  0.0020,  0.0541],
        [-0.0088, -0.0451, -0.0152,  ...,  0.0060, -0.0079, -0.0027],
        ...,
        [-0.0066, -0.0090, -0.0239,  ...,  0.0724, -0.0043,  0.0117],
        [-0.0451, -0.0136,  0.0419,  ...,  0.0208,  0.0978,  0.0114],
        [-0.0099,  0.0023,  0.0220,  ...,  

# 以下尚未確認

In [17]:


# print global training loss after every 'i' rounds
if (1+1) % print_every == 0:
    print(f' \nAvg Training Stats after {1+1} global rounds:')
    print(f'Training Loss : {np.mean(np.array(train_loss))}')
    print('Train Accuracy: {:.2f}% \n'.format(100*train_accuracy[-1])) # on testing set of clients though

 
Avg Training Stats after 2 global rounds:
Training Loss : 538.7666110436893
Train Accuracy: 25.77% 



In [16]:
from transformers.training_args import TrainingArguments
from transformers import Trainer
from typing import Any, Dict, List, Optional, Union
import json

LOG_DIR = './'#log/'
from datasets import load_metric
wer_metric = load_metric("wer")
def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

class CustomTrainer(Trainer):    
    def compute_loss(self, model, inputs, return_outputs=False):
            """
            How the loss is computed by Trainer. By default, all models return the loss in the first element.
            Subclass and override for custom behavior.
            """
            #dementia_labels = inputs.pop("dementia_labels") # pop 出來就會不見?
            
            if self.label_smoother is not None and "labels" in inputs:
                labels = inputs.pop("labels")
            else:
                labels = None
            
            outputs = model(**inputs)
            # Save past state if it exists
            # TODO: this needs to be fixed and made cleaner later.
            if self.args.past_index >= 0:
                self._past = outputs[self.args.past_index]

            if labels is not None:
                loss = self.label_smoother(outputs, labels)
            else:
                # We don't use .loss here since the model may return tuples instead of ModelOutput.
                loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]

            return (loss, outputs) if return_outputs else loss
    def log(self, logs: Dict[str, float]) -> None:
        """
        Log `logs` on the various objects watching training.
        Subclass and override this method to inject custom behavior.
        Args:
            logs (`Dict[str, float]`):
                The values to log.
        """
        if self.state.epoch is not None:
            logs["epoch"] = round(self.state.epoch, 2)

        output = {**logs, **{"step": self.state.global_step}}
        self.state.log_history.append(output)
        
        # write to txt file
        file_object = open(LOG_DIR + args.log_path, 'a')
        # Append at the end of file
        file_object.write(json.dumps(output) + '\n')
        # Close the file
        file_object.close()

        self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)


training_args = TrainingArguments(
    output_dir=args.model_out_path + "just_for_testing",
    group_by_length=True,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    evaluation_strategy="steps",
    num_train_epochs=2,
    fp16=True,
    gradient_checkpointing=True, 
    save_steps=500,
    eval_steps=500,
    logging_steps=500,
    learning_rate=1e-4,
    weight_decay=0.005,
    warmup_steps=1000,
    save_total_limit=2,
    log_level='debug',
    logging_strategy="steps",
    #adafactor=True,            # default:false. Whether or not to use transformers.Adafactor optimizer instead of transformers.AdamW
    #fp16_full_eval=True,      # to save memory
    #max_grad_norm=0.5
)
trainer = CustomTrainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=processor.feature_extractor,
)
trainer.train()

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).
Using amp half precision backend
The following columns in the training set  don't have a corresponding argument in `Data2VecAudioForCTC.forward` and have been ignored: array, text, path. If array, text, path are not expected by `Data2VecAudioForCTC.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 1868
  Num Epochs = 2
  Instantaneous batch size per device = 1
  Total train batch size (w. parallel, distributed & accumulation) = 1
  Gradient Accumulation steps = 1
  Total optimization steps = 3736


loss: cel


Step,Training Loss,Validation Loss,Wer
500,7.1109,5.394214,0.259387
1000,2.2982,0.730392,0.25923
1500,0.9706,1.098743,0.25923
2000,1.1222,1.066993,0.25923
2500,1.1327,1.041327,0.260644
3000,1.1501,1.019284,0.257816
3500,1.0259,1.029965,0.258602


loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel


The following columns in the evaluation set  don't have a corresponding argument in `Data2VecAudioForCTC.forward` and have been ignored: array, text, path. If array, text, path are not expected by `Data2VecAudioForCTC.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 800
  Batch size = 1


loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel


Saving model checkpoint to ./saves/wav2vec2-base-960h_linear_GRLjust_for_testing/checkpoint-500
Configuration saved in ./saves/wav2vec2-base-960h_linear_GRLjust_for_testing/checkpoint-500/config.json
Model weights saved in ./saves/wav2vec2-base-960h_linear_GRLjust_for_testing/checkpoint-500/pytorch_model.bin
Feature extractor saved in ./saves/wav2vec2-base-960h_linear_GRLjust_for_testing/checkpoint-500/preprocessor_config.json


loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel


The following columns in the evaluation set  don't have a corresponding argument in `Data2VecAudioForCTC.forward` and have been ignored: array, text, path. If array, text, path are not expected by `Data2VecAudioForCTC.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 800
  Batch size = 1


loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel


Saving model checkpoint to ./saves/wav2vec2-base-960h_linear_GRLjust_for_testing/checkpoint-1000
Configuration saved in ./saves/wav2vec2-base-960h_linear_GRLjust_for_testing/checkpoint-1000/config.json
Model weights saved in ./saves/wav2vec2-base-960h_linear_GRLjust_for_testing/checkpoint-1000/pytorch_model.bin
Feature extractor saved in ./saves/wav2vec2-base-960h_linear_GRLjust_for_testing/checkpoint-1000/preprocessor_config.json


loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel


The following columns in the evaluation set  don't have a corresponding argument in `Data2VecAudioForCTC.forward` and have been ignored: array, text, path. If array, text, path are not expected by `Data2VecAudioForCTC.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 800
  Batch size = 1


loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel


Saving model checkpoint to ./saves/wav2vec2-base-960h_linear_GRLjust_for_testing/checkpoint-1500
Configuration saved in ./saves/wav2vec2-base-960h_linear_GRLjust_for_testing/checkpoint-1500/config.json
Model weights saved in ./saves/wav2vec2-base-960h_linear_GRLjust_for_testing/checkpoint-1500/pytorch_model.bin
Feature extractor saved in ./saves/wav2vec2-base-960h_linear_GRLjust_for_testing/checkpoint-1500/preprocessor_config.json
Deleting older checkpoint [saves/wav2vec2-base-960h_linear_GRLjust_for_testing/checkpoint-500] due to args.save_total_limit


loss: cel




loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel


The following columns in the evaluation set  don't have a corresponding argument in `Data2VecAudioForCTC.forward` and have been ignored: array, text, path. If array, text, path are not expected by `Data2VecAudioForCTC.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 800
  Batch size = 1


loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel


Saving model checkpoint to ./saves/wav2vec2-base-960h_linear_GRLjust_for_testing/checkpoint-2000
Configuration saved in ./saves/wav2vec2-base-960h_linear_GRLjust_for_testing/checkpoint-2000/config.json
Model weights saved in ./saves/wav2vec2-base-960h_linear_GRLjust_for_testing/checkpoint-2000/pytorch_model.bin
Feature extractor saved in ./saves/wav2vec2-base-960h_linear_GRLjust_for_testing/checkpoint-2000/preprocessor_config.json
Deleting older checkpoint [saves/wav2vec2-base-960h_linear_GRLjust_for_testing/checkpoint-1000] due to args.save_total_limit


loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel


The following columns in the evaluation set  don't have a corresponding argument in `Data2VecAudioForCTC.forward` and have been ignored: array, text, path. If array, text, path are not expected by `Data2VecAudioForCTC.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 800
  Batch size = 1


loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel


Saving model checkpoint to ./saves/wav2vec2-base-960h_linear_GRLjust_for_testing/checkpoint-2500
Configuration saved in ./saves/wav2vec2-base-960h_linear_GRLjust_for_testing/checkpoint-2500/config.json
Model weights saved in ./saves/wav2vec2-base-960h_linear_GRLjust_for_testing/checkpoint-2500/pytorch_model.bin
Feature extractor saved in ./saves/wav2vec2-base-960h_linear_GRLjust_for_testing/checkpoint-2500/preprocessor_config.json
Deleting older checkpoint [saves/wav2vec2-base-960h_linear_GRLjust_for_testing/checkpoint-1500] due to args.save_total_limit


loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel


The following columns in the evaluation set  don't have a corresponding argument in `Data2VecAudioForCTC.forward` and have been ignored: array, text, path. If array, text, path are not expected by `Data2VecAudioForCTC.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 800
  Batch size = 1


loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel


Saving model checkpoint to ./saves/wav2vec2-base-960h_linear_GRLjust_for_testing/checkpoint-3000
Configuration saved in ./saves/wav2vec2-base-960h_linear_GRLjust_for_testing/checkpoint-3000/config.json
Model weights saved in ./saves/wav2vec2-base-960h_linear_GRLjust_for_testing/checkpoint-3000/pytorch_model.bin
Feature extractor saved in ./saves/wav2vec2-base-960h_linear_GRLjust_for_testing/checkpoint-3000/preprocessor_config.json
Deleting older checkpoint [saves/wav2vec2-base-960h_linear_GRLjust_for_testing/checkpoint-2000] due to args.save_total_limit


loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel


The following columns in the evaluation set  don't have a corresponding argument in `Data2VecAudioForCTC.forward` and have been ignored: array, text, path. If array, text, path are not expected by `Data2VecAudioForCTC.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 800
  Batch size = 1


loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel


Saving model checkpoint to ./saves/wav2vec2-base-960h_linear_GRLjust_for_testing/checkpoint-3500
Configuration saved in ./saves/wav2vec2-base-960h_linear_GRLjust_for_testing/checkpoint-3500/config.json
Model weights saved in ./saves/wav2vec2-base-960h_linear_GRLjust_for_testing/checkpoint-3500/pytorch_model.bin
Feature extractor saved in ./saves/wav2vec2-base-960h_linear_GRLjust_for_testing/checkpoint-3500/preprocessor_config.json
Deleting older checkpoint [saves/wav2vec2-base-960h_linear_GRLjust_for_testing/checkpoint-2500] due to args.save_total_limit


loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel
loss: cel




Training completed. Do not forget to share your model on huggingface.co/models =)




TrainOutput(global_step=3736, training_loss=2.053420505891265, metrics={'train_runtime': 905.8756, 'train_samples_per_second': 4.124, 'train_steps_per_second': 4.124, 'total_flos': 4.252288675625975e+17, 'train_loss': 2.053420505891265, 'epoch': 2.0})

In [22]:
trainer.model.arbitrator.state_dict()

OrderedDict([('weight',
              tensor([[-0.0123,  0.0295,  0.0211,  ..., -0.0300,  0.0236,  0.0029],
                      [ 0.0174,  0.0111,  0.0379,  ...,  0.0432, -0.0009,  0.0588],
                      [-0.0086, -0.0458, -0.0061,  ...,  0.0018, -0.0037, -0.0036],
                      ...,
                      [ 0.0082, -0.0028, -0.0058,  ...,  0.0179, -0.0070, -0.0085],
                      [-0.0223, -0.0279,  0.0233,  ...,  0.0101,  0.0206,  0.0180],
                      [-0.0271,  0.0189,  0.0437,  ...,  0.0140, -0.0013,  0.0029]],
                     device='cuda:0')),
             ('bias',
              tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'))])

In [19]:
trainer.state.log_history

[{'loss': 7.1109, 'learning_rate': 4.99e-05, 'epoch': 0.27, 'step': 500},
 {'eval_loss': 5.394213676452637,
  'eval_wer': 0.2593872741555381,
  'eval_runtime': 75.9333,
  'eval_samples_per_second': 10.536,
  'eval_steps_per_second': 10.536,
  'epoch': 0.27,
  'step': 500},
 {'loss': 2.2982, 'learning_rate': 9.99e-05, 'epoch': 0.54, 'step': 1000},
 {'eval_loss': 0.7303922772407532,
  'eval_wer': 0.25923016496465046,
  'eval_runtime': 75.0628,
  'eval_samples_per_second': 10.658,
  'eval_steps_per_second': 10.658,
  'epoch': 0.54,
  'step': 1000},
 {'loss': 0.9706,
  'learning_rate': 8.176169590643276e-05,
  'epoch': 0.8,
  'step': 1500},
 {'eval_loss': 1.0987427234649658,
  'eval_wer': 0.25923016496465046,
  'eval_runtime': 75.0833,
  'eval_samples_per_second': 10.655,
  'eval_steps_per_second': 10.655,
  'epoch': 0.8,
  'step': 1500},
 {'loss': 1.1222,
  'learning_rate': 6.348684210526316e-05,
  'epoch': 1.07,
  'step': 2000},
 {'eval_loss': 1.0669927597045898,
  'eval_wer': 0.25923016

In [21]:
trainer.state.log_history[-1]["train_loss"]

2.053420505891265

In [None]:
for epoch in tqdm(range(args.epochs)):
    local_weights, local_losses = [], []
    print(f'\n | Global Training Round : {epoch+1} |\n') # epoch = round

    global_model.train()
    m = max(int(args.frac * args.num_users), 1) # 取部分client來train，至少一個client
    idxs_users = np.random.choice(range(args.num_users), m, replace=False) # 選client_id

    for idx in idxs_users: # 每一位選上的client
        # 做training?
        local_model = LocalUpdate(args=args, dataset=train_dataset,
                                    idxs=user_groups[idx], logger=logger)
        # 得到這次的weights + loss
        w, loss = local_model.update_weights(
            model=copy.deepcopy(global_model), global_round=epoch)
        local_weights.append(copy.deepcopy(w))
        local_losses.append(copy.deepcopy(loss))

    # compute global weights
    global_weights = average_weights(local_weights)

    # update global weights
    global_model.load_state_dict(global_weights)

    loss_avg = sum(local_losses) / len(local_losses)
    train_loss.append(loss_avg) # loss average over all clients

    # Calculate avg training accuracy over all users at every epoch
    list_acc, list_loss = [], []
    global_model.eval()
    for c in range(args.num_users): # 所有clients
        # 做training? 應為update weight
        local_model = LocalUpdate(args=args, dataset=train_dataset,
                                    idxs=user_groups[idx], logger=logger)
        # 用global model去測試，得到acc & loss
        acc, loss = local_model.inference(model=global_model)
        list_acc.append(acc)
        list_loss.append(loss)
    train_accuracy.append(sum(list_acc)/len(list_acc)) # acc_avg

    # print global training loss after every 'i' rounds
    if (epoch+1) % print_every == 0:
        print(f' \nAvg Training Stats after {epoch+1} global rounds:')
        print(f'Training Loss : {np.mean(np.array(train_loss))}')
        print('Train Accuracy: {:.2f}% \n'.format(100*train_accuracy[-1]))

# Test inference after completion of training
test_acc, test_loss = test_inference(args, global_model, test_dataset)

print(f' \n Results after {args.epochs} global rounds of training:')
print("|---- Avg Train Accuracy: {:.2f}%".format(100*train_accuracy[-1]))
print("|---- Test Accuracy: {:.2f}%".format(100*test_acc))

# Saving the objects train_loss and train_accuracy:
file_name = '../save/objects/{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}].pkl'.\
    format(args.dataset, args.model, args.epochs, args.frac, args.iid,
            args.local_ep, args.local_bs)

# save 失敗
with open(file_name, 'wb') as f:
    pickle.dump([train_loss, train_accuracy], f)

print('\n Total Run Time: {0:0.4f}'.format(time.time()-start_time))