<a href="https://colab.research.google.com/github/Dor890/Speech-Processing/blob/master/FinProject.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [73]:
# Numerical constants
SR = 16000
FILE_2CHECK = 13
HOP_LEN = 160
N_FFT = 400
N_MELS = 128  # 128 for Mel_Spec, 23 for MFCC
N_MFCC = 13
N_EPOCHS = 200
BATCH_SIZE = 32
LEARNING_RATE = 0.0004
WEIGHT_DECAY = 0.001
NUM_LAYERS = 12
HIDDEN_DIM = 64
EMBED_DIM = 300
NUM_CLASSES = 29
TIME = 513  # 513 for Mel_Spec, 641 for MFCC
PAD_TOKEN = 0
SEQ_LEN = 3
DROPOUT = 0.15

# Strings constants
CTC_MODEL_PATH = 'models/ctc_model.pth'
LANG_MODEL_PATH = 'models/lang_model.pth'
DATA_PATH = '/content/drive/MyDrive/an4'

hparams = {
    "n_cnn_layers": 2,
    "n_rnn_layers": 2,
    "rnn_dim": 512,
    "n_class": 29,
    "n_feats": 128,
    "stride": 2,
    "dropout": DROPOUT,
    "learning_rate": LEARNING_RATE,
    "batch_size": BATCH_SIZE,
    "epochs": N_EPOCHS
}

In [3]:
import torch


class TextTransform:
    """Maps characters to integers and vice versa"""
    """Maps characters to integers and vice versa"""

    def __init__(self):
        char_map_str = """
            ' 0
            <SPACE> 1
            a 2
            b 3
            c 4
            d 5
            e 6
            f 7
            g 8
            h 9
            i 10
            j 11
            k 12
            l 13
            m 14
            n 15
            o 16
            p 17
            q 18
            r 19
            s 20
            t 21
            u 22
            v 23
            w 24
            x 25
            y 26
            z 27
            """
        self.char_map = {}
        self.index_map = {}
        for line in char_map_str.strip().split('\n'):
            ch, index = line.split()
            self.char_map[ch] = int(index)
            self.index_map[int(index)] = ch
        self.index_map[1] = ' '

    def text_to_int(self, text):
        """ Use a character map and convert text to an integer sequence """
        int_sequence = []
        for c in text:
            if c == ' ':
                ch = self.char_map['<SPACE>']
            else:
                ch = self.char_map[c]
            int_sequence.append(ch)
        return int_sequence

    def int_to_text(self, labels):
        """ Use a character map and convert integer labels to an text sequence """
        string = []
        for i in labels:
            string.append(self.index_map[i])
        return ''.join(string).replace('<SPACE>', ' ')


def gd(output, labels, label_lengths, blank_label=28, collapse_repeated=True):
    arg_maxes = torch.argmax(output, dim=2)
    decodes = []
    targets = []
    text_transform = TextTransform()
    for i, args in enumerate(arg_maxes):
        decode = []
        targets.append(text_transform.int_to_text(labels[i][:label_lengths[i]].tolist()))
        for j, index in enumerate(args):
            if index != blank_label:
                if collapse_repeated and j != 0 and index == args[j -1]:
                    continue
                decode.append(index.item())
        decodes.append(text_transform.int_to_text(decode))
    return decodes, targets


In [58]:
import os
import torch
import torchaudio
import matplotlib.pyplot as plt
from torch import nn
from torch.utils.data import Dataset




