In [None]:
"""
You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.

Instructions for setting up Colab are as follows:
1. Open a new Python 3 notebook.
2. Import this notebook from GitHub (File -> Upload Notebook -> "GitHub" tab -> copy/paste GitHub URL)
3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select "GPU" for hardware accelerator)
4. Run this cell to set up dependencies.
5. Restart the runtime (Runtime -> Restart Runtime) for any upgraded packages to take effect


NOTE: User is responsible for checking the content of datasets and the applicable licenses and determining if suitable for the intended use.
"""
# If you're using Google Colab and not running locally, run this cell.
import os

# Install dependencies
!pip install wget
!apt-get install sox libsndfile1 ffmpeg
!pip install text-unidecode
!pip install matplotlib>=3.3.2

## Install NeMo
BRANCH = 'main'
!python -m pip install "nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git@$BRANCH"

# Multi Task Adaptation with Adapters


In earlier tutorials, we utilized a specific model for one task - for example, an ASR model (CTC, RNN-T etc) for the singular task of Speech Recognition. This is very useful if we want to specialize one task per model, but it can be expensive to deploy a fleet of models for each task, and learn routers to pass user tasks to correct models.

We now support Multi Task models in NeMo, such that a single model can perform multiple tasks such as speech recognition, speech translation, voice activity detection, and more in the future. With one model supporting multiple tasks, we can simplify the task of deploying models and also hope to leverage individual tasks to improve each other (for example: you do need strong speech recognition first before you start doing translation).

---

Multi Task (Canary) models are highly capable large neural networks capable of things like speech recognition, X to English and English to X translation and able to select whether to transcribe speech with punctuation and capitalization. These huge models are trained on several thousand hours of speech and text data, making it challenging to adapt to new datasets.

