# ASR Icefall Examples

This notebook demonstrates integrating Icefall with ART's speech recognition module.

Icefall contains speech recognition recipes for K2-FSA. Repository link: https://github.com/k2-fsa/icefall

---


## 1. Preliminaries and Dependencies

Icefall is not a python module, but a collection of scripts and recipes which could be called or imported to facilitate the use of the underlying toolkit, k2, for speech and speech-related tasks. Details on its installation, as well as that of its dependencies, can be found [here](https://icefall.readthedocs.io/en/latest/installation/index.html). 

We will assume that icefall (the top-level repository) can be found at:

`PATH\TO\ICEFALL\icefall`

## 2. SLU Inference with Icefall Transducer Models

### 2.1 Download Data

Dataset Link: https://fluent.ai/fluent-speech-commands-a-dataset-for-spoken-language-understanding-research/

The Fluent Speech Commands dataset consists of short (1-5 seconds) spoken commands that would typically be directed at a smart device. The task at hand is spoken language recognition: transcribing these commands into 3 frames, "action", "target" and "location". For example, the command "turn the lights on in the kitchen" would be transcribed as "action: activate", "target: lights", "location: kitchen".

In [3]:
%cd $ART_DATA_PATH
!wget -O fluent_speech_commands_dataset.tar.gz https://www.dropbox.com/scl/fi/3abks6crfha9flkxjn2yu/fluent_speech_commands_dataset_small.tar.gz?rlkey=hg1br2lqska8pi1dqarlfmaxu&dl=0
!tar -xvzf fluent_speech_commands_dataset.tar.gz

/content
--2024-01-28 01:35:34--  https://www.dropbox.com/scl/fi/3abks6crfha9flkxjn2yu/fluent_speech_commands_dataset_small.tar.gz?rlkey=hg1br2lqska8pi1dqarlfmaxu
Resolving www.dropbox.com (www.dropbox.com)... 162.125.6.18, 2620:100:6019:18::a27d:412
Connecting to www.dropbox.com (www.dropbox.com)|162.125.6.18|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://uc4a255b38a93b37526b97f3c0dc.dl.dropboxusercontent.com/cd/0/inline/CMJkOWAQvs8mwxvpctxr6g3zFURxsTGlpWT7G78YhWzpFXypi_RDJEMNrLV3e1siHeLRQlc0uddFfkzce-HX08HsE1cKCI8d-L9qmiEhDEeeAkdBgu9V1vhBzz1-hoQDMQtcqm5-2VE-SQPUN_JtTYDW/file# [following]
--2024-01-28 01:35:34--  https://uc4a255b38a93b37526b97f3c0dc.dl.dropboxusercontent.com/cd/0/inline/CMJkOWAQvs8mwxvpctxr6g3zFURxsTGlpWT7G78YhWzpFXypi_RDJEMNrLV3e1siHeLRQlc0uddFfkzce-HX08HsE1cKCI8d-L9qmiEhDEeeAkdBgu9V1vhBzz1-hoQDMQtcqm5-2VE-SQPUN_JtTYDW/file
Resolving uc4a255b38a93b37526b97f3c0dc.dl.dropboxusercontent.com (uc4a255b38a93b37526b97f3c0dc.dl.dropbox

### 2.2 Create Model and Data Utilities

We first download a trained model, as well as some auxiliary files necessary for the model.

In [4]:
%cd ART_DATA_PATH
!wget -O frames.tar.gz https://www.dropbox.com/scl/fi/4tvkvvv4w2zoeei238sfj/frames.tar.gz?rlkey=5ubi7j9xokz57xqwtb0o6y2oe&dl=0
!tar -xvzf frames.tar.gz
!wget -O epoch-6.pt https://www.dropbox.com/scl/fi/97wvdjmbuyj13kpzhricc/epoch-6.pt?rlkey=7mehc4v41fovfz0ksbt98krry&dl=0
!wget -O trigger.wav https://www.dropbox.com/scl/fi/mldc9sv2vc5xdc141cs8o/clapping.wav?rlkey=ohtu4m9f73wfudznh15h3vn1l&dl=0

/content
--2024-01-28 01:40:16--  https://www.dropbox.com/scl/fi/4tvkvvv4w2zoeei238sfj/frames.tar.gz?rlkey=5ubi7j9xokz57xqwtb0o6y2oe
Resolving www.dropbox.com (www.dropbox.com)... 162.125.6.18, 2620:100:601d:18::a27d:512
Connecting to www.dropbox.com (www.dropbox.com)|162.125.6.18|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://uca87a66121fbbe42393eee03525.dl.dropboxusercontent.com/cd/0/inline/CMI7ByLwwKTeXDk-9_WegOvegTIISpK3DHrxdlv0QOgV1alGyO2IiKNdTwnubDjDe6YAVa_bFd3Uy8Mypc_EHylvSY5A6qbpv790U-rkLL7Qb-oFQMmeWpwBBBR9WxJDGRVa0UbjQUbbc_oR7RQX8fI3/file# [following]
--2024-01-28 01:40:16--  https://uca87a66121fbbe42393eee03525.dl.dropboxusercontent.com/cd/0/inline/CMI7ByLwwKTeXDk-9_WegOvegTIISpK3DHrxdlv0QOgV1alGyO2IiKNdTwnubDjDe6YAVa_bFd3Uy8Mypc_EHylvSY5A6qbpv790U-rkLL7Qb-oFQMmeWpwBBBR9WxJDGRVa0UbjQUbbc_oR7RQX8fI3/file
Resolving uca87a66121fbbe42393eee03525.dl.dropboxusercontent.com (uca87a66121fbbe42393eee03525.dl.dropboxusercontent.com)... 162.125.4.

In [19]:
# Import dependencies
import numpy as np
import sys
sys.path.insert(1, 'PATH/TO/ICEFALL/icefall' + '/icefall/egs/slu')
sys.path.insert(1, 'PATH/TO/ICEFALL/icefall')

from art.config import ART_NUMPY_DTYPE

from utils import AttributeDict
from pathlib import Path

from transducer.conformer import Conformer
from transducer.decoder import Decoder
from transducer.joiner import Joiner
from transducer.model import Transducer
from typing import Union, List
import k2
from art.estimators.speech_recognition.pytorch_icefall import PyTorchIcefall
from transducer.decode import get_id2word
import torch
from icefall.checkpoint import load_checkpoint

Next, we load the pre-trained model and wrap it with the PyTorchIcefall class in ART:

In [21]:
# Define model

def get_transducer_model(params: AttributeDict):
    encoder = Conformer(
        num_features=params.feature_dim,
        output_dim=params.hidden_dim,
    )
    decoder = Decoder(
        vocab_size=params.vocab_size,
        embedding_dim=params.embedding_dim,
        blank_id=params.blank_id,
        num_layers=params.num_decoder_layers,
        hidden_dim=params.hidden_dim,
        embedding_dropout=0.4,
        rnn_dropout=0.4,
    )
    joiner = Joiner(input_dim=params.hidden_dim, output_dim=params.vocab_size)
    transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner)

    return transducer


def get_params() -> AttributeDict:
    """Return a dict containing training parameters.

    All training related parameters that are not passed from the commandline
    is saved in the variable `params`.

    Commandline options are merged into `params` after they are parsed, so
    you can also access them via `params`.

    Explanation of options saved in `params`:

        - lr: It specifies the initial learning rate

        - feature_dim: The model input dim. It has to match the one used
                       in computing features.

        - weight_decay:  The weight_decay for the optimizer.

        - subsampling_factor:  The subsampling factor for the model.

        - start_epoch:  If it is not zero, load checkpoint `start_epoch-1`
                        and continue training from that checkpoint.

        - best_train_loss: Best training loss so far. It is used to select
                           the model that has the lowest training loss. It is
                           updated during the training.

        - best_valid_loss: Best validation loss so far. It is used to select
                           the model that has the lowest validation loss. It is
                           updated during the training.

        - best_train_epoch: It is the epoch that has the best training loss.

        - best_valid_epoch: It is the epoch that has the best validation loss.

        - batch_idx_train: Used to writing statistics to tensorboard. It
                           contains number of batches trained so far across
                           epochs.

        - log_interval:  Print training loss if batch_idx % log_interval` is 0

        - valid_interval:  Run validation if batch_idx % valid_interval` is 0

        - reset_interval: Reset statistics if batch_idx % reset_interval is 0


    """
    params = AttributeDict(
        {
            "lr": 1e-3,
            "feature_dim": 23,
            "weight_decay": 1e-6,
            "start_epoch": 0,
            "best_train_loss": float("inf"),
            "best_valid_loss": float("inf"),
            "best_train_epoch": -1,
            "best_valid_epoch": -1,
            "batch_idx_train": 0,
            "log_interval": 100,
            "reset_interval": 20,
            "valid_interval": 300,
            "exp_dir": Path(config.ART_DATA_PATH),
            "lang_dir": Path(config.ART_DATA_PATH),
            # encoder/decoder params
            "vocab_size": 27,
            "blank_id": 0,
            "embedding_dim": 32,
            "hidden_dim": 16,
            "num_decoder_layers": 4,
            "epoch": 6,
            "avg": 1
        }
    )
    return params


def get_word2id(params):
    word2id = {}

    # 0 is blank
    id = 1
    with open(Path(params.lang_dir) / 'lexicon_disambig.txt') as lexicon_file:
        for line in lexicon_file:
            if len(line.strip()) > 0:
                word2id[line.split()[0]] = id
                id += 1

    return word2id


def get_labels(texts: List[str], word2id) -> k2.RaggedTensor:
    """
    Args:
      texts:
        A list of transcripts.
    Returns:
      Return a ragged tensor containing the corresponding word ID.
    """
    # blank is 0
    word_ids = []
    for t in texts:
        words = t.split()
        ids = [word2id[w] for w in words]
        word_ids.append(ids)

    return k2.RaggedTensor(word_ids)


params = get_params()
transducer_model = get_transducer_model(params).to('cpu')
if params.avg == 1:
    load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", transducer_model)
else:
    start = params.epoch - params.avg + 1
    filenames = []
    for i in range(start, params.epoch + 1):
        if start >= 0:
            filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
    logging.info(f"averaging {filenames}")
    transducer_model.load_state_dict(average_checkpoints(filenames))
transducer_model.device = 'cpu'
word2ids = get_word2id(params)
get_id2word = get_id2word(params)
model_ensemble = {
    'model': transducer_model,
    'word2ids': word2ids,
    'get_id2word': get_id2word,
    'params': params
}

speech_recognizer = PyTorchIcefall(model_ensemble=model_ensemble)

Finally, with the data ready and the model defined and loaded, we are ready to perform inference:

In [22]:
import torchaudio

wav = torchaudio.load(config.ART_DATA_PATH + '/fluent_speech_commands_dataset_small/wavs/speakers/2BqVo8kVB2Skwgyb/ffc2c6b0-4478-11e9-a9a5-5dbec3b8816a.wav')[0]

wav = np.expand_dims(np.array(wav), 0)
speech_recognizer.predict(wav)

array(['<s>', 'deactivate', 'lights', 'kitchen', '</s>'], dtype='<U10')