class Data:
    def __init__(self):
        print('Loading data...')
        self.data_dir = 'an4'
        self.x_train_paths, self.x_train, self.y_train = self.load_data('train')
        self.x_val_paths, self.x_val, self.y_val = self.load_data('val')
        self.x_test_paths, self.x_test, self.y_test = self.load_data('test')

        # Print the first example in the training set
        # print(f"First train file path: {self.x_train_paths[FILE_2CHECK]}")
        # print(f"Transcription: {self.y_train[FILE_2CHECK]}")
        # print(f"Preview first train file: {self.x_train[FILE_2CHECK]}")
        # self.plot_waveform(self.x_train[0], sample_rate=SR)
        # self.plot_mfcc(extract_features(self.x_train[FILE_2CHECK]))
        # self.plot_mel_spec(self.extract_mel_spec(self.x_train[FILE_2CHECK]))
        print('Data loaded successfully')

    def load_data(self, split):
        """
        Load the data from the provided 'an4' folder, and split it into train, dev, and test sets.
        """
        audio_dir = os.path.join(self.data_dir, split, 'an4', 'wav')
        transcript_dir = os.path.join(self.data_dir, split, 'an4', 'txt')

        audio_files = sorted(os.listdir(audio_dir))
        transcript_files = sorted(os.listdir(transcript_dir))

        audio_paths = [os.path.join(audio_dir, file) for file in audio_files]
        transcript_paths = [os.path.join(transcript_dir, file) for file in
                            transcript_files]

        audios, transcripts = [], []

        for audio_path, transcript_path in zip(audio_paths, transcript_paths):
            with open(transcript_path, 'r') as f:
                transcript = f.read().strip()
                transcripts.append(transcript)

        loaded_audios = [torchaudio.load(audio)[0] for audio in audio_paths]
        return audio_paths, loaded_audios, transcripts

    def get_data(self, split):
        if split == 'train':
            return self.x_train, self.y_train
        elif split == 'val':
            return self.x_val, self.y_val
        elif split == 'test':
            return self.x_test, self.y_test
        else:
            raise ValueError(f"Invalid data split '{split}'")

    @staticmethod
    def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None, ylim=None):
        waveform = waveform.numpy()

        num_channels, num_frames = waveform.shape
        time_axis = torch.arange(0, num_frames) / sample_rate

        figure, axes = plt.subplots(num_channels, 1)
        if num_channels == 1:
            axes = [axes]
        for c in range(num_channels):
            axes[c].plot(time_axis, waveform[c], linewidth=1)
            axes[c].grid(True)
            if num_channels > 1:
                axes[c].set_ylabel(f'Channel {c + 1}')
            if xlim:
                axes[c].set_xlim(xlim)
            if ylim:
                axes[c].set_ylim(ylim)
        figure.suptitle(title)
        plt.show(block=False)

    @staticmethod
    def extract_mel_spec(waveform):
        mel_specgram = torchaudio.transforms.MelSpectrogram(SR, n_mels=N_MELS)(waveform)
        return mel_specgram

    @staticmethod
    def plot_mfcc(mfcc, title="MFCC", xlim=None, ylim=None):
        fig, ax = plt.subplots()
        im = ax.imshow(mfcc, origin='lower', aspect='auto')
        fig.colorbar(im, ax=ax)
        ax.set(title=title, xlabel='Time', ylabel='MFCC')
        if xlim:
            ax.set_xlim(xlim)
        if ylim:
            ax.set_ylim(ylim)
        plt.show(block=False)

    @staticmethod
    def plot_mel_spec(mel_spec, title="Mel Spectrogram", xlim=None, ylim=None):
        mel_spec = mel_spec.squeeze(0).numpy()
        fig, ax = plt.subplots()
        im = ax.imshow(mel_spec, origin='lower', aspect='auto')
        fig.colorbar(im, ax=ax)
        ax.set(title=title, xlabel='Time', ylabel='Frequency (Hz)')
        if xlim:
            ax.set_xlim(xlim)
        if ylim:
            ax.set_ylim(ylim)
        plt.show(block=False)


class AN4Dataset(Dataset):
    def __init__(self, split, transform=None, target_transform=None):
        self.transform = transform
        self.target_transform = target_transform

        audio_dir = os.path.join(DATA_PATH, split, 'an4', 'wav')
        transcript_dir = os.path.join(DATA_PATH, split, 'an4', 'txt')

        audio_files = sorted(os.listdir(audio_dir))
        transcript_files = sorted(os.listdir(transcript_dir))

        assert len(audio_dir) == len(transcript_dir)
        # for i in range(len(audio_files)):
        #     a_name = audio_files[i].split(".")[0]
        #     t_name = transcript_files[i].split(".")[0]

        #     # assert a_name == t_name

        audio_paths = [os.path.join(audio_dir, file) for file in audio_files if file.endswith('wav')]
        transcript_paths = [os.path.join(transcript_dir, file) for file in
                            transcript_files if file.endswith('txt')]

        self.audios, self.transcripts = [], []

        for audio_path, transcript_path in zip(audio_paths, transcript_paths):
            with open(transcript_path, 'r') as f:
                transcript = f.read().strip()
                self.transcripts.append(transcript)

        self.loaded_audios = [torchaudio.load(audio)[0] for audio in audio_paths]


    def __len__(self):
        return len(self.loaded_audios)

    def __getitem__(self, idx):
        return self.loaded_audios[idx], self.transcripts[idx]


