Install necessary dependencies (if you are using CPU):

In [None]:
%pip install ipykernel notebook transformers torch torchvision torchaudio datasets "datasets[audio]" "jax[cpu]==0.4.11" git+https://github.com/sanchit-gandhi/whisper-jax.git cached_property
%conda install ffmpeg -c conda-forge 

Install necessary dependencies (if you are using GPU):

In [None]:
%pip install ipykernel notebook transformers torch torchvision torchaudio datasets "datasets[audio]" "jax==0.4.11" git+https://github.com/sanchit-gandhi/whisper-jax.git cached_property
%conda install ffmpeg -c conda-forge 
%conda install cuda-nvcc -c nvidia

## Transformer Implementation
Now we can use the `transformers` library to load the ASR model checkpoint `whisper-large-v2`:

In [None]:
from transformers import pipeline

whisper_pipeline = pipeline(model="openai/whisper-large-v2")

Lets use Mozilla's Common Voice Dataset to get some samples that we can use.
Note that we are using `streaming` here to avoid downloading the whole dataset.

In [None]:
from datasets import load_dataset

audio_dataset = load_dataset("mozilla-foundation/common_voice_11_0", "en", split="test", streaming=True)
audio_data_samples = audio_dataset.take(10)

Now we can start creating some transcriptions using the pipeline:

In [None]:
transcriptions = [whisper_pipeline(sample["audio"]) for sample in audio_data_samples]

Let's have a look at the resulting transcriptions:

In [None]:
print(transcriptions)

Great! Let's build a simple Huggingface Dataset that contains our Transcriptions and the source Audio:

In [None]:
from datasets import Dataset, Audio

new_dataset = Dataset.from_dict({
    "audio": [(sample["audio"]) for sample in audio_data_samples],
    "transcription": [transcription["text"] for transcription in transcriptions]
}).cast_column("audio", Audio())

Let's inspect the dataset and their features.

In [None]:
print(new_dataset)
print(new_dataset.features)

Now we can push our dataset to the hub to the split `example`.

In [None]:
new_dataset.push_to_hub("myuser/testset", split="example")

## Using Jax Whisper
Now that we have used the original Transformers implementation, let us test the implementation used in Whisper JAX by https://github.com/sanchit-gandhi.
Let's start by defining a pipeline:

In [None]:
from whisper_jax import FlaxWhisperPipline
import jax.numpy as jnp

# Most users should use jnp.float16, use jnp.bfloat16 for A100 / TPU
jax_pipeline = FlaxWhisperPipline("openai/whisper-large-v2", dtype=jnp.float16)

Let's load a dataset and create transcriptions using Jax Whisper:

In [None]:
from datasets import load_dataset

audio_dataset = load_dataset("mozilla-foundation/common_voice_11_0", "en", split="test", streaming=True)
audio_data_samples = audio_dataset.take(10)
transcriptions = [jax_pipeline(sample["audio"], task="transcribe") for sample in audio_data_samples]

Let's take a look at the output:

In [None]:
print(transcriptions)

We can also activate timestamps easily:

In [None]:
timestamped_transcriptions = [jax_pipeline(sample["audio"], task="transcribe", return_timestamps=True) for sample in audio_data_samples]

Let's have another look:

In [None]:
print(timestamped_transcriptions)

Let's also create a dataset and push it to hub:

In [None]:
from datasets import Dataset, Audio

jax_dataset = Dataset.from_dict({
    "audio": [(sample["audio"]) for sample in audio_data_samples],
    "transcription": [transcription["text"] for transcription in timestamped_transcriptions],
    "chunks": [transcription["chunks"] for transcription in timestamped_transcriptions]
}).cast_column("audio", Audio())

jax_dataset.push_to_hub("myuser/testsetjax", split="example")