In [4]:
import torch
import numpy as np
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
import os.path
import pandas as pd
from datasets import Dataset, load_from_disk
import librosa
from datasets import load_dataset, load_metric
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Trainer, TrainingArguments
from transformers import Data2VecAudioConfig, HubertConfig, SEWDConfig, UniSpeechSatConfig
from transformers import Data2VecAudioForCTC, HubertForCTC, SEWDForCTC, UniSpeechSatForCTC
from jiwer import wer
import scipy.io
import argparse
import os
import pandas as pd
from datasets import Dataset, load_from_disk
import librosa
import scipy.io.wavfile
import numpy as np
# from utils import csv2dataset, WriteResult

# =======================================
# 有一些common的function丟到這邊
# =======================================
parser = argparse.ArgumentParser()
#parser.add_argument('-model', '--model_path', type=str, default="./saves/wav2vec2-base-960h_GRL_0.5", help="Where the model is saved")
parser.add_argument('-opt', '--optimizer', type=str, default="adamw_hf", help="The optimizer to use: adamw_hf, adamw_torch, adamw_apex_fused, or adafactor")
parser.add_argument('-MGN', '--max_grad_norm', type=float, default=1.0, help="Maximum gradient norm (for gradient clipping)")
parser.add_argument('-model_type', '--model_type', type=str, default="data2vec", help="Type of the model")
parser.add_argument('-sr', '--sampl_rate', type=float, default=16000, help="librosa read smping rate")
parser.add_argument('-lr', '--learning_rate', type=float, default=1e-4, help="Learning rate")
parser.add_argument('-RD', '--root_dir', default='/mnt/Internal/FedASR/Data/ADReSS-IS2020-data', help="Learning rate")
parser.add_argument('--AudioLoadFunc', default='librosa', help="用scipy function好像可以比較快")
args = parser.parse_args(args=[])
# from utils import csv2dataset
def ID2Label(ID,
            spk2label = np.load("/mnt/Internal/FedASR/weitung/HuggingFace/Pretrain/dataset/test_dic.npy", allow_pickle=True).tolist()):
    name = ID.split("_")                                                    #  from file name to spkID
    if (name[1] == 'INV'):                                                  # interviewer is CC
        label = 0
    else:                                                                   # for participant
        label = spk2label[name[0]]                                          # label according to look-up table
    return label                                                            # return dementia label for this file
def prepare_dataset(batch):
    audio = batch["array"]

    # batched output is "un-batched" to ensure mapping is correct
    batch["input_values"] = processor(audio, sampling_rate=16000).input_values[0]
    
    with processor.as_target_processor():
        batch["labels"] = processor(batch["text"]).input_ids
    return batch
def csv2dataset(audio_path = '{}/clips/'.format(args.root_dir),
                csv_path = '{}/mid_csv/test.csv'.format(args.root_dir)):
    stored = "./dataset/" + csv_path.split("/")[-1].split(".")[0]
    if (os.path.exists(stored)):
        print("Load data from local...")
        return load_from_disk(stored)
 
    data = pd.read_csv(csv_path)                                                # read desired csv
    dataset = Dataset.from_pandas(data)                                     # turn into class dataset
    
    # initialize a dictionary
    my_dict = {}
    my_dict["path"] = []                                                    # path to audio
    my_dict["array"] = []                                                   # waveform in array
    my_dict["text"] = []                                                    # ground truth transcript
    my_dict["dementia_labels"] = []

    spk2label=np.load("/mnt/Internal/FedASR/weitung/HuggingFace/Pretrain/dataset/test_dic.npy", allow_pickle=True).tolist()

    i = 1
    for file_path in dataset['path']:                                            # for all files
        if dataset['sentence'][i-1] != None:                                # only the non-empty transcript
            if args.AudioLoadFunc == 'librosa':
                sig, s = librosa.load('{0}/{1}'.format(audio_path,file_path), sr=args.sampl_rate, dtype='float32')  # read audio w/ 16k sr
            else:
                s, sig = scipy.io.wavfile.read('{0}/{1}'.format(audio_path,file_path))
                sig=librosa.util.normalize(sig)
            if len(sig) > 1600:                                             # get rid of audio that's too short
                my_dict["path"].append(file_path)                                # add path
                my_dict["array"].append(sig)                                # add audio wave
                my_dict["text"].append(dataset['sentence'][i-1].upper())    # transcript to uppercase
                my_dict["dementia_labels"].append(ID2Label(ID=file_path,
                                                           spk2label=spk2label))
        print(i, end="\r")                                                  # print progress
        i += 1
    print("There're ", len(my_dict["path"]), " non-empty files.")

    result_dataset = Dataset.from_dict(my_dict)
    result_dataset.save_to_disk(stored)                                     # save for later use
    
    return result_dataset
def map_to_result(batch,*,  model=None, processor=None):
    with torch.no_grad():
        input_values = torch.tensor(batch["input_values"]).unsqueeze(0)
        logits = model(input_values).logits

    pred_ids = torch.argmax(logits, dim=-1)
    batch["pred_str"] = processor.batch_decode(pred_ids)[0]
    batch["text"] = processor.decode(batch["labels"], group_tokens=False)
  
    return batch
def WriteResult(result,Save_path):
    df_results=pd.DataFrame([result['text'],result['pred_str']], index=['GroundTruth','PredStr']).T
    df_results.to_csv('{}/Result.csv'.format(Save_path))
    print("Writing results to {}".format(Save_path))