def data_processing(data, data_type="train"):
    inputs, inputs_lengths, labels, labels_length = [], [], [], []

    if data_type == "train":
        transform = nn.Sequential(
            torchaudio.transforms.MelSpectrogram(sample_rate=SR, n_mels=N_MELS),
            torchaudio.transforms.FrequencyMasking(freq_mask_param=30),
            torchaudio.transforms.TimeMasking(time_mask_param=100))
    else:
        transform = torchaudio.transforms.MelSpectrogram()

    text_transform = TextTransform()
    for (wav, transcript) in data:
        spec = transform(wav).squeeze(0).transpose(0, 1)
        inputs.append(spec)
        inputs_lengths.append(spec.shape[0] // 2)
        label = torch.Tensor(text_transform.text_to_int(str(transcript).lower()))
        labels.append(label)
        labels_length.append(len(label))


    max_length = (max(inputs_lengths) * 2) + 1
    # Pad tensors and create the big tensor
    spectrograms = torch.zeros((len(inputs), max_length, N_MELS))
    for i, tensor in enumerate(inputs):
        spectrograms[i, :tensor.shape[0], :] = tensor[:, :]

    spectrograms = spectrograms.unsqueeze(1).transpose(2, 3)
    # spectrograms = nn.utils.rnn.pad_sequence(inputs, batch_first=True).unsqueeze(1).transpose(2, 3)
    labels = nn.utils.rnn.pad_sequence(labels, batch_first=True)
    return spectrograms, labels, inputs_lengths, labels_length


In [71]:
import torch
import torchaudio

import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torchaudio.models.decoder import ctc_decoder
from torchaudio.models.decoder._ctc_decoder import download_pretrained_files

from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
files = download_pretrained_files("librispeech-4-gram")


class CNNLayerNorm(nn.Module):
    """Layer normalization built for cnns input"""

    def __init__(self, n_feats):
        super(CNNLayerNorm, self).__init__()
        self.layer_norm = nn.LayerNorm(n_feats)

    def forward(self, x):
        # x (batch, channel, feature, time)
        x = x.transpose(2, 3).contiguous()  # (batch, channel, time, feature)
        x = self.layer_norm(x)
        return x.transpose(2, 3).contiguous()  # (batch, channel, feature, time)


class ResidualCNN(nn.Module):
    """Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf
        except with layer norm instead of batch norm
    """

    def __init__(self, in_channels, out_channels, kernel, stride, dropout,
                 n_feats):
        super(ResidualCNN, self).__init__()

        self.cnn1 = nn.Conv2d(in_channels, out_channels, kernel, stride,
                              padding=kernel // 2)
        self.cnn2 = nn.Conv2d(out_channels, out_channels, kernel, stride,
                              padding=kernel // 2)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.layer_norm1 = CNNLayerNorm(n_feats)
        self.layer_norm2 = CNNLayerNorm(n_feats)

    def forward(self, x):
        residual = x  # (batch, channel, feature, time)
        x = self.layer_norm1(x)
        x = F.gelu(x)
        x = self.dropout1(x)
        x = self.cnn1(x)
        x = self.layer_norm2(x)
        x = F.gelu(x)
        x = self.dropout2(x)
        x = self.cnn2(x)
        x += residual
        return x  # (batch, channel, feature, time)


class BidirectionalGRU(nn.Module):

    def __init__(self, rnn_dim, hidden_size, dropout, batch_first):
        super(BidirectionalGRU, self).__init__()

        self.BiGRU = nn.GRU(
            input_size=rnn_dim, hidden_size=hidden_size,
            num_layers=1, batch_first=batch_first, bidirectional=True)
        self.layer_norm = nn.LayerNorm(rnn_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.layer_norm(x)
        x = F.gelu(x)
        x, _ = self.BiGRU(x)
        x = self.dropout(x)
        return x


class SpeechRecognitionModel(nn.Module):
    """Speech Recognition Model Inspired by DeepSpeech 2"""

    def __init__(self, n_cnn_layers, n_rnn_layers, rnn_dim,
                 n_class, n_feats,
                 stride=2, dropout=0.1):
        super(SpeechRecognitionModel, self).__init__()
        # self.vocabulary = vocabulary
        n_feats = n_feats // 2
        self.cnn = nn.Conv2d(1, 32, 3, stride=stride,
                             padding=3 // 2)  # cnn for extracting heirachal features

        # n residual cnn layers with filter size of 32
        self.rescnn_layers = nn.Sequential(*[
            ResidualCNN(32, 32, kernel=3, stride=1, dropout=dropout,
                        n_feats=n_feats)
            for _ in range(n_cnn_layers)
        ])
        self.fully_connected = nn.Linear(n_feats * 32, rnn_dim)
        self.birnn_layers = nn.Sequential(*[
            BidirectionalGRU(rnn_dim=rnn_dim if i == 0 else rnn_dim * 2,
                             hidden_size=rnn_dim, dropout=dropout, batch_first=i == 0)
            for i in range(n_rnn_layers)
        ])
        self.classifier = nn.Sequential(
            nn.Linear(rnn_dim * 2, rnn_dim),  # birnn returns rnn_dim*2
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(rnn_dim, n_class)
        )

        # Decoders
        # self.greedy_decoder = GreedyDecoder(vocabulary.translator.values())
        # self.beam_decoder = ctc_decoder(lexicon='lexicon.txt',
        #                                 tokens='tokens.txt', lm=None)

    def forward(self, x):
        x = self.cnn(x)
        x = self.rescnn_layers(x)
        sizes = x.size()
        x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3])  # (batch, feature, time)
        x = x.transpose(1, 2)  # (batch, time, feature)
        x = self.fully_connected(x)
        x = self.birnn_layers(x)
        x = self.classifier(x)
        return x


def save_model(model, path):
    """
    Saves a pytorch models to the given path.
    """
    torch.save(model.state_dict(), '{}'.format(path))


def load_model(model, path):
    """
    Loads a pytorch models from the given path. The models should already by
    created (e.g. by calling the constructor) and should be passed as an argument.
    """
    model.load_state_dict(torch.load('{}'.format(path)))
    model.eval()


def extract_features(wavs, is_train=False):
    """
    Extract MFCC features from the given audios batch.
    More ideas: try Time Domain / STFT / Mel Spectrogram
    """
    spectrograms, input_lengths = [], []

    # MFCC Transform
    # transform = torchaudio.transforms.MFCC(
    #     sample_rate=SR, n_mfcc=N_MFCC)
    # mfcc_batch = mfcc_transform(wavs).squeeze()
    # mfcc_batch = mfcc_batch.permute(0, 2, 1)
    # return mfcc_batch

    # transform = torchaudio.transforms.MelSpectrogram(SR)
    # mel_batch = transform(wavs).squeeze()
    # mel_batch = mel_batch.permute(0, 2, 1)
    # return mel_batch
    if is_train:
        transform = nn.Sequential(
            torchaudio.transforms.MelSpectrogram(sample_rate=SR, n_mels=N_MELS),
            torchaudio.transforms.FrequencyMasking(freq_mask_param=15),
            torchaudio.transforms.TimeMasking(time_mask_param=35))
    else:
        transform = torchaudio.transforms.MelSpectrogram()

    for wav in wavs:
        spec = transform(wav).squeeze(0).transpose(0, 1)
        spectrograms.append(spec)
        input_lengths.append(spec.shape[0] // 2)

    spectrograms = nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1).transpose(2, 3)

    # mel_batch = mel_batch.permute(0, 2, 1)  # (batch, mel, timeFrame)
    return spectrograms, torch.Tensor(input_lengths).long()


class LSTMModel(nn.Module):
    """
    A basic LSTM models for speech recognition.
    """

    def __init__(self, vocabulary, lang_model=None):
        super(LSTMModel, self).__init__()
        self.vocabulary = vocabulary
        self.lang_model = lang_model

        # RNN layers
        self.rnn = nn.LSTM(input_size=4032, hidden_size=HIDDEN_DIM,
                           num_layers=NUM_LAYERS, batch_first=True)

        self.conv = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3)

        # Fully connected layer
        self.fc = nn.Linear(HIDDEN_DIM, NUM_CLASSES)

        # Decoders
        # self.greedy_decoder = GreedyDecoder(vocabulary.translator.values())
        # self.beam_decoder = ctc_decoder(lexicon='lexicon.txt',
        #                                 tokens='tokens.txt', lm=files.lm)

    def forward(self, x):
        x = self.conv(x)
        sizes = x.size()
        x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3])  # (batch, feature, time)
        x = x.transpose(1, 2)  # (batch, time, feature)
        rnn_output, _ = self.rnn(x)
        output = self.fc(rnn_output)

        return output