In the previous tutorial for [ASR Adapters](https://github.com/NVIDIA/NeMo/blob/main/tutorials/asr/asr_adapters/ASR_with_Adapters.ipynb), we used small adapter modules to tune a large ASR model on a small amount of data. In this tutorial, we will adapt a [Nvidia Canary](https://huggingface.co/nvidia/canary-1b) model onto a small amount of speech data for both Automatic Speech Recognition (ASR) and Automatic Speech Translation (AST).

In this tutorial, we will also demonstrate a simple way of creating custom Data Modules from PyTorch Lightning to design custom datasets and data loaders for the highly flexible Multi Task Models in NeMo ASR. This offers users more flexibility in designing new tasks, and finetuning the models on small amounts of data.

----

First, lets instantiate the [Canary](https://huggingface.co/nvidia/canary-1b) model

In [None]:
import os
import json

import nemo.collections.asr as nemo_asr

In [None]:
model = nemo_asr.models.ASRModel.from_pretrained("nvidia/canary-1b")

# Enable Adapter Support in Model

New in NeMo 2.0, we now have a simple utility function to convert the model into one that supports adapters, called `replace_adapter_compatible_modules()`.

This will go through the full model and check modules if they support adapters, and then enable that ability. Once used, you can freely use adapter methods.

In [None]:
model.replace_adapter_compatible_modules()

## Check Which Targets Are Supported For This Model

Now that the model has enabled adapter support, lets take a look at which of its modules support adapter modules to be attached to them.

**Note**
Below, you might see an adapter module with no name `''` - this corresponds to the "default" model target if the target isn't specified. Users can chose to simply skip the module name when adding an adapter, and the model will by default add adapters to the encoder module.

In [None]:
model.adapter_module_names

## Prepare the Adapter

Now that we know which modules are supported, lets create a simple adapter module for the encoder and decoder modules.

In [None]:
from nemo.collections.common.parts import LinearAdapterConfig

In [None]:
input_dim = model.cfg.encoder.d_model
adapter_dim = 8

In [None]:
enc_adapter_cfg = LinearAdapterConfig(in_features=input_dim, dim=adapter_dim)
dec_adapter_cfg = LinearAdapterConfig(in_features=input_dim, dim=adapter_dim)

## Add Adapter Modules

Now that we have the adapter configs prepared, lets add them to the model !

We provide the target module by using `target:adapter_name` when calling `add_adapter()` - this tells the model to setup an adapter called `adapter_name` to the module denoted by `target` with the config `cfg`.

In [None]:
model.add_adapter(name="encoder:enc", cfg=enc_adapter_cfg)
model.add_adapter(name="transf_decoder:dec", cfg=dec_adapter_cfg)

print("Added adapters!")

## Freeze Original Module Parameters and Unfreeze Adapter Weights Only

When tuning adapters, we usually freeze the entire base model and only tune the adapters. This prevents the need for large amounts of data, preserves a lot of memory (since the full model doesnt need backward pass, only the adapters) and makes it easier to adapt huge models.

In [None]:
model.freeze()
model.unfreeze_enabled_adapters()

----

Lets make sure that the number of trainable parameters is a lot smaller (< 1 M) than the total number of params (1 B).

In [None]:
model.summarize()

## Check Enabled Adapters

Here, we check that the adapters that we named above (`enc` and `dec`) are both setup and enabled.

In [None]:
model.get_enabled_adapters()

# Customizing Multi Task Models

In the following section, we will take a deeper look into what are the components that compose a Multi Task Model and how users can override each of these parts to create their own customizable multi task models.

---

In this tutorial, we will only see the internal components such as the prompt format and dataset construction, but not change them.

In a following tutorial, we will show how to add an additional task to a pre-trained Multi Task Model using a pre-trained model as a starting point.

# Prompt Handling for Multi Task Models
Nvidia Canary is our first model that is a Multi Task Model.

Multi Task models utilize a prompt format, similar to those used in Large Language Models, in order to denote to the model which task is to be performed, which langauge is being spoken and what language should the output transcript be in, whether to provide punctuation and capitalization or not, and so much more in the future !

Lets take a look at the model's `prompt` for the Canary model that we have created -

In [None]:
model.prompt_format

----

This gives us the prompt format functions name, which we will see below points to a prompt format function that reads in manifest items and maps it to the template.

## Reuse / Register a Prompt Format Function

When we print `model.prompt_format` it writes `canary` which is one of the registered prompt templates available in NeMo ASR.
For simplicity's sake, we will continue to use the same prompt format for this tutorial. However, we enable users to define their own prompt formats and register them as needed.

Let's see what the `canary` prompt format looks like:

In [None]:
from nemo.collections.common.prompts.fn import get_prompt_format_fn, registered_prompt_format_fn

In [None]:
canary_prompt_format_fn = get_prompt_format_fn("canary")
canary_prompt_format_fn?

### Registering a New Prompt Format Function

Just to show that this is user-configurable, we show how to register a dummy prompt format below:

In [None]:
@registered_prompt_format_fn
def canary2(cuts, tokenizer, inference: bool):
    """ Users can implement this as needed """
    raise NotImplementedError()

print("Registered prompt")

In [None]:
temp = get_prompt_format_fn('canary2')
temp.__name__

## Create / Reuse a Prompt Format

Canary Multi Task Model comes with a pre-defined prompt template, so we need to provide it data in a format that can be handled by that prompt format class.

A `PromptFormatter` is a special class that defines the dialog template of the order of turns that occur in a model's prompt. For example, in Language Models, we normally may begin with either a `System` or `User` turn, followed by an `Assistant` turn which produces an output from the model. Similarly in Multi Task models, we enable support for such a usage pattern.

Do note: Current generation of Canary models are not trained to operate on multi turn conversations, however future variants of Multi Task models may support such usage.

In [None]:
# Let's review the actual prompt formatter clas docs
model.prompt?

In [None]:
# Let's see the actual template of this prompt formatter
model.prompt.TEMPLATE

---

We see that the template contains two turns - `user` and `assistant`.

User template looks as follows: `<|startoftranscript|>|source_lang||task||target_lang||pnc|`
During execution, we remove the `|` in order to fill in the actual value of the slots provided by the the data loader.

User holds the following allowed slots -
* `source_lang`
* `target_lang`
* `task`
* `pnc`

Similarly, for Assistant template : `|text|<|endoftext|>`

Assistant holds the following allowed slots -
* `text`

### Creating and Using a Custom Prompt Formatter

While we provide a pre-trained model with a pre-defined prompt format, we also enable users to create their own PromptFormatter subclass and change it as needed.

Below, we show a simple modification to the model's PromptFormatter and show how to change it.

In [None]:
# Create a new prompt formatter using the original CanaryPromptFormatter class as baseclass
class CanaryPromptFormatterV2(model.prompt.__class__):

    # make sure to provide a new name
    NAME: str = "canary2"

    # Make any changes as necessary.
    # For this demonstration, we will not change anything other than the name

In [None]:
# Next, lets update the model's prompt formatter
model.change_prompt("canary2")

---

We have now successfully changed the prompt format to `canary2`.

**Note**: It is important to know that when changing the prompt format, the name of the new prompt format class (`canary2` in this case) **has to match** the name of the prompt function registered with `@registered_prompt_format_fn`!

In [None]:
# Check if everything is ok -
model.prompt.__class__.__name__

In [None]:
model.prompt_format

---
For the rest of the tutorial, we will revert back to the original prompt formatter

In [None]:
model.change_prompt('canary')

## Creating / Using a Multi Task Dataset

Now that we have learned how to modify the model's prompt formatter and the underlying format function that maps manifest items into slots to inject into the prompt template, next let's take a look at how to use and create custom datasets for training multi task models.

---

Unlike previous tutorials that showcase how to use pre-defined datasets and point them to your manifest files, we will take a slightly more hands-on approach for multi task modes. This is due to shear flexibility of multi task models - they can do almost any task that you can formulate into a "speech in - text out" problem.

So it is not easy to have a pre-defined dataset class that can handle all new ideas and tasks that researchers can come up with.

Instead, we showcase how to build a custom dataset for yourself and use it with the Multi Task model instead.

---

However, we also provide a base class that can be used as is by users if they dont want the hassle of writing their own datasets.

This is handled by the `PromptedAudioToTextLhotseDataset` -  it maps user defined manifest items to the items defined in the prompt template of the model, so as long as the manifest corresponds to the slots supported by the model, it will be managed by the Dataset automatically.

In [None]:
from nemo.collections.asr.data.audio_to_text_lhotse_prompted import PromptedAudioToTextLhotseDataset

# Uncomment below line to see the class definition of PromptedAudioToTextLhotseDataset
# PromptedAudioToTextLhotseDataset??

### Creating a New Prompted Dataset

In [None]:
import torch.utils.data
from lhotse import CutSet
from lhotse.cut import MixedCut, MonoCut
from lhotse.dataset import AudioSamples
from lhotse.dataset.collation import collate_vectors

from nemo.collections.asr.data.audio_to_text_lhotse import TokenizerWrapper
from nemo.collections.asr.data.audio_to_text_lhotse_prompted import PromptedAudioToTextLhotseDataset

class MyCanaryPromptedAudioToTextLhotseDataset(torch.utils.data.Dataset):
    """
    This dataset is based on :class:`~nemo.collections.asr.data.audio_to_text_lhotse.LhotseSpeechToTextBpeDataset`.
    It is a Lhotse-style dataset that converts a mini-batch of Cuts into tensors.
    The main difference from ``LhotseSpeechToTextBpeDataset`` is that we introduce
    a special prompt format for multitask encoder-decoder models.

    To perform the prompt formatting, we accept a ``prompt_format_fn``.
    It's expected to accept:
    * a ``CutSet`` which it will internally iterate over for utterances, and
    * a ``TokenizerWrapper`` object that will be internally used to tokenize the utterances

    Tokenized utterances will be extended with special prompt tokens according to ``prompt_format_fn`` logic.
    We support cuts with multiple supervision segments -- their tokenized texts will be concatenated before we add the prompt tokens.
    This is useful, for example, in code-switched scenarios where each segment is spoken in a different language.
    """

    def __init__(
        self,
        tokenizer: 'TokenizerSpec',
        inference: bool = False,
    ):
        super().__init__()
        self.tokenizer = TokenizerWrapper(tokenizer)
        self.load_audio = AudioSamples(fault_tolerant=True)
        self.padding_value = self.tokenizer._tokenizer.pad_id
        self.prompt_format_fn = get_prompt_format_fn('canary')  # Use the default canary prompt function
        self.inference = inference

    def __getitem__(self, cuts: CutSet) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        audio, audio_lens, cuts = self.load_audio(cuts)

        prompts_with_answers, prompts = self.prompt_format_fn(cuts, self.tokenizer, inference=self.inference)

        prompts_with_answers = [torch.as_tensor(t) for t in prompts_with_answers]
        prompts_with_answers_lens = torch.tensor([t.size(0) for t in prompts_with_answers], dtype=torch.long)
        prompts_with_answers = collate_vectors(prompts_with_answers, padding_value=self.padding_value)

        if self.inference:
            prompts = [torch.as_tensor(t) for t in prompts]
            prompts_lens = torch.tensor([t.size(0) for t in prompts], dtype=torch.long)
            prompts = collate_vectors(prompts, padding_value=self.padding_value)
        else:
            prompts = None
            prompts_lens = None

        return audio, audio_lens, prompts_with_answers, prompts_with_answers_lens, prompts, prompts_lens

---

The above class is mostly a demonstration, but it showcases how users might flexibly change the prompt formatter, prompt format function and even the data set that handles these two in a flexible way.

The order of operations is usually this -

1) Create a new Prompt Formatter class - this denotes the slots that each turn can have (including new task inputs or other values). This class is auto registered.
2) Create a new Prompt Format function - Using `@registered_prompt_format_fn` decorator, write a custom function that accepts args and processes the provided input data from a manifest.
3) Create a new Dataset class (usually based on the `PromptedAudioToTextLhotseDataset` dataset) that uses the Prompt Format function to convert manifest items into nicely formatted samples that can be passed to the Prompt Formatter.

