Reference

- [AudioSet](https://research.google.com/audioset/)

More model ideas

- Other (smaller) versions of AST

- https://huggingface.co/topel/ConvNeXt-Tiny-AT

- https://huggingface.co/search/full-text?q=audioset&p=1&type=model

- https://paperswithcode.com/paper/efficient-large-scale-audio-tagging-via

- https://paperswithcode.com/paper/dynamic-convolutional-neural-networks-as

- https://paperswithcode.com/paper/panns-large-scale-pretrained-audio-neural-1

# Imports, installs, etc.

In [1]:
!pip install -qq transformers

In [2]:
import requests
import sys
import time

import numpy as np

from tqdm.notebook import tqdm

import torch
import torchaudio

In [3]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

cpu


In [4]:
sys.path.append('./drive/MyDrive/Projects/MiniSoundFinder_v2/model/library/')

import event_finder, preprocess

In [5]:
# from importlib import reload
# reload(event_finder)

# Samples

In [6]:
!cp ./drive/MyDrive/Projects/MiniSoundFinder_v2/samples/* .

In [7]:
sample_path = 'freesound_442485_dogs_barking_60sec.wav'
print(torchaudio.info(sample_path))

AudioMetaData(sample_rate=48000, num_frames=2847537, num_channels=2, bits_per_sample=24, encoding=PCM_S)


In [8]:
waveform, sampling_rate = torchaudio.load(sample_path)
waveform.shape

torch.Size([2, 2847537])

In [9]:
wf_prep = preprocess.convert_audio(waveform, sampling_rate)
wf_prep.shape

torch.Size([949179])

# Models

## AST

- [AST on HuggingFace](https://huggingface.co/MIT/ast-finetuned-audioset-10-10-0.4593)

- [AST Paper](https://arxiv.org/pdf/2104.01778.pdf)

In [10]:
from transformers import AutoFeatureExtractor, ASTForAudioClassification

In [11]:
extractor_ast = AutoFeatureExtractor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
extractor_ast

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


ASTFeatureExtractor {
  "do_normalize": true,
  "feature_extractor_type": "ASTFeatureExtractor",
  "feature_size": 1,
  "max_length": 1024,
  "mean": -4.2677393,
  "num_mel_bins": 128,
  "padding_side": "right",
  "padding_value": 0.0,
  "return_attention_mask": false,
  "sampling_rate": 16000,
  "std": 4.5689974
}

In [12]:
model_ast = ASTForAudioClassification.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593").to(DEVICE)
model_ast

ASTForAudioClassification(
  (audio_spectrogram_transformer): ASTModel(
    (embeddings): ASTEmbeddings(
      (patch_embeddings): ASTPatchEmbeddings(
        (projection): Conv2d(1, 768, kernel_size=(16, 16), stride=(10, 10))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ASTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ASTLayer(
          (attention): ASTSdpaAttention(
            (attention): ASTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ASTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ASTIntermediate(
       

In [13]:
target_sampling_rate = extractor_ast.sampling_rate
wf_prep = preprocess.convert_audio(waveform, sampling_rate, sampling_rate=target_sampling_rate)
sample_features = extractor_ast(wf_prep, target_sampling_rate, return_tensors="pt").to(DEVICE)
sample_features['input_values'].shape

torch.Size([1, 1024, 128])

In [14]:
with torch.no_grad():
    probs = torch.sigmoid(model_ast(**sample_features).logits)

In [15]:
top_classes = torch.argsort(probs, dim=-1, descending=True).flatten()[:10]
top_labels = [(model_ast.config.id2label[id.item()], probs[0, id].item()) for id in top_classes]
top_labels

[('Dog', 0.7794262766838074),
 ('Animal', 0.741685152053833),
 ('Domestic animals, pets', 0.6783856153488159),
 ('Bark', 0.6013586521148682),
 ('Bow-wow', 0.43776294589042664),
 ('Canidae, dogs, wolves', 0.19103160500526428),
 ('Yip', 0.10324634611606598),
 ('Whimper (dog)', 0.06323525309562683),
 ('Vehicle', 0.025787796825170517),
 ('Growling', 0.01947779580950737)]

In [16]:
def measure_inference_time_ast(model, feature_extractor,
                               sample_length_sec=60,
                               repeats=10,
                               chunk_length_sec=10):

    sampling_rate = feature_extractor.sampling_rate
    sample_length = sampling_rate * sample_length_sec

    extr_times = []
    inf_times = []
    for i in tqdm(range(repeats)):
        extr_start = time.time()
        wf = torch.distributions.uniform.Uniform(-10000, 10000).sample((sample_length,))
        chunks = event_finder.chunk_audio(wf, sampling_rate, chunk_length_sec)
        inp = feature_extractor(chunks, sampling_rate, return_tensors="pt").to(DEVICE)
        extr_times.append(time.time() - extr_start)

        inf_start = time.time()
        with torch.no_grad():
            probs = torch.sigmoid(model(**inp).logits)
        inf_times.append(time.time() - inf_start)

    print("Extraction:", np.mean(extr_times), "±", np.std(extr_times))
    print("Inference:", np.mean(inf_times), "±", np.std(inf_times))

# print("1 minute")
# measure_inference_time_ast(model_ast, extractor_ast, sample_length_sec=60)
# print()

# print("2 minutes")
# measure_inference_time_ast(model_ast, extractor_ast, sample_length_sec=120)
# print()

# print("5 minutes")
# measure_inference_time_ast(model_ast, extractor_ast, sample_length_sec=300)
# print()

## AST Distilled

https://huggingface.co/bookbot/distil-ast-audioset

In [17]:
extractor_ast_distil = AutoFeatureExtractor.from_pretrained("bookbot/distil-ast-audioset")
extractor_ast_distil

ASTFeatureExtractor {
  "do_normalize": true,
  "feature_extractor_type": "ASTFeatureExtractor",
  "feature_size": 1,
  "max_length": 1024,
  "mean": -4.2677393,
  "num_mel_bins": 128,
  "padding_side": "right",
  "padding_value": 0.0,
  "return_attention_mask": false,
  "sampling_rate": 16000,
  "std": 4.5689974
}

In [18]:
model_ast_distil = ASTForAudioClassification.from_pretrained("bookbot/distil-ast-audioset").to(DEVICE)
model_ast_distil

ASTForAudioClassification(
  (audio_spectrogram_transformer): ASTModel(
    (embeddings): ASTEmbeddings(
      (patch_embeddings): ASTPatchEmbeddings(
        (projection): Conv2d(1, 768, kernel_size=(16, 16), stride=(10, 10))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ASTEncoder(
      (layer): ModuleList(
        (0-5): 6 x ASTLayer(
          (attention): ASTSdpaAttention(
            (attention): ASTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ASTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ASTIntermediate(
         

In [19]:
target_sampling_rate = extractor_ast_distil.sampling_rate
wf_prep = preprocess.convert_audio(waveform, sampling_rate, sampling_rate=target_sampling_rate)
sample_features = extractor_ast_distil(wf_prep, target_sampling_rate, return_tensors="pt").to(DEVICE)

with torch.no_grad():
    probs = torch.sigmoid(model_ast_distil(**sample_features).logits)

top_classes = torch.argsort(probs, dim=-1, descending=True).flatten()[:10]
top_labels = [(model_ast.config.id2label[id.item()], probs[0, id].item()) for id in top_classes]
top_labels

[('Animal', 0.8284608125686646),
 ('Dog', 0.7962832450866699),
 ('Domestic animals, pets', 0.7366082072257996),
 ('Bark', 0.5144532322883606),
 ('Bow-wow', 0.4676070213317871),
 ('Speech', 0.3647533059120178),
 ('Canidae, dogs, wolves', 0.1839541494846344),
 ('Yip', 0.16342568397521973),
 ('Whimper (dog)', 0.1528146117925644),
 ('Growling', 0.05686212331056595)]

In [20]:
# print("1 minute")
# measure_inference_time_ast(model_ast_distil, extractor_ast_distil, sample_length_sec=60)
# print()

# print("2 minutes")
# measure_inference_time_ast(model_ast_distil, extractor_ast_distil, sample_length_sec=120)
# print()

# print("5 minutes")
# measure_inference_time_ast(model_ast_distil, extractor_ast_distil, sample_length_sec=300)
# print()

# Finding events

In [21]:
# NOTE: library usage here is outdated

model = model_ast
extractor = extractor_ast
chunk_length_sec = 10
sampling_rate = extractor.sampling_rate

finder = event_finder.EventFinder(
    model.config,
    chunk_length_sec=chunk_length_sec,
    probability_threshold=0.2)

def find_events(audio_path):
    waveform_raw, source_sampling_rate = torchaudio.load(audio_path)
    waveform = preprocess.convert_audio(waveform_raw, source_sampling_rate,
                                        channels="mono", sampling_rate=sampling_rate)
    chunks = event_finder.chunk_audio(waveform, sampling_rate, chunk_length_sec=chunk_length_sec)

    features = extractor(chunks, sampling_rate, return_tensors="pt").to(DEVICE)

    with torch.no_grad():
        probs = torch.sigmoid(model(**features).logits)

    events = finder(probs)
    return probs, events

In [22]:
# probs, events = find_events('/content/freesound_442485_dogs_barking_60sec.wav')
# events

In [23]:
# probs, events = find_events('/content/freesound_471408_birds_90sec.wav')
# events

In [24]:
probs, events = find_events('/content/recorded_street_150sec.wav')
events

[('Vehicle', 10, 20),
 ('Speech', 10, 80),
 ('Music', 20, 30),
 ('Vehicle', 40, 90),
 ('Music', 80, 100),
 ('Speech', 90, 150)]

In [25]:
# probs, top_labels = find_events('/content/freesound_442485_dogs_barking_60sec.wav')
# top_classes = torch.argsort(probs, dim=-1, descending=True)[:, :5]
# for i in range(top_classes.shape[0]):
#     print([(model.config.id2label[id.item()], id.item(), probs[i, id].item()) for id in top_classes[i]])

# AudioSet Antology

In [26]:
!curl https://raw.githubusercontent.com/audioset/ontology/master/ontology.json > ontology.json

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  334k  100  334k    0     0  1074k      0 --:--:-- --:--:-- --:--:-- 1076k


In [27]:
ROOT_NAME = '<root>'

def build_ontology_tree(ontology_data):
    id_to_obj = {}
    for el in ontology_data:
        id_to_obj[el['id']] = el

    name_to_node = {}
    for el in ontology_data:
        name_to_node[el['name']] = {
            'children': [id_to_obj[id]['name'] for id in el['child_ids']],
            'abstract': 'abstract' in el['restrictions'],
            'blacklist': 'blacklist' in el['restrictions'],
        }

    top_level = set(name_to_node.keys())
    for node in name_to_node.values():
        top_level -= set(node['children'])

    name_to_node['<root>'] = {
        'children': list(top_level),
        'abstract': True,
        'blacklist': False,
    }

    return name_to_node

In [28]:
def remove_recursive(nodes, parent_name, name):
    for child_name in nodes[name]['children']:
        remove_recursive(nodes, name, child_name)
    nodes[parent_name]['children'].remove(name)
    del nodes[name]

def manual_fix_nodes(nodes):
    nodes['Sounds of things']['abstract'] = True
    nodes['Sounds of things']['blacklist'] = False

    nodes['Natural sounds']['abstract'] = True
    nodes['Natural sounds']['blacklist'] = False

    nodes['Microphone']['abstract'] = True
    nodes['Microphone']['blacklist'] = False

    nodes['Domestic sounds, home sounds']['abstract'] = True
    nodes['Domestic sounds, home sounds']['blacklist'] = False

    nodes['Non-motorized land vehicle']['abstract'] = True

In [29]:
import json

with open('ontology.json') as f:
    ontology_raw = json.load(f)

nodes = build_ontology_tree(ontology_raw)

manual_fix_nodes(nodes)

remove_recursive(nodes, ROOT_NAME, 'Channel, environment and background')

len(nodes)

621

In [30]:
model_classes = set(model.config.label2id.keys())
len(model_classes)

527

In [31]:
def add_not_in_model_property(nodes, model_classes):
    for name, node in nodes.items():
        if name in model_classes:
            node['not_in_model'] = False
        else:
            node['not_in_model'] = True

add_not_in_model_property(nodes, model_classes)

In [32]:
def check_blacklist_subtrees(nodes):
    def check_recursive(name):
        node = nodes[name]
        all_children = True
        for child_name in node['children']:
            if not check_recursive(child_name):
                all_children = False

        if node['blacklist'] and all_children:
            node['blacklist_subtree'] = True
        else:
            if node['blacklist'] and node['children'] and not all_children:
                print("Warning: '" + name + "' - blacklist but not blacklist subtree")
            node['blacklist_subtree'] = False
        return node['blacklist_subtree']

    check_recursive(ROOT_NAME)

check_blacklist_subtrees(nodes)

In [33]:
def check_not_in_model_subtrees(nodes):
    def check_recursive(name):
        node = nodes[name]
        all_children = True
        for child_name in node['children']:
            if not check_recursive(child_name):
                all_children = False

        if node['not_in_model'] and all_children:
            node['not_in_model_subtree'] = True
        else:
            if node['not_in_model'] and not node['abstract'] and node['children'] and not all_children:
                print("Warning: '" + name + "' - non-abstract not in model with children in model")
            node['not_in_model_subtree'] = False
        return node['not_in_model_subtree']

    check_recursive(ROOT_NAME)

check_not_in_model_subtrees(nodes)

In [35]:
def remove_by_property(nodes, property_name):
    def remove_by_property_recursive(name):
        for child_name in list(nodes[name]['children']):  # copy for correctness
            if child_name not in nodes:  # duplicate reference, already deleted
                nodes[name]['children'].remove(child_name)
            else:
                if nodes[child_name][property_name]:
                    remove_recursive(nodes, name, child_name)
                else:
                    remove_by_property_recursive(child_name)

    remove_by_property_recursive(ROOT_NAME)

remove_by_property(nodes, 'blacklist_subtree')
remove_by_property(nodes, 'not_in_model_subtree')

In [36]:
len(nodes)

547

In [44]:
def pretty_print_tree(nodes, label2id, probs, prob_cutoff=0.0):
    def pp_recursive(name, indent):
        if not nodes[name]['abstract'] and probs[label2id[name]].item() < prob_cutoff:
            return

        if nodes[name]['abstract']:
            prob_str = " (abstract)"
        else:
            prob_str = " ({:.3%})".format(probs[label2id[name]].item())

        print(indent + name + prob_str)
        for child_name in nodes[name]['children']:
            pp_recursive(child_name, indent + '- ')

    for child_name in nodes[ROOT_NAME]['children']:
        pp_recursive(child_name, indent='')

pretty_print_tree(nodes, model.config.label2id, probs[1], prob_cutoff=0.01)

Sounds of things (abstract)
- Vehicle (40.364%)
- - Motor vehicle (road) (15.284%)
- - - Car (24.684%)
- - - - Car passing by (3.041%)
- - - - Race car, auto racing (1.199%)
- - - Truck (1.410%)
- - - - Air brake (1.177%)
- - - Bus (15.965%)
- - - Traffic noise, roadway noise (48.065%)
- - Non-motorized land vehicle (abstract)
- Domestic sounds, home sounds (abstract)
- Miscellaneous sources (abstract)
- - Sound equipment (abstract)
- - - Microphone (abstract)
- Specific impact sounds (abstract)
Human sounds (abstract)
- Human voice (abstract)
- - Speech (38.632%)
- Respiratory sounds (abstract)
- Human locomotion (abstract)
- Digestive (abstract)
- Human group actions (abstract)
Natural sounds (abstract)
Source-ambiguous sounds (abstract)
- Generic impact sounds (abstract)
- Surface contact (abstract)
- Deformable shell (abstract)
- Onomatopoeia (abstract)
- - Brief tone (abstract)
- Other sourceless (abstract)
Animal (1.715%)
Music (3.624%)
- Music genre (abstract)
- Music role (abst