# How to use pyctcdecode when working with a SpeechBrain model

SpeechBrain will release their own n-gram scorer, this fall, along with their already existing TransformerLM scorer. This notebook gives an example of how to use SpeechBrain models with pyctcdecode and its features. There are many models implementations in SoeechBrain. This notebook gives a walkthrough for the follwoing models:
* EnoderASR (shared by @Moumeneb1 at [link](https://github.com/speechbrain/speechbrain/issues/1467#issuecomment-1184340981))
* EncoderDecoderASR 

In [None]:
# Install SpeechBrain
! pip install speechbrain

In [None]:
# get a single audio file
!wget https://dldata-public.s3.us-east-2.amazonaws.com/1919-142785-0028.wav

In [None]:
import fastwer
import torch, os

from pyctcdecode import build_ctcdecoder
from speechbrain.pretrained import EncoderDecoderASR

## EncoderDecoderASR

In [None]:
# Either use an existing checkpoint directory or download from 
# https://drive.google.com/drive/folders/15pjw9NnGIyAP1P_z1Vn8xtNYLQH32m3V (SpeechBrain author uploads)

def load_model():
    model = EncoderDecoderASR.from_hparams(
        source="/root/workspace/results/CRDNN_BPE_960h_LM/2602/save", savedir="pretrained_model"
    )
    return model

In [None]:
def reset_mem(model, batch_size, device):
    """Needed to reset the memory during beamsearch."""
    hs = None
    model.mods.decoder.dec.attn.reset()
    c = torch.zeros(batch_size,  model.mods.decoder.dec.attn_dim, device=device)
    return hs, c

In [None]:
@torch.no_grad()
def main(sound_file, arpa_path):
    model = load_model()

    device = model.device

    wav = model.load_audio(sound_file)
    # wav is a 1-d tensor, e.g., [52173]
    wavs = wav.unsqueeze(0).to(device)
    # wavs is a 2-d tensor, e.g., [1, 52173]

    wav_lens = torch.tensor([1.0])
    wav_lens = wav_lens.to(device)

    encoder_out = model.mods.encoder(wavs, wav_lens)
    enc_lens = torch.round(encoder_out.shape[1] * wav_lens).int()
    device = encoder_out.device
    batch_size = encoder_out.shape[0]

    memory = reset_mem(model, batch_size, device=device)

    # Using bos as the first input
    inp_tokens = (
        encoder_out.new_zeros(batch_size).fill_(model.mods.decoder.bos_index).long()
    )

    log_probs_lst = []
    max_decode_steps = int(encoder_out.shape[1] * model.mods.decoder.max_decode_ratio)

    for t in range(max_decode_steps):
        log_probs, memory, _ = model.mods.decoder.forward_step(
            inp_tokens, memory, encoder_out, enc_lens)   
        log_probs_lst.append(log_probs)
        inp_tokens = log_probs.argmax(dim=-1)

    log_probs = torch.stack(log_probs_lst, dim=1)
    squeezed_log_probs = torch.squeeze(log_probs, dim =0)

    labels  = [model.tokenizer.id_to_piece(id).upper() for id in range(model.tokenizer.get_piece_size())]

    labels[0] ='<pad>'
    labels[1]=''
    # Add more customizations here based on bos and eos index from the hyperparameter.yaml file

    decoder = build_ctcdecoder(labels, arpa_path)
    text_w_lm = decoder.decode(squeezed_log_probs.numpy(), beam_width=50)
   
    return text_w_lm


In [None]:
if __name__ == '__main__':
    file_path = "1919-142785-0028.wav"
    arpa_path = "/root/workspace/lms/3-gram.pruned.3e-7.arpa"
    
    ground_truth, transcript =[], []
    ground_truth.append("")
    transcript.append(main(file_path, arpa_path))

    print(fastwer.score(ground_truth, transcript))
