In [69]:
from transformers import AutoFeatureExtractor, AutoModel
from IPython.display import Audio as player
from datasets import load_dataset, Audio
from qdrant_client import QdrantClient
from qdrant_client.http import models
from os.path import join
from pathlib import Path
from glob import glob
import pandas as pd
import numpy as np
import librosa
import torch

In [70]:
client = QdrantClient(host="localhost", port=6333)

In [71]:
my_collection = "music_collection"
client.recreate_collection(
    collection_name=my_collection,
    vectors_config=models.VectorParams(size=768, distance=models.Distance.COSINE)
)

True

# 2. Data Prep

In [72]:
data_path = Path(f"ludwig-dataset")
data_path

PosixPath('ludwig-dataset')

In [73]:
data_dir = data_path / "mp3" / "mp3" / "latin"
music_data = load_dataset(
    "audiofolder", data_dir=data_dir, split="train", drop_labels=True
).shuffle(42).select(range(400))
music_data

Resolving data files:   0%|          | 0/979 [00:00<?, ?it/s]

Dataset({
    features: ['audio'],
    num_rows: 400
})

In [74]:
music_data[115]

{'audio': {'path': '/Users/zero/projects/python-playground/DS-playground/qdrant-search/ludwig-dataset/mp3/mp3/latin/1g8TA3JM0K0kxtQi4n38qk.mp3',
  'array': array([-0.0751303 , -0.11050164, -0.11514139, ...,  0.22492811,
          0.2993504 ,  0.22477728]),
  'sampling_rate': 44100}}

In [75]:
player(music_data[115]["audio"]["array"], rate=44100)

In [76]:
ids = [
    (
        music_data[i]
        ["audio"]
        ["path"]
        .split("/")[-1]
        .replace(".mp3", "")
    )
    for i in range(len(music_data))
]
index = [num for num in range(len(music_data))]
ids[:4], index[:4]

(['4CpcXFmmVEmkVt3exUvNNZ',
  '2NWa2R7J4gerg2p6WEDvJx',
  '5JxCu2Vn1fL7lcTHESxBRD',
  '7uzTm4WG1hFLcj4QPGkVZm'],
 [0, 1, 2, 3])

In [77]:
music_data = music_data.add_column("index", index)
music_data = music_data.add_column("ids", ids)
music_data[-1]

{'audio': {'path': '/Users/zero/projects/python-playground/DS-playground/qdrant-search/ludwig-dataset/mp3/mp3/latin/4O6hfbAp9y65YyaUS1hQk7.mp3',
  'array': array([0.00000000e+00, 1.19851318e-09, 1.27332223e-09, ...,
         8.65739230e-02, 9.57795307e-02, 9.86581109e-02]),
  'sampling_rate': 44100},
 'index': 399,
 'ids': '4O6hfbAp9y65YyaUS1hQk7'}

In [78]:
label_path = data_path / "labels.json"
labels = pd.read_json(label_path)
labels.head()

