In [1]:
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import torch

In [2]:
model = Wav2Vec2ForCTC.from_pretrained("monideep2255/finetuning-xlsr-53-PSST_V7")
processor = Wav2Vec2Processor.from_pretrained("monideep2255/finetuning-xlsr-53-PSST_V7")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
processor.decode

<bound method Wav2Vec2Processor.decode of Wav2Vec2Processor:
- feature_extractor: Wav2Vec2FeatureExtractor {
  "do_normalize": true,
  "feature_extractor_type": "Wav2Vec2FeatureExtractor",
  "feature_size": 1,
  "padding_side": "right",
  "padding_value": 0.0,
  "return_attention_mask": true,
  "sampling_rate": 16000
}

- tokenizer: Wav2Vec2CTCTokenizer(name_or_path='monideep2255/finetuning-xlsr-53-PSST_V7', vocab_size=46, model_max_length=1000000000000000019884624838656, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<UNK>', 'pad_token': '<PAD>', 'additional_special_tokens': [AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=True), AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True)]}, clean_up_tokenization_spaces=True)>

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

Wav2Vec2ForCTC(
  (wav2vec2): Wav2Vec2Model(
    (feature_extractor): Wav2Vec2FeatureEncoder(
      (conv_layers): ModuleList(
        (0): Wav2Vec2LayerNormConvLayer(
          (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,))
          (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (activation): GELUActivation()
        )
        (1-4): 4 x Wav2Vec2LayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,))
          (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (activation): GELUActivation()
        )
        (5-6): 2 x Wav2Vec2LayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,))
          (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (activation): GELUActivation()
        )
      )
    )
    (feature_projection): Wav2Vec2FeatureProjection(
      (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (projec

In [5]:
from datasets import load_dataset, load_metric, DatasetDict, Dataset, Audio

# Load the datasets and observe the structure
dataset_dict = load_dataset('csv', data_files={
    "test": '/work/van-speech-nlp/PSST-experiments/psst-csv/test_utterances_excel.csv',
})

# review the datasets
test_inferences = dataset_dict["test"]

dataset_dict = load_dataset('csv', data_files={
    "valid": '/work/van-speech-nlp/PSST-experiments/psst-csv/valid_utterances_excel.csv',
})

# review the datasets
valid_inferences = dataset_dict["valid"]

print(test_inferences)
print(valid_inferences)

Found cached dataset csv (/home/lewis.jor/.cache/huggingface/datasets/csv/default-6df7c8ed6d6a957a/0.0.0/eea64c71ca8b46dd3f537ed218fc9bf495d5707789152eb2764f5c78fa66d59d)


  0%|          | 0/1 [00:00<?, ?it/s]

Found cached dataset csv (/home/lewis.jor/.cache/huggingface/datasets/csv/default-5fd77a60300ece6c/0.0.0/eea64c71ca8b46dd3f537ed218fc9bf495d5707789152eb2764f5c78fa66d59d)


  0%|          | 0/1 [00:00<?, ?it/s]

Dataset({
    features: ['utterance_id', 'session', 'test', 'prompt', 'transcript', 'correctness', 'aq_index', 'duration_frames', 'filename_old', 'filename_new'],
    num_rows: 652
})
Dataset({
    features: ['utterance_id', 'session', 'test', 'prompt', 'transcript', 'correctness', 'aq_index', 'duration_frames', 'filename_old', 'filename_new'],
    num_rows: 341
})


In [6]:
# remove columns that we do not need
test_inferences = test_inferences.remove_columns(["aq_index", "test", "duration_frames","filename_old"])
valid_inferences = valid_inferences.remove_columns(["aq_index", "test", "duration_frames","filename_old"])
# print to verify
print(test_inferences)
print(valid_inferences)

Dataset({
    features: ['utterance_id', 'session', 'prompt', 'transcript', 'correctness', 'filename_new'],
    num_rows: 652
})
Dataset({
    features: ['utterance_id', 'session', 'prompt', 'transcript', 'correctness', 'filename_new'],
    num_rows: 341
})


In [7]:
test_inferences = test_inferences.cast_column("filename_new", Audio(sampling_rate=16000))
valid_inferences = valid_inferences.cast_column("filename_new", Audio(sampling_rate=16000))

In [8]:
test_inferences["filename_new"][5]

{'path': '/work/van-speech-nlp/PSST-experiments/psst-data/psst-data-2022-03-02-full/test/audio/bnt/ACWT01a/ACWT01a-BNT06-volcano.wav',
 'array': array([-0.00097656,  0.00195312,  0.01193237, ..., -0.00048828,
         0.00024414,  0.00213623]),
 'sampling_rate': 16000}

In [9]:
valid_inferences["filename_new"][5]

{'path': '/work/van-speech-nlp/PSST-experiments/psst-data/psst-data-2022-03-02-full/valid/audio/bnt/BU01a/BU01a-BNT06-volcano.wav',
 'array': array([-0.03860474, -0.06341553, -0.07208252, ...,  0.02603149,
         0.03674316,  0.03723145]),
 'sampling_rate': 16000}

In [10]:
def prepare_references_dataset(batch):
    # load the audio data into batch
    audio = batch["filename_new"]

    # extract the values from the audio files
    batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
    batch["input_length"] = len(batch["input_values"])

    # encode the transcript to the label ids
    with processor.as_target_processor():
        batch["labels"] = processor(batch["transcript"]).input_ids
    
    # remove all columns except for 'transcript'
    batch = {key: batch[key] for key in batch.keys() if key == 'transcript'}
    
    return batch

test_inferences = test_inferences.map(prepare_references_dataset, num_proc=4)
valid_inferences = valid_inferences.map(prepare_references_dataset, num_proc=4)

Loading cached processed dataset at /home/lewis.jor/.cache/huggingface/datasets/csv/default-6df7c8ed6d6a957a/0.0.0/eea64c71ca8b46dd3f537ed218fc9bf495d5707789152eb2764f5c78fa66d59d/cache-d3332a8dbee6835d_*_of_00004.arrow
Loading cached processed dataset at /home/lewis.jor/.cache/huggingface/datasets/csv/default-5fd77a60300ece6c/0.0.0/eea64c71ca8b46dd3f537ed218fc9bf495d5707789152eb2764f5c78fa66d59d/cache-121dbdd23bf3822b_*_of_00004.arrow


In [11]:
sample_inference_data = test_inferences[20:50]
# sample_inference_data['input_values'][0]

In [12]:
processor.tokenizer.decoder = {24: '<???>',
 3: '<PAD>',
 2: '<SIL>',
 18: '<SPN>',
 19: '<UNK>',
 1: ' AA ',
 8: '  AE',
 6: ' AH ',
 36: ' AO ',
 33: ' AW ',
 17: ' AY ',
 20: ' B ',
 43: ' CH ',
 35: ' D ',
 42: ' DH ',
 10: ' DX ',
 7: ' EH ',
 12: ' ER ',
 44: ' EY ',
 27: ' F ',
 40: ' G ',
 9: ' HH ',
 41: ' IH ',
 14: ' IY ',
 28: ' JH ',
 21: ' K ',
 22: ' L ',
 37: ' M ',
 0: ' N ',
 25: ' NG ',
 16: ' OW ',
 15: ' OY ',
 32: ' P ',
 45: ' R ',
 38: ' S ',
 29: ' SH ',
 5: ' T ',
 31: ' TH ',
 11: ' UH ',
 4: ' UW ',
 34: ' V ',
 30: ' W ',
 39: ' Y ',
 13: ' Z ',
 26: ' ZH ',
 23: '|'}

In [13]:
import librosa
import numpy as np

# print(test_inferences['input_values'][0])
# Generate predictions for each sample

def predictions_list(dataset):
    res = []
    for i in range(len(dataset['transcript'])):
        
        input_values = np.array(dataset['input_values'][i])
        sampling_rate = dataset['input_length'][i]

        # Resample the input speech to match the model's sampling rate
        input_values = librosa.resample(input_values, orig_sr=sampling_rate, target_sr=16000)
        input_values = processor(input_values, sampling_rate=16000, return_tensors="pt").input_values
        input_values = input_values.to(device)  # Move input to the same device as the model
        
        with torch.no_grad():
            logits = model(input_values).logits

        predicted_ids = torch.argmax(logits, dim=-1)
        #print(predicted_ids)
        transcription = processor.decode(predicted_ids[0], clean_up_tokenization_spaces=False)
        print(predicted_ids[0])       
        prediction = transcription.lstrip().rstrip().replace('  ',' ').replace('\t',' ')
        
        res.append(prediction)
        
        reference_transcription = dataset['transcript'][i]
        print("Utterance Id:", dataset['utterance_id'][i])
        print("Reference:", reference_transcription)
        print("Prediction:", prediction)
        print("---")
#     return res

        # print("Reference:", reference_transcription)
        # print("Prediction:", transcription.lstrip().rstrip().replace('  ',' ').replace('\t',' '))
        # print("---")

In [14]:
print(len(sample_inference_data['transcript']))

30


In [15]:
test_predictions_list = predictions_list(sample_inference_data)

tensor([ 3,  3,  3,  3,  3,  3,  3,  3,  3,  3, 23,  3,  3,  5,  6,  3,  3,  3,
         3,  3,  3,  3, 23,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,
         3,  3,  5,  3,  3,  3,  3,  3,  3,  3, 23,  3, 23], device='cuda:0')
Utterance Id: ACWT01a-VNT09-watch
Reference: K AE T AA M K AE T AA M
Prediction: T AH  T
---
tensor([ 3,  3,  3,  3,  3,  3,  3, 39,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,
         3,  3,  3,  3,  1,  3,  3,  3,  3,  3,  0,  6,  3,  3,  3,  3,  3,  3,
         3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3, 23], device='cuda:0')
Utterance Id: ACWT01a-VNT10-give
Reference: AH G EH DX IH NG AH P R EH Z IH N
Prediction: Y AA N AH
---
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3], device='cuda:0')
Utterance Id: ACWT01a-VNT11-swim
Reference: IH N D AH M B IY AH M <spn> <sil> IH T S OW P N OW AH M <spn> <sil> IH T S M B IH M IH NG S W 

In [18]:
# valid_predictions_list = predictions_list(valid_inferences)

In [19]:
print(test_predictions_list)

['T', 'K IH K  K AH L IH NG', 'SH EY V IH NG', 'HH AW S', 'K OW M', 'T IH NG K AH S', '', 'B AH T  S', '']


In [19]:
import csv
def write_tsv(dataset, dataset_predictions):
    file_name = "temp-decoded-test.tsv"
    with open(file_name, "w") as f:
        writer = csv.writer(f, dialect=csv.excel_tab)
        writer.writerow(("utterance_id", "asr_transcript"))
        for i in range(len(dataset)):
            utterance_id = dataset['utterance_id'][i]
            writer.writerow((utterance_id, dataset_predictions[i]))

In [20]:
write_tsv(sample_inference_data, test_predictions_list)

In [None]:
# write_tsv(valid_predictions_list)

In [None]:
processor.tokenizer.encoder['AA']

In [None]:
processor.tokenizer.decoder = {24: '<???>',
 3: '<PAD>',
 2: '<SIL>',
 18: '<SPN>',
 19: '<UNK>',
 1: ' AA ',
 8: '  AE',
 6: ' AH ',
 36: ' AO ',
 33: ' AW ',
 17: ' AY ',
 20: ' B ',
 43: ' CH ',
 35: ' D ',
 42: ' DH ',
 10: ' DX ',
 7: ' EH ',
 12: ' ER ',
 44: ' EY ',
 27: ' F ',
 40: ' G ',
 9: ' HH ',
 41: ' IH ',
 14: ' IY ',
 28: ' JH ',
 21: ' K ',
 22: ' L ',
 37: ' M ',
 0: ' N ',
 25: ' NG ',
 16: ' OW ',
 15: ' OY ',
 32: ' P ',
 45: ' R ',
 38: ' S ',
 29: ' SH ',
 5: ' T ',
 31: ' TH ',
 11: ' UH ',
 4: ' UW ',
 34: ' V ',
 30: ' W ',
 39: ' Y ',
 13: ' Z ',
 26: ' ZH ',
 23: '|'}