# Evaluation of longer sequences

* This notebook looks at the effect of using longer sequences on WER.  
* It includes utils to generate longer files and a `torch.utils.data.Dataset` but this is already generated for `dev-clean` at `brahe:/data/Long/dev-clean`.
* It should be possible to set options at the top of this notebook and then `run_all`. 
* **NOTE:** it is necessary to run this notebook in the conda `myrtlespeech` env. Use commit: `b68aefc25be90ee32151d53366d340aee8fd4026`.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1" # set this before importing torch

In [None]:
from typing import List
from typing import Sequence
import fnmatch
from pathlib import Path
import torch

import torchaudio 
from google.protobuf import text_format
from myrtlespeech.protos import task_config_pb2
from myrtlespeech.builders.task_config import build
from myrtlespeech.run.load import load_seq_to_seq
from myrtlespeech.data.dataset.librispeech import LibriSpeech
from myrtlespeech.data.alphabet import Alphabet
from myrtlespeech.data.dataset.librispeech import LibriSpeech
from myrtlespeech.data.batch import seq_to_seq_collate_fn
from myrtlespeech.post_process.transducer_beam_decoder import TransducerBeamDecoder
from myrtlespeech.post_process.transducer_greedy_decoder import TransducerGreedyDecoder
from myrtlespeech.run.run import WordSegmentor
from myrtlespeech.post_process.utils import levenshtein

## Config options

In [None]:
# FPs
DATA = Path('/data')
SPEAKER = DATA / 'LibriSpeech/SPEAKERS.TXT'
LONG_FP = DATA / 'Long'
create_files = False # set this to true to re-generate data. (~ 3mins)

# Model info
log_dir = "/home/julian/exp/rnnt/2H/1/" #to load model from
epoch = 69 #i.e. state_dict_<epoch>.pt is filename
cfg_location = "../src/myrtlespeech/configs/rnn_t_en_2H.config"

## Generate longer files + transcripts

In [None]:
ALPHABET = Alphabet(list(" abcdefghijklmnopqrstuvwxyz'_"))

def target_transform(target):
        return torch.tensor(
            ALPHABET.get_indices(target),
            dtype=torch.int32,
            requires_grad=False,)
            
def get_speaker_ids(fp, subset):
    ids = []
    with open(fp, 'r') as f:
        for line in f:
            if subset in line:
                line = line.split()
                line = [l.strip() for l in line]
                id_ = int(line[0])
                ids.append(id_)
    return ids

def create_long_files(out_fp, paths, transcriptions, durations, subset='dev-clean'):
    """Combines all audio/transcript files for a given speaker into a single file.
    
    Requires ordered paths, transcriptions and durations as input. 
    """
    out_fp = out_fp / subset
    speaker_ids = get_speaker_ids(SPEAKER, subset)
    ids_to_path = {id: [] for id in speaker_ids}
    for idx, path in enumerate(paths):
        path = Path(path)
        fname = path.name
        speaker_id = int(path.parents[1].name)
        assert speaker_id in speaker_ids
        ids_to_path[speaker_id].append((idx, path))
    
    for id_, paths in ids_to_path.items():
        assert paths != [], f"speaker with id={id_} not present"
        # now generate longer files
        all_audio = []
        all_transcripts = ""
        for idx, path in paths:
            audio, sr = torchaudio.load(path)
            assert sr == 16000
            all_audio.append(audio)
            all_transcripts += " " + transcriptions[idx]
        
        all_audio = torch.cat(all_audio, dim=1)
        
        out_dir = out_fp / str(id_)
        out_dir.mkdir(parents=True, exist_ok=True)
        # save audio
        torchaudio.save(str(out_dir / 'audio.flac'), all_audio, sr)
        #save transcripts
        with open(out_dir / 'transcript.trans.txt', 'w') as f:
            f.write(all_transcripts)
            
if create_files:
    libri = LibriSpeech(root=DATA,
                    subsets=['dev-clean'],
                    audio_transform= None,
                    label_transform=target_transform,
                    download=False,
                    skip_integrity_check=True,
                    max_duration=None,
                   )
    
    create_long_files(LONG_FP, libri.paths, libri.transcriptions, libri.durations)

## Create Dataset with new files

In [None]:

