## Import Relevant Files

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torcheval.metrics import R2Score

from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

import math

import sys
sys.path.insert(1, '../CDeeS/utils')
from paths import *

import chromadb

## Load Neural Network

In [2]:
class NeuralNetwork(nn.Module):
    def __init__(self, input_size):
        super(NeuralNetwork, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_size, math.ceil((input_size**0.5) * 2)),
            nn.ReLU(),
            nn.Linear(math.ceil((input_size**0.5) * 2), math.ceil((input_size**0.5) * 2)),
            nn.ReLU(),
            nn.Linear(math.ceil((input_size**0.5) * 2), math.ceil((input_size**0.5) * 2)),
            nn.ReLU(),
            nn.Linear(math.ceil((input_size**0.5) * 2), math.ceil((input_size**0.5) * 2)),
            nn.ReLU(),
            nn.Linear(math.ceil((input_size**0.5) * 2), math.ceil((input_size**0.5) * 2)),
            nn.ReLU(),
            nn.Linear(math.ceil((input_size**0.5) * 2), math.ceil((input_size**0.5) * 2)),
            nn.ReLU(),
            nn.Linear(math.ceil((input_size**0.5) * 2), math.ceil((input_size**0.5) * 2)),
            nn.ReLU(),
            nn.Linear(math.ceil((input_size**0.5) * 2), math.ceil((input_size**0.5) * 2)),
            nn.ReLU(),
            nn.Linear(math.ceil((input_size**0.5) * 2), math.ceil((input_size**0.5) * 2)),
            nn.ReLU(),
            nn.Linear(math.ceil((input_size**0.5) * 2), math.ceil((input_size**0.5) * 2)),
            nn.ReLU(),
            nn.Linear(math.ceil((input_size**0.5) * 2), math.ceil((input_size**0.5) * 2)),
            nn.ReLU(),
            nn.Linear(math.ceil((input_size**0.5) * 2), 2)
        )

    def forward(self, x):
        return self.layers(x)

In [3]:
model = NeuralNetwork(66)
model.load_state_dict(torch.load('deam_feedforward_nn_essentia_best_valence_mean_normalised.pt'))

<All keys matched successfully>

## Load Music Metadata

In [4]:
df_music = pd.read_csv("./music_info_withname.csv")
df_music.rename(columns={"i_id_c": "song_id"}, inplace=True)
df_music = df_music.astype({'song_id': 'int32'})

cols = ["song_id", "general_genre", "music", "singer"]
df_music = df_music[cols]
display(df_music)

Unnamed: 0,song_id,general_genre,music,singer
0,874,indie,The Privateers,Andrew Bird
1,566,other,Never Stops,Deerhunter
2,812,jazz,Night Of The Iguana,The Cinematic Orchestra
3,941,rock,You And Me,Plain White T's
4,802,reggae,Saber Su Nombre,Daddy Yankee
...,...,...,...,...
931,438,rock,Uptown Girl,Me First And The Gimme Gimmes
932,690,other,Oblivious,Aztec Camera
933,163,alternative,America's Suitehearts,Fall Out Boy
934,946,metal,Ride The Lightning,Metallica


## Load Essentia Features

In [5]:
df_features = pd.read_csv("./essentia_features_clipped.csv", index_col=0)
display(df_features)

