In [1]:
"""A notebook to visualize audio embeddings of the HuggingFace Speech Commands dataset."""
%set_env PYTORCH_ENABLE_MPS_FALLBACK=1

import meerkat as mk

%load_ext autoreload
%autoreload 2

# Set your device here
device = "mps"

env: PYTORCH_ENABLE_MPS_FALLBACK=1


In [2]:
mk.gui.start(api_port=5005, frontend_port=8001, dev=False, skip_build=True)

(APIInfo(api=<fastapi.applications.FastAPI object at 0x13f6d96d0>, port=5005, server=<meerkat.interactive.server.Server object at 0x108ab9ca0>, name='127.0.0.1', shared=False, process=None, _url=None),
 FrontendInfo(package_manager='npm', port=8001, name='localhost', shared=False, process=<Popen: returncode: None args: ['python', '-m', 'http.server', '8001']>, _url=None))

## Load the Dataset
In this demo, we will be working with [`music_genres_small`](https://huggingface.co/datasets/lewtun/music_genres_small) dataset on HuggingFace.

In [3]:
dataset = mk.get(name="lewtun/music_genres_small", registry="huggingface")

  0%|          | 0/1 [00:00<?, ?it/s]



In [4]:
from typing import Any, Dict
import datasets as hf_datasets

from meerkat.interactive.formatter import AudioFormatterGroup
from meerkat.cells.audio import Audio

df = dataset["train"].view()

# The audio column is a dictionary containing the bytes.
# Extract the bytes lazily.
# The byte string is actually the fastest way to display the audio,
# because the encoding is already done.
df["audio"] = df["audio"].defer(lambda x: x["bytes"])
# df["audio"] = df["audio"].defer(to_mk_audio)

# Set the formatter for this column.
df["audio"].formatters = AudioFormatterGroup()

In [5]:
df

## Encode the dataset with Wav2Vec2
Encode the dataset with Wav2Vec2. This will take a few minutes.

You can also optionally download the embeddings from huggingface. See the code for how to do this

In [6]:
# Make a column that returns the audio tensors @ sampling rage 16kHz.
from meerkat.cells.audio import Audio
import datasets as hf_datasets

sampling_rate = 16000

def to_mk_audio(audio: bytes) -> Audio:
    audio_dict = hf_datasets.Audio().decode_example({"path": None, "bytes": audio})
    return Audio(data=audio_dict["array"], sampling_rate=audio_dict["sampling_rate"])

def to_array(audio: Audio):
    return audio.resample(sampling_rate).data

df_embed = df[["song_id", "audio"]]
df_embed["audio"] = df_embed["audio"].defer(to_mk_audio)
df_embed["audio_tensor"] = df_embed["audio"].defer(to_array)

In [7]:
import torch
from transformers import AutoProcessor, Wav2Vec2Model


processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h", device=device)
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h").to(device)

Some weights of the model checkpoint at facebook/wav2vec2-base-960h were not used when initializing Wav2Vec2Model: ['lm_head.weight', 'lm_head.bias']
- This IS expected if you are initializing Wav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
def embed(audio_tensor: torch.Tensor):
    audio_tensor = audio_tensor.type(torch.float32).to(device)
    inputs = processor(audio_tensor, sampling_rate=sampling_rate, return_tensors="pt", device=device)
    inputs["input_values"] = inputs["input_values"].to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    last_hidden_states = outputs.last_hidden_state
    return last_hidden_states.mean(dim=1).squeeze().cpu()

df_embed["embeddings"] = df_embed["audio_tensor"].map(embed, use_ray=False, pbar=True)
df_embed["embeddings"] = df_embed["embeddings"].to("cpu")

100%|██████████| 1000/1000 [04:28<00:00,  3.73it/s]


In [10]:
# Optionally download embeddings from huggingface
df_embed = mk.DataFrame.read("https://huggingface.co/datasets/meerkat-ml/meerkat-dataframes/resolve/main/music_genres_small-wav2vec2-embedded.mk.tar.gz")

df_embed["embeddings"] = df_embed["embeddings"].cpu().numpy()

In [12]:
# df_embed[["song_id", "embeddings"]].write("~/.meerkat/dataframes/music_genres_small-wav2vec2-embedded.mk")

## Make the interface

In [15]:
df_embed

In [13]:
plot_df = df.merge(df_embed, on="song_id")

TypeError: Cannot convert MemoryMappedTable to pyarrow.lib.Table

In [None]:
from typing import List

# we can make a scatter plot with selections visualized in a gallery
plot_df = df.mark()
plot = mk.gui.plotly.ScatterPlot(df=plot_df, x="umap_1", y="umap_2",)

# Because we're using the reactive decorator, the filter function will re-run whenever
# plot.selected changes. This will update the gallery to only show the selected points.
@mk.gui.reactive
def filter(selected: List[str], df: mk.DataFrame):
    return df[df.primary_key.isin(plot.selected)]

filtered_df = filter(plot.selected, plot_df)
gallery = mk.gui.Gallery(filtered_df, main_column="tweet")

mk.gui.html.div(
    [plot, gallery],
    classes="h-[1200px]",
)