def predict(model, feats):
    """
    Predicts a batch of waveforms using the given models.
    """
    emission = model(feats)
    greedy_result = model.greedy_decoder(emission)
    # beam_search_result = model.beam_decoder(emission)
    return greedy_result


def test(model, test_loader, criterion):
    print('\nevaluating...')
    model.eval()
    model = model.to(device)
    test_loss = 0
    with torch.no_grad():
        for i, _data in enumerate(test_loader):
            spectrograms, labels, input_lengths, label_lengths = _data
            spectrograms, labels = spectrograms.to(device), labels.to(device)

            output = model(spectrograms)  # (batch, time, n_class)
            output = F.log_softmax(output, dim=2).to(device)
            output = output.transpose(0, 1)  # (time, batch, n_class)

            loss = criterion(output, labels, input_lengths, label_lengths)
            test_loss += loss.item() / len(test_loader)
            # arg_maxes = torch.argmax(output.transpose(0, 1), dim=2)
            decoded_preds, decoded_targets = gd(output.transpose(0, 1), labels, label_lengths)
            print(decoded_preds, decoded_targets)


def train_all_data(model, train_loader, criterion):
    data_len = len(train_loader.dataset)
    optimizer = torch.optim.RMSprop(model.parameters(), LEARNING_RATE)
    # scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LEARNING_RATE,
    #                                           steps_per_epoch=data_len,
    #                                           epochs=N_EPOCHS,
    #                                           anneal_strategy='linear')
    model.train()
    model = model.to(device)
    for epoch in range(N_EPOCHS):
        if (epoch + 1) % 5 == 0:
            save_model(model, CTC_MODEL_PATH)
        e_loss = 0
        for batch_idx, _data in enumerate(train_loader):
            spectrograms, labels, input_lengths, label_lengths = _data
            spectrograms, labels = spectrograms.to(device), labels.to(device)

            optimizer.zero_grad()
            output = model(spectrograms)  # (batch, time, n_class)
            output = F.log_softmax(output, dim=2)
            output = output.transpose(0, 1)  # (time, batch, n_class)

            loss = criterion(output, labels, input_lengths, label_lengths)
            loss.backward()
            optimizer.step()
            # scheduler.step()
            e_loss += loss.item()
        print(f"Train Epoch: {epoch}, loss = {e_loss}")

    save_model(model, CTC_MODEL_PATH)