Unnamed: 0,song_id,lowlevel.average_loudness,lowlevel.barkbands_crest.dmean,lowlevel.barkbands_crest.dmean2,lowlevel.barkbands_crest.dvar,lowlevel.barkbands_crest.dvar2,lowlevel.barkbands_crest.max,lowlevel.barkbands_crest.mean,lowlevel.barkbands_crest.median,lowlevel.barkbands_crest.min,...,metadata.version.essentia_git_sha,metadata.version.extractor,tonal.chords_key,tonal.chords_scale,tonal.key_edma.key,tonal.key_edma.scale,tonal.key_krumhansl.key,tonal.key_krumhansl.scale,tonal.key_temperley.key,tonal.key_temperley.scale
0,275,0.973211,2.734494,4.519869,6.629131,17.751158,22.156931,9.782679,9.064570,2.926671,...,v2.1_beta5-1110-g77a6a954-dirty,music 2.0,Bb,major,Eb,major,Eb,major,Eb,major
1,507,0.914995,1.779870,2.783498,2.342016,5.451638,24.106762,11.063436,10.109395,3.955993,...,v2.1_beta5-1110-g77a6a954-dirty,music 2.0,F#,major,F#,major,F#,major,F#,major
2,936,0.980768,2.198826,3.604454,6.320270,16.075104,21.427538,7.163894,6.297218,2.519053,...,v2.1_beta5-1110-g77a6a954-dirty,music 2.0,Ab,major,C#,major,C#,major,C#,major
3,739,0.957888,2.358258,3.890606,5.011153,12.956815,24.700670,15.441179,15.776317,4.622088,...,v2.1_beta5-1110-g77a6a954-dirty,music 2.0,A,major,D,minor,D,minor,D,minor
4,659,0.985889,2.725301,4.472162,7.312916,17.942963,19.204697,7.446805,6.570291,2.597616,...,v2.1_beta5-1110-g77a6a954-dirty,music 2.0,E,major,E,major,E,major,E,major
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
291,126,0.981803,3.347317,5.755877,9.081094,23.494911,24.453768,9.621816,8.968574,2.641244,...,v2.1_beta5-1110-g77a6a954-dirty,music 2.0,B,major,C,major,C,major,C,major
292,898,0.968436,2.509090,4.032299,6.231036,16.421440,23.289955,12.805885,12.573533,3.175924,...,v2.1_beta5-1110-g77a6a954-dirty,music 2.0,G,major,D,major,D,major,D,major
293,873,0.944901,3.191716,5.233109,9.425733,24.588820,26.887018,9.700563,8.773985,2.415689,...,v2.1_beta5-1110-g77a6a954-dirty,music 2.0,E,minor,E,minor,E,minor,E,minor
294,697,0.979438,2.086719,3.368903,5.024929,12.668143,23.188427,7.980445,7.247404,2.558759,...,v2.1_beta5-1110-g77a6a954-dirty,music 2.0,F#,minor,D,major,D,major,D,major


In [6]:
df_combined = df_music.merge(df_features, how="inner", on="song_id")
display(df_combined)

Unnamed: 0,song_id,general_genre,music,singer,lowlevel.average_loudness,lowlevel.barkbands_crest.dmean,lowlevel.barkbands_crest.dmean2,lowlevel.barkbands_crest.dvar,lowlevel.barkbands_crest.dvar2,lowlevel.barkbands_crest.max,...,metadata.version.essentia_git_sha,metadata.version.extractor,tonal.chords_key,tonal.chords_scale,tonal.key_edma.key,tonal.key_edma.scale,tonal.key_krumhansl.key,tonal.key_krumhansl.scale,tonal.key_temperley.key,tonal.key_temperley.scale
0,874,indie,The Privateers,Andrew Bird,0.392642,2.495658,4.118659,5.243636,14.063597,26.688347,...,v2.1_beta5-1110-g77a6a954-dirty,music 2.0,D,major,D,major,D,major,D,major
1,812,jazz,Night Of The Iguana,The Cinematic Orchestra,0.933627,2.397552,3.892377,4.056502,10.908409,25.141388,...,v2.1_beta5-1110-g77a6a954-dirty,music 2.0,C,minor,C,minor,C,minor,C,minor
2,941,rock,You And Me,Plain White T's,0.973347,2.531084,4.147055,5.669414,14.551494,26.319483,...,v2.1_beta5-1110-g77a6a954-dirty,music 2.0,G,major,G,major,G,major,G,major
3,739,metal,Planet Caravan,Black Sabbath,0.957888,2.358258,3.890606,5.011153,12.956815,24.700670,...,v2.1_beta5-1110-g77a6a954-dirty,music 2.0,A,major,D,minor,D,minor,D,minor
4,70,reggae,No Creo En El Jamas,Juanes,0.954137,2.347313,3.715736,4.513411,11.102341,23.834007,...,v2.1_beta5-1110-g77a6a954-dirty,music 2.0,E,major,E,major,E,major,E,major
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
291,47,rock,The Stomp,The Hives,0.971516,2.380748,3.945108,5.812252,14.684026,24.004417,...,v2.1_beta5-1110-g77a6a954-dirty,music 2.0,A,major,A,minor,A,minor,A,minor
292,529,hip-hop,Where The Hood At,DMX,0.971770,3.323876,5.149584,8.319780,20.000862,24.849039,...,v2.1_beta5-1110-g77a6a954-dirty,music 2.0,G,major,G,minor,G,minor,G,minor
293,875,alternative,Sullen Girl,Fiona Apple,0.975494,1.793537,2.983498,2.826975,8.117479,21.265247,...,v2.1_beta5-1110-g77a6a954-dirty,music 2.0,D,minor,D,minor,D,minor,D,minor
294,946,metal,Ride The Lightning,Metallica,0.967736,2.375322,3.969078,5.563212,15.381688,21.387815,...,v2.1_beta5-1110-g77a6a954-dirty,music 2.0,C,major,C,major,C,major,C,major


