## Run experiments

In [1]:
import lancedb
import wandb
wandb.login()

  from .autonotebook import tqdm as notebook_tqdm
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mangeliney[0m ([33mangeliney-georgian[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
uri = "../../data/lancedb-data/audio-lancedb"
db = lancedb.connect(uri)
db_tbl = db.open_table("audio_dataset")

audio_df = db_tbl.to_pandas()



In [3]:
import pandas as pd
import numpy as np 

def test_method(test_fn, embed_fn=None):
    queries_tbl = db.open_table("audio_example_queries")
    total_rows = queries_tbl.count_rows()
    song_num_actual = []
    song_num_retrieved = []

    conditions = [
        "(offset == 0) and (pitch_shift == 0) and (time_stretch == 1.0)",
        "(offset != 0) and (pitch_shift == 0) and (time_stretch == 1.0)",
        "(offset == 0) and (pitch_shift != 0) and (time_stretch == 1.0)",
        "(offset == 0) and (pitch_shift == 0) and (time_stretch != 1.0)",
        "(offset == 0) and (pitch_shift != 0) and (time_stretch != 1.0)",
        "(offset != 0) and (pitch_shift != 0) and (time_stretch != 1.0)",
    ]

    for condition in conditions:
        print(f"Running test for condition: {condition}")
        filtered_tbl = queries_tbl.search().where(condition).select(["song_num", "vector"])

        for _, row in filtered_tbl.to_pandas().iterrows():
            if embed_fn:
                row["vector"] = embed_fn(row["vector"])
            song_num_actual.append(row["song_num"])
            retrieved_info_list = test_fn(row["vector"]).to_pandas()

            song_num_retrieved.append([retrieved_info["song_num"] 
                                       for _, retrieved_info in retrieved_info_list.iterrows()])
    return song_num_actual, song_num_retrieved


def calculate_mrr(actual_songs, retrieved_songs):
    """
    Calculate Mean Reciprocal Rank (MRR) for a list of song retrievals.

    Parameters:
    actual_songs (list of int): A list of the actual song numbers.
    retrieved_songs (list of list of int): A list of lists, where each inner list contains retrieved song numbers.

    Returns:
    float: The Mean Reciprocal Rank (MRR) score.
    """
    reciprocal_ranks = []

    for actual, retrieved in zip(actual_songs, retrieved_songs):
        try:
            # Find the rank (1-indexed) of the actual song in the retrieved list
            rank = retrieved.index(actual) + 1
            reciprocal_ranks.append(1 / rank)
        except ValueError:
            # If the actual song is not in the retrieved list, reciprocal rank is 0
            reciprocal_ranks.append(0.0)

    # Calculate the mean of the reciprocal ranks
    return sum(reciprocal_ranks) / len(reciprocal_ranks)


def retrieval_recall(actual_songs, retrieved_songs):
    in_retrieved = []
    for actual, retrieved in zip(actual_songs, retrieved_songs):
        in_retrieved.append(actual in retrieved)
    
    return np.sum(in_retrieved)/len(in_retrieved)


def test_and_log(search_fn, embed_fn, search_metric, embedding):
    actual, retrieved = test_method(search_fn, embed_fn)
    mrr = calculate_mrr(actual, retrieved)
    rr = retrieval_recall(actual, retrieved) 
    print("mrr", mrr, "rr", rr)

    wandb.init(
        # set the wandb project where this run will be logged
        project="children-song-dataset-retrieval",

        # track hyperparameters and run metadata
        config={
        "embedding": embedding,
        "retrieval": search_metric,
        }
    )

    wandb.log({"mrr": mrr})
    wandb.log({"retrieval_recall": rr})
    wandb.finish()


In [4]:

def search(query_vector, db_tbl, metric="l2"):
    return db_tbl.search(query_vector).metric(metric).limit(3)

In [5]:
import librosa
import numpy as np

def extract_features(audio, sr=44100, aggregate="summary_stat"):
    # Extract MFCC features
    mfccs = librosa.feature.mfcc(y=audio, sr=sr, n_mfcc=13)

    # Extract Chroma features
    chroma = librosa.feature.chroma_stft(y=audio, sr=sr)
    
    # Extract Mel-scaled spectrogram features
    mel_spectrogram = librosa.feature.melspectrogram(y=audio, sr=sr, n_mels=128)

    if aggregate == "summary_stat":
        # Aggregate the MFCCs across time
        mfccs_mean = np.mean(mfccs, axis=1)
        mfccs_std = np.std(mfccs, axis=1)
        mfcc_embedding = np.concatenate([mfccs_mean, mfccs_std])

        chroma_mean = np.mean(chroma, axis=1)
        chroma_std = np.std(chroma, axis=1)
        chroma_embedding = np.concatenate([chroma_mean, chroma_std])

        mel_spectrogram_mean = np.mean(mel_spectrogram, axis=1)
        mel_spectrogram_std = np.std(mel_spectrogram, axis=1)
        mel_spectrogram_embedding = np.concatenate([mel_spectrogram_mean, mel_spectrogram_std])
    
    else:
        # Flatten the MFCCs into a 1D array
        mfcc_embedding = mfccs.flatten()
        chroma_embedding = chroma.flatten()
        mel_spectrogram_embedding = mel_spectrogram.flatten()
   
    return np.concatenate([mfcc_embedding, chroma_embedding, mel_spectrogram_embedding])



In [6]:
def embed_lookup_data(embed_fn, df, db_name):
    ## Re-embed the data with the new features
    db_setup = False

    batch_size = len(df)//5
    for i in range(0, len(df), batch_size):
        print(i)
        sound_arrays = []
        for _, row in df.iloc[i:i+batch_size].iterrows():
            sound_arrays.append(
                {
                    "vector": embed_fn(row["vector"]),
                    "sample_rate": row["sample_rate"],
                    "offset": row["offset"],
                    "pitch_shift": row["pitch_shift"],
                    "time_stretch": row["time_stretch"],
                    "song_num": row["song_num"],
                    "song_version": row["song_version"],
                    "chunk_num": row["chunk_num"],
                    "filename": row["filename"],
                }
            )
    

        if db_setup:
            feat_tbl.add(sound_arrays)
        else:
            if db_name in db.table_names():
                db.drop_table(db_name)
            feat_tbl = db.create_table(db_name, data=sound_arrays)
            db_setup = True

    

In [10]:
for search_metric in ["l2", "cosine", "dot"]:
    test_and_log(lambda x: search(x, db_tbl, search_metric),
                 None,
                 search_metric,
                 "none")
    

Running test for condition: (offset == 0) and (pitch_shift == 0) and (time_stretch == 1.0)
Running test for condition: (offset != 0) and (pitch_shift == 0) and (time_stretch == 1.0)
Running test for condition: (offset == 0) and (pitch_shift != 0) and (time_stretch == 1.0)
Running test for condition: (offset == 0) and (pitch_shift == 0) and (time_stretch != 1.0)
Running test for condition: (offset == 0) and (pitch_shift != 0) and (time_stretch != 1.0)
Running test for condition: (offset != 0) and (pitch_shift != 0) and (time_stretch != 1.0)
mrr 0.16666666666666666 rr 0.16666666666666666


0,1
mrr,▁
retrieval_recall,▁

0,1
mrr,0.16667
retrieval_recall,0.16667


Running test for condition: (offset == 0) and (pitch_shift == 0) and (time_stretch == 1.0)
Running test for condition: (offset != 0) and (pitch_shift == 0) and (time_stretch == 1.0)
Running test for condition: (offset == 0) and (pitch_shift != 0) and (time_stretch == 1.0)
Running test for condition: (offset == 0) and (pitch_shift == 0) and (time_stretch != 1.0)
Running test for condition: (offset == 0) and (pitch_shift != 0) and (time_stretch != 1.0)
Running test for condition: (offset != 0) and (pitch_shift != 0) and (time_stretch != 1.0)
mrr 0.25833333333333336 rr 0.26666666666666666


0,1
mrr,▁
retrieval_recall,▁

0,1
mrr,0.25833
retrieval_recall,0.26667


Running test for condition: (offset == 0) and (pitch_shift == 0) and (time_stretch == 1.0)
Running test for condition: (offset != 0) and (pitch_shift == 0) and (time_stretch == 1.0)
Running test for condition: (offset == 0) and (pitch_shift != 0) and (time_stretch == 1.0)
Running test for condition: (offset == 0) and (pitch_shift == 0) and (time_stretch != 1.0)
Running test for condition: (offset == 0) and (pitch_shift != 0) and (time_stretch != 1.0)
Running test for condition: (offset != 0) and (pitch_shift != 0) and (time_stretch != 1.0)
mrr 0.23055555555555557 rr 0.26666666666666666


0,1
mrr,▁
retrieval_recall,▁

0,1
mrr,0.23056
retrieval_recall,0.26667


## Try with feature extraction

In [11]:
sample_audio = db_tbl.search().limit(1).select(["song_num", "vector"]).to_pandas().iloc[0]["vector"]

In [14]:
embed_lookup_data(
    lambda audio: extract_features(audio, aggregate="summary_stat"),
    audio_df,
    "audio_feat_eng_sumstat"
)

embed_lookup_data(
    lambda audio: extract_features(audio, aggregate="full"),
    audio_df,
    "audio_feat_eng_full"
)

0
158
316
474
632
0
158
316
474
632


In [15]:
sumstat_tbl = db.open_table("audio_feat_eng_sumstat")
fullfeat_tbl = db.open_table("audio_feat_eng_full")

In [16]:
for search_metric in ["l2", "cosine", "dot"]:
    test_and_log(
        lambda x: search(x, sumstat_tbl, search_metric),
        lambda audio: extract_features(audio, aggregate="summary_stat"),
        search_metric,
        "audio_feat_eng_sumstat"
    )

    test_and_log(
        lambda x: search(x, fullfeat_tbl, search_metric),
        lambda audio: extract_features(audio, aggregate="full"),
        search_metric,
        "audio_feat_eng_full"
    )

Running test for condition: (offset == 0) and (pitch_shift == 0) and (time_stretch == 1.0)
Running test for condition: (offset != 0) and (pitch_shift == 0) and (time_stretch == 1.0)
Running test for condition: (offset == 0) and (pitch_shift != 0) and (time_stretch == 1.0)
Running test for condition: (offset == 0) and (pitch_shift == 0) and (time_stretch != 1.0)
Running test for condition: (offset == 0) and (pitch_shift != 0) and (time_stretch != 1.0)
Running test for condition: (offset != 0) and (pitch_shift != 0) and (time_stretch != 1.0)
mrr 0.4083333333333333 rr 0.4166666666666667


0,1
mrr,▁
retrieval_recall,▁

0,1
mrr,0.40833
retrieval_recall,0.41667


Running test for condition: (offset == 0) and (pitch_shift == 0) and (time_stretch == 1.0)
Running test for condition: (offset != 0) and (pitch_shift == 0) and (time_stretch == 1.0)
Running test for condition: (offset == 0) and (pitch_shift != 0) and (time_stretch == 1.0)
Running test for condition: (offset == 0) and (pitch_shift == 0) and (time_stretch != 1.0)
Running test for condition: (offset == 0) and (pitch_shift != 0) and (time_stretch != 1.0)
Running test for condition: (offset != 0) and (pitch_shift != 0) and (time_stretch != 1.0)
mrr 0.3861111111111111 rr 0.4166666666666667


0,1
mrr,▁
retrieval_recall,▁

0,1
mrr,0.38611
retrieval_recall,0.41667


Running test for condition: (offset == 0) and (pitch_shift == 0) and (time_stretch == 1.0)
Running test for condition: (offset != 0) and (pitch_shift == 0) and (time_stretch == 1.0)
Running test for condition: (offset == 0) and (pitch_shift != 0) and (time_stretch == 1.0)
Running test for condition: (offset == 0) and (pitch_shift == 0) and (time_stretch != 1.0)
Running test for condition: (offset == 0) and (pitch_shift != 0) and (time_stretch != 1.0)
Running test for condition: (offset != 0) and (pitch_shift != 0) and (time_stretch != 1.0)
mrr 0.4305555555555555 rr 0.4666666666666667


0,1
mrr,▁
retrieval_recall,▁

0,1
mrr,0.43056
retrieval_recall,0.46667


Running test for condition: (offset == 0) and (pitch_shift == 0) and (time_stretch == 1.0)
Running test for condition: (offset != 0) and (pitch_shift == 0) and (time_stretch == 1.0)
Running test for condition: (offset == 0) and (pitch_shift != 0) and (time_stretch == 1.0)
Running test for condition: (offset == 0) and (pitch_shift == 0) and (time_stretch != 1.0)
Running test for condition: (offset == 0) and (pitch_shift != 0) and (time_stretch != 1.0)
Running test for condition: (offset != 0) and (pitch_shift != 0) and (time_stretch != 1.0)
mrr 0.35833333333333334 rr 0.36666666666666664


0,1
mrr,▁
retrieval_recall,▁

0,1
mrr,0.35833
retrieval_recall,0.36667


Running test for condition: (offset == 0) and (pitch_shift == 0) and (time_stretch == 1.0)
Running test for condition: (offset != 0) and (pitch_shift == 0) and (time_stretch == 1.0)
Running test for condition: (offset == 0) and (pitch_shift != 0) and (time_stretch == 1.0)
Running test for condition: (offset == 0) and (pitch_shift == 0) and (time_stretch != 1.0)
Running test for condition: (offset == 0) and (pitch_shift != 0) and (time_stretch != 1.0)
Running test for condition: (offset != 0) and (pitch_shift != 0) and (time_stretch != 1.0)
mrr 0.0 rr 0.0


0,1
mrr,▁
retrieval_recall,▁

0,1
mrr,0.0
retrieval_recall,0.0


Running test for condition: (offset == 0) and (pitch_shift == 0) and (time_stretch == 1.0)
Running test for condition: (offset != 0) and (pitch_shift == 0) and (time_stretch == 1.0)
Running test for condition: (offset == 0) and (pitch_shift != 0) and (time_stretch == 1.0)
Running test for condition: (offset == 0) and (pitch_shift == 0) and (time_stretch != 1.0)
Running test for condition: (offset == 0) and (pitch_shift != 0) and (time_stretch != 1.0)
Running test for condition: (offset != 0) and (pitch_shift != 0) and (time_stretch != 1.0)
mrr 0.0 rr 0.0


0,1
mrr,▁
retrieval_recall,▁

0,1
mrr,0.0
retrieval_recall,0.0


## Try hubert or other audio pre-trained embedder

In [7]:
import torch
import librosa


def embed_with_model(audio, processor, model, output_field=None):
    y_16k = librosa.resample(audio, orig_sr=44100, target_sr=16000)

    # Preprocess the audio for the model
    inputs = processor(y_16k, sampling_rate=16000, return_tensors="pt", padding=True)

    # Pass the inputs to the model
    with torch.no_grad():
        out = model(**inputs)

    if output_field:
        out = out[output_field] #last_hidden_state or extract_features

    return out


In [11]:
from transformers import Wav2Vec2Processor, Wav2Vec2Model

# Load the pretrained Wav2Vec2 model and processor
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")

out = embed_with_model(sample_audio, processor, model)

Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [13]:
out.__dict__.keys()

dict_keys(['last_hidden_state', 'extract_features', 'hidden_states', 'attentions'])

In [19]:
out["extract_features"].numpy().flatten().shape

(255488,)

In [20]:
embed_lookup_data(
    lambda audio: embed_with_model(audio, processor, model, output_field="last_hidden_state").numpy().flatten(),
    audio_df,
    "audio_feat_wav2vec2_last_hidden_state"
)

0
158
316
474
632


In [21]:
embed_lookup_data(
    lambda audio: embed_with_model(audio, processor, model, output_field="extract_features").numpy().flatten(),
    audio_df,
    "audio_feat_wav2vec2_extract_features"
)

0
158
316
474
632


In [22]:
wav2vec2_hidden_state_tbl = db.open_table("audio_feat_wav2vec2_last_hidden_state")
wav2vec2_extract_feat_tbl = db.open_table("audio_feat_wav2vec2_extract_features")

In [23]:
for search_metric in ["l2", "cosine", "dot"]:
    test_and_log(
        lambda x: search(x, wav2vec2_hidden_state_tbl, search_metric),
        lambda audio: embed_with_model(audio, processor, model, output_field="last_hidden_state").numpy().flatten(),
        search_metric,
        "audio_feat_wav2vec2_last_hidden_state"
    )

    test_and_log(
        lambda x: search(x, wav2vec2_extract_feat_tbl, search_metric),
        lambda audio: embed_with_model(audio, processor, model, output_field="extract_features").numpy().flatten(),
        search_metric,
        "audio_feat_wav2vec2_extract_features"
    )

Running test for condition: (offset == 0) and (pitch_shift == 0) and (time_stretch == 1.0)
Running test for condition: (offset != 0) and (pitch_shift == 0) and (time_stretch == 1.0)
Running test for condition: (offset == 0) and (pitch_shift != 0) and (time_stretch == 1.0)
Running test for condition: (offset == 0) and (pitch_shift == 0) and (time_stretch != 1.0)
Running test for condition: (offset == 0) and (pitch_shift != 0) and (time_stretch != 1.0)
Running test for condition: (offset != 0) and (pitch_shift != 0) and (time_stretch != 1.0)
mrr 0.2916666666666667 rr 0.3


0,1
mrr,▁
retrieval_recall,▁

0,1
mrr,0.29167
retrieval_recall,0.3


Running test for condition: (offset == 0) and (pitch_shift == 0) and (time_stretch == 1.0)
Running test for condition: (offset != 0) and (pitch_shift == 0) and (time_stretch == 1.0)
Running test for condition: (offset == 0) and (pitch_shift != 0) and (time_stretch == 1.0)
Running test for condition: (offset == 0) and (pitch_shift == 0) and (time_stretch != 1.0)
Running test for condition: (offset == 0) and (pitch_shift != 0) and (time_stretch != 1.0)
Running test for condition: (offset != 0) and (pitch_shift != 0) and (time_stretch != 1.0)
mrr 0.18333333333333332 rr 0.18333333333333332


0,1
mrr,▁
retrieval_recall,▁

0,1
mrr,0.18333
retrieval_recall,0.18333


Running test for condition: (offset == 0) and (pitch_shift == 0) and (time_stretch == 1.0)
Running test for condition: (offset != 0) and (pitch_shift == 0) and (time_stretch == 1.0)
Running test for condition: (offset == 0) and (pitch_shift != 0) and (time_stretch == 1.0)
Running test for condition: (offset == 0) and (pitch_shift == 0) and (time_stretch != 1.0)
Running test for condition: (offset == 0) and (pitch_shift != 0) and (time_stretch != 1.0)
Running test for condition: (offset != 0) and (pitch_shift != 0) and (time_stretch != 1.0)
mrr 0.016666666666666666 rr 0.016666666666666666


0,1
mrr,▁
retrieval_recall,▁

0,1
mrr,0.01667
retrieval_recall,0.01667


Running test for condition: (offset == 0) and (pitch_shift == 0) and (time_stretch == 1.0)
Running test for condition: (offset != 0) and (pitch_shift == 0) and (time_stretch == 1.0)
Running test for condition: (offset == 0) and (pitch_shift != 0) and (time_stretch == 1.0)
Running test for condition: (offset == 0) and (pitch_shift == 0) and (time_stretch != 1.0)
Running test for condition: (offset == 0) and (pitch_shift != 0) and (time_stretch != 1.0)
Running test for condition: (offset != 0) and (pitch_shift != 0) and (time_stretch != 1.0)
mrr 0.19166666666666668 rr 0.2


0,1
mrr,▁
retrieval_recall,▁

0,1
mrr,0.19167
retrieval_recall,0.2


Running test for condition: (offset == 0) and (pitch_shift == 0) and (time_stretch == 1.0)
Running test for condition: (offset != 0) and (pitch_shift == 0) and (time_stretch == 1.0)
Running test for condition: (offset == 0) and (pitch_shift != 0) and (time_stretch == 1.0)
Running test for condition: (offset == 0) and (pitch_shift == 0) and (time_stretch != 1.0)
Running test for condition: (offset == 0) and (pitch_shift != 0) and (time_stretch != 1.0)
Running test for condition: (offset != 0) and (pitch_shift != 0) and (time_stretch != 1.0)
mrr 0.21666666666666667 rr 0.23333333333333334


0,1
mrr,▁
retrieval_recall,▁

0,1
mrr,0.21667
retrieval_recall,0.23333


Running test for condition: (offset == 0) and (pitch_shift == 0) and (time_stretch == 1.0)
Running test for condition: (offset != 0) and (pitch_shift == 0) and (time_stretch == 1.0)
Running test for condition: (offset == 0) and (pitch_shift != 0) and (time_stretch == 1.0)
Running test for condition: (offset == 0) and (pitch_shift == 0) and (time_stretch != 1.0)
Running test for condition: (offset == 0) and (pitch_shift != 0) and (time_stretch != 1.0)
Running test for condition: (offset != 0) and (pitch_shift != 0) and (time_stretch != 1.0)
mrr 0.15 rr 0.15


0,1
mrr,▁
retrieval_recall,▁

0,1
mrr,0.15
retrieval_recall,0.15


In [17]:
from transformers import AutoProcessor, AutoModel

# Load the pretrained Wav2Vec2 model and processor
model_name = "facebook/hubert-large-ls960-ft"
processor = AutoProcessor.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

out = embed_with_model(audio, processor, model)

Some weights of HubertModel were not initialized from the model checkpoint at facebook/hubert-large-ls960-ft and are newly initialized: ['hubert.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'hubert.encoder.pos_conv_embed.conv.parametrizations.weight.original1']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [18]:
out

BaseModelOutput(last_hidden_state=tensor([[[ 0.0286,  0.4783,  0.3566,  ...,  0.0384, -0.1241, -0.0626],
         [ 0.0280,  0.4963,  0.4230,  ...,  0.0139, -0.1003,  0.0727],
         [ 0.0387,  0.7239,  0.3847,  ...,  0.0613, -0.1662, -0.0239],
         ...,
         [ 0.2749,  0.0290,  0.2642,  ..., -0.0689, -0.2864,  0.0280],
         [ 0.1355,  0.2981,  0.3246,  ...,  0.3942, -0.4068,  0.1216],
         [-0.0655,  0.3900,  0.3681,  ...,  0.1234, -0.5148, -0.1127]]]), hidden_states=None, attentions=None)

In [28]:
embed_lookup_data(
    lambda audio: embed_with_model(audio, processor, model, output_field="last_hidden_state").numpy().flatten(),
    audio_df,
    "audio_feat_hubert_large_last_hidden_state"
)

0
158
316
474
632


In [8]:
hubert_large_hidden_state_tbl = db.open_table("audio_feat_hubert_large_last_hidden_state")


In [11]:
for search_metric in ["l2", "dot"]:
    print(search_metric)
    test_and_log(
        lambda x: search(x, hubert_large_hidden_state_tbl, search_metric),
        lambda audio: embed_with_model(audio, processor, model, output_field="last_hidden_state").numpy().flatten(),
        search_metric,
        "audio_feat_hubert_large_last_hidden_state"
    )

l2
Running test for condition: (offset == 0) and (pitch_shift == 0) and (time_stretch == 1.0)
Running test for condition: (offset != 0) and (pitch_shift == 0) and (time_stretch == 1.0)
Running test for condition: (offset == 0) and (pitch_shift != 0) and (time_stretch == 1.0)
Running test for condition: (offset == 0) and (pitch_shift == 0) and (time_stretch != 1.0)
Running test for condition: (offset == 0) and (pitch_shift != 0) and (time_stretch != 1.0)
Running test for condition: (offset != 0) and (pitch_shift != 0) and (time_stretch != 1.0)
mrr 0.14166666666666666 rr 0.15


0,1
mrr,▁
retrieval_recall,▁

0,1
mrr,0.14167
retrieval_recall,0.15


dot
Running test for condition: (offset == 0) and (pitch_shift == 0) and (time_stretch == 1.0)
Running test for condition: (offset != 0) and (pitch_shift == 0) and (time_stretch == 1.0)
Running test for condition: (offset == 0) and (pitch_shift != 0) and (time_stretch == 1.0)
Running test for condition: (offset == 0) and (pitch_shift == 0) and (time_stretch != 1.0)
Running test for condition: (offset == 0) and (pitch_shift != 0) and (time_stretch != 1.0)
Running test for condition: (offset != 0) and (pitch_shift != 0) and (time_stretch != 1.0)
mrr 0.11666666666666667 rr 0.11666666666666667


0,1
mrr,▁
retrieval_recall,▁

0,1
mrr,0.11667
retrieval_recall,0.11667


In [8]:
from transformers import AutoProcessor, AutoModel, WhisperFeatureExtractor

models = ["openai/whisper-tiny", "openai/whisper-base", "openai/whisper-small", "openai/whisper-medium"]

# Load the pretrained Wav2Vec2 model and processor
for model_name in models:
    print(model_name)
    feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name)
    processor = AutoProcessor.from_pretrained(model_name)
    def custom_processor(audio, sampling_rate=16000, return_tensors="pt", **kwargs):
        return processor.feature_extractor.pad([{
                "input_features":
                feature_extractor(audio, sampling_rate=sampling_rate)["input_features"][0]}],
            return_tensors=return_tensors)

    model = AutoModel.from_pretrained(model_name)
    def custom_model(input_features):
        return model.encoder(input_features)


    db_name = f"audio_feat_{model_name.replace('/', '_').replace("-", "_")}_last_hidden_state"
    if db_name not in db.table_names():
        embed_lookup_data(
            lambda audio: embed_with_model(audio, custom_processor, custom_model, output_field="last_hidden_state").numpy().flatten(),
            audio_df,
            db_name
        )

    hidden_state_tbl = db.open_table(db_name)

    for search_metric in ["l2", "cosine", "dot"]:
        print(search_metric)
        test_and_log(
            lambda x: search(x, hidden_state_tbl, search_metric),
            lambda audio: embed_with_model(audio, custom_processor, custom_model, output_field="last_hidden_state").numpy().flatten(),
            search_metric,
            db_name
        )                                                                       

openai/whisper-tiny
cosine
Running test for condition: (offset == 0) and (pitch_shift == 0) and (time_stretch == 1.0)
Running test for condition: (offset != 0) and (pitch_shift == 0) and (time_stretch == 1.0)
Running test for condition: (offset == 0) and (pitch_shift != 0) and (time_stretch == 1.0)
Running test for condition: (offset == 0) and (pitch_shift == 0) and (time_stretch != 1.0)
Running test for condition: (offset == 0) and (pitch_shift != 0) and (time_stretch != 1.0)
Running test for condition: (offset != 0) and (pitch_shift != 0) and (time_stretch != 1.0)
mrr 0.10833333333333334 rr 0.11666666666666667


0,1
mrr,▁
retrieval_recall,▁

0,1
mrr,0.10833
retrieval_recall,0.11667


openai/whisper-base
cosine
Running test for condition: (offset == 0) and (pitch_shift == 0) and (time_stretch == 1.0)


: 