In [32]:
import csv
import warnings
warnings.filterwarnings("ignore")

from hear21passt.base import load_model
import pandas as pd
import torch

from src.datasets import SSPNetVC


DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
THRESHOLD = 0.2

In [23]:
def get_labels(model, wav, labels):
    result = set()
    window_size = 32_000 * 10 # sample rate multiplied by max length
    hop_size = int(32_000 * 2.5) # 2.5 seconds hop size
    for start_pos in range(0, len(wav) - window_size + hop_size, hop_size):
        result.update(
            get_labels_short_wav(model, wav[start_pos:start_pos + window_size], labels)
        )
    return result


def get_labels_short_wav(model, wav, labels):
    with torch.no_grad():
        probs = torch.sigmoid(model(wav[None, :]).squeeze())
    cls_indices = torch.nonzero(probs > THRESHOLD).cpu()
    return [labels.iloc[i.item()].display_name for i in cls_indices]


labels_csv = pd.read_csv('data/audioset/class_labels_indices.csv')

In [24]:
model = load_model(mode='logits').eval().to(DEVICE)



 Loading PASST TRAINED ON AUDISET 


PaSST(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(1, 768, kernel_size=(16, 16), stride=(10, 10))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (1): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwi

In [None]:
ds = SSPNetVC()

In [27]:
data = {'Name': [], 'Labels': []}
for i in range(len(ds)):
    name, wav, labels = ds[i]

    all_labels = get_labels(model, wav.to(DEVICE), labels_csv)
    all_labels.update(labels.split(','))
    data['Name'].append(name)
    data['Labels'].append('|'.join(sorted(all_labels)))

In [34]:
pd.DataFrame.from_dict(data).to_csv('data/ssp/all_labels.csv', index=False, quoting=csv.QUOTE_ALL)