In [7]:
# Pretend to keep only VA values - Used dummy variables instead
df_combined["valence_pred"] = np.random.randn(df_combined.shape[0])
df_combined["arousal_pred"] = np.random.randn(df_combined.shape[0])
display(df_combined)

Unnamed: 0,song_id,general_genre,music,singer,lowlevel.average_loudness,lowlevel.barkbands_crest.dmean,lowlevel.barkbands_crest.dmean2,lowlevel.barkbands_crest.dvar,lowlevel.barkbands_crest.dvar2,lowlevel.barkbands_crest.max,...,tonal.chords_key,tonal.chords_scale,tonal.key_edma.key,tonal.key_edma.scale,tonal.key_krumhansl.key,tonal.key_krumhansl.scale,tonal.key_temperley.key,tonal.key_temperley.scale,valence_pred,arousal_pred
0,874,indie,The Privateers,Andrew Bird,0.392642,2.495658,4.118659,5.243636,14.063597,26.688347,...,D,major,D,major,D,major,D,major,1.561614,0.431368
1,812,jazz,Night Of The Iguana,The Cinematic Orchestra,0.933627,2.397552,3.892377,4.056502,10.908409,25.141388,...,C,minor,C,minor,C,minor,C,minor,-1.613583,-1.390048
2,941,rock,You And Me,Plain White T's,0.973347,2.531084,4.147055,5.669414,14.551494,26.319483,...,G,major,G,major,G,major,G,major,-0.800255,-0.237850
3,739,metal,Planet Caravan,Black Sabbath,0.957888,2.358258,3.890606,5.011153,12.956815,24.700670,...,A,major,D,minor,D,minor,D,minor,1.466953,-0.044559
4,70,reggae,No Creo En El Jamas,Juanes,0.954137,2.347313,3.715736,4.513411,11.102341,23.834007,...,E,major,E,major,E,major,E,major,-0.714615,0.913900
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
291,47,rock,The Stomp,The Hives,0.971516,2.380748,3.945108,5.812252,14.684026,24.004417,...,A,major,A,minor,A,minor,A,minor,-0.160364,1.405267
292,529,hip-hop,Where The Hood At,DMX,0.971770,3.323876,5.149584,8.319780,20.000862,24.849039,...,G,major,G,minor,G,minor,G,minor,-1.738442,0.767901
293,875,alternative,Sullen Girl,Fiona Apple,0.975494,1.793537,2.983498,2.826975,8.117479,21.265247,...,D,minor,D,minor,D,minor,D,minor,0.604298,1.930051
294,946,metal,Ride The Lightning,Metallica,0.967736,2.375322,3.969078,5.563212,15.381688,21.387815,...,C,major,C,major,C,major,C,major,1.105372,0.389298


In [8]:
df_va = df_combined[["valence_pred", "arousal_pred"]].copy()
display(df_va.values.tolist())

