In [None]:
import os, torch, jiwer, librosa
import matplotlib.pyplot as plt
%matplotlib inline
from tqdm import tqdm
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor

In [None]:
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")

In [None]:
blacklisted_paths = [
  '.DS_Store'
]

In [None]:
files = {}
for i, command in enumerate(sorted(os.listdir('../recordings/'))[1:-1]):
  if command in blacklisted_paths:
    continue
  for j, sample in enumerate(sorted(os.listdir(f'../recordings/{command}/'))):
    if sample in blacklisted_paths:
      continue
    data, _ = librosa.load(f'../recordings/{command}/{sample}', sr=16000)
    files[f'{command}/{sample}'] = data.astype('float32')

In [None]:
# plt.figure(figsize=(16, 4))
# plt.plot(files[files.keys().__iter__().__next__()])

In [None]:
trans = {}
for name, data in tqdm(files.items()):
  input_values = processor(data, sampling_rate=16_000, return_tensors='pt', padding='longest').input_values
  logits = model(input_values).logits
  predicted_ids = torch.argmax(logits, dim=-1)
  trans[name] = processor.batch_decode(predicted_ids)[0]

In [None]:
trans

In [None]:
ref = {}
for name, data in tqdm(files.items()):
  ref[name] = name.split('/')[0]

In [None]:
ref

In [None]:
def evaluate(ref, trans):
  h = []
  r = []

  for name in trans.keys():
    print(f'name: {name}, true: {ref[name]}, pred: {trans[name].lower()}')
    h.append(trans[name].lower())
    r.append(ref[name])

  print(jiwer.compute_measures(r, h))

In [None]:
evaluate(ref, trans)

In [None]:
from speechbrain.pretrained import EncoderDecoderASR

asr_model = EncoderDecoderASR.from_hparams(source="speechbrain/asr-crdnn-rnnlm-librispeech", savedir="pretrained_models/asr-crdnn-rnnlm-librispeech")

In [None]:
trans = {}
for i, command in enumerate(sorted(os.listdir('../recordings/'))[1:-1]):
  if command in blacklisted_paths:
    continue
  for j, sample in enumerate(sorted(os.listdir(f'../recordings/{command}/'))):
    if sample in blacklisted_paths:
      continue
    # trans[sample] = asr_model.transcribe_file(f"../recordings/{command}/{sample}")
    trans[f'{command}/{sample}'] = asr_model.transcribe_file(f"../recordings/{command}/{sample}")
    os.remove(f"./{sample}")

In [None]:
evaluate(ref, trans)