# Visualizing Audio Embeddings

In this notebook, we will visualize audio embeddings of the HuggingFace Speech Commands dataset.

**Dependencies**

    # Install UMAP for dimensionality reduction.
    pip install umap-learn

    # (Optional) If you want to compute embeddings on your own.
    # This is not necessary if you are fetching precomputed embeddings.
    pip install transformers librosa soundfile

In [1]:
%set_env PYTORCH_ENABLE_MPS_FALLBACK=1

import meerkat as mk

%load_ext autoreload
%autoreload 2

# Set your device here
device = "mps"

env: PYTORCH_ENABLE_MPS_FALLBACK=1


In [2]:
# Skip the build if you do not have npm installed.
mk.gui.start(dev=False, skip_build=True)

(APIInfo(api=<fastapi.applications.FastAPI object at 0x16b817d90>, port=5000, server=<meerkat.interactive.server.Server object at 0x12077ef10>, name='127.0.0.1', shared=False, process=None, _url=None),
 FrontendInfo(package_manager='npm', port=8001, name='localhost', shared=False, process=<Popen: returncode: None args: ['python', '-m', 'http.server', '8001']>, _url=None))

## Load the Dataset
In this demo, we will be working with [`music_genres_small`](https://huggingface.co/datasets/lewtun/music_genres_small) dataset on HuggingFace.

In [3]:
dataset = mk.get(name="lewtun/music_genres_small", registry="huggingface")

Downloading readme:   0%|          | 0.00/487 [00:00<?, ?B/s]

  0%|          | 0/1 [00:00<?, ?it/s]



In [4]:
from typing import Any, Dict
import datasets as hf_datasets

from meerkat.interactive.formatter import AudioFormatterGroup
from meerkat.cells.audio import Audio

df = dataset["train"].view()

# The audio column is a dictionary containing the bytes.
# Extract the bytes lazily.
# The byte string is actually the fastest way to display the audio,
# because the encoding is already done.
df["audio"] = df["audio"].defer(lambda x: x["bytes"])

# Set the formatter for this column.
df["audio"].formatters = AudioFormatterGroup()

In [5]:
df

## Encode the dataset with Wav2Vec2
Encode the dataset with Wav2Vec2. This will take a few minutes.

You can also optionally download the embeddings from huggingface. See the code for how to do this

In [5]:
# Download embeddings from huggingface.
# If you want to generate your own embeddings, see the rest of this section.
df_embed = mk.DataFrame.read(
    "https://huggingface.co/datasets/meerkat-ml/meerkat-dataframes/resolve/main/music_genres_small-wav2vec2-embedded.mk.tar.gz",
    overwrite=True
)

Downloading:   0%|          | 0.00/2.88M [00:00<?, ?B/s]

Extracting tar archive, this may take a few minutes...


In [None]:
from meerkat.cells.audio import Audio
import datasets as hf_datasets

# The sampling rate used by Wav2Vec2.
sampling_rate = 16000

def to_mk_audio(audio: bytes) -> Audio:
    """Convert from bytes to Audio object."""
    audio_dict = hf_datasets.Audio().decode_example({"path": None, "bytes": audio})
    return Audio(data=audio_dict["array"], sampling_rate=audio_dict["sampling_rate"])

def to_array(audio: Audio):
    """Resample the audio to the sampling rate used by Wav2Vec2 and extract the array."""
    return audio.resample(sampling_rate).data

df_embed = df[["song_id", "audio"]]
df_embed["audio"] = df_embed["audio"].defer(to_mk_audio)
df_embed["audio_tensor"] = df_embed["audio"].defer(to_array)

In [None]:
import torch
from transformers import AutoProcessor, Wav2Vec2Model


processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h", device=device)
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h").to(device)

In [None]:
def embed(audio_tensor: torch.Tensor):
    audio_tensor = audio_tensor.type(torch.float32).to(device)
    inputs = processor(audio_tensor, sampling_rate=sampling_rate, return_tensors="pt", device=device)
    inputs["input_values"] = inputs["input_values"].to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    last_hidden_states = outputs.last_hidden_state
    return last_hidden_states.mean(dim=1).squeeze().cpu()

df_embed["embeddings"] = df_embed["audio_tensor"].map(embed, use_ray=False, pbar=True)
df_embed["embeddings"] = df_embed["embeddings"].to("cpu")

## Make the interface
Build the interface for visualizing the embeddings.

We will first merge the embedding dataframe (`df_embed`) with the dataset dataframe (`df`).
Then, we will use UMAP to decompose the embeddings.

In [6]:
plot_df = df.merge(df_embed, on="song_id")

In [7]:
# Compute umap of embeddings. This may take a few seconds.
from umap import UMAP

umap = UMAP(n_components=2)
umap = umap.fit_transform(plot_df["embeddings"])
plot_df["umap_1"] = umap[:, 0]
plot_df["umap_2"] = umap[:, 1]


OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.


In [None]:
plot_df = plot_df.mark()

In [17]:
plot = mk.gui.plotly.ScatterPlot(df=plot_df, x="umap_1", y="umap_2",)

@mk.gui.reactive
def filter(selected: list, df: mk.DataFrame):
    return df[df.primary_key.isin(selected)]

filtered_df = filter(plot.selected, plot_df)
table = mk.gui.Table(filtered_df, classes="h-full")

mk.gui.html.flex([plot, table], classes="h-[600px]") 