In [3]:
from torchaudio.transforms import MelSpectrogram, FrequencyMasking, TimeMasking
import torch.nn as nn
import torch
from torch.nn.utils.rnn import pad_sequence
from lab4_proto import *
from lab4_main import *

  from .autonotebook import tqdm as notebook_tqdm


### 3.1 Representing text

In [4]:
text = "Hello World"
int_list = strToInt(text)
result_text = intToStr(int_list)
print(f"Original text: '{text}' -> Integers: {int_list} -> Back to text: '{result_text}'") 

Original text: 'Hello World' -> Integers: [9, 6, 13, 13, 16, 1, 24, 16, 19, 13, 5] -> Back to text: 'hello_world'


### 3.2 verify with example

In [5]:
example = torch.load('lab4_example.pt')
inputs = example['data']
spectrograms, labels, input_lengths, label_lengths = dataProcessing(inputs, test_audio_transform)

In [6]:
expected_spectrograms = example['spectrograms']
expected_labels = example['labels'].long() 
expected_input_lengths = example['input_lengths']
expected_label_lengths = example['label_lengths']

In [7]:
print("Spectrograms match:", torch.allclose(spectrograms, expected_spectrograms, atol=1e-7))
print("Labels match:", torch.equal(labels, expected_labels))
print("Input lengths match:", input_lengths == expected_input_lengths)
print("Label lengths match:", label_lengths == expected_label_lengths)

Spectrograms match: True
Labels match: True
Input lengths match: True
Label lengths match: True


### 5.4 Train the model and check the results

In [8]:
use_cuda = torch.cuda.is_available()
torch.manual_seed(7)
device = torch.device("cuda:0" if use_cuda else "cpu")

train_dataset = torchaudio.datasets.LIBRISPEECH(".", url='train-clean-100', download=True)
val_dataset = torchaudio.datasets.LIBRISPEECH(".", url='dev-clean', download=True)
test_dataset = torchaudio.datasets.LIBRISPEECH(".", url='test-clean', download=True)

kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
train_loader = data.DataLoader(dataset=train_dataset,
                batch_size=hparams['batch_size'],
                shuffle=True,
                collate_fn=lambda x: dataProcessing(x, train_audio_transform),
                **kwargs)

val_loader = data.DataLoader(dataset=val_dataset,
                batch_size=hparams['batch_size'],
                shuffle=True,
                collate_fn=lambda x: dataProcessing(x, test_audio_transform),
                **kwargs)

test_loader = data.DataLoader(dataset=test_dataset,
                batch_size=hparams['batch_size'],
                shuffle=False,
                collate_fn=lambda x: dataProcessing(x, test_audio_transform),
                **kwargs)

model = SpeechRecognitionModel(
    hparams['n_cnn_layers'], 
    hparams['n_rnn_layers'], 
    hparams['rnn_dim'],
    hparams['n_class'], 
    hparams['n_feats'], 
    hparams['stride'], 
    hparams['dropout']
    ).to(device)

# write log into the file
cur_version = datetime.now().strftime('%y%m%d-%H%M%S')
os.makedirs('logs', exist_ok=True)

# logger.info(model)
print('Num Model Parameters', sum([param.nelement() for param in model.parameters()]))

optimizer = optim.AdamW(model.parameters(), hparams['learning_rate'])
criterion = nn.CTCLoss(blank=28).to(device)

model_path = 'checkpoints/epoch-19-wer-0.479.pt'
model.load_state_dict(torch.load(model_path))
logger.info(f'Load pre-trained model from "{model_path}"')

