## [Whisper-AT: Noise-Robust Automatic Speech Recognizers are Also Strong Audio Event Taggers]()

This colab script contains a step-by-step tutorial on how to use Whisper-AT for joint automatic speech recognitiom (ASR) and audio tagging (AT).

Please cite our paper if you find this repository useful.

```
@inproceedings{gong_whisperat,
  author={Gong, Yuan and Khurana, Sameer and Karlinsky, Leonid and Glass, James},
  title={Whisper-AT: Noise-Robust Automatic Speech Recognizers are Also Strong Audio Event Taggers},
  year=2023,
  booktitle={Proc. Interspeech 2023}
}
```
For more information, please check https://github.com/YuanGongND/whisper-at

### Step 1. Install Whisper-AT Package

We intentionally do not any additional dependencies to the original Whisper. So if your environment can run the original Whisper, it must can also run Whisper-AT. Note that following original Whisper, it also requires the command-line tool [`ffmpeg`](https://ffmpeg.org/) to be installed on your system. Please check OpenAI Whisper repo for details.


Whisper-AT can be installed simply by `pip install whisper-at`

In [1]:
!pip install whisper-at

Collecting whisper-at
  Downloading whisper_at-0.2-py3-none-any.whl (1.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
Collecting tiktoken==0.3.3 (from whisper-at)
  Downloading tiktoken-0.3.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/1.7 MB[0m [31m85.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tiktoken, whisper-at
Successfully installed tiktoken-0.3.3 whisper-at-0.2


### Step 2. Use as the Original Whisper

In [2]:
# download a sample audio
!pip -q install wget
import wget,IPython
wget.download('https://www.dropbox.com/s/7eznyazmc1pmw9h/case_closed.wav?dl=1', '/content/sample_audio.flac')
#IPython.display.Audio('/content/sample_audio.flac')

  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for wget (setup.py) ... [?25l[?25hdone


'/content/sample_audio.flac'

In [3]:
# note this is whisper"_"at not whisper-at
import whisper_at as whisper

# the only new thing in whisper-at
# specify the temporal resolution for audio tagging, 10 means Whisper-AT predict audio event every 10 seconds (hop and window=10s).
audio_tagging_time_resolution = 10

model = whisper.load_model("large-v1")
# for large, medium, small models, we provide low-dim proj AT models to save compute.
# model = whisper.load_model("large-v1", at_low_compute=Ture)
result = model.transcribe("/content/sample_audio.flac", at_time_res=audio_tagging_time_resolution)
for segment in result['segments']:
  print(segment['start'], 's-', segment['end'], 's', segment['text'])

# # translation task is also supported
# result = model.transcribe("/content/sample_audio.flac", task='translate', at_time_res=audio_tagging_time_resolution)
# print(result["text"])

100%|█████████████████████████████████████| 2.87G/2.87G [01:04<00:00, 47.8MiB/s]
100%|███████████████████████████████████████| 153M/153M [00:05<00:00, 27.4MiB/s]


0.0 s- 5.0 s  My name's Jimmy Kudo, and I've always wanted to be a great detective.
5.0 s- 10.0 s  Over the years I've worked at perfecting the craft, impressing almost everyone, except maybe Rachel.
10.0 s- 14.0 s  But then some bad guys ambushed me, slipping me some kind of poison.
14.0 s- 16.0 s  Now what's gonna happen?
16.0 s- 30.0 s  It was the first new century in one hundred years.
30.0 s- 36.0 s  And when I felt like I should cry, I laughed away my tears.
36.0 s- 48.0 s  At the end of a millennium, we waited a long, long time,
48.0 s- 54.0 s  To see the brave new world, and the mountains we would climb.
54.0 s- 60.0 s  Things I tried to comprehend as a child were made of mystery.
60.0 s- 67.0 s  There's nothing I need to defend, there's nothing great about me.
67.0 s- 73.0 s  All I will ever believe is the pounding of my heart, though.
73.0 s- 79.0 s  It doesn't answer questions, that's just the way it goes.
79.0 s- 84.0 s  All I will ever have bacon is the beating in my heart

`result["text"]` is the ASR output transcripts, it will be identical to that of the original Whisper and is not impacted by `at_time_res`, the ASR function still follows Whisper's 30 second window. `at_time_res` is only related to audio tagging.

Compared to the original Whisper, the only new thing is `at_time_res`, which is the hop and window size for Whisper-AT to predict audio events. For example, for a 60-second audio, setting `at_time_res = 10` means the audio will be segmented to 6 10-second segments, and Whisper-AT will predict audio tags based on each 10-second segment,
a total of 6 audio event predictions will be made. **Note `at_time_res` must be an integer multiple of 0.4, e.g., 0.4, 0.8, ...**, the default value is 10.0, which is the value we use to train the model and should lead to best performance.


### Step 3. Get the Audio Tagging Output

Compared with the original Whisper, `result` contains a new entry called `audio_tag`. `result['audio_tag']` is a torch tensor of shape [⌈`audio_length`/`at_time_res`⌉, 527]. For example, for a 60-second audio and `at_time_res = 10`, `result['audio_tag']` is a tensor of shape [6, 527]. 527 is the size of the [AudioSet label set](), `result['audio_tag'][i,j]` is the (unnormalised) logits of class `j` of the `i`th segment.

If you are familiar with audio tagging and AudioSet, you can take raw `result['audio_tag']` for your usage.

In [4]:
import torchaudio
audio, sr = torchaudio.load('/content/sample_audio.flac')
audio_len = audio.shape[1] / sr
print('Audio length is {:.2f}, at time resolution is {:.1f}, Whisper-AT output in shape'.format(audio_len, audio_tagging_time_resolution), result['audio_tag'].shape)

Audio length is 90.05, at time resolution is 10.0, Whisper-AT output in shape torch.Size([10, 527])


But we also provide a tool to make it easier.
You can feed the `result` to `whisepr.parse_at_label` and get readable results.

In [5]:
audio_tag_result = whisper.parse_at_label(result, language='follow_asr', top_k=5, p_threshold=-2, include_class_list=list(range(527)))
for segment in audio_tag_result:
  print(segment)

{'time': {'start': 0, 'end': 10}, 'audio tags': [('Music', 1.8221180438995361), ('Speech', 0.932380735874176)]}
{'time': {'start': 10, 'end': 20}, 'audio tags': [('Music', 1.3558011054992676), ('Grunge', -1.350265622138977), ('Progressive rock', -1.424497127532959), ('Punk rock', -1.5711394548416138)]}
{'time': {'start': 20, 'end': 30}, 'audio tags': [('Music', 0.8049014806747437)]}
{'time': {'start': 30, 'end': 40}, 'audio tags': [('Music', 1.323777437210083), ('Grunge', -1.3895658254623413), ('Progressive rock', -1.614871859550476), ('Punk rock', -1.6720657348632812), ('Independent music', -1.8473705053329468)]}
{'time': {'start': 40, 'end': 50}, 'audio tags': [('Music', 0.7121016383171082), ('Singing', -1.492790699005127)]}
{'time': {'start': 50, 'end': 60}, 'audio tags': [('Music', -1.4463484287261963)]}
{'time': {'start': 60, 'end': 70}, 'audio tags': [('Music', 1.2629746198654175), ('Grunge', -0.9662912487983704), ('Pop music', -1.6864861249923706), ('Independent music', -1.78539

If you change the audio tagging resolution to 2s, then the output will be more fine-grained.

In [6]:
audio_tagging_time_resolution = 2
result = model.transcribe("/content/sample_audio.flac", at_time_res=audio_tagging_time_resolution)
print('Audio length is {:.2f}, at time resolution is {:.1f}, Whisper-AT output in shape'.format(audio_len, audio_tagging_time_resolution), result['audio_tag'].shape)
audio_tag_result = whisper.parse_at_label(result, language='follow_asr', top_k=5, p_threshold=-2, include_class_list=list(range(527)))
for segment in audio_tag_result:
  print(segment)

  result = model.transcribe("/content/sample_audio.flac", at_time_res=audio_tagging_time_resolution)


Audio length is 90.05, at time resolution is 2.0, Whisper-AT output in shape torch.Size([46, 527])
{'time': {'start': 0, 'end': 2}, 'audio tags': [('Music', -0.4525575637817383), ('Speech', -0.8882830739021301), ('Silence', -1.8616548776626587)]}
{'time': {'start': 2, 'end': 4}, 'audio tags': [('Music', 0.9328102469444275), ('Speech', -0.158194437623024)]}
{'time': {'start': 4, 'end': 6}, 'audio tags': [('Music', 0.7153851389884949), ('Speech', 0.38911065459251404)]}
{'time': {'start': 6, 'end': 8}, 'audio tags': [('Music', 0.40575844049453735), ('Speech', -0.6977154612541199)]}
{'time': {'start': 8, 'end': 10}, 'audio tags': [('Music', 1.6714180707931519), ('Speech', 0.5128814578056335), ('Drum kit', -1.7371801137924194), ('Drum', -1.9375439882278442)]}
{'time': {'start': 10, 'end': 12}, 'audio tags': [('Music', 1.3964262008666992), ('Speech', -0.23263752460479736)]}
{'time': {'start': 12, 'end': 14}, 'audio tags': [('Music', 0.7910271883010864), ('Speech', -1.9267847537994385)]}
{'ti

In [7]:
# Go back to 10s for better readability
audio_tagging_time_resolution = 10
result = model.transcribe("/content/sample_audio.flac", at_time_res=audio_tagging_time_resolution)

Let's take a closer look at `whisper.parse_at_label`.

First, `top_k` and `p_threshold` controls how many audio tags are output. Specifically, `whisper.parse_at_label` will output up to `k` labels that have unnormalised logits above `p_threshold`.

For example, set `top_k` = 1 allows the model to output at most 1 label.

In [8]:
audio_tag_result = whisper.parse_at_label(result, language='follow_asr', top_k=1, p_threshold=-2, include_class_list=list(range(527)))
for segment in audio_tag_result:
  print(segment)

{'time': {'start': 0, 'end': 10}, 'audio tags': [('Music', 1.8221180438995361)]}
{'time': {'start': 10, 'end': 20}, 'audio tags': [('Music', 1.3558011054992676)]}
{'time': {'start': 20, 'end': 30}, 'audio tags': [('Music', 0.8049014806747437)]}
{'time': {'start': 30, 'end': 40}, 'audio tags': [('Music', 1.323777437210083)]}
{'time': {'start': 40, 'end': 50}, 'audio tags': [('Music', 0.7121016383171082)]}
{'time': {'start': 50, 'end': 60}, 'audio tags': [('Music', -1.4463484287261963)]}
{'time': {'start': 60, 'end': 70}, 'audio tags': [('Music', 1.2629746198654175)]}
{'time': {'start': 70, 'end': 80}, 'audio tags': [('Music', 0.7769227623939514)]}
{'time': {'start': 80, 'end': 90}, 'audio tags': [('Speech', -1.0284314155578613)]}
{'time': {'start': 90, 'end': 100}, 'audio tags': [('Silence', 0.9160257577896118)]}


Setting larger `top_k` and smaller `p_threshold` makes the model more verbose, and vise-versa.

In [9]:
audio_tag_result = whisper.parse_at_label(result, language='follow_asr', top_k=10, p_threshold=-5, include_class_list=[47])
for segment in audio_tag_result:
  print(segment)

{'time': {'start': 0, 'end': 10}, 'audio tags': []}
{'time': {'start': 10, 'end': 20}, 'audio tags': []}
{'time': {'start': 20, 'end': 30}, 'audio tags': []}
{'time': {'start': 30, 'end': 40}, 'audio tags': []}
{'time': {'start': 40, 'end': 50}, 'audio tags': []}
{'time': {'start': 50, 'end': 60}, 'audio tags': []}
{'time': {'start': 60, 'end': 70}, 'audio tags': []}
{'time': {'start': 70, 'end': 80}, 'audio tags': []}
{'time': {'start': 80, 'end': 90}, 'audio tags': []}
{'time': {'start': 90, 'end': 100}, 'audio tags': []}


Second, you can also select the classes of interest by inputting a list to `include_class_list`. For the name-to-index mappling, simply let Whisper-AT print it for you.

In [10]:
whisper.print_label_name(language='en')

index: 0 : Speech
index: 1 : Male speech, man speaking
index: 2 : Female speech, woman speaking
index: 3 : Child speech, kid speaking
index: 4 : Conversation
index: 5 : Narration, monologue
index: 6 : Babbling
index: 7 : Speech synthesizer
index: 8 : Shout
index: 9 : Bellow
index: 10 : Whoop
index: 11 : Yell
index: 12 : Battle cry
index: 13 : Children shouting
index: 14 : Screaming
index: 15 : Whispering
index: 16 : Laughter
index: 17 : Baby laughter
index: 18 : Giggle
index: 19 : Snicker
index: 20 : Belly laugh
index: 21 : Chuckle, chortle
index: 22 : Crying, sobbing
index: 23 : Baby cry, infant cry
index: 24 : Whimper
index: 25 : Wail, moan
index: 26 : Sigh
index: 27 : Singing
index: 28 : Choir
index: 29 : Yodeling
index: 30 : Chant
index: 31 : Mantra
index: 32 : Male singing
index: 33 : Female singing
index: 34 : Child singing
index: 35 : Synthetic singing
index: 36 : Rapping
index: 37 : Humming
index: 38 : Groan
index: 39 : Grunt
index: 40 : Whistling
index: 41 : Breathing
index: 4

Assume we only interested in class 0, 1, 2 (Speech). We can let Whisper-AT only output these classes.

In [11]:
audio_tag_result = whisper.parse_at_label(result, language='follow_asr', top_k=10, p_threshold=-5, include_class_list=[0, 1, 2])
for segment in audio_tag_result:
  print(segment)

{'time': {'start': 0, 'end': 10}, 'audio tags': [('Speech', 0.932380735874176), ('Male speech, man speaking', -3.2922842502593994)]}
{'time': {'start': 10, 'end': 20}, 'audio tags': []}
{'time': {'start': 20, 'end': 30}, 'audio tags': []}
{'time': {'start': 30, 'end': 40}, 'audio tags': []}
{'time': {'start': 40, 'end': 50}, 'audio tags': []}
{'time': {'start': 50, 'end': 60}, 'audio tags': []}
{'time': {'start': 60, 'end': 70}, 'audio tags': []}
{'time': {'start': 70, 'end': 80}, 'audio tags': []}
{'time': {'start': 80, 'end': 90}, 'audio tags': [('Speech', -1.0284314155578613)]}
{'time': {'start': 90, 'end': 100}, 'audio tags': []}


Finally, Whisper-AT support multiple languages. The default behavior is to output audio tag label names in the same language as ASR transcripts (i.e., `follow-asr`). But you can specify any supported language. Check supported language by:

In [12]:
whisper.print_support_language()

language code: en : english
language code: zh : chinese
language code: de : german
language code: es : spanish
language code: ru : russian
language code: ko : korean
language code: fr : french
language code: ja : japanese
language code: pt : portuguese
language code: tr : turkish
language code: pl : polish
language code: ca : catalan
language code: nl : dutch
language code: ar : arabic
language code: sv : swedish
language code: it : italian
language code: id : indonesian
language code: hi : hindi
language code: fi : finnish
language code: vi : vietnamese
language code: he : hebrew
language code: uk : ukrainian
language code: el : greek
language code: ms : malay
language code: cs : czech
language code: ro : romanian
language code: da : danish
language code: hu : hungarian
language code: ta : tamil
language code: no : norwegian
language code: th : thai
language code: ur : urdu
language code: hr : croatian
language code: bg : bulgarian
language code: lt : lithuanian
language code: mi : ma

Let's say we want the output labels in Chinese (zh):

In [13]:
audio_tag_result = whisper.parse_at_label(result, language='zh', top_k=5, p_threshold=-2, include_class_list=list(range(527)))
for segment in audio_tag_result:
  print(segment)

{'time': {'start': 0, 'end': 10}, 'audio tags': [('音乐之声', 1.8221180438995361), ('说话的声音', 0.932380735874176)]}
{'time': {'start': 10, 'end': 20}, 'audio tags': [('音乐之声', 1.3558011054992676), ('垃圾的声音', -1.350265622138977), ('进步摇滚的声音', -1.424497127532959), ('朋克摇滚之声', -1.5711394548416138)]}
{'time': {'start': 20, 'end': 30}, 'audio tags': [('音乐之声', 0.8049014806747437)]}
{'time': {'start': 30, 'end': 40}, 'audio tags': [('音乐之声', 1.323777437210083), ('垃圾的声音', -1.3895658254623413), ('进步摇滚的声音', -1.614871859550476), ('朋克摇滚之声', -1.6720657348632812), ('独立音乐之声', -1.8473705053329468)]}
{'time': {'start': 40, 'end': 50}, 'audio tags': [('音乐之声', 0.7121016383171082), ('歌声', -1.492790699005127)]}
{'time': {'start': 50, 'end': 60}, 'audio tags': [('音乐之声', -1.4463484287261963)]}
{'time': {'start': 60, 'end': 70}, 'audio tags': [('音乐之声', 1.2629746198654175), ('垃圾的声音', -0.9662912487983704), ('流行音乐的声音', -1.6864861249923706), ('独立音乐之声', -1.7853952646255493)]}
{'time': {'start': 70, 'end': 80}, 'audio tags': 

### Step 4. Dubbing a video

Let's check the result! The above audio track is actually from a video. You can of course generate .srt. But in this example, we directly put text and audio transcriptions to the video.

**Step 4 is independent from the above, replace the URL and play with your own video!**

In [14]:
from IPython.display import HTML
from base64 import b64encode
# Replace this URL to play with your own video
wget.download('https://www.dropbox.com/s/pzc72c59xtluuc0/case_closed.mp4?dl=1', '/content/sample_video.mp4')
# mp4 = open('/content/sample_video.mp4','rb').read()
# data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
# HTML("""
# <video width=800 controls>
#       <source src="%s" type="video/mp4">
# </video>
# """ % data_url)

'/content/sample_video.mp4'

In [15]:
!pip install -q ffmpeg-python
import os,ffmpeg,cv2

def dubbing_video(video_path, out_video_path, text_info, font_size=0.5, font_v_pos=0.95, font_color=(0, 0, 255)):
    extract_audio(video_path, './temp_audio.wav')

    video = cv2.VideoCapture(video_path)
    # Get video properties
    frame_width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = video.get(cv2.CAP_PROP_FPS)

    # Create output video writer
    output_video = cv2.VideoWriter('./temp_video.mp4', cv2.VideoWriter_fourcc(*"mp4v"), fps,
                                   (frame_width, frame_height))

    # Process each frame of the video
    current_frame = 0
    while video.isOpened():
        ret, frame = video.read()
        if not ret:
            break

        # Calculate current time in seconds
        current_time = current_frame / fps

        # Iterate through text information and add text if within the time interval
        for text_start, text_end, text in text_info:
            if text_start <= current_time <= text_end:
                text_position = (int(frame_width * 0.0), int(frame_height * font_v_pos))
                font = cv2.FONT_HERSHEY_SIMPLEX
                font_scale = font_size
                font_color = font_color
                line_type = 1

                cv2.putText(frame, text, text_position, font, font_scale, font_color, line_type, cv2.LINE_AA)

        # Write the frame to the output video
        output_video.write(frame)
        current_frame += 1

    # Release video resources
    video.release()
    output_video.release()

    combine_audio_video('./temp_video.mp4', './temp_audio.wav', out_video_path)
    os.remove('./temp_video.mp4')
    os.remove('./temp_audio.wav')

def combine_audio_video(video_path, audio_path, output_path):
    video = ffmpeg.input(video_path)
    audio = ffmpeg.input(audio_path)
    output_file = ffmpeg.output(video, audio, output_path)
    output_file.overwrite_output().run()

def extract_audio(video_path, output_path):
    video = ffmpeg.input(video_path)
    audio = video.audio
    output_file = ffmpeg.output(audio, output_path)
    output_file.overwrite_output().run()

extract_audio('/content/sample_video.mp4', '/content/sample_audio_from_video.wav')
result = model.transcribe("/content/sample_audio_from_video.wav", at_time_res=audio_tagging_time_resolution)

# ASR Output
text_segments = result['segments']
text_annotation = [(x['start'], x['end'], x['text']) for x in text_segments]

# Audio Tagging Output
audio_tag_result = whisper.parse_at_label(result, language='follow_asr', top_k=5, p_threshold=-2, include_class_list=list(range(527)))

all_seg = []
for segment in audio_tag_result:
    cur_start = segment['time']['start']
    cur_end = segment['time']['end']
    cur_tags = segment['audio tags']
    cur_tags = [x[0] for x in cur_tags]
    cur_tags = '; '.join(cur_tags)
    all_seg.append((cur_start, cur_end, cur_tags))

dubbing_video('/content/sample_video.mp4', '/content/sample_video_at.mp4', all_seg)
dubbing_video('/content/sample_video_at.mp4', '/content/sample_video_at_text.mp4', text_annotation, font_color=(0,255,0), font_v_pos=0.85)
os.remove('/content/sample_video.mp4')
os.remove('/content/sample_video_at.mp4')

# mp4 = open('/content/sample_video_at_text.mp4','rb').read()
# data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
# HTML("""
# <video width=800 controls>
#       <source src="%s" type="video/mp4">
# </video>
# """ % data_url)


That's all. If you like the project, give us a star at https://github.com/YuanGongND/whisper-at.