<a href="https://colab.research.google.com/github/GreihMurray/KriolTranscriber/blob/master/char_net.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install evaluate
!pip install jiwer

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting evaluate
  Downloading evaluate-0.3.0-py3-none-any.whl (72 kB)
[K     |████████████████████████████████| 72 kB 860 kB/s 
[?25hCollecting datasets>=2.0.0
  Downloading datasets-2.7.1-py3-none-any.whl (451 kB)
[K     |████████████████████████████████| 451 kB 66.8 MB/s 
[?25hCollecting responses<0.19
  Downloading responses-0.18.0-py3-none-any.whl (38 kB)
Collecting huggingface-hub>=0.7.0
  Downloading huggingface_hub-0.11.1-py3-none-any.whl (182 kB)
[K     |████████████████████████████████| 182 kB 84.1 MB/s 
Collecting xxhash
  Downloading xxhash-3.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212 kB)
[K     |████████████████████████████████| 212 kB 48.0 MB/s 
[?25hCollecting multiprocess
  Downloading multiprocess-0.70.14-py38-none-any.whl (132 kB)
[K     |████████████████████████████████| 132 kB 15.3 MB/s 
Collecting urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21

In [2]:
from keras.utils import to_categorical
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
import tensorflow as tf
import numpy as np
import itertools
from evaluate import load
from scipy.io import wavfile
from keras.models import load_model
from evaluate import load

In [3]:
import math, random
import torch
import torchaudio
from torchaudio import transforms
import pandas as pd
from tqdm import tqdm
import os
import unicodedata
import re

In [4]:
from google.colab import drive

drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [5]:
DEFAULT_DIR = '/content/gdrive/MyDrive/Colab_Notebooks/NLP/project/'
SR = 44100

In [6]:
def load_html_data(dir_ext):
    all_data = []
    chars_to_replace = '.,;:<>?/\'\\[]()!#$%"”“'
    directory = DEFAULT_DIR + 'html/' + dir_ext + '/'

    dir_files = sorted(os.listdir(directory))

    for file in tqdm(dir_files, desc='Loading HTML Data'):
        file = directory + file
        divs = []
        with open(file, 'r', encoding='UTF-8') as in_file:
            data = ' '.join(in_file.readlines())
            data = unicodedata.normalize('NFC', data)
            data = re.sub(r'<span class=.*?</span>', '', data)
            data = re.sub(r'<a.*?</a>', '', data)
            data = data.replace(u'\xa0', u' ')
            check_divs = re.findall(r'<div class=\'mt\'.*?>(.*?) </div>', data)
            check_divs.extend(re.findall(r'<div class=\'mt2\'.*?>(.*?) </div>', data))
            # print(data)
            check_divs.extend(re.findall(r'<div class=\'ip\'>(.*)', data))

            if len(check_divs) > 0:
                full = '!'.join(check_divs)
                full = re.sub(r'[\,,@,#,$,%,^,&,*,(,),\[,\],\',\",;,:,“,”,‘,’]', '', full)
                full = re.sub('^\s+', ' ', full).strip('\u00A0')
                full = re.split('[\.,\?,!,\n]', str(full))
                all_data.extend([s.strip() for s in full])

            divs.extend(re.findall(r'<div class=\'[p,s]\'.*?>(.*?) </div>', data))

        full_data = '!'.join(divs)
        full_data = re.sub(r'[\,,@,#,$,%,^,&,*,(,),\[,\],\',\",;,:,“,”,‘,’]', '', full_data)
        full_data = re.sub(' +', ' ', full_data)
        full_data = re.split('[\.,\?,!,\n]', str(full_data))

        all_data.extend([s.strip() for s in full_data])

        clean = []
        for row2 in all_data:
            if len(row2) >= 1:
                clean.append(row2.lower())

    return clean

In [7]:
def load_dataset(dir_ext):
    transcripts = load_html_data(dir_ext)
    all_x = []
    all_y = []
    directory = DEFAULT_DIR + 'audio/' + dir_ext + '/'

    order = []

    dir_files = sorted(os.listdir(directory))

    for i, file in tqdm(enumerate(dir_files), desc='Loading Audio Data & Creating Dataset'):
        file = directory + file
        
        sr, data = wavfile.read(file)

        max_len = SR//1000 * 15000

        if len(data) > max_len:
          continue

        all_x.append(data)
        all_y.append(transcripts[i])


    return pd.DataFrame(list(zip(all_x, all_y)), columns=['audio', 'transcription'])

In [18]:
def one_hot(data, test_data, map_use=None):
    mapping = {}
    mapped = []
    
    if not test_data:
        i = 0

        for sentence in data:
          cur_map = []
          for word in sentence:
              if word in mapping.keys():
                cur_map.append(mapping[word])
              else:
                i += 1
                cur_map.append(i)
                mapping[word] = i
                
          mapped.append(cur_map)
    else:
          for sentence in data:
            cur_map = []
            for word in sentence:
                if word in map_use.keys():
                  cur_map.append(map_use[word])
                else:
                  i += 1
                  cur_map.append(i)
                  map_use[word] = i
          mapped.append(cur_map)
      
    return mapped, mapping

In [11]:
def vectorize(data):
    vecs = []
    for row in data:
      cur_vec = [0] * len(row[0])
      for piece in row:
        piece = list(piece)
        cur_vec[piece.index(1)] += 1
      cur_vec[7] = 1
      vecs.append(cur_vec)

    return vecs

In [12]:
def pad_audio(data):
    max_len = SR//1000 * 15000

    for i, aud in tqdm(enumerate(data), desc='Padding audio'):
      if len(aud) < max_len:
        data[i] = np.array(np.append(aud, np.zeros(max_len - len(aud))))

    return data

In [13]:
def build_model(input_len, output_len):
    model = Sequential()
    model.add(Dense(256, input_shape=(input_len,), activation="sigmoid"))
    model.add(Dense(128, activation="sigmoid"))
    model.add(Dense(output_len, activation="relu"))

    model.compile(loss='poisson', metrics=['accuracy'])

    return model

In [14]:
def load_and_process_data(folder_path, test_data=False, map_use=None):
    df = load_dataset(folder_path)
    clean_y = list(df['transcription'])

    padded_y = list(zip(*itertools.zip_longest(*list(df['transcription']), fillvalue='!')))
    enc_y, mapping = one_hot(padded_y, test_data, map_use)
    enc_y = np.array(enc_y)

    padded_x = pad_audio(df['audio'])

    padded_x = np.stack(padded_x)

    return clean_y, enc_y, padded_x, mapping

In [15]:
def evaluate(model, padded_x, mapping, clean_y):
    cur_words = []

    map_list = mapping.items()
    map_keys = list(mapping.keys())
    map_vals = list(mapping.values())

    all_words = []
    preds = model.predict(padded_x)

    for pred in preds:
        cur_words = []
        for i, val in enumerate(pred):
            if val > 0:
                if round(val) == 0:
                    val = 1
                if round(val) == 7:
                    break
                cur_words.append(map_keys[map_vals.index(round(val))])
        all_words.append(''.join(cur_words))

    cer = load('cer')
    cer_score = cer.compute(predictions=all_words, references=clean_y)
    print('Character Error Rate:', cer_score)

    wer = load('wer')
    wer_score = wer.compute(predictions=all_words, references=clean_y)
    print('Word Error Rate:', wer_score)

In [16]:
def main(train=False, model_path='/content/gdrive/MyDrive/Colab_Notebooks/NLP/project/charnet_model_350e'):
    clean_y, enc_y, padded_x, mapping = load_and_process_data('train')
    y_test, _, x_test, _ = load_and_process_data('train', test_data=True, map_use=mapping)

    if train:
        model = build_model(len(padded_x[0]), len(enc_y[0]))
        model.fit(padded_x, enc_y, epochs=50, verbose=1, batch_size=1)
    else:
        model = load_model(model_path)

    evaluate(model, x_test, mapping, y_test)

In [19]:
main()

Loading HTML Data: 100%|██████████| 293/293 [00:01<00:00, 205.28it/s]
  sr, data = wavfile.read(file)
Loading Audio Data & Creating Dataset: 356it [00:00, 419.15it/s]
Padding audio: 306it [00:00, 808.53it/s]
Loading HTML Data: 100%|██████████| 293/293 [00:01<00:00, 214.38it/s]
Loading Audio Data & Creating Dataset: 356it [00:00, 453.58it/s]
Padding audio: 306it [00:00, 937.87it/s]




Downloading builder script:   0%|          | 0.00/5.60k [00:00<?, ?B/s]

Character Error Rate: 0.8919685968080793


Downloading builder script:   0%|          | 0.00/4.49k [00:00<?, ?B/s]

Word Error Rate: 0.9990221005280657