In [72]:
import os
import torch
import random
import matplotlib.pyplot as plt
# from jiwer import wer, cer
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def evaluate(model, x_test, y_test):
    """
    Evaluate the models over the test set.
    """
    predictions, targets = [], []

    for i, batch_start in tqdm(enumerate(range(0, len(x_test), BATCH_SIZE))):
        batch = x_test[batch_start:batch_start + BATCH_SIZE]
        # feats = torch.zeros((len(batch), 1, MAX_LEN))
        # for i, tensor in enumerate(batch):
        #     padded_tensor = torch.cat(
        #         [tensor, torch.zeros((1, MAX_LEN-tensor.size(1)))], dim=1)
        #     big_tensor[i] = padded_tensor
        feats, _ = ctc_model.extract_features(batch)
        batch_preds = ctc_model.predict(model, feats)
        for j in range(len(batch_preds)):
            pred_tokens = model.beam_decoder.idxs_to_tokens(batch_preds[j][0].tokens)
            if j % 50 == 0:
                print(f'True transcription: {y_test[batch_start + j]}')
                print(f'Predicted transcription: {pred_tokens}')
            predictions.append(pred_tokens)
            targets.append(y_test[batch_start + j])
            # plot_alignments(batch[j],
            #                 models(models.extract_features(batch[j])),
            #                 pred_tokens, batch_preds[j].timesteps)

    wer_error = wer(targets, predictions)
    cer_error = cer(targets, predictions)
    return wer_error, cer_error


def plot_alignments(waveform, emission, tokens, timesteps):
    """
    Plots the alignment between the waveform and the predicted transcription.
    """
    fig, ax = plt.subplots(figsize=(32, 10))
    ax.plot(waveform)

    ratio = waveform.shape[0] / emission.shape[1]
    word_start = 0
    for i in range(len(tokens)):
        if i != 0 and tokens[i - 1] == "|":
            word_start = timesteps[i]
        if tokens[i] != "|":
            plt.annotate(tokens[i].upper(), (timesteps[i] * ratio, waveform.max() * 1.02), size=14)
        elif i != 0:
            word_end = timesteps[i]
            ax.axvspan(word_start * ratio, word_end * ratio, alpha=0.1, color="red")

    xticks = ax.get_xticks()
    plt.xticks(xticks, xticks / SR)
    ax.set_xlabel("Time")
    ax.set_xlim(0, waveform.shape[0])
    plt.show()