[32m2024-05-28 13:10:33.601[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m50[0m - [1mLoad pre-trained model from "checkpoints/epoch-19-wer-0.479.pt"[0m


Num Model Parameters 23311869


test with test set audio:

In [14]:
original_text = 'ALSO A POPULAR CONTRIVANCE WHEREBY LOVE MAKING MAY BE SUSPENDED BUT NOT STOPPED DURING THE PICNIC SEASON'
wavfile = 'LibriSpeech/test-clean/121/121726/121-121726-0000.flac'
waveform, sample_rate = torchaudio.load(wavfile, normalize=True)
resample_rate = 16000
if sample_rate != resample_rate:
    resampler = T.Resample(orig_freq=sample_rate, new_freq=resample_rate)
    waveform = resampler(waveform)
spectrogram = test_audio_transform(waveform)
input = torch.unsqueeze(spectrogram,dim=0).to(device)
output = model(input)
text = greedyDecoder(output)
print(f'wavfile: {wavfile}')
print(f'orginal text: {original_text}')
print(f'predicted text: {text}')

wavfile: LibriSpeech/test-clean/121/121726/121-121726-0000.flac
orginal text: ALSO A POPULAR CONTRIVANCE WHEREBY LOVE MAKING MAY BE SUSPENDED BUT NOT STOPPED DURING THE PICNIC SEASON
predicted text: ['al_so_a_populocand_drivans_wherm_i_nove_makin_may_bes_suspended_but_not_stoked_during_the_phicknec_eason']


test with recorded audio with reading the same as above

In [19]:
original_text = 'ALSO A POPULAR CONTRIVANCE WHEREBY LOVE MAKING MAY BE SUSPENDED BUT NOT STOPPED DURING THE PICNIC SEASON'
wavfile = 'TestAudios/demo2.m4a'
waveform, sample_rate = torchaudio.load(wavfile, normalize=True)
resample_rate = 16000
if sample_rate != resample_rate:
    resampler = T.Resample(orig_freq=sample_rate, new_freq=resample_rate)
    waveform = resampler(waveform)
spectrogram = test_audio_transform(waveform)
input = torch.unsqueeze(spectrogram,dim=0).to(device)
output = model(input)
text = greedyDecoder(output)
print(f'wavfile: {wavfile}')
print(f'orginal text: {original_text}')
print(f'predicted text: {text}')

wavfile: TestAudios/demo2.m4a
orginal text: ALSO A POPULAR CONTRIVANCE WHEREBY LOVE MAKING MAY BE SUSPENDED BUT NOT STOPPED DURING THE PICNIC SEASON
predicted text: ['al_sapocnacunge_hrer_then_speared_by_lofe_making_ma_besaspendingt_bu_nastob_durin_of_pi_pick_seso_']


test with recorded audio:

In [15]:
original_text = "Should the Royal Family be made to do National Service? Vote in our poll as minister refuses to rule out Prince"
wavfile = 'TestAudios/demo.m4a'
waveform, sample_rate = torchaudio.load(wavfile, normalize=True)
resample_rate = 16000
if sample_rate != resample_rate:
    resampler = T.Resample(orig_freq=sample_rate, new_freq=resample_rate)
    waveform = resampler(waveform)
spectrogram = test_audio_transform(waveform)
input = torch.unsqueeze(spectrogram,dim=0).to(device)
output = model(input)
text = greedyDecoder(output)
print(f'wavfile: {wavfile}')
print(f'orginal text: {original_text}')
print(f'predicted text: {text}')

wavfile: TestAudios/demo.m4a
orginal text: Should the Royal Family be made to do National Service? Vote in our poll as minister refuses to rule out Prince
predicted text: ['tud_borbo_only_bein_might_two_motualof_cervis_bo_tie_alrpo_as_mes_thrne_retusa_cina_alpre']


### 5.5 Language model

without language model:

In [17]:
use_language_model = False

avg_cer, avg_wer = test(model, device, test_loader, criterion, -1, use_language_model)
if use_language_model:
    print(f"Use language model (alpha={alpha}, beta={beta}), avg_cer={avg_cer:.3f}, avg_wer={avg_wer:.3f}")
else:
    print(f"Not use language model, avg_cer={avg_cer:.3f}, avg_wer={avg_wer:.3f}")

[32m2024-05-28 13:18:14.966[0m | [1mINFO    [0m | [36mlab4_main[0m:[36mtest[0m:[36m204[0m - [1m
evaluating…[0m
100%|██████████| 82/82 [41:59<00:00, 30.72s/it]
[32m2024-05-28 14:00:14.254[0m | [1mINFO    [0m | [36mlab4_main[0m:[36mtest[0m:[36m253[0m - [1mTest set: Average loss: 0.4941, Average CER: 0.1500 Average WER: 0.4597
[0m


Not use language model, avg_cer=0.150, avg_wer=0.460


with language model:
alpha = 0.4
beta = 0.0

In [18]:
use_language_model = True
alpha = 0.4
beta = 0.0

avg_cer, avg_wer = test(model, device, test_loader, criterion, -1, use_language_model, alpha, beta)
if use_language_model:
    logger.info(f"Use language model (alpha={alpha}, beta={beta}), avg_cer={avg_cer:.3f}, avg_wer={avg_wer:.3f}")
else:
    logger.info(f"Not use language model, avg_cer={avg_cer:.3f}, avg_wer={avg_wer:.3f}")

[32m2024-05-28 14:04:36.535[0m | [1mINFO    [0m | [36mlab4_main[0m:[36mtest[0m:[36m204[0m - [1m
evaluating…[0m
Loading the LM will be faster if you build a binary file.
Unigrams and labels don't seem to agree.
Reading /raid/yixu/Projects/Speech/lab4/wiki-interpolate.3gram.arpa
----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100
****************************************************************************************************
100%|██████████| 82/82 [37:02<00:00, 27.10s/it]
[32m2024-05-28 14:41:39.843[0m | [1mINFO    [0m | [36mlab4_main[0m:[36mtest[0m:[36m253[0m - [1mTest set: Average loss: 0.4941, Average CER: 0.1141 Average WER: 0.2835
[0m
[32m2024-05-28 14:41:39.844[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m7[0m - [1mUse language model (alpha=0.4, beta=0.0), avg_cer=0.114, avg_wer=0.283[0m


### 5.5.1 Tuning the language model weights
Grid search results. Best pair: alpha = 0.4, beta = 0.0
| Alpha | Beta | WER |
|-------|------|-------------|
| 0.0   | 0.0  | 0.4037      |
| 0.0   | 0.2  | 0.4071      |
| 0.0   | 0.4  | 0.4155      |
| 0.0   | 0.6  | 0.4171      |
| 0.0   | 0.8  | 0.4228      |
| 0.0   | 1.0  | 0.4282      |
| 0.2   | 0.0  | 0.3092      |
| 0.2   | 0.2  | 0.3893      |
| 0.2   | 0.4  | 0.4115      |
| 0.2   | 0.6  | 0.4152      |
| 0.2   | 0.8  | 0.4219      |
| 0.2   | 1.0  | 0.4278      |
| 0.4   | 0.0  | 0.2913      |
| 0.4   | 0.2  | 0.3727      |
| 0.4   | 0.4  | 0.4097      |
| 0.4   | 0.6  | 0.4177      |
| 0.4   | 0.8  | 0.4217      |
| 0.4   | 1.0  | 0.4281      |
| 0.6   | 0.0  | 0.2979      |
| 0.6   | 0.2  | 0.3588      |
| 0.6   | 0.4  | 0.4074      |
| 0.6   | 0.6  | 0.4148      |
| 0.6   | 0.8  | 0.4226      |
| 0.6   | 1.0  | 0.4270      |
| 0.8   | 0.0  | 0.3290      |
| 0.8   | 0.2  | 0.3478      |
| 0.8   | 0.4  | 0.4044      |
| 0.8   | 0.6  | 0.4163      |
| 0.8   | 0.8  | 0.4227      |
| 0.8   | 1.0  | 0.4283      |
| 1.0   | 0.0  | 0.3874      |
| 1.0   | 0.2  | 0.3373      |
| 1.0   | 0.4  | 0.4026      |
| 1.0   | 0.6  | 0.4155      |
| 1.0   | 0.8  | 0.4228      |
| 1.0   | 1.0  | 0.4275      |