# Preparing a Canary Dataset

Now that we have all the pieces together on the model side, let's take a look on the data side.

## Required Roles Defined by Prompt Format

These are the available 'roles' available in the prompt format - they denote at each turn, one role can be enabled and its input or output can be calculated.

In [None]:
model.prompt.get_roles()

In [None]:
for role in model.prompt.get_roles():
    print(role, model.prompt.get_slots(role))
    print()

## Create a Data Module

Data Modules are one way of organizing datasets in PyTorch Lightning. It provides a unified place where data loading and processing can be potentially handled.

**Note**: This isn't strictly necessary - you can achieve the same using just Pytorch dataloaders directly and passing it to Trainer.fit() but we showcase a data module codebase that can be extended by the user.

----

In our CanaryAN4DataModule - we will perform two tasks. One is En ASR - transcribing the AN4 English dataset. Another is En to De AST - directly translating the english audio to German text.

For simplicity's sake, we will use a small off-the-shelf model to perform the translation of English Transcripts to German.

---

In NeMo 2.0, we utilize [Lhotse](https://github.com/lhotse-speech/lhotse) as our data backbone for speech tasks, which simplifies using custom speech datasets.

Most of the magic is handled by the following code

```python
from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config

get_lhotse_dataloader_from_config(
    OmegaConf.create(config),  # Pass in a config that points to the manifest files and other arguments
    global_rank=self.trainer.global_rank,
    world_size=self.trainer.world_size,
    # Pass in the dataset class for Lhotse to handle. This class now receives CutSet as input.
    dataset=MyCanaryPromptedAudioToTextLhotseDataset(tokenizer=self.tokenizer, inference=inference),
)
```

In [None]:
import os
import glob
import json
import copy
import subprocess
import tarfile
import wget
import librosa
import tqdm
from omegaconf import OmegaConf

from torch.utils.data import DataLoader, Dataset

import pytorch_lightning as L

from transformers import T5Tokenizer, T5ForConditionalGeneration

from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest
from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config


# Function to build a manifest
def build_manifest(transcripts_path, manifest_path, wav_path, data_dir):
    with open(transcripts_path, 'r') as fin:
        with open(manifest_path, 'w') as fout:
            for line in fin:
                # Lines look like this:
                # <s> transcript </s> (fileID)
                transcript = line[: line.find('(')-1].lower()
                transcript = transcript.replace('<s>', '').replace('</s>', '')
                transcript = transcript.strip()

                file_id = line[line.find('(')+1 : -2]  # e.g. "cen4-fash-b"
                audio_path = os.path.join(
                    data_dir, wav_path,
                    file_id[file_id.find('-')+1 : file_id.rfind('-')],
                    file_id + '.wav')

                duration = librosa.core.get_duration(path=audio_path)

                # Write the metadata to the manifest
                metadata = {
                    "audio_filepath": audio_path,
                    "duration": duration,
                    "text": transcript,
                    "pnc": "no",
                    "source_lang": "en",
                    "target_lang": "en",
                    "task": "asr",
                }
                json.dump(metadata, fout)
                fout.write('\n')

    return manifest_path


class CanaryAN4DataModule(L.LightningDataModule):

    def __init__(self, tokenizer, data_dir: str = "./an4/", batch_size=8):
        super().__init__()
        self.tokenizer = tokenizer
        self.data_dir = data_dir
        self.batch_size = batch_size

        # ASR manifests
        self.train_manifest = data_dir + '/an4/train_manifest.json'
        self.test_manifest = data_dir + '/an4/test_manifest.json'

        # AST manifests
        self.ast_train_manifest = data_dir + '/an4/ast_train_manifest.json'
        self.ast_test_manifest = data_dir + '/an4/ast_test_manifest.json'

        # Combined manifests
        self.combined_train_manifest = data_dir + '/an4/combined_train_manifest.json'
        self.combined_test_manifest = data_dir + '/an4/combined_test_manifest.json'

    def setup(self, stage):
        # make assignments here (val/train/test split)
        # called on every process in DDP
        # Assign train/val datasets for use in dataloaders
        pass

    def train_dataloader(self):
        config = {'manifest_filepath': self.combined_train_manifest, 'batch_size': self.batch_size,
                  'num_workers': 4, 'shuffle': True, 'min_duration': 0.3, 'max_duration': 10.0}
        return self._setup_dataloader(config)

    def val_dataloader(self):
        config = {'manifest_filepath': self.combined_test_manifest, 'batch_size': self.batch_size,
                  'num_workers': 4, 'shuffle': False, 'min_duration': 0.3, 'max_duration': 10.0}
        return self._setup_dataloader(config, inference=True)

    def test_dataloader(self):
        config = {'manifest_filepath': self.combined_test_manifest, 'batch_size': self.batch_size,
                  'num_workers': 4, 'shuffle': False, 'min_duration': 0.3, 'max_duration': 10.0}
        return self._setup_dataloader(config, inference=True)

    def teardown(self, stage):
        # clean up after fit or test
        # called on every process in DDP
        pass

    def _setup_dataloader(self, config, inference: bool = False):
        """
        The main function that creates the data loader using Lhotse's integration with NeMo.
        """
        return get_lhotse_dataloader_from_config(
                OmegaConf.create(config),
                global_rank=self.trainer.global_rank,
                world_size=self.trainer.world_size,
                # Note the passing of our custom dataset
                dataset=MyCanaryPromptedAudioToTextLhotseDataset(tokenizer=self.tokenizer, inference=inference),
            )

    def prepare_data(self):
        # download, split, etc...
        # only called on 1 GPU/TPU in distributed
        if not os.path.exists(self.data_dir):
            os.makedirs(self.data_dir)

        data_dir = self.data_dir
        if not os.path.exists(data_dir + '/an4_sphere.tar.gz'):
            an4_url = 'https://dldata-public.s3.us-east-2.amazonaws.com/an4_sphere.tar.gz'
            an4_path = wget.download(an4_url, data_dir)
            print(f"Dataset downloaded at: {an4_path}")
        else:
            print("Tarfile already exists.")
            an4_path = data_dir + '/an4_sphere.tar.gz'

        if not os.path.exists(data_dir + '/an4/'):
            # Untar and convert .sph to .wav (using sox)
            tar = tarfile.open(an4_path)
            tar.extractall(path=data_dir)

            print("Converting .sph to .wav...")
            sph_list = glob.glob(data_dir + '/an4/**/*.sph', recursive=True)
            for sph_path in sph_list:
                wav_path = sph_path[:-4] + '.wav'
                cmd = ["sox", sph_path, wav_path]
                subprocess.run(cmd)
        print("Finished conversion.\n******")

        # Building Manifests
        print("******")
        train_transcripts = data_dir + '/an4/etc/an4_train.transcription'
        train_manifest = self.train_manifest
        if not os.path.isfile(train_manifest):
            build_manifest(train_transcripts, train_manifest, 'an4/wav/an4_clstk', data_dir)
            print("Training manifest created.")

        test_transcripts = data_dir + '/an4/etc/an4_test.transcription'
        test_manifest = self.test_manifest
        if not os.path.isfile(test_manifest):
            build_manifest(test_transcripts, test_manifest, 'an4/wav/an4test_clstk', data_dir)
            print("Test manifest created.")
        print("*** Wrote manifests for Eng ***")

        train_manifest_data = read_manifest(self.train_manifest)
        test_manifest_data = read_manifest(self.test_manifest)

        if not os.path.isfile(self.ast_train_manifest) or not os.path.isfile(self.ast_test_manifest) or not os.path.isfile(self.combined_train_manifest) or not os.path.isfile(self.combined_test_manifest):
            tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
            t5_model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small")

            if torch.cuda.is_available():
                t5_model = t5_model.cuda()

            def pipe(text):
                if isinstance(text, str):
                    text = [text]

                prefix = "translate English to German"
                prompts = [prefix + ": " + x for x in text]
                input_ids = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).input_ids
                input_ids = input_ids.to(t5_model.device)
                outputs = t5_model.generate(input_ids, max_new_tokens=64)
                return [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]

            ast_train_manifest_data = copy.deepcopy(train_manifest_data)
            ast_test_manifest_data = copy.deepcopy(test_manifest_data)

            print("Translating train set")
            train_texts = [x['text'] for x in train_manifest_data]
            BATCH_SIZE = 32

            for i in tqdm.tqdm(range(0, len(train_texts), BATCH_SIZE), total=len(train_texts) // BATCH_SIZE):
                batch_texts = train_texts[i:i+BATCH_SIZE]
                batch_texts = pipe(batch_texts)
                for j, text in enumerate(batch_texts):
                    ast_train_manifest_data[i+j]['text'] = text
                    ast_train_manifest_data[i+j]['task'] = 'ast'
                    ast_train_manifest_data[i+j]['target_lang'] = 'de'

            print("Translating test set")
            for data in tqdm.tqdm(ast_test_manifest_data, total=len(ast_test_manifest_data)):
                data['text'] = pipe(data['text'])[0]
                data['task'] = 'ast'
                data['target_lang'] = 'de'

            write_manifest(self.ast_train_manifest, ast_train_manifest_data)
            write_manifest(self.ast_test_manifest, ast_test_manifest_data)

            print("*** Wrote ast manifests ***")

            combined_train, combined_test = [], []
            combined_train.extend(train_manifest_data)
            combined_train.extend(ast_train_manifest_data)

            combined_test.extend(test_manifest_data)
            combined_test.extend(ast_test_manifest_data)

            write_manifest(self.combined_train_manifest, combined_train)
            write_manifest(self.combined_test_manifest, combined_test)
            print("*** Wrote combined manifests ***")

        else:
            print("*** Wrote ast and combined manifests ***")


---

Each item in the prepared manifest has the following items by default.

As you will recognize, these are the same keys provided by the `CanaryPromptFormatter` classes `slots` argument, so each of these values in the is mapped back to those slots.

```python
metadata = {
    "audio_filepath": audio_path,
    "duration": duration,
    "text": transcript,
    "pnc": "no",
    "source_lang": "en",
    "target_lang": "en",
    "task": "asr",
}
```

The most important function in the Data Module above is `prepare_data()`:

1) It first downloads and converts the AN4 audio files to wav files.
2) Then it writes a new manifest file with the above keys for ASR task
3) It then translates the En transcripts with a `t5-small` model to generate German transcripts
4) Finally it writes another manifest for the AST task with these translated texts.
5) Finally it builds a combined manifest item for both ASR (en) and AST (en to de) multi-task training

