In [None]:
# Load the libraries
import torch
import os
import librosa


from BEATs.Tokenizers import TokenizersConfig, Tokenizers
from BEATs.BEATs import BEATs, BEATsConfig

# Test BEATs with the provided code

In [None]:
### Tokenizer

# load the pre-trained checkpoints
checkpoint = torch.load('/data/BEATs/Tokenizer_iter3_plus_AS2M.pt')

cfg = TokenizersConfig(checkpoint['cfg'])
BEATs_tokenizer = Tokenizers(cfg)
BEATs_tokenizer.load_state_dict(checkpoint['model'])
BEATs_tokenizer.eval()

# tokenize the audio and generate the labels
audio_input_16khz = torch.randn(1, 10000)
padding_mask = torch.zeros(1, 10000).bool()

labels = BEATs_tokenizer.extract_labels(audio_input_16khz, padding_mask=padding_mask)

In [None]:
labels.shape

In [None]:
# load the pre-trained checkpoints
checkpoint = torch.load('/data/BEATs/BEATs_iter3_plus_AS2M.pt')

cfg = BEATsConfig(checkpoint['cfg'])
BEATs_model = BEATs(cfg)
BEATs_model.load_state_dict(checkpoint['model'])
BEATs_model.eval()

# extract the the audio representation
audio_input_16khz = torch.randn(1, 10000)
padding_mask = torch.zeros(1, 10000).bool()

representation = BEATs_model.extract_features(audio_input_16khz, padding_mask=padding_mask)[0]

In [None]:
representation.shape

In [None]:
# load the fine-tuned checkpoints
checkpoint = torch.load('/data/BEATs/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt')

cfg = BEATsConfig(checkpoint['cfg'])
BEATs_model = BEATs(cfg)
BEATs_model.load_state_dict(checkpoint['model'])
BEATs_model.eval()

# predict the classification probability of each class
audio_input_16khz = torch.randn(3, 10000)
padding_mask = torch.zeros(3, 10000).bool()

probs = BEATs_model.extract_features(audio_input_16khz, padding_mask=padding_mask)[0]

for i, (top5_label_prob, top5_label_idx) in enumerate(zip(*probs.topk(k=5))):
    top5_label = [checkpoint['label_dict'][label_idx.item()] for label_idx in top5_label_idx]
    print(f'Top 5 predicted labels of the {i}th audio are {top5_label} with probability of {top5_label_prob}')

In [None]:
BEATs_model.parameters

# Test BEATs with my own files

In [None]:
import glob

data_folder = "/data/different_bird_songs/"
data = glob.glob(data_folder + "/**/*.mp3", recursive=True)

In [None]:
# Open file and resample at 16000Hz
trs = []
l = []

for afile in data:
    sig, sr = librosa.load(afile, sr = 16000, mono=True)
    sig_t = torch.tensor(sig).unsqueeze(0)
    trs.append(sig_t)
    l.append(afile.split("/")[-2])

In [None]:
# load the pre-trained checkpoints for the tokenizer
checkpoint = torch.load('/data/BEATs/Tokenizer_iter3_plus_AS2M.pt')
cfg = TokenizersConfig(checkpoint['cfg'])
BEATs_tokenizer = Tokenizers(cfg)
BEATs_tokenizer.load_state_dict(checkpoint['model'])
BEATs_tokenizer.eval()

# tokenize the audio and generate the labels
labels = BEATs_tokenizer.extract_labels(trs[0], padding_mask=padding_mask)

In [None]:
labels.shape

In [None]:
# load the pre-trained checkpoints
checkpoint = torch.load('/data/BEATs/BEATs_iter3_plus_AS2M.pt')
cfg = BEATsConfig(checkpoint['cfg'])
BEATs_model = BEATs(cfg)
BEATs_model.load_state_dict(checkpoint['model'])
BEATs_model.eval()

# extract the the audio representation
l_representations = []

for t in trs:
    padding_mask = torch.zeros(t.shape[0], t.shape[1]).bool()
    representation = BEATs_model.extract_features(t, padding_mask=padding_mask)[0]
    # DIMENSIONS are: Batch / Number of labels / Audio encoded in 768 dimension
    l_representations.append(representation[:,-1,:]) # Take only the last dimension as this is the encoded audio

In [None]:
padding_mask.shape

In [None]:
trs[1].shape

In [None]:
from sklearn.manifold import TSNE

representation = torch.cat(l_representations, dim=0)
representation = representation.detach().numpy()
tsne = TSNE(n_components=2, perplexity=5)
representation_2d = tsne.fit_transform(representation)

In [None]:
import seaborn as sns
sns.scatterplot(x = representation_2d[:, 0], y = representation_2d[:, 1], hue = l)

In [None]:
from sklearn.decomposition import PCA
pca = PCA(n_components=2)
representation_2d_pca = tsne.fit_transform(representation)

In [None]:
import torch

tensor = torch.tensor([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])

In [None]:
tensor.shape