[[1.5616135311733512, 0.4313682079314598],
 [-1.6135830616948528, -1.3900484057095837],
 [-0.8002553704286894, -0.23784962244836902],
 [1.466953448566375, -0.0445590506253277],
 [-0.7146154332107751, 0.9139002180695679],
 [0.5929799682835929, 0.26953936685876295],
 [-0.41663276636171026, 0.30324128650361726],
 [0.9006214634948152, -0.4512937972844698],
 [-0.3324584972053156, 0.45088099534712117],
 [0.9588675968943076, 1.6620741583214076],
 [-1.9407469205313612, 1.6927267306144376],
 [-1.9104231105553842, 1.179672467338831],
 [0.4002395457434478, -1.744934694609595],
 [-0.3654172616836193, -0.48480787416137217],
 [1.3271271934766782, 0.9275375486017436],
 [-0.2631866597456472, -0.3840371530133859],
 [0.7530082064758642, 1.1385014508791926],
 [-1.3160351085119684, 0.6605815305281982],
 [-0.4153399832702466, -0.35609413383115557],
 [-0.005020792561050923, -0.3847882506989781],
 [0.7114305595178242, -0.007361435431717674],
 [0.10800926794939848, 0.27372727033959343],
 [0.4238391803691664, 

In [9]:
metadata = df_combined.columns.tolist()[1:4]
# print(metadata)
metadata_dicts = df_combined[metadata].to_dict('records')
# print(metadata_dicts)

In [15]:
song_ids = df_combined["song_id"].values.tolist()
print(song_ids)

[874, 812, 941, 739, 70, 987, 5, 300, 720, 266, 576, 6, 876, 360, 912, 74, 158, 239, 965, 276, 373, 550, 984, 905, 578, 170, 604, 204, 325, 361, 818, 268, 490, 258, 680, 506, 600, 26, 923, 699, 627, 558, 205, 17, 629, 594, 327, 13, 136, 673, 914, 175, 223, 532, 520, 14, 182, 480, 397, 634, 754, 120, 707, 23, 962, 514, 863, 422, 868, 790, 409, 995, 289, 544, 439, 414, 972, 297, 264, 826, 428, 727, 603, 649, 706, 126, 795, 192, 169, 591, 202, 372, 282, 822, 911, 24, 499, 947, 209, 228, 283, 329, 208, 977, 168, 796, 20, 333, 149, 934, 97, 341, 174, 975, 111, 759, 269, 723, 144, 384, 994, 429, 954, 222, 835, 843, 390, 621, 104, 672, 19, 117, 786, 418, 440, 886, 925, 320, 652, 936, 411, 90, 207, 897, 807, 216, 736, 731, 57, 780, 793, 924, 465, 46, 719, 565, 709, 752, 937, 516, 375, 951, 142, 883, 67, 469, 388, 619, 555, 356, 609, 635, 548, 996, 878, 518, 442, 367, 726, 957, 741, 76, 100, 784, 659, 43, 872, 206, 676, 363, 22, 974, 285, 81, 743, 816, 916, 523, 9, 677, 132, 233, 364, 862, 491,

In [16]:
song_ids_str = [str(x) for x in song_ids]
print(song_ids_str)

['874', '812', '941', '739', '70', '987', '5', '300', '720', '266', '576', '6', '876', '360', '912', '74', '158', '239', '965', '276', '373', '550', '984', '905', '578', '170', '604', '204', '325', '361', '818', '268', '490', '258', '680', '506', '600', '26', '923', '699', '627', '558', '205', '17', '629', '594', '327', '13', '136', '673', '914', '175', '223', '532', '520', '14', '182', '480', '397', '634', '754', '120', '707', '23', '962', '514', '863', '422', '868', '790', '409', '995', '289', '544', '439', '414', '972', '297', '264', '826', '428', '727', '603', '649', '706', '126', '795', '192', '169', '591', '202', '372', '282', '822', '911', '24', '499', '947', '209', '228', '283', '329', '208', '977', '168', '796', '20', '333', '149', '934', '97', '341', '174', '975', '111', '759', '269', '723', '144', '384', '994', '429', '954', '222', '835', '843', '390', '621', '104', '672', '19', '117', '786', '418', '440', '886', '925', '320', '652', '936', '411', '90', '207', '897', '807', 

## Chroma Vector Database

In [23]:
client = chromadb.Client()

## Persistent store
# client = chromadb.PersistentClient(path="./db/")

client.heartbeat()


1712281638316949000

In [25]:
# https://docs.trychroma.com/usage-guide#changing-the-distance-function
collection = client.create_collection(
        name="SiTunes_datasetv2",
        metadata={"hnsw:space": "l2"} # "l2", "ip" or "cosine"
    )

In [26]:
collection.add(
    documents=[f"doc{i}" for i in range(df_combined.shape[0])], # file_path
    embeddings=df_va.values.tolist(), # VA embeddings
    metadatas=metadata_dicts, # song_names
    ids=song_ids_str # song_id from SiTunes dataset
)

In [27]:
collection.query(
    query_embeddings=[[1, 1]], # search by embeddings
    n_results=1, # number of results
    # where={"metadata_field": "is_equal_to_this"}, # filter on metadata, e.g. genre
    # where_document={"$contains":"search_string"} # filter on file path???
    include=["distances", "metadatas", "embeddings", "documents", "uris", "data"] # Need to include all relevant parameters to show
)

{'ids': [['170']],
 'distances': [[8.344650268554688e-07]],
 'metadatas': [[{'general_genre': 'rock',
    'music': 'With Twilight As My Guide',
    'singer': 'The Mars Volta'}]],
 'embeddings': [[[1.1071932315826416, 1.1100881099700928]]],
 'documents': [['doc25']],
 'uris': [[None]],
 'data': None}