**Note**: We are using prepare_data() only for demonstration. Normally, users should process before experimentation, and so they would only need to implement methods above prepare_data() in their Data Module.

## Download and Prepare Dataset

In [None]:
data_module = CanaryAN4DataModule(tokenizer=model.tokenizer, batch_size=16)

In [None]:
data_module.prepare_data()

In [None]:
!head -n 5 {data_module.train_manifest}

In [None]:
!head -n 5 {data_module.ast_train_manifest}

# Evaluate Model before Training

Canary Multi Task model is already very capable, achieving strong scores on multiple benchmarks. So we first evaluate the baseline numbers on the two tasks

1) ASR: WER calculation on transcripts

2) AST: SacreBLEU calculation on translations

In [None]:
from nemo.collections.asr.metrics.wer import word_error_rate
from torchmetrics.text import SacreBLEUScore

In [None]:
asr_test = read_manifest(data_module.test_manifest)
ast_test = read_manifest(data_module.ast_test_manifest)

In [None]:
asr_filepaths = [x['audio_filepath'] for x in asr_test]
asr_gt = [x['text'] for x in asr_test]

ast_filepaths = [x['audio_filepath'] for x in ast_test]
ast_gt = [x['text'] for x in ast_test]

print("Num files:", len(asr_filepaths))

