# Audacity WaveformToLabels Example

In this notebook we will load in a [speech to text model](https://huggingface.co/facebook/s2t-medium-librispeech-asr) from Facebook using Huggingface's Transformers module/package. We will look at the necessary dependencies to serialize  a model, how to create a wrapper class for a pretrained WaveformToLabels model, and show how to save this wrapped model so that it can easily be used in Audacity. 

## Dependencies

In [10]:
!pip install torchaudio==0.9.0
!pip install transformers
!pip install audacitorch

Collecting audacitorch
  Downloading audacitorch-0.0.1-py3-none-any.whl (10 kB)
Installing collected packages: audacitorch
Successfully installed audacitorch-0.0.1


In [11]:
%%capture
import torch
from torch import nn
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import torchaudio
import json

# use no grad!
torch.set_grad_enabled(False)

These packages will be needed if you want to upload your model to Huggingface using a CLI. 

In [12]:
# %%capture
# # required for huggingface
# !sudo apt-get install git-lfs
# !git lfs install


## Wrapping the model
We need to create a `.pt` containing the model itself, and a json string with the model's metadata. This meta data will tell end users about the model's domain, sample rate, labels, etc...

`torchaudacity` provides a [`WaveformToLabels` class](https://github.com/hugofloresgarcia/torchaudacity/blob/main/torchaudacity/core.py#L52). We will use this as a base class for our pretrained models wrapper. The `WaveformToLabels` class provides us with tests to ensure that our model is receiving properly sized input, and outputting the expected tensor shapes for Audacity's Deep Learning Analyzer, for a [graphical explination visit the main README here](https://github.com/hugofloresgarcia/torchaudacity#contributing-models-to-audacity). 



In [13]:
import sys
sys.path.append("..")

In [98]:
from audacitorch.core import WaveformToLabelsBase

class SubModels(nn.Module):
    def __init__(self):
        super().__init__()
        self._model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h", torchscript=True)
        self._processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h", torchscript=True)
        self.token_to_idx = {val:key for key, val in self._processor.tokenizer.decoder.items()}
    

@torch.jit.script_if_tracing
def get_timestamps(num_preds: int, total_time: int):
  """ if the model produces no output, prevent a division by zero error"""
  if num_preds == 0:
    return torch.empty(1, 0)
  else:
    equal_size_timestamp = total_time / num_preds
    timestamps = torch.zeros(num_preds, 2)
    return timestamps

@torch.jit.script_if_tracing
def check_empty_output(preds, timestamps):
  """ we need to create fake output if our model produces empty output"""
  if preds.shape[0] == 0:
    return torch.tensor([0]), torch.tensor([[0., 0.01]])
  else:
    return preds, timestamps


class ModelWrapper(WaveformToLabelsBase):
    def do_forward_pass(self, _input):
        input_values = self.model._processor(_input, return_tensors="pt", padding="longest").input_values[0]
        logits = self.model._model(input_values)[0]
        predicted_ids = torch.argmax(logits, dim=-1)
        transcription = self.model._processor.decode(predicted_ids[0])   
        num_preds = len(transcription)

        # model predictions must be logits or one-hot encoded 
        preds_onehot = torch.FloatTensor(num_preds, len(self.model.token_to_idx))
        preds_onehot.zero_()
        for i, token in enumerate(transcription):
            if token == ' ':
                token = '<s>'
            token_idx = self.model.token_to_idx[token]
            preds_onehot[i][token_idx] = 0.99
        
        # this model does not use timestamps, therefore we will use 
        # equally sized time ranges for each prediction
        total_time = _input.shape[1] / 16000
        
        timestamps = get_timestamps(num_preds, total_time)
        for i in range(num_preds):
            if i == 0:
                timestamps[0][1] = equal_size_timestamp
            else:
                timestamps[i][0] = timestamps[i-1][1]
                timestamps[i][1] = timestamps[i][0] + equal_size_timestamp

        # return the predictions and timestamps as a tensor
        preds = torch.argmax(preds_onehot, dim=-1, keepdim=False)       
        preds, timestamps = check_empty_output(preds, timestamps)
        print(preds, timestamps)
        return (preds, timestamps)

In [99]:
sub_models = SubModels()
torchscript_model = ModelWrapper(sub_models)

## Model Metadata

We need to create a `metadata.json` file for our model. This file will be added to the Huggingface repo and will provide Audacity with important information about our model. This allows for users to quickly get important information about this model directly from Audacity. See the [contributing documentation](https://github.com/hugofloresgarcia/torchaudacity) for the full metadata schema.

In [100]:
vocab = [str(letter) for letter in sub_models._processor.tokenizer.decoder.values()]

In [101]:
# create a dictionary with model metadata
metadata = {
    'sample_rate': 16000, 
    'domain_tags': ['speech'],
    'short_description': 'I will label your speech into text :]',
    'long_description': 
              'This is an Audacity wrapper for the model, '
              'forked from the repository '
              'facebook/wav2vec2-base-960h'
              'This model was trained by Alexei Baevski'
              ', Henry Zhou, Abdelrahman Mohamed and,'
              'Michael Auli.',
    'tags': ['speech-to-text'],
    'effect_type': 'waveform-to-labels',
    'multichannel': False,
    'labels': vocab,
}

## Saving Our Model & Metadata

We will now save the wrapped model locally by tracing it with torchscript, and generating a `ScriptModule` or `ScriptFunction` using `torch.jit.script`. We can then use `torchaudacity's` utility function `save_model` to save the model and meta data easily. 

In [102]:
import os
from pathlib import Path
from audacitorch.utils import save_model, get_example_inputs
import torchaudio 

In [103]:
# compiling and saving model
example_inputs = get_example_inputs()
traced_model = torch.jit.trace(torchscript_model, example_inputs[0])

tensor([0]) tensor([[0.0000, 0.0100]])
None


In [105]:
WaveformToLabelsBase?

In [104]:
save_model(traced_model, metadata, Path('audacity-Wav2Vec2-Base'))

NameError: ignored

## Upload your model
Now you're ready to upload your model, in the case of this note book the model is stored in a folder titled 'audacity-s2t-medium'. For more information see [the main README](https://github.com/hugofloresgarcia/torchaudacity#exporting-to-huggingface) 

--- 


## Note on Huggingface `transformers` module

Currently the Huggingface `transformers` module has limited support when exporting a model to torchscript. Through trial and error we have found that the [`Wav2Vec2`](https://huggingface.co/transformers/model_doc/wav2vec2.html) models seem to export with little issue. The [`Speech2Text`](https://huggingface.co/transformers/model_doc/speech_to_text.html) models appear to have issues when exported to torchscript. 

For more information about the Huggingface `transformers` torchscript compatiliablity follow the [this hyperlink](https://huggingface.co/transformers/torchscript.html). 