# Train a 🐸 STT model with Common Voice data 💫

👋 Hello and welcome to Coqui (🐸) STT 

This notebook shows a **typical workflow** for **training** and **testing** an 🐸 STT model on data from Common Voice.

In this notebook, we will:

1. Download Common Voice data (pre-formatted for 🐸 STT)
2. Configure the training and testing runs
3. Train a new model
4. Test the model and display its performance

So, let's jump right in!

*PS - If you just want a working, off-the-shelf model, check out the [🐸 Model Zoo](https://www.coqui.ai/models)*

In [None]:
!python -m pip install --upgrade pip wheel setuptools
! pip install coqui_stt_training
! apt-get install libopusfile0 libopus-dev libopusfile-dev
!pip install 'tensorflow-gpu==1.15.4'

In [None]:
import os
%cd '/content'
!git clone https://github.com/coqui-ai/STT.git
%cd STT
!pip install -r requirements_eval_tflite.txt
!pip install -r requirements_tests.txt
!pip install -r requirements_transcribe.txt
!python3 setup.py bdist_wheel
!pip install dist/*.whl

In [None]:
from google.colab import drive
import os
os.chdir('/content')
drive.mount('drive', force_remount=True)
gdrive = '/content/drive/MyDrive/'

In [None]:
# for mounting 2 diff drives, from https://stackoverflow.com/questions/53728127/mount-multiple-drives-in-google-colab
!sudo add-apt-repository -y ppa:alessandro-strada/ppa 2>&1 > /dev/null
!sudo apt-get update -qq 2>&1 > /dev/null
!sudo apt -y install -qq google-drive-ocamlfuse 2>&1 > /dev/null
!google-drive-ocamlfuse

In [None]:
!apt-get install -qq w3m
!xdg-settings set default-web-browser w3m.desktop # to set default browser
%cd /content
!mkdir drive2
%cd drive2
!mkdir MyDrive
%cd ..
!google-drive-ocamlfuse /content/drive2/MyDrive
gdrive2 = '/content/drive2/MyDrive/'
!cp 'drive2/MyDrive/Voice-Cloning/cpg' '/usr/local/bin'
!chmod +x '/usr/local/bin/cpg'

In [None]:
gdrive = '/content/drive/MyDrive/'
gdrive2 = '/content/drive2/MyDrive/'

In [None]:
!ln -s '/content/drive2/MyDrive/' /content/drive2/My\ Drive

In [None]:
%cd /content
!cpg '$gdrive2''Voice-Cloning/clips.tar.gz' '/content'
!tar -xf clips.tar.gz
!rm clips.tar.gz
!rm clips/dev.csv clips/train.csv clips/train-all.csv clips/other.csv clips/test.csv

In [None]:
# generate the training data
# !apt-get install sox libsox-fmt-mp3
# !python3 STT/bin/import_cv2.py --filter_alphabet /content/alphabet.txt /content/cv-corpus-9.0-2022-04-27/fa/

In [None]:
%tensorflow_version 1.x
import tensorflow as tf
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))

In [None]:
!wc -l clips/*.csv

In [None]:
!nvidia-smi -L

In [None]:
dataset_dir = '/content/clips/'
metadata = dataset_dir+'validated.csv'
train_files = dataset_dir+'train.csv'
dev_files = dataset_dir+"dev.csv"
test_files = dataset_dir+"test.csv"
checkpoint_dir = gdrive+'Voice-Cloning/STT/checkpoints'
alphabet_file = gdrive+'Voice-Cloning/'+'alphabet.txt'
export_dir = gdrive+'Voice-Cloning/STT/models'

In [None]:
!ls clips | wc -l

In [None]:
# !wget https://github.com/mozilla/DeepSpeech/releases/download/v0.9.3/deepspeech-0.9.3-checkpoint.tar.gz
# !tar -xf deepspeech-0.9.3-checkpoint.tar.gz
# for transfer learning

In [None]:
!cpg '$gdrive2''Voice-Cloning/kenlm-persian.scorer' '/content/kenlm-persian.scorer'

## ✅ Configure & set hyperparameters

Coqui STT comes with a long list of hyperparameters you can tweak. We've set default values, but you will often want to set your own. You can use `initialize_globals_from_args()` to do this. 

You must **always** configure the paths to your data, and you must **always** configure your alphabet. Additionally, here we show how you can specify the size of hidden layers (`n_hidden`), the number of epochs to train for (`epochs`), and to initialize a new model from scratch (`load_train="init"`).

If you're training on a GPU, you can uncomment the (larger) training batch sizes for faster training.

In [None]:
from coqui_stt_training.util.config import initialize_globals_from_args
# !rm STT/alphabet.txt
initialize_globals_from_args(
    scorer_path='/content/kenlm-persian.scorer',
    # train_files=[train_files],
    # dev_files=[dev_files],
    # test_files=[test_files],
    dropout_rate=0.175,
    # load_checkpoint_dir='/content/deepspeech-0.9.3-checkpoint',
    # drop_source_layers=2, # set when tranfer learning
    learning_rate=0.000095,
    force_initialize_learning_rate=True,
    train_cudnn=True,
    reduce_lr_on_plateau=True,
    plateau_epochs=3,
    plateau_reduction=0.2,
    auto_input_dataset=metadata,
    # alphabet_config_path=alphabet_file,
    checkpoint_dir=checkpoint_dir,
    export_dir=export_dir,
    epochs=200,
    train_batch_size=128,
    dev_batch_size=128,
    test_batch_size=384,
)

## ✅ Train a new model

Let's kick off a training run 🚀🚀🚀 (using the configure you set above).

In [None]:
from coqui_stt_training.train import train

train()

In [None]:
# view loss graph
%load_ext tensorboard
%tensorboard --logdir '$checkpoint_dir/summaries/'

## ✅ Test the model

We made it! 🙌

Let's kick off the testing run, which displays performance metrics.

The settings we used here are for demonstration purposes, so you don't want to deploy this model into production. In this notebook we're focusing on the workflow itself, so it's forgivable 😇

You can still train a more State-of-the-Art model by finding better hyperparameters, so go for it 💪

In [None]:
from coqui_stt_training.evaluate import test

test()

In [None]:
model_name = 'persian_stt'
version = '0.1.0'

In [None]:
!curl -L https://github.com/coqui-ai/STT/releases/download/v0.9.3/convert_graphdef_memmapped_format.linux.amd64.zip | funzip > convert_graphdef_memmapped_format
!chmod +x convert_graphdef_memmapped_format 

In [None]:
!python -m coqui_stt_training.export \
  --checkpoint_dir='$checkpoint_dir' \
  --export_dir='$export_dir' \
  --export_model_name='$model_name' \
  --export_author_id='oct4pie' \
  --export_model_version='$version' \
  --export_contact_info='https://github.com/Oct4Pie/persian-stt/issues' \
  --export_license='LGPL-3.0-only' \
  --export_language='fa-IR' \
  --export_file_name='$model_name'
  
# export protocol buffer
!python -m coqui_stt_training.export \
  --checkpoint_dir='$checkpoint_dir' \
  --export_dir='$export_dir' \
  --export_tflite='false' \
  --export_model_name='$model_name' \
  --export_author_id='oct4pie' \
  --export_model_version='$version' \
  --export_contact_info='https://github.com/Oct4Pie/persian-stt/issues' \
  --export_license='LGPL-3.0-only' \
  --export_language='fa-IR' \
  --export_file_name='$model_name'

In [None]:
!./convert_graphdef_memmapped_format --in_graph="$export_dir/"'$model_name''.pb' --out_graph='$export_dir/''$model_name''.pbmm'

In [None]:
!pip install stt ffmpeg-python

In [None]:
# transcription test
from stt import Model
from io import BytesIO
import ffmpeg
import numpy as np
import wave

audio = open('/content/test.mp3','rb').read()

out, err = (
        ffmpeg.input("pipe:0")
        .output(
            "pipe:1",
            f="WAV",
            acodec="pcm_s16le",
            ac=1,
            ar="16k",
            loglevel="error",
            hide_banner=None,
        )
        .run(input=audio, capture_stdout=True, capture_stderr=True)
    )
if err:
  raise Exception(err)

with wave.Wave_read(BytesIO(out)) as wav:
  audio = np.frombuffer(wav.readframes(wav.getnframes()), np.int16)
ds = Model(export_dir+'/'+model_name+'.tflite')
txt = ds.stt(audio)

with open('transcript.txt', 'w') as f:
  f.write(txt)

In [None]:
!wget https://storage.googleapis.com/danielk-files/farsi-text/merged_files/commoncrawl_fa_merged.txt # dirty text corpus

In [None]:
!wc -l commoncrawl_fa_merged.txt

In [None]:
!pip install hazm

In [None]:
# clean up the text
import re
import unicodedata
import sys
from urllib.parse import unquote
from html.parser import HTMLParser
from hazm import WordTokenizer, word_tokenize

tokenizer = WordTokenizer(join_verb_parts=False)

def persianify(sentence):
  # omit short phrases/sentences
  if len(sentence.split()) < 5:
    return
  
  # b'\xd9\x8e\xd9\x90\xd9\x8f\xd9\x8d\xd9\x8b'.decode()
  alpha = ' !"&\'()-.:=،؛؟ءآأؤئابةتثجحخدذرزسشصضطظعغـفقكلمنهوىئًٌَُِّْ٬پچژکگۀیے–“”…'
  puncs = '،؛؟!:.()"\'…“”'

  try:
    for letter in sentence:
      if letter not in alpha:
        return

    sentence = sentence.strip().replace('...', '')
    l_seg = sorted(sentence.split(), key=lambda w: len(w))[-1]
    if len(l_seg) > 11:
      l_seg = sorted(word_tokenize(l_seg), key=lambda w: len(w))[-1]
      if len(l_seg) > 12:
        return

    sentence = ' '.join(sentence.split('-'))
    segs = tokenizer.tokenize(sentence)

    if segs[-1] == '.':
      del segs[-1]
      segs.insert(0, '.')

    sentence = ""
    for seg in segs:
      # sentence += seg if seg in puncs else ' '+seg
      if seg in puncs:
        sentence += seg

      else:
        sentence += ' '+seg

      

    
    if sentence[-1] in ':،"\'…“”':
      sentence.replace(sentence[-1], '', len(sentence)-4)
    return sentence.replace('. ', '.', len(sentence)-4).strip().replace('( ',' (')

  except Exception as e:
    print('error: ', e, f'for "{sentence.split()}"')
    return

# from https://github.com/common-voice/CorporaCreator

RE_DIGITS = re.compile('\d')

def _has_digit(sentence):
    return RE_DIGITS.search(sentence)


class _HTMLStripper(HTMLParser):
    """Class that strips HTML from strings.
    Examples:
        >>> stripper = _HTMLStripper()
        >>> stripper.feed(html)
        >>> nohtml = stripper.get_data()
    """

    def __init__(self):
        super().__init__()
        self.reset()
        self.strict = False
        self.convert_charrefs = True
        self.fed = []

    def handle_data(self, d):
        self.fed.append(d)

    def get_data(self):
        return "".join(self.fed)


def _strip_tags(html):
    """Removes HTML tags from passed text.
    Args:
      html (str): String containing HTML
    Returns:
      (str): String with HTML removed
    """
    s = _HTMLStripper()
    s.feed(html)
    return s.get_data()


def _strip_string(sentence):
    """Cleans a string based on a whitelist of printable unicode categories.
    You can find a full list of categories here:
    http://www.fileformat.info/info/unicode/category/index.htm
    """
    letters     = ('LC', 'Ll', 'Lm', 'Lo', 'Lt', 'Lu')
    numbers     = ('Nd', 'Nl', 'No')
    marks       = ('Mc', 'Me', 'Mn')
    punctuation = ('Pc', 'Pd', 'Pe', 'Pf', 'Pi', 'Po', 'Ps')
    symbol      = ('Sc', 'Sk', 'Sm', 'So')
    space       = ('Zs',)

    allowed_categories = letters + numbers + marks + punctuation + symbol + space

    return u''.join([c for c in sentence if unicodedata.category(c) in allowed_categories])


def common(sentence):
    """Cleans up the passed sentence in a language independent manner, removing or reformatting invalid data.
    Args:
      sentence (str): Sentence to be cleaned up.
    Returns:
      (is_valid,str): A boolean indicating validity and cleaned up sentence.
    """

    # Define a boolean indicating validity
    is_valid = True
    # Decode any URL encoded elements of sentence
    sentence = unquote(sentence)
    # Remove any HTML tags
    sentence = _strip_tags(sentence)
    # Remove non-printable characters
    sentence = _strip_string(sentence)
    # collapse all whitespace and replace with single space
    sentence = (' ').join(sentence.split())
    # TODO: Clean up data in a language independent manner
    # If the sentence contains digits reject it
    if _has_digit(sentence):
        is_valid = False
    # If the sentence is blank reject it
    if not sentence.strip():
        is_valid = False
    return (is_valid, sentence)

In [None]:
import os
import time
os.chdir('/content')
import sys
clean = open('clean.txt', 'w')
with open('commoncrawl_fa_merged.txt', 'r') as f:
  try:
    line = next(f)
    while line:
      for sentence in line.split('\t'):
        s = common(sentence.replace('\n', '').replace('/','').replace('|',''))[1]
        s = persianify(s)
        if s:
          clean.write(s+'\n')

      line = next(f)
  except Exception as e:
      clean.close()
      print('error: ', e, f'for {line}')


clean.close()

In [None]:
!head -n 6 clean.txt

In [None]:
if not os.getcwd().endswith('STT'):
  os.chdir('STT')
# !pip uninstall coqui_stt_training
!echo  "$checkpoint_dir"
!python3 lm_optimizer.py \
  --checkpoint_dir "$checkpoint_dir" \
  --test_files "$test_files"

In [None]:
!cpg '/content/clean.txt' '$gdrive''Voice-Cloning/'
!cpg '$gdrive2''Voice-Cloning/clean-nodup-nopunc.txt' .
!cpg '$gdrive''Voice-Cloning/clean-nodup.txt' .

In [None]:
!sort --unique -o clean-nodup.txt clean.txt # remove duplicates (python failed because of high memory usage)

In [None]:
# remove special chars
def rm_spec(sentence):
  # puncs = '،؛؟!:.()"\'…“”–&-='
  puncs = '()"\'…“”–&-='
  for punc in puncs:
    sentence = sentence.replace(punc, '')
  return sentence

ofile = open('clean-nodup-nopspec.txt', 'w')
ifile = open('clean-nodup.txt')

for line in ifile:
  phrase = rm_spec(line).strip()
  if len(phrase) > 4:
    ofile.write(phrase+'\n')

ofile.close()

In [None]:
# !sort --unique -r -o clean-nodup-nopspec1.txt clean-nodup-nopspec.txt
!ls

In [None]:
!cpg clean-nodup-nospec.txt '$gdrive2''Voice-Cloning/clean-nodup-nospec.txt'

In [None]:
# rm all punctuations for scorer
def rm_punc(sentence):
  puncs = '،؛؟!:.()"\'…“”–&-=ـِء'
  # puncs = '()"\'…“”–&-='
  for punc in puncs:
    sentence = sentence.replace(punc, '')
  return sentence

ofile = open('clean-nodup-nopunc.txt', 'w')
ifile = open('clean-nodup.txt')

for line in ifile:
  phrase = rm_punc(line).strip()
  if len(phrase.split()) > 4:
    ofile.write(phrase+'\n')

ofile.close()

In [None]:
!cpg clean-nodup-nopunc.txt '$gdrive2''Voice-Cloning/clean-nodup-nopunc.txt'

In [None]:
## building kenlm
# %cd /content/STT/native_client/kenlm
# !rm -rf * && \
# 	git clone https://github.com/kpu/kenlm 
# %cd kenlm
# !git checkout 87e85e66c99ceff1fab2500a7c60c01da7315eec
# !mkdir -p build
# %cd build
# !cmake ..
# !make -j $(nproc)

In [None]:
!mv /content/STT/data/lm/persian-scorer '$gdrive''Voice-Cloning/persian-scorer'
!cpg -a '$gdrive2''Voice-Cloning/persian-scorer/.' /content/STT/data/lm/persian-scorer

In [None]:
%cd /content/STT/data/lm
!mkdir 'persian-scorer'
!python3 generate_lm.py \
  --input_txt /content/clean-nodup-nopunc.txt \
  --output_dir persian-scorer \
  --kenlm_bins /content/STT/native_client/kenlm/kenlm/build/bin \
  --arpa_order 5 --max_arpa_memory "85%" --arpa_prune "0|0|1" \
  --binary_a_bits 255 --binary_q_bits 8 --binary_type trie \
  --top_k 500000

In [None]:
%cd /content/STT/data/lm
!curl -L https://github.com/coqui-ai/STT/releases/download/v1.3.0/native_client.tflite.Linux.tar.xz | tar -Jxvf -

In [None]:
%cd /content/STT/data/lm
!./generate_scorer_package \
  --checkpoint "$checkpoint_dir" \
  --lm "/content/STT/data/lm/persian-scorer/lm.binary" \
  --vocab "/content/STT/data/lm/persian-scorer/vocab-500000.txt" \
  --package kenlm-persian.scorer \
  --default_alpha 0.36669178512950323 \
  --default_beta 0.3457913671678824

In [None]:
%cd /content/STT
!python3 lm_optimizer.py \
     --test_files "$test_files" \
     --checkpoint_dir "$checkpoint_dir" \
     --scorer_path "data/lm/kenlm-persian.scorer" \
     --n_hidden 2048 \
     --lm_alpha_max 1 \
     --lm_beta_max 1 \
     --n_trials 100 \
     --test_batch_size 384 \
     --train_cudnn 'true'

In [None]:
!cpg kenlm-persian.scorer '$gdrive2''Voice-Cloning/kenlm-persian.scorer'