In [None]:
if torch.cuda.is_available():
    model = model.cuda()  # move model to gpu
    model = model.to(torch.bfloat16)  # cast full model to bfloat16

In [None]:
asr_preds = model.transcribe(asr_filepaths, pnc='no', task='asr', source_lang='en', target_lang='en', batch_size=32)

In [None]:
ast_preds = model.transcribe(ast_filepaths, pnc='no', task='ast', source_lang='en', target_lang='de', batch_size=32)

In [None]:
wer = word_error_rate(asr_preds, asr_gt)
print("WER", wer)

sacrebleu = SacreBLEUScore(n_gram=4)
scores = []
preds = []
gts = []
for pred, gt in zip(ast_preds, ast_gt):
    preds.append(pred)
    gts.append([gt])

# bleu = sum(scores) / len(scores)
sacrebleu.update(preds, gts)
bleu = sacrebleu.compute()
print("BLEU", bleu.item() * 100)

# Train Model

Finally, now that adapters have been prepared, model has been evaluated for a baseline and the dataset is prepared, it's time to train the adapter weights on the new datasets.

---

First, we update the optimizer and scheduler config

In [None]:
print(OmegaConf.to_yaml(model.cfg.optim))

In [None]:
# Setup optimization
model.cfg.optim.lr = 3e-4
model.cfg.optim.sched.warmup_steps = 25

