<a href="https://colab.research.google.com/github/GreihMurray/KriolTranscriber/blob/master/lstm_chars.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install evaluate
!pip install jiwer

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting jiwer
  Using cached jiwer-2.5.1-py3-none-any.whl (15 kB)
Collecting levenshtein==0.20.2
  Downloading Levenshtein-0.20.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.4 MB)
[K     |████████████████████████████████| 1.4 MB 13.8 MB/s 
[?25hCollecting rapidfuzz<3.0.0,>=2.3.0
  Downloading rapidfuzz-2.13.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.2 MB)
[K     |████████████████████████████████| 2.2 MB 38.4 MB/s 
[?25hInstalling collected packages: rapidfuzz, levenshtein, jiwer
Successfully installed jiwer-2.5.1 levenshtein-0.20.2 rapidfuzz-2.13.3


In [2]:
from keras.utils import to_categorical
from tensorflow.keras.models import Sequential
from keras.utils.vis_utils import plot_model
import tensorflow as tf
from tensorflow.keras import activations
import numpy as np
from evaluate import load
import itertools
from scipy.io import wavfile
from keras import backend as K
from tensorflow.keras.layers import LSTM, Embedding, Dense, TimeDistributed, Dropout, Bidirectional, Reshape, GRU, Flatten, Activation

In [3]:
import math, random
import torch
import torchaudio
from torchaudio import transforms
import pandas as pd
from tqdm import tqdm
import os
import unicodedata
import re
from keras.models import load_model

In [4]:
from google.colab import drive

drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [5]:
DEFAULT_DIR = '/content/gdrive/MyDrive/Colab_Notebooks/NLP/project/'
SR = 44100

In [6]:
def load_html_data(dir_ext):
    all_data = []
    chars_to_replace = '.,;:<>?/\'\\[]()!#$%"”“'
    directory = DEFAULT_DIR + 'html/' + dir_ext + '/'

    dir_files = sorted(os.listdir(directory))

    for file in tqdm(dir_files, desc='Loading HTML Data'):
        file = directory + file
        divs = []
        with open(file, 'r', encoding='UTF-8') as in_file:
            data = ' '.join(in_file.readlines())
            data = unicodedata.normalize('NFC', data)
            data = re.sub(r'<span class=.*?</span>', '', data)
            data = re.sub(r'<a.*?</a>', '', data)
            data = data.replace(u'\xa0', u' ')
            check_divs = re.findall(r'<div class=\'mt\'.*?>(.*?) </div>', data)
            check_divs.extend(re.findall(r'<div class=\'mt2\'.*?>(.*?) </div>', data))
            # print(data)
            check_divs.extend(re.findall(r'<div class=\'ip\'>(.*)', data))

            if len(check_divs) > 0:
                full = '!'.join(check_divs)
                full = re.sub(r'[\,,@,#,$,%,^,&,*,(,),\[,\],\',\",;,:,“,”,‘,’]', '', full)
                full = re.sub('^\s+', ' ', full).strip('\u00A0')
                full = re.split('[\.,\?,!,\n]', str(full))
                all_data.extend([s.strip() for s in full])

            divs.extend(re.findall(r'<div class=\'[p,s]\'.*?>(.*?) </div>', data))

        full_data = '!'.join(divs)
        full_data = re.sub(r'[\,,@,#,$,%,^,&,*,(,),\[,\],\',\",;,:,“,”,‘,’]', '', full_data)
        full_data = re.sub(' +', ' ', full_data)
        full_data = re.split('[\.,\?,!,\n]', str(full_data))

        all_data.extend([s.strip() for s in full_data])

        clean = []
        for row2 in all_data:
            if len(row2) >= 1:
                clean.append(row2.lower())

    return clean

In [20]:
def load_dataset(dir_ext):
    transcripts = load_html_data(dir_ext)
    all_x = []
    all_y = []
    directory = DEFAULT_DIR + 'audio/' + dir_ext + '/'

    order = []

    dir_files = sorted(os.listdir(directory))

    for i, file in tqdm(enumerate(dir_files), desc='Loading Audio Data & Creating Dataset'):
        file = directory + file
        
        sr, data = wavfile.read(file)

        max_len = SR//1000 * 15000

        if len(data) > max_len:
          continue

        all_x.append(data)
        all_y.append(transcripts[i])


    return pd.DataFrame(list(zip(all_x, all_y)), columns=['audio', 'transcription'])

In [8]:
def one_hot(data, test_data, map_use):
    mapping = {}
    mapped = []

    i = 0

    if not test_data:
        for sentence in data:
          cur_map = []
          for word in sentence:
              if word in mapping.keys():
                cur_map.append(mapping[word])
              else:
                i += 1
                cur_map.append(i)
                mapping[word] = i
                
          mapped.append(cur_map)
    else:
        for sentence in data:
          cur_map = []
          for word in sentence:
              if word in map_use.keys():
                cur_map.append(map_use[word])
              else:
                i += 1
                cur_map.append(i)
                map_use[word] = i
                
          mapped.append(cur_map)
      
    return mapped, mapping

In [9]:
def vectorize(data):
    vecs = []
    for row in data:
      cur_vec = [0] * len(row[0])
      for piece in row:
        piece = list(piece)
        cur_vec[piece.index(1)] += 1
      cur_vec[2] = 1
      vecs.append(cur_vec)

    return vecs

In [10]:
def pad_audio(data):
    max_len = SR//1000 * 15000

    for i, aud in tqdm(enumerate(data), desc='Padding audio'):
      if len(aud) < max_len:
        data[i] = np.array(np.append(aud, np.zeros(max_len - len(aud))))

    return data

In [11]:
def get_min(audio):
    minim = 0

    for row in tqdm(audio, desc='Finding min'):
        if min(row) < minim:
            minim = min(row)
    
    return minim

In [12]:
def get_max(audio):
    maxim = 0

    for row in tqdm(audio, desc='Finding max'):
        if max(row) > maxim:
           maxim = max(row)

    return maxim

In [13]:
def adjust_audio(audio, minim):
    for i, row in tqdm(enumerate(audio), desc='Adjusting Audio'):
        for j, val in enumerate(row):
            audio[i][j] = int(val + abs(minim))

    return audio

In [14]:
def advanced_relu(x):
    return K.relu(x, max_value=500000)

In [15]:
def build_model(input_len, output_len, maxim):
    model = Sequential()
    model.add(Dense(321, input_shape=(input_len,), activation=advanced_relu))
    model.add(Embedding(input_dim=499999, output_dim=34))
    model.add(Bidirectional(LSTM(units=50, return_sequences=True, recurrent_dropout=0.1)))
    model.add(Dense(34, activation='softmax'))
    #model.add(Reshape((321, 34)))
    
    model.compile(optimizer="adam", metrics=["accuracy"], loss='categorical_crossentropy')


    return model

In [16]:
def load_and_process_data(folder_path, test_data=False, map_use=None):
    df = load_dataset(folder_path)
    clean_y = list(df['transcription'])

    padded_y = list(zip(*itertools.zip_longest(*list(df['transcription']), fillvalue='!')))
    enc_y, mapping = one_hot(padded_y, test_data, map_use)
    cat_y = np.array(to_categorical(enc_y))

    padded_x = pad_audio(df['audio'])

    minim = get_min(padded_x)
    padded_x = adjust_audio(padded_x, minim)
    maxim = get_max(padded_x)

    padded_x = np.stack(padded_x)

    for entry in cat_y:
        for row in entry:
            if row[7] != 0:
                row[7] = 0

    return clean_y, cat_y, padded_x, minim, maxim, mapping

In [17]:
def evaluate(model, clean_y, padded_x, mapping):
    cur_words = []

    map_list = mapping.items()
    map_key = list(mapping.keys())

    last_char = False

    all_words = []

    preds = model.predict(padded_x)

    print(len(preds))

    for pred in preds:
        cur_words = []
        for row in pred:
            row = list(row)
            cur_words.append(map_key[row.index(max(row))-1])
            
        len_words = len(cur_words) - 1

        for i in range(len_words, 0, -1):
            if cur_words[i] != ' ':
                joined = ''.join(cur_words)
                all_words.append(re.sub(' +', ' ', joined))
                break
            else:
                cur_words.pop()

    cer = load('cer')
    cer_score = cer.compute(predictions=all_words, references=clean_y)
    print('Character Error Rate:', cer_score)

    wer = load('wer')
    wer_score = wer.compute(predictions=all_words, references=clean_y)
    print('Word Error Rate:', wer_score)

In [24]:
def main(train=False, model=''):
    print('\033[95m' + 'LOADING TRAINING DATA\n')
    clean_y, cat_y, padded_x, minim, maxim, mapping = load_and_process_data('train')
    print('\n\n', '\033[95m' + 'LOADING TESTING DATA\n', sep='')
    y_test, _, x_test, _, _, _ = load_and_process_data('train', test_data=True, map_use=mapping)

    if train:
        model = build_model(len(padded_x[0]), len(cat_y[0]), maxim)
        plot_model(model, show_shapes=True)

        model.fit(padded_x, cat_y, epochs=100, verbose=1, batch_size=1)
    else:
        # '/content/gdrive/MyDrive/Colab_Notebooks/NLP/project/lstm_model_450e'
        model = load_model(model)

    evaluate(model, y_test, x_test, mapping)

In [23]:
main(model='/content/gdrive/MyDrive/Colab_Notebooks/NLP/project/lstm_model_450e')

[95mLOADING TRAINING DATA


Loading HTML Data: 100%|██████████| 293/293 [00:01<00:00, 168.88it/s]
  sr, data = wavfile.read(file)
Loading Audio Data & Creating Dataset: 302it [00:00, 398.91it/s]


KeyboardInterrupt: ignored