In [None]:
import glob
import pandas as pd
import os
import torch
import librosa
import random

from BEATs.Tokenizers import TokenizersConfig, Tokenizers
from BEATs.BEATs import BEATs, BEATsConfig

data_folder = "/data/ESC-50-master/"
audio = glob.glob("/data/ESC-50-master/audio/*.wav")
labels = pd.read_csv("/data/ESC-50-master/meta/esc50.csv")

In [None]:
random_cat = random.sample(list(labels['category'].unique()), 5)
labels = labels[labels["category"].isin(random_cat)].sample(100)

In [None]:
df_audio = pd.DataFrame(audio, columns=["filepath"])
df_audio["filename"] = [f.split("/")[-1] for f in audio]
df_audio

In [None]:
filepath_labels = labels.merge(df_audio, how="inner", on="filename")

In [None]:
filepath_labels

In [None]:
list(filepath_labels["filepath_x"])

In [None]:
# Open file and resample at 16000Hz
trs = []
l = []

for afile in audio:
    sig, sr = librosa.load(afile, sr = 16000, mono=True)
    sig_t = torch.tensor(sig).unsqueeze(0)
    trs.append(sig_t)
    l.append(afile.split("/")[-2])

In [None]:
# load the pre-trained checkpoints
checkpoint = torch.load('/data/BEATs/BEATs_iter3_plus_AS2M.pt')
cfg = BEATsConfig(checkpoint['cfg'])
BEATs_model = BEATs(cfg)
BEATs_model.load_state_dict(checkpoint['model'])
BEATs_model.eval()

# extract the the audio representation
l_representations = []

for t in trs:
    padding_mask = torch.zeros(t.shape[0], t.shape[1]).bool()
    representation = BEATs_model.extract_features(t, padding_mask=padding_mask)[0]
    l_representations.append(representation[:,-1,:]) # Take only the last dimension as this is the encoded audio

In [None]:
from sklearn.manifold import TSNE
import torch

representation = torch.cat(l_representations, dim=0)
representation = representation.detach().numpy()
tsne = TSNE(n_components=2, perplexity=5)
representation_2d = tsne.fit_transform(representation)

In [None]:
import seaborn as sns
sns.scatterplot(x = representation_2d[:, 0], y = representation_2d[:, 1], hue = labels['category'])