---

Next, we setup a Lightning Trainer and Experiment Manager

In [None]:
from omegaconf import OmegaConf
from nemo.utils import exp_manager

In [None]:
trainer = L.Trainer(max_steps=200, accumulate_grad_batches=1, logger=False, enable_checkpointing=False, check_val_every_n_epoch=5)

In [None]:
# # Environment variable generally used for multi-node multi-gpu training.
# # In notebook environments, this flag is unnecessary and can cause logs of multiple training runs to overwrite each other.
# os.environ.pop('NEMO_EXPM_VERSION', None)

# config = exp_manager.ExpManagerConfig(
#     exp_dir=f'experiments/canary/',
#     name=f"Canary-Model-Adapter-Training",
#     checkpoint_callback_params=exp_manager.CallbackParams(
#         monitor="val_wer",
#         mode="min",
#         always_save_nemo=False,
#         save_best_model=False,
#     ),
# )

# config = OmegaConf.structured(config)

# logdir = exp_manager.exp_manager(trainer, config)

---

Begin training !

In [None]:
trainer.fit(model, data_module)

---

Save just the adapter parameters - which is less than 2 MB !

In [None]:
model.save_adapters("adapters.pt")
!ls -l -- *.pt
!du -sh *.pt

# Evaluate after Adaptation

