In [1]:
import pandas as pd
import torch
import torch.nn.functional as F
import torchaudio

from hear21passt.base import load_model

In [2]:
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
INITIAL_SR = 48_000
TARGET_SR = 32_000
THRESHOLD = 0.2

In [3]:
model  = load_model(mode='logits').to(DEVICE).eval()



 Loading PASST TRAINED ON AUDISET 


PaSST(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(1, 768, kernel_size=(16, 16), stride=(10, 10))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (1): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwi

In [4]:
resampler = torchaudio.transforms.Resample(INITIAL_SR, TARGET_SR)
_ = resampler.eval()

def preprocess(wav):
    assert wav.dim() == 2
    wav = wav.mean(dim=0, keepdim=True)
    wav = resampler(wav)
    return wav

In [5]:
labels = pd.read_csv('data/audioset/class_labels_indices.csv')
labels.head()

Unnamed: 0,index,mid,display_name
0,0,/m/09x0r,Speech
1,1,/m/05zppz,"Male speech, man speaking"
2,2,/m/02zsn,"Female speech, woman speaking"
3,3,/m/0ytgt,"Child speech, kid speaking"
4,4,/m/01h8n0,Conversation


In [6]:
def get_labels(model, wav, labels):
    with torch.no_grad():
        probs = torch.sigmoid(model(wav).squeeze())
    cls_indices = torch.nonzero(probs > THRESHOLD).cpu()
    return [labels.iloc[i.item()].display_name for i in cls_indices]

In [8]:
helicopter, sr = torchaudio.load('data/audioset/test_wavs/helicopter.wav') # helicopter, vehicle, aircraft
assert sr == INITIAL_SR
helicopter = preprocess(helicopter).to(DEVICE)
print('helicopter:', get_labels(model, helicopter, labels))

music, sr = torchaudio.load('data/audioset/test_wavs/music.wav') # music, radio
assert sr == INITIAL_SR
music = preprocess(music).to(DEVICE)
print('music:', get_labels(model, music, labels))

barking, sr = torchaudio.load('data/audioset/test_wavs/barking.wav') # canidae, dogs, wolves, bark, domestic animals, pets, bow-wow, dog, growling, animal
assert sr == INITIAL_SR
barking = preprocess(barking).to(DEVICE)
print('barking:', get_labels(model, barking, labels))

water, sr = torchaudio.load('data/audioset/test_wavs/water.wav') # pump(liquid), water
assert sr == INITIAL_SR
water = preprocess(water).to(DEVICE)
print('water:', get_labels(model, water, labels))

helicopter: ['Vehicle', 'Aircraft', 'Aircraft engine', 'Propeller, airscrew', 'Helicopter', 'Fixed-wing aircraft, airplane']
music: ['Music']
barking: ['Animal', 'Domestic animals, pets', 'Dog', 'Bark', 'Growling', 'Canidae, dogs, wolves']
water: []


In [12]:
rspeech, sr = torchaudio.load('data/audioset/test_wavs/russian_speech.wav')
assert sr == INITIAL_SR
rspeech = preprocess(rspeech).to(DEVICE)
print('russian_speech:', get_labels(model, rspeech, labels))

rspeech, sr = torchaudio.load('data/audioset/test_wavs/russian_speech2.wav')
assert sr == INITIAL_SR
rspeech = preprocess(rspeech).to(DEVICE)
print('russian_speech2:', get_labels(model, rspeech, labels))

laugh, sr = torchaudio.load('data/audioset/test_wavs/laugh_wspeech.wav')
assert sr == INITIAL_SR
laugh = preprocess(laugh).to(DEVICE)
print('laugh with speech:', get_labels(model, laugh, labels))

laugh, sr = torchaudio.load('data/audioset/test_wavs/laugh.wav')
assert sr == INITIAL_SR
laugh = preprocess(laugh).to(DEVICE)
print('laugh:', get_labels(model, laugh, labels))

russian_speech: ['Speech']
russian_speech2: ['Speech']
laugh with speech: ['Speech', 'Snicker', 'Inside, small room']
laugh: ['Speech', 'Laughter', 'Snicker']