model_type="data2vec"
name='facebook/data2vec-audio-large-960h'
rootdir='/mnt/Internal/FedASR/Data/ADReSS-IS2020-data'
test_data = csv2dataset(audio_path = '{}/clips/'.format(rootdir),
                        csv_path = "{}/mid_csv/test.csv".format(rootdir))
processor = Wav2Vec2Processor.from_pretrained(name)
test_data = test_data.map(prepare_dataset, num_proc=4)
# load in trained model
if model_type == "wav2vec":
    new_model = Wav2Vec2ForCTC.from_pretrained("./saves/" + name.split("/")[-1] + "_finetuned/final")
    new_processor = Wav2Vec2Processor.from_pretrained(name)
elif model_type == "data2vec":
    mask_time_prob = 0                                                                     # change config
    config = Data2VecAudioConfig.from_pretrained(name, mask_time_prob=mask_time_prob)
    model = Data2VecAudioForCTC.from_pretrained(name, config=config)
    new_model = Data2VecAudioForCTC.from_pretrained("./saves/" + name.split("/")[-1] + "_finetuned/final")
    new_processor = Wav2Vec2Processor.from_pretrained(name)
elif model_type == "hubert":
    new_model = HubertForCTC.from_pretrained("./saves/" + name.split("/")[-1] + "_finetuned/final")
    new_processor = Wav2Vec2Processor.from_pretrained(name)
elif model_type == "sewd":
    new_model = SEWDForCTC.from_pretrained("./saves/" + name.split("/")[-1] + "_finetuned/final")
    new_processor = Wav2Vec2Processor.from_pretrained(name)
elif model_type == "unispeech":
    new_model = UniSpeechSatForCTC.from_pretrained("./saves/" + name.split("/")[-1] + "_finetuned/final")
    new_processor = Wav2Vec2Processor.from_pretrained(name)
else:
    print("WRONG TYPE!!!!!!!!!!!!!!!!")
Save_path="./saves/" + name.split("/")[-1] + "_finetuned/final"

# 在 datasets 模組中，map 函式並不支援直接將額外引數傳遞給映射函式。您可以使用其他方式來實現在 map_to_result 函式中傳遞 model 和 processor。
# 在使用 map 函式時，可以將 model 和 processor 作為函式引數傳遞給 map_to_result 函式
# lambda arguments: expression
result = test_data.map(lambda batch: map_to_result(batch, model=model, processor=processor))
result_newmdl = test_data.map(lambda batch: map_to_result(batch, model=new_model, processor=new_processor))


print("WER of ", name, " : ", wer(result["text"], result["pred_str"]))
WriteResult(result,Save_path)
print("DONE!")

Load data from local...


Loading cached processed dataset at /home/FedASR/dacs/centralized/dataset/test/cache-1b7e3673b55144b2_*_of_00004.arrow
Some weights of the model checkpoint at facebook/data2vec-audio-large-960h were not used when initializing Data2VecAudioForCTC: ['data2vec_audio.masked_spec_embed']
- This IS expected if you are initializing Data2VecAudioForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Data2VecAudioForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Map:   0%|          | 0/800 [00:00<?, ? examples/s]

Map:   0%|          | 0/800 [00:00<?, ? examples/s]

WER of  facebook/data2vec-audio-large-960h  :  0.4746268656716418
Writing results to ./saves/data2vec-audio-large-960h_finetuned/final
DONE!


In [12]:
result = test_data.map(lambda batch: map_to_result(batch , model=model, processor=processor), num_proc=4)



Map (num_proc=4):   0%|          | 0/800 [00:00<?, ? examples/s]

KeyboardInterrupt: 

In [6]:
result['pred_str']

["MOTHER'S GETTING HER FEET WET",
 '',
 'NO',
 'WITH A COUPLE OF BOWLS AND A PLATE ON THE E COUNTER',
 'EVERYTHING THAT YOU SEE HAPPENING HERE',
 "THEY'RE GETTING IN THE COOKIE JAR AND THEIR UPSETTING THE STOOL",
 "AUNT MIE'S HAND MORSOM QUICKLY",
 "THER'S NOTING THERE OUTSADE HE AS POSHESS",
 'THE GIRL IS REACHING FOR A COOKIE',
 '',
 "THERE'S A CUP TWO CUPS AND A SAUCER ON THE SINK",
 'THE GODDESS THE CUPS WER TI HER MESE AN DWIS WEN SHE GAN SIT ON HIS OWNESID',
 'BU A BOY WITH E COOKI IN HIS ONE HAND AND HIS HAND IN THE COFFEE JAR',
 'O',
 'SIX RUNNING OVER',
 'I LIKE THESE THINGS AND TRO HERE',
 '',
 'A LITTLE BOY IS GETTING HIMSELF IN THE CRUBBIES ADDED JA',
 "EEMS TO ME THAT THAT'S ESSENTIALLY THE THINGS THAT ARE GOING ON IN THIS PICTURE",
 'IT THAT THE STORE WAS TEPPING OUT',
 'HIS',
 'AND THE LITTLE GIRL HAS HER FINGER UP TO HER MOUTH TO BE QUIET',
 'HIS PICTURE',
 "AH SHE'S HOLDING THE DISHCLOTH IN HER RIGHT HAND AND THE PLATE SHE IS DRYING IN HER LEFT",
 'AND THE BOY IS STAND