Reference

- [AudioSet](https://research.google.com/audioset/)

More model ideas

- Other (smaller) versions of AST

- https://huggingface.co/topel/ConvNeXt-Tiny-AT

- https://huggingface.co/search/full-text?q=audioset&p=1&type=model

- https://paperswithcode.com/paper/efficient-large-scale-audio-tagging-via

- https://paperswithcode.com/paper/dynamic-convolutional-neural-networks-as

- https://paperswithcode.com/paper/panns-large-scale-pretrained-audio-neural-1

# Imports, installs, etc.

In [None]:
!pip install -qq transformers

In [None]:
import requests
import sys
import time

import numpy as np

from tqdm.notebook import tqdm

import torch
import torchaudio

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

In [None]:
sys.path.append('./drive/MyDrive/Projects/MiniSoundFinder_v2/model/library/')

import event_finder, preprocess

In [None]:
# from importlib import reload
# reload(event_finder)

# Samples

In [None]:
!cp ./drive/MyDrive/Projects/MiniSoundFinder_v2/samples/* .

In [None]:
sample_path = 'freesound_442485_dogs_barking_60sec.wav'
print(torchaudio.info(sample_path))

In [None]:
waveform, sampling_rate = torchaudio.load(sample_path)
waveform.shape

In [None]:
wf_prep = preprocess.convert_audio(waveform, sampling_rate)
wf_prep.shape

# Models

## AST

- [AST on HuggingFace](https://huggingface.co/MIT/ast-finetuned-audioset-10-10-0.4593)

- [AST Paper](https://arxiv.org/pdf/2104.01778.pdf)

In [None]:
from transformers import AutoFeatureExtractor, ASTForAudioClassification

In [None]:
extractor_ast = AutoFeatureExtractor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
extractor_ast

In [None]:
model_ast = ASTForAudioClassification.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593").to(DEVICE)
model_ast

In [None]:
target_sampling_rate = extractor_ast.sampling_rate
wf_prep = preprocess.convert_audio(waveform, sampling_rate, sampling_rate=target_sampling_rate)
sample_features = extractor_ast(wf_prep, target_sampling_rate, return_tensors="pt").to(DEVICE)
sample_features['input_values'].shape

In [None]:
with torch.no_grad():
    probs = torch.sigmoid(model_ast(**sample_features).logits)

In [None]:
top_classes = torch.argsort(probs, dim=-1, descending=True).flatten()[:10]
top_labels = [(model_ast.config.id2label[id.item()], probs[0, id].item()) for id in top_classes]
top_labels

In [None]:
def measure_inference_time_ast(model, feature_extractor,
                               sample_length_sec=60,
                               repeats=10,
                               chunk_length_sec=10):

    sampling_rate = feature_extractor.sampling_rate
    sample_length = sampling_rate * sample_length_sec

    extr_times = []
    inf_times = []
    for i in tqdm(range(repeats)):
        extr_start = time.time()
        wf = torch.distributions.uniform.Uniform(-10000, 10000).sample((sample_length,))
        chunks = event_finder.chunk_audio(wf, sampling_rate, chunk_length_sec)
        inp = feature_extractor(chunks, sampling_rate, return_tensors="pt").to(DEVICE)
        extr_times.append(time.time() - extr_start)

        inf_start = time.time()
        with torch.no_grad():
            probs = torch.sigmoid(model(**inp).logits)
        inf_times.append(time.time() - inf_start)

    print("Extraction:", np.mean(extr_times), "±", np.std(extr_times))
    print("Inference:", np.mean(inf_times), "±", np.std(inf_times))

# print("1 minute")
# measure_inference_time_ast(model_ast, extractor_ast, sample_length_sec=60)
# print()

# print("2 minutes")
# measure_inference_time_ast(model_ast, extractor_ast, sample_length_sec=120)
# print()

# print("5 minutes")
# measure_inference_time_ast(model_ast, extractor_ast, sample_length_sec=300)
# print()

## AST Distilled

https://huggingface.co/bookbot/distil-ast-audioset

In [None]:
extractor_ast_distil = AutoFeatureExtractor.from_pretrained("bookbot/distil-ast-audioset")
extractor_ast_distil

In [None]:
model_ast_distil = ASTForAudioClassification.from_pretrained("bookbot/distil-ast-audioset").to(DEVICE)
model_ast_distil

In [None]:
target_sampling_rate = extractor_ast_distil.sampling_rate
wf_prep = preprocess.convert_audio(waveform, sampling_rate, sampling_rate=target_sampling_rate)
sample_features = extractor_ast_distil(wf_prep, target_sampling_rate, return_tensors="pt").to(DEVICE)

with torch.no_grad():
    probs = torch.sigmoid(model_ast_distil(**sample_features).logits)

top_classes = torch.argsort(probs, dim=-1, descending=True).flatten()[:10]
top_labels = [(model_ast.config.id2label[id.item()], probs[0, id].item()) for id in top_classes]
top_labels

In [None]:
# print("1 minute")
# measure_inference_time_ast(model_ast_distil, extractor_ast_distil, sample_length_sec=60)
# print()

# print("2 minutes")
# measure_inference_time_ast(model_ast_distil, extractor_ast_distil, sample_length_sec=120)
# print()

# print("5 minutes")
# measure_inference_time_ast(model_ast_distil, extractor_ast_distil, sample_length_sec=300)
# print()

# Finding events

In [None]:
# NOTE: library usage here is outdated

model = model_ast
extractor = extractor_ast
chunk_length_sec = 10
sampling_rate = extractor.sampling_rate

finder = event_finder.EventFinder(
    model.config,
    chunk_length_sec=chunk_length_sec,
    probability_threshold=0.2)

def find_events(audio_path):
    waveform_raw, source_sampling_rate = torchaudio.load(audio_path)
    waveform = preprocess.convert_audio(waveform_raw, source_sampling_rate,
                                        channels="mono", sampling_rate=sampling_rate)
    chunks = event_finder.chunk_audio(waveform, sampling_rate, chunk_length_sec=chunk_length_sec)

    features = extractor(chunks, sampling_rate, return_tensors="pt").to(DEVICE)

    with torch.no_grad():
        probs = torch.sigmoid(model(**features).logits)

    events = finder(probs)
    return probs, events

In [None]:
# probs, events = find_events('/content/freesound_442485_dogs_barking_60sec.wav')
# events

In [None]:
# probs, events = find_events('/content/freesound_471408_birds_90sec.wav')
# events

In [None]:
probs, events = find_events('/content/recorded_street_150sec.wav')
events

In [None]:
# probs, top_labels = find_events('/content/freesound_442485_dogs_barking_60sec.wav')
# top_classes = torch.argsort(probs, dim=-1, descending=True)[:, :5]
# for i in range(top_classes.shape[0]):
#     print([(model.config.id2label[id.item()], id.item(), probs[i, id].item()) for id in top_classes[i]])

# AudioSet Antology

In [None]:
!curl https://raw.githubusercontent.com/audioset/ontology/master/ontology.json > ontology.json

In [None]:
ROOT_NAME = '<root>'

def build_ontology_tree(ontology_data):
    id_to_obj = {}
    for el in ontology_data:
        id_to_obj[el['id']] = el

    name_to_node = {}
    for el in ontology_data:
        name_to_node[el['name']] = {
            'children': [id_to_obj[id]['name'] for id in el['child_ids']],
            'abstract': 'abstract' in el['restrictions'],
            'blacklist': 'blacklist' in el['restrictions'],
        }

    top_level = set(name_to_node.keys())
    for node in name_to_node.values():
        top_level -= set(node['children'])

    name_to_node['<root>'] = {
        'children': list(top_level),
        'abstract': True,
        'blacklist': False,
    }

    return name_to_node

In [None]:
def remove_recursive(nodes, parent_name, name):
    for child_name in nodes[name]['children']:
        remove_recursive(nodes, name, child_name)
    nodes[parent_name]['children'].remove(name)
    del nodes[name]

def manual_fix_nodes(nodes):
    nodes['Sounds of things']['abstract'] = True
    nodes['Sounds of things']['blacklist'] = False

    nodes['Natural sounds']['abstract'] = True
    nodes['Natural sounds']['blacklist'] = False

    nodes['Microphone']['abstract'] = True
    nodes['Microphone']['blacklist'] = False

    nodes['Domestic sounds, home sounds']['abstract'] = True
    nodes['Domestic sounds, home sounds']['blacklist'] = False

    nodes['Non-motorized land vehicle']['abstract'] = True

In [None]:
import json

with open('ontology.json') as f:
    ontology_raw = json.load(f)

nodes = build_ontology_tree(ontology_raw)

manual_fix_nodes(nodes)

remove_recursive(nodes, ROOT_NAME, 'Channel, environment and background')

len(nodes)

In [None]:
model_classes = set(model.config.label2id.keys())
len(model_classes)

In [None]:
def add_not_in_model_property(nodes, model_classes):
    for name, node in nodes.items():
        if name in model_classes:
            node['not_in_model'] = False
        else:
            node['not_in_model'] = True

add_not_in_model_property(nodes, model_classes)

In [None]:
def check_blacklist_subtrees(nodes):
    def check_recursive(name):
        node = nodes[name]
        all_children = True
        for child_name in node['children']:
            if not check_recursive(child_name):
                all_children = False

        if node['blacklist'] and all_children:
            node['blacklist_subtree'] = True
        else:
            if node['blacklist'] and node['children'] and not all_children:
                print("Warning: '" + name + "' - blacklist but not blacklist subtree")
            node['blacklist_subtree'] = False
        return node['blacklist_subtree']

    check_recursive(ROOT_NAME)

check_blacklist_subtrees(nodes)

In [None]:
def check_not_in_model_subtrees(nodes):
    def check_recursive(name):
        node = nodes[name]
        all_children = True
        for child_name in node['children']:
            if not check_recursive(child_name):
                all_children = False

        if node['not_in_model'] and all_children:
            node['not_in_model_subtree'] = True
        else:
            if node['not_in_model'] and not node['abstract'] and node['children'] and not all_children:
                print("Warning: '" + name + "' - non-abstract not in model with children in model")
            node['not_in_model_subtree'] = False
        return node['not_in_model_subtree']

    check_recursive(ROOT_NAME)

check_not_in_model_subtrees(nodes)

In [None]:
def remove_by_property(nodes, property_name):
    def remove_by_property_recursive(name):
        for child_name in list(nodes[name]['children']):  # copy for correctness
            if child_name not in nodes:  # duplicate reference, already deleted
                nodes[name]['children'].remove(child_name)
            else:
                if nodes[child_name][property_name]:
                    remove_recursive(nodes, name, child_name)
                else:
                    remove_by_property_recursive(child_name)

    remove_by_property_recursive(ROOT_NAME)

remove_by_property(nodes, 'blacklist_subtree')
remove_by_property(nodes, 'not_in_model_subtree')

In [None]:
len(nodes)

In [None]:
def pretty_print_tree(nodes, label2id, probs, prob_cutoff=0.0):
    def pp_recursive(name, indent):
        if not nodes[name]['abstract'] and probs[label2id[name]].item() < prob_cutoff:
            return

        if nodes[name]['abstract']:
            prob_str = " (abstract)"
        else:
            prob_str = " ({:.3%})".format(probs[label2id[name]].item())

        print(indent + name + prob_str)
        for child_name in nodes[name]['children']:
            pp_recursive(child_name, indent + '- ')

    for child_name in nodes[ROOT_NAME]['children']:
        pp_recursive(child_name, indent='')

pretty_print_tree(nodes, model.config.label2id, probs[1], prob_cutoff=0.01)