In [2]:
from transformers import pipeline
from scripts.utils import get_dataset
import pandas as pd
from tqdm.notebook import tqdm
import numpy as np

In [3]:
classifier = pipeline(
    task="zero-shot-audio-classification", model="laion/clap-htsat-unfused"
)

In [4]:
dataset = get_dataset()
dataset

DatasetDict({
    train: Dataset({
        features: ['audio', 'slice_file_name', 'fsID', 'start', 'end', 'salience', 'fold', 'classID', 'class'],
        num_rows: 6112
    })
    test: Dataset({
        features: ['audio', 'slice_file_name', 'fsID', 'start', 'end', 'salience', 'fold', 'classID', 'class'],
        num_rows: 1310
    })
    valid: Dataset({
        features: ['audio', 'slice_file_name', 'fsID', 'start', 'end', 'salience', 'fold', 'classID', 'class'],
        num_rows: 1310
    })
})

In [5]:
classes = sorted(list(set(dataset["train"]["class"])))

In [6]:
def get_predictions(ds):
    data = []
    for x in tqdm(ds["audio"]):
        preds = classifier(x["array"], candidate_labels=classes)
        preds.sort(key=lambda x: x["label"])
        data.append([x["score"] for x in preds])

    data = np.array(data)
    
    df = pd.DataFrame(
        zip(
            ds["slice_file_name"],
            ds["class"],
            *data.T
        ),
        columns=["filename", "label"] + classes
    )

    return df

In [7]:
df_preds_valid = get_predictions(dataset["valid"])
print(len(df_preds_valid))
df_preds_valid.head()

  0%|          | 0/1310 [00:00<?, ?it/s]

1310


Unnamed: 0,filename,label,air_conditioner,car_horn,children_playing,dog_bark,drilling,engine_idling,gun_shot,jackhammer,siren,street_music
0,49312-2-0-16.wav,children_playing,2e-06,5.5e-05,0.998321,9e-06,8e-06,2.6e-05,2e-05,5.8e-05,0.000319,0.001182
1,169466-4-3-9.wav,drilling,0.001141,0.007098,0.070889,0.002817,0.007584,0.017082,0.074979,0.751844,0.022784,0.043782
2,39884-5-0-1.wav,engine_idling,0.002819,0.009947,4.8e-05,0.000188,4.3e-05,0.986275,5.3e-05,0.000401,8.3e-05,0.000142
3,167701-4-6-4.wav,drilling,0.001359,0.214464,0.002815,0.001745,0.308239,0.056742,0.000485,0.388747,0.02216,0.003244
4,24347-8-0-48.wav,siren,2e-06,0.005404,0.000871,0.000463,0.000141,0.000293,1.9e-05,9.7e-05,0.91458,0.078131


In [8]:
df_preds_valid.to_csv("preds/clap-valid.csv", index=False)

In [17]:
df_preds_test = get_predictions(dataset["test"])
print(len(df_preds_test))
df_preds_test.head()

  0%|          | 0/1310 [00:00<?, ?it/s]

1310


Unnamed: 0,filename,label,air_conditioner,car_horn,children_playing,dog_bark,drilling,engine_idling,gun_shot,jackhammer,siren,street_music
0,164797-2-0-50.wav,children_playing,0.002943992,0.008064,0.952775,0.0004795736,2.6e-05,0.000759,0.000759,0.000116,0.005448,0.02863
1,17578-5-0-23.wav,engine_idling,0.0218946,0.000205,2.6e-05,0.0001846368,0.000476,0.965558,5.4e-05,0.011473,5.1e-05,7.7e-05
2,207214-2-0-26.wav,children_playing,3.995908e-06,0.000179,0.992761,2.809861e-06,9e-06,0.000101,4.2e-05,7e-06,0.000857,0.006037
3,14470-2-0-14.wav,children_playing,9.924141e-08,2e-06,0.999881,4.821749e-07,2e-06,3e-06,1.8e-05,4.2e-05,8e-06,4.4e-05
4,93567-8-0-17.wav,siren,1.587656e-06,0.006024,0.000153,0.0001106603,1.5e-05,1e-05,9e-06,1.6e-05,0.984928,0.008732


In [18]:
df_preds_test.to_csv("preds/clap-test.csv", index=False)