In [5]:
import librosa
import numpy as np

In [60]:
data_train_fp = '/srv/scratch/z5313567/thesis/AusKidTalk_local/spontaneous/AusKidTalk_spontaneous_test_only_transcription_filepath.csv'
data_test_fp = '/srv/scratch/z5313567/thesis/AusKidTalk_local/spontaneous/AusKidTalk_spontaneous_test_only_transcription_filepath.csv'
cache_fp = '/srv/scratch/z5313567/thesis/cache'
#model_fp = '/srv/scratch/z5313567/thesis/wav2vec2/model/AusKidTalk_scripted_spontaneous_combined/finetune_20230718'
model_fp = '/srv/scratch/z5313567/thesis/wav2vec2/model/AusKidTalk/progressive_finetune_CU_MyST_AusKidTalk_20230828'
LM_fp = '/srv/scratch/z5313567/thesis/wav2vec2/model/AusKidTalk_scripted_spontaneous_combined/finetune_20230718_with_lm_AusKidTalk_4gram_correct'

In [61]:
from datasets import load_dataset

dataset = load_dataset('csv', 
                        data_files={'train': data_train_fp,
                                'test': data_test_fp},
                        cache_dir = cache_fp)

Found cached dataset csv (/srv/scratch/z5313567/thesis/cache/csv/default-5bd1000f98154dca/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1)
100%|██████████| 2/2 [00:00<00:00, 459.35it/s]


In [62]:
dataset

DatasetDict({
    train: Dataset({
        features: ['filepath', 'transcription_clean'],
        num_rows: 24
    })
    test: Dataset({
        features: ['filepath', 'transcription_clean'],
        num_rows: 24
    })
})

In [63]:
def speech_file_to_array_fn(batch):
    speech_array, sampling_rate = librosa.load(batch['filepath'], sr=16000)
    batch["speech"] = speech_array
    batch["sampling_rate"] = sampling_rate
    batch["target_text"] = batch["transcription_clean"]
    return batch

data = dataset.map(speech_file_to_array_fn, remove_columns=dataset.column_names["train"], num_proc=4)

Loading cached processed dataset at /srv/scratch/z5313567/thesis/cache/csv/default-5bd1000f98154dca/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-ed9c423eb2f570dd_*_of_00004.arrow
Loading cached processed dataset at /srv/scratch/z5313567/thesis/cache/csv/default-5bd1000f98154dca/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-90530466a11f1899_*_of_00004.arrow


In [64]:
data

DatasetDict({
    train: Dataset({
        features: ['speech', 'sampling_rate', 'target_text'],
        num_rows: 24
    })
    test: Dataset({
        features: ['speech', 'sampling_rate', 'target_text'],
        num_rows: 24
    })
})

In [65]:
import IPython.display as ipd

audio_sample = data["test"][12]
print(audio_sample["target_text"].lower())
ipd.Audio(data=audio_sample["speech"], autoplay=True, rate=audio_sample["sampling_rate"])

the egg starts cracking near the top and hulk is scared by this


In [66]:
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC

#processor = Wav2Vec2Processor.from_pretrained('facebook/wav2vec2-base-100h', cache_dir = cache_fp)
#model = Wav2Vec2ForCTC.from_pretrained('facebook/wav2vec2-base-100h', cache_dir = cache_fp)
processor = Wav2Vec2Processor.from_pretrained(model_fp, cache_dir = cache_fp)
model = Wav2Vec2ForCTC.from_pretrained(model_fp, cache_dir = cache_fp)

In [67]:
inputs = processor(audio_sample["speech"], sampling_rate=16_000, return_tensors="pt")

In [68]:
import torch

with torch.no_grad():
    logits = model(**inputs).logits

In [69]:
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)[0]

print(f'Predicted transcription without the language model is:\n{transcription}')

Predicted transcription without the language model is:
THE EGGSTACS CRACEN THE TOP ANDHOEISETTRBITHOS


In [42]:
from transformers import Wav2Vec2ProcessorWithLM

processor_with_LM = Wav2Vec2ProcessorWithLM.from_pretrained(LM_fp, cache_dir = cache_fp)

In [43]:
logits.shape

torch.Size([1, 307, 32])

In [44]:
" ".join(sorted(processor_with_LM.tokenizer.get_vocab()))

"' </s> <pad> <s> <unk> A B C D E F G H I J K L M N O P Q R S T U V W X Y Z |"

In [47]:
transcription = processor_with_LM.batch_decode(logits.numpy()).text[0]
print(f'Predicted transcription from the language model is:\n{transcription}')

Predicted transcription from the language model is:
THE EGGSTOTS CRACKING NEARTHE TOPANDAHAK IS SCARED BY THIS


In [48]:
processor_with_LM_pretrained = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm", cache_dir = cache_fp)

Fetching 4 files: 100%|██████████| 4/4 [00:00<00:00, 1700.51it/s]


In [49]:
transcription = processor_with_LM_pretrained.batch_decode(logits.numpy()).text[0]
print(f'Predicted transcription from the pretrained language model is:\n{transcription}')

Predicted transcription from the language model is:
THE EGG STOTS CRACKING NEAR THE TOP AND A HAWK IS SCARED BY THIS