class LongLibrispeech(LibriSpeech):
    """i.e. inherit methods from LibriSpeech class and ovverride where necessary."""
    base_dir = ""
    use_sox = False
    def __init__(self, root, subsets, audio_transform, label_transform):
        self.root = root
        self.subsets = subsets
        self._transform = audio_transform
        self._target_transform = label_transform
        
        self.load_data()
    
    def load_data(self) -> None:
        """Loads the data from disk."""
        self.paths: List[str] = []
        self.durations: List[float] = []
        self.transcriptions: List[str] = []

        def raise_(err):
            """raises error if problem during os.walk"""
            raise err

        for subset in self.subsets:
            subset_path = os.path.join(self.root, self.base_dir, subset)
            for root, dirs, files in os.walk(subset_path, onerror=raise_):
                if not files:
                    continue
                matches = fnmatch.filter(files, "*.trans.txt")
                assert len(matches) == 1, "> 1 transcription file found"
                self._parse_transcription_file(root, matches[0])
                
                # now get audio
                matches = fnmatch.filter(files, "*audio.flac")
                assert len(matches) == 1, f"> 1 audio file found"
                self._process_audio(root, matches[0]) 
        self._sort_by_duration()
        
    def _parse_transcription_file(self, root: str, name: str) -> None:
        """Parses each sample in a transcription file."""
        trans_path = os.path.join(root, name)
        with open(trans_path, "r", encoding="utf-8") as trans:
            transcript = trans.read()
        transcript = transcript.strip().lower()
        self.transcriptions.append(transcript)
        
    def _process_audio(self, root: str, name: str):
        """Returns True if sample was dropped due to being too long."""
        path = os.path.join(root, name)
        si, _ = torchaudio.info(path)
        duration = (si.length / si.channels) / si.rate
        self.paths.append(path)
        self.durations.append(duration)
    


In [None]:
# get config and load seq_to_seq

with open(cfg_location) as f:
    task_config = text_format.Merge(f.read(), task_config_pb2.TaskConfig())

task_config.train_config.dataset.librispeech.subset[:] = [0] #dev-clean for speed
task_config.eval_config.dataset.librispeech.subset[:] = [0]     #dev-clean for speed
seq_to_seq, epochs, _, _ = build(task_config)

training_state = {}
fp = log_dir + f'state_dict_{epoch}.pt'
training_state = load_seq_to_seq(seq_to_seq, fp)

In [None]:
## add lengths to audio transform 
from myrtlespeech.builders.dataset import _add_seq_len
audio_transform = _add_seq_len(seq_to_seq.pre_process, len_fn=lambda x: x.size(-1))
label_transform = _add_seq_len(target_transform, len_fn=len)

In [None]:
# create audio loader
long_libri = LongLibrispeech(root=LONG_FP,
                    subsets=['dev-clean'],
                    audio_transform= audio_transform,
                    label_transform=label_transform,)

eval_loader = torch.utils.data.DataLoader(
        dataset=long_libri,
        batch_size=1,
        num_workers=1,
        collate_fn=seq_to_seq_collate_fn,
        pin_memory=torch.cuda.is_available(),
    )

## Run decoding


In [None]:
ALPHABET = seq_to_seq.alphabet
WORD_SEGMENTOR = WordSegmentor(" ")

## Can't use fit because the memory usage is too high!
def calc_dist(preds, targets, target_lens):
    assert len(preds) == 1
    for pred, target, target_len in zip(preds, targets, target_lens):
        pred_chars = ALPHABET.get_symbols(pred)
        exp_chars = ALPHABET.get_symbols(
            [int(e) for e in target[:target_len]]
        )
        pred_words = WORD_SEGMENTOR(pred_chars)
        exp_words = WORD_SEGMENTOR(exp_chars)
        distance = levenshtein(pred_words, exp_words)
        length = len(exp_words)
        wer = distance / length
        
        return distance, length, pred_words, exp_words

def eval_sample(batch, decoder):
    x, y = batch
    predictions = []
    wers = []
    
    preds = decoder(x)
    distance, length, pred_words, exp_words = calc_dist(preds, y[0], y[1])
    pred = " ".join(pred_words)
    exp = " ".join(exp_words)
    return distance, length, pred, exp

    
def set_decoder(stt, dec_type, beam_width=None, max_symbols_per_step=100):
    if dec_type == 'beam':
        assert beam_width is not None
        decoder = TransducerBeamDecoder(blank_index=28,
                                        beam_width=beam_width,
                                        length_norm=False,
                                        max_symbols_per_step = max_symbols_per_step,
                                        model=stt.model)
    elif dec_type == 'greedy':
        decoder = TransducerGreedyDecoder(blank_index=28,
                                          max_symbols_per_step = max_symbols_per_step,
                                          model=stt.model) 
    else:
        raise ValueError()
    stt.post_process = decoder


In [None]:
preds = []
exps = []
distances = []
lengths = []
for decoder, beam_width in [('beam', 4)]: # use ('greedy', None) for greedy decoding
    print(decoder, beam_width)
    seq_to_seq.eval()
    set_decoder(seq_to_seq, dec_type=decoder, beam_width=beam_width)
    
    for loader in [eval_loader,]: 
        for idx, batch in enumerate(loader):
            
            distance, length, pred, exp = eval_sample(batch, seq_to_seq.post_process)
            distances.append(distance)
            lengths.append(length)
            preds.append(pred)
            exps.append(exp)
            try:
                print(f'{idx}, wer = {distance / length}')
                print(f'running WER = {sum(distances) / sum(lengths)}')
            except:
                pass
       

In [None]:
idx = 16
pred = preds[idx]
exp = exps[idx] 
dist = distances[idx]
len_ = lengths[idx]
print("PRED: ", pred, end='\n\n')
print("EXP : ", exp, end='\n\n')
print("leven : ", dist,)
print("length: ", len_,)
print(f"WER for idx={idx} = {dist/len_}")

In [None]:
assert len(exp.split()) == len_