def test_distance_algorithms(data):
    """
    Test the Distances algorithms (DTW & Euclidean) as the most naive
    implementations.
    """
    x_train, y_train = data.get_data('train')
    x_val, y_val = data.get_data('val')
    x_test, y_test = data.get_data('test')

    dtw = DTWModel(x_train, y_train)
    dtw.add_data(x_val, y_val)
    predictions_dtw = dtw.classify_using_DTW_distance(x_test)
    print('Predictions:')
    print(predictions_dtw[:5])
    print('True labels:')
    print(y_test[:5])
    print('Testing DTW algorithm...')
    wer_error = wer(y_test, predictions_dtw)
    cer_error = cer(y_test, predictions_dtw)
    print(f'DTW Test WER: {wer_error:.4f}')
    print(f'DTW Test CER: {cer_error:.4f}')
    print('DTW tested successfully')

    # print('Testing Euclidean algorithm...')
    # euclidean = EuclideanModel(x_train, y_train)
    # euclidean.add_data(x_val, y_val)
    # predictions_euclidean = euclidean.classify_using_euclidean_distance(x_test)
    # wer_error = wer(y_test, predictions_euclidean)
    # cer_error = cer(y_test, predictions_euclidean)
    # print('Predictions:')
    # print(predictions_euclidean[:5])
    # print('True labels:')
    # print(y_test[:5])
    # print(f'Euclidean Test WER: {wer_error:.4f}')
    # print(f'Euclidean Test CER: {cer_error:.4f}')
    # print('Euclidean tested successfully')


def main():
    print('--- Start running ---')
    # test_distance_algorithms(data)

    train_data_set = AN4Dataset('train')
    test_data_set = AN4Dataset('test')
    train_loader = DataLoader(dataset=train_data_set,
                              batch_size=BATCH_SIZE,
                              shuffle=True,
                              collate_fn=lambda x: data_processing(x, 'train'))
    test_loader = DataLoader(dataset=test_data_set,
                             batch_size=BATCH_SIZE,
                             shuffle=False,
                             collate_fn=lambda x: data_processing(x, 'val'))

    # lang_model = language_model.LanguageModel(vocabulary)
    print('Training the language models...')
    # language_model.train_all_data(lang_model, y_train+y_val)
    print('Language models trained successfully')
    # ctc_lstm = SpeechRecognitionModel(hparams['n_cnn_layers'], hparams['n_rnn_layers'], hparams['rnn_dim'],
    #                                             hparams['n_class'], hparams['n_feats'], hparams['stride'],
    #                                             hparams['dropout']
    #                                             ).to(device)

    ctc_lstm = LSTMModel(0)
    lossFunc = nn.CTCLoss(blank=28, zero_infinity=True).to(device)
    if os.path.exists(CTC_MODEL_PATH):
        print('Loading the models...')
        load_model(ctc_lstm, CTC_MODEL_PATH)
        print('Model loaded successfully')
    else:  # Train the models
        print('Training the models...')
        train_all_data(ctc_lstm, train_loader, lossFunc)
    print('Model trained successfully')

    test(ctc_lstm, test_loader, lossFunc)

    # Evaluate the models on the test set
    # print('Evaluating the models...')
    # x_test, y_test = data.get_data('train')
    # test_wer, test_cer = evaluate(ctc_lstm, x_test, y_test)
    # print(f'Test WER: {test_wer:.4f}')
    # print(f'Test CER: {test_cer:.4f}')
    # print('Model evaluated successfully')
    print('-- Finished running ---')


if __name__ == '__main__':
    main()


--- Start running ---
Training the language models...
Language models trained successfully
Loading the models...
Model loaded successfully
Model trained successfully

evaluating...
['', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', ''] ['enter two nine eight one', 'repeat', 'erase u d b e five', 'rubout u b u t r six', 'enter one oh four', 'erase a b f n q fifty seven', 't r t f i seven', 'rubout c b w x v four', 'w y a t u seventy seven seventy seven', 'enter eight', 'enter seven one five four', 'no', 'g f t u one three three two', 'l 