Now that the model is done training, lets evaluate its scores on the test set again.
We should see a markedly higher translation BLEU and lower WER from above.

In [None]:
asr_test = read_manifest(data_module.test_manifest)
ast_test = read_manifest(data_module.ast_test_manifest)

In [None]:
asr_filepaths = [x['audio_filepath'] for x in asr_test]
asr_gt = [x['text'] for x in asr_test]

ast_filepaths = [x['audio_filepath'] for x in ast_test]
ast_gt = [x['text'] for x in ast_test]

print("Num files:", len(asr_filepaths))

In [None]:
if torch.cuda.is_available():
    model = model.cuda()
    model = model.to(torch.bfloat16)

In [None]:
asr_preds = model.transcribe(asr_filepaths, pnc='no', task='asr', source_lang='en', target_lang='en', batch_size=32)

In [None]:
ast_preds = model.transcribe(ast_filepaths, pnc='no', task='ast', source_lang='en', target_lang='de', batch_size=32)

In [None]:
from nemo.collections.asr.metrics.wer import word_error_rate
from torchmetrics.text import SacreBLEUScore

In [None]:
wer = word_error_rate(asr_preds, asr_gt)
print("WER", wer)

In [None]:
sacrebleu = SacreBLEUScore(n_gram=4)
scores = []
preds = []
gts = []
for pred, gt in zip(ast_preds, ast_gt):
    preds.append(pred)
    gts.append([gt])

# bleu = sum(scores) / len(scores)
sacrebleu.update(preds, gts)
bleu = sacrebleu.compute()
print("BLEU", bleu.item() * 100)

# Conclusion

In this tutorial we added adapters to a Multi Task model (Nvidia Canary) and show how to create a custom dataset to finetune a canary model to a new dataset with previous tasks such as ASR and AST. The primary goal of this tutorial was to show how to flexibly adapt a Canary model to any of the pre-existing tasks.

In a future tutorial, we will show how to add additional tasks to a pre-trained Canary, so that you can leverage the pre-trained encoder and decoder for your own custom tasks!