Unnamed: 0,tracks
35ecMLCJ1x2giJuvHLrI1t,{'otherSubgenres': {'L': [{'S': 'electronic---...
3p0EUhkUeCNrBIZwkjmeYe,"{'otherSubgenres': {'L': []}, 'artist': {'S': ..."
0rb6HvdvWJRkyhxsfFf1ep,"{'otherSubgenres': {'L': [{'S': 'rock'}, {'S':..."
4ssD5IkaicvM3L2Ff8FPWQ,"{'otherSubgenres': {'L': []}, 'artist': {'S': ..."
586ncAs8cYRTBlrxMDfmSP,{'otherSubgenres': {'L': [{'S': 'electronic---...


In [79]:
def get_metadata(x):
    cols = ["artist", "genre", "name", "subgenres"]
    list_of_cols = []
    for col in cols:
        try:
            mdata = list(x[col].values())[0]
        except:
            mdata = "Unknown"
        list_of_cols.append(mdata)
    return pd.Series(list_of_cols, index=cols)

In [80]:
clean_labels = labels["tracks"].apply(get_metadata).reset_index()
clean_labels.head()

Unnamed: 0,index,artist,genre,name,subgenres
0,35ecMLCJ1x2giJuvHLrI1t,Riovolt,electronic,It Ain't Over 'till It's Over,"[{'S': 'electronic---ambient'}, {'S': 'electro..."
1,3p0EUhkUeCNrBIZwkjmeYe,R.L. Burnside,blues,Fireman Ring the Bell,[{'S': 'blues---country blues'}]
2,0rb6HvdvWJRkyhxsfFf1ep,Chapterhouse,rock,Falling Down,[{'S': 'rock---shoegaze'}]
3,4ssD5IkaicvM3L2Ff8FPWQ,Lowell Fulsom,funk / soul,Tramp,[{'S': 'funk / soul---rhythm & blues'}]
4,586ncAs8cYRTBlrxMDfmSP,Paul Ellis,electronic,Dissolve,[{'S': 'electronic---ambient'}]


In [81]:
def get_vals(genres):
    genre_list = []
    for dicts in genres:
        for _,val in dicts.items():
            genre_list.append(val)
    return genre_list

clean_labels["subgenres"] = clean_labels.subgenres.apply(get_vals)
clean_labels["subgenres"].head()

0    [electronic---ambient, electronic---downtempo,...
1                              [blues---country blues]
2                                    [rock---shoegaze]
3                       [funk / soul---rhythm & blues]
4                               [electronic---ambient]
Name: subgenres, dtype: object

In [82]:
file_path = data_path / "mp3" / "mp3" / "latin" / "*.mp3"
files = glob(str(file_path))
ids = [i.split("/")[-1].replace(".mp3", "") for i in files]
music_paths = pd.DataFrame(zip(ids, files), columns=["ids", "urls"])

music_paths.head()

Unnamed: 0,ids,urls
0,5f1SjUy6ySgaEUIIy2m9l4,ludwig-dataset/mp3/mp3/latin/5f1SjUy6ySgaEUIIy...
1,03tbpnBQ9kiAL8GX0ouZUG,ludwig-dataset/mp3/mp3/latin/03tbpnBQ9kiAL8GX0...
2,67wqhzuPtGbZNYG1eVoLsd,ludwig-dataset/mp3/mp3/latin/67wqhzuPtGbZNYG1e...
3,0YfDtPub9AsTu4278mDWJE,ludwig-dataset/mp3/mp3/latin/0YfDtPub9AsTu4278...
4,7vH4D94WWhAdjll6b62wiw,ludwig-dataset/mp3/mp3/latin/7vH4D94WWhAdjll6b...


In [83]:
metadata = (
    music_data.select_columns(["index", "ids"])
    .to_pandas()
    .merge(right=clean_labels, how="left", left_on="ids", right_on="index")
    .merge(right=music_paths, how="left", left_on="ids", right_on="ids")
    .drop("index_y", axis=1)
    .rename({"index_x": "index"}, axis=1)
)

metadata.head()

Unnamed: 0,index,ids,artist,genre,name,subgenres,urls
0,0,4CpcXFmmVEmkVt3exUvNNZ,Carlos Puebla,latin,Y en Llego Fidel,[latin---cubano],ludwig-dataset/mp3/mp3/latin/4CpcXFmmVEmkVt3ex...
1,1,2NWa2R7J4gerg2p6WEDvJx,Baden Powell,latin,Manhã de Carnaval,[latin---samba],ludwig-dataset/mp3/mp3/latin/2NWa2R7J4gerg2p6W...
2,2,5JxCu2Vn1fL7lcTHESxBRD,Cal Tjader,latin,Mamblues,[latin---salsa],ludwig-dataset/mp3/mp3/latin/5JxCu2Vn1fL7lcTHE...
3,3,7uzTm4WG1hFLcj4QPGkVZm,Ibrahim Ferrer,latin,Silencio,[latin---cubano],ludwig-dataset/mp3/mp3/latin/7uzTm4WG1hFLcj4QP...
4,4,2VvvGYUy2KT7DWI38wQsum,Adriana Calcanhotto,latin,Eu vivo a sorrir,[latin---samba],ludwig-dataset/mp3/mp3/latin/2VvvGYUy2KT7DWI38...


In [86]:
payload = metadata.drop(["index", "ids"], axis=1).to_dict(orient="records")
payload[:3]

[{'artist': 'Carlos Puebla',
  'genre': 'latin',
  'name': 'Y en Llego Fidel',
  'subgenres': ['latin---cubano'],
  'urls': 'ludwig-dataset/mp3/mp3/latin/4CpcXFmmVEmkVt3exUvNNZ.mp3'},
 {'artist': 'Baden Powell',
  'genre': 'latin',
  'name': 'Manhã de Carnaval',
  'subgenres': ['latin---samba'],
  'urls': 'ludwig-dataset/mp3/mp3/latin/2NWa2R7J4gerg2p6WEDvJx.mp3'},
 {'artist': 'Cal Tjader',
  'genre': 'latin',
  'name': 'Mamblues',
  'subgenres': ['latin---salsa'],
  'urls': 'ludwig-dataset/mp3/mp3/latin/5JxCu2Vn1fL7lcTHESxBRD.mp3'}]

# 4. Embeddings

In [88]:
one_song = data_path / "mp3" / "mp3" / "latin" / "0rXvhxGisD2djBmNkrv5Gt.mp3"
audio, sr = librosa.core.load(one_song, sr=44100, mono=True)
audio.shape

(1322496,)

In [89]:
player(audio, rate=sr)

In [90]:
audio2 = audio[None, :]
audio2.shape

(1, 1322496)

## Transformers

In [91]:
# Check that MPS is available
if not torch.backends.mps.is_available():
    if not torch.backends.mps.is_built():
        print("MPS not available because the current PyTorch install was not "
              "built with MPS enabled.")
    else:
        print("MPS not available because the current MacOS version is not 12.3+ "
              "and/or you do not have an MPS-enabled device on this machine.")

else:
    mps_device = torch.device("mps")

mps_device

device(type='mps')

In [96]:
model = AutoModel.from_pretrained("facebook/wav2vec2-base").to(mps_device)
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")

Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [93]:
resampled_audio = librosa.resample(y=audio2, orig_sr=sr, target_sr=16_000)
display(player(resampled_audio, rate=16_000))
resampled_audio.shape

(1, 479818)

In [94]:
inputs = feature_extractor(
    resampled_audio[0], sampling_rate=feature_extractor.sampling_rate, return_tensors="pt",
    padding=True, return_attention_mask=True, truncation=True, max_length=16_000
)
inputs["input_values"].shape

torch.Size([1, 16000])