In [8]:
import faiss
import json
import numpy as np

In [24]:
def transform_dict(list_of_dict):
    transformed_embedding = {}
    for dictionary in list_of_dict:
        transformed_embedding[dictionary['id']] = dictionary['embedding']

    return transformed_embedding

In [92]:
tracks = json.load(open('fairouz_conf/fairouz/tracks_contextualized.json'))
track_ids = list(tracks.keys())

In [78]:
audio_embeddings = json.load(open('fairouz_conf/fairouz/embeddings/audio/song_audio_vggish_embeddings.json'))
graph_embeddings = json.load(open('fairouz_conf/fairouz/embeddings/graph/karate_club/song_nodes_RandNE_embedding.json'))
image_embeddings = json.load(open('fairouz_conf/fairouz/embeddings/image/album_covers_openclip_embeddings.json'))
text_embeddings = json.load(open('fairouz_conf/fairouz/embeddings/lyrics/song_lyrics_mxbai_embeddings.json'))

In [60]:
audio_embeddings_dict = transform_dict(audio_embeddings)
graph_embeddings_dict = transform_dict(graph_embeddings)
image_embeddings_dict = transform_dict(image_embeddings)
text_embeddings_dict = transform_dict(text_embeddings)

In [61]:
a_array = np.array([np.array(audio["embedding"]) for audio in audio_embeddings]).astype('float32')
g_array = np.array([np.array(graph["embedding"]) for graph in graph_embeddings]).astype('float32')
i_array = np.array([np.array(image["embedding"]) for image in image_embeddings]).astype('float32')
t_array = np.array([np.array(text["embedding"]) for text in text_embeddings]).astype('float32')

In [62]:
a_array.shape, g_array.shape, i_array.shape, t_array.shape

((822, 128), (822, 128), (822, 512), (822, 1024))

In [63]:
audio_index = faiss.IndexFlatL2(a_array.shape[1])
audio_index.add(a_array)

In [64]:
graph_index = faiss.IndexFlatL2(g_array.shape[1])
graph_index.add(g_array)

In [65]:
image_index = faiss.IndexFlatL2(i_array.shape[1])
image_index.add(i_array)

In [66]:
text_index = faiss.IndexFlatL2(t_array.shape[1])
text_index.add(t_array)

In [67]:
def get_modality_embeddings(track_id):
    audio_embedding = audio_embeddings_dict[track_id]
    graph_embedding = graph_embeddings_dict[track_id]
    image_embedding = image_embeddings_dict[track_id]
    text_embedding = text_embeddings_dict[track_id]
    return audio_embedding, graph_embedding, image_embedding, text_embedding

In [104]:
def get_positives(track_id, k = 10):
    a_emb, g_emb, i_emb, t_emb = get_modality_embeddings(track_id)
    modality_index = {"audio": audio_index, "graph": graph_index, "image": image_index, "text": text_index}
    ids = {"audio": audio_embeddings, "graph": graph_embeddings, "image": image_embeddings, "text": text_embeddings}
    modalities = {"audio": a_emb, "graph": g_emb, "image": i_emb, "text": t_emb}
    positives = []
    for modality, index in modality_index.items():
        D, I = index.search(np.array(modalities[modality]).reshape(1, -1).astype("float32"), k)
        zipped_list = list(zip(D[0].tolist(), I[0].tolist()))
        # positives.append({"modality": modality, "positives": [{"distance": D, "index": I} for D, I in zipped_list if D != 0]})
        unique = [tuple for tuple in zipped_list if tuple[0] != 0]
        positives.append(ids[modality][unique[0][1]]["id"])
    return positives


In [133]:
def get_negatives(track_id):
    ref_audio, ref_graph, ref_image, ref_text = get_modality_embeddings(track_id)
    ids = np.random.choice(track_ids, 40)
    a_d = []
    g_d = []
    i_d = []
    t_d = []
    for id in ids:
        a_emb, g_emb, i_emb, t_emb = get_modality_embeddings(id)
        a_d.append(np.dot(np.array(ref_audio), np.array(a_emb)))
        g_d.append(np.dot(np.array(ref_graph), np.array(g_emb)))
        i_d.append(np.dot(np.array(ref_image), np.array(i_emb)))
        t_d.append(np.dot(np.array(ref_text), np.array(t_emb)))

        # a_d.append(faiss.pairwise_distances(np.array(ref_audio).reshape(1, -1).astype("float32"), np.array(a_emb).reshape(1, -1).astype("float32"))[0][0])
        # g_d.append(faiss.pairwise_distances(np.array(ref_graph).reshape(1, -1).astype("float32"), np.array(g_emb).reshape(1, -1).astype("float32"))[0][0])
        # i_d.append(faiss.pairwise_distances(np.array(ref_image).reshape(1, -1).astype("float32"), np.array(i_emb).reshape(1, -1).astype("float32"))[0][0])
        # t_d.append(faiss.pairwise_distances(np.array(ref_text).reshape(1, -1).astype("float32"), np.array(t_emb).reshape(1, -1).astype("float32"))[0][0])
    return list(set([ids[np.argmin(a_d)], ids[np.argmin(g_d)], ids[np.argmin(i_d)], ids[np.argmin(t_d)]]))

In [130]:
def dot_product(v1, v2):
    return np.dot(np.array(v1), np.array(v2))

In [125]:
get_positives("id264cf6a9f396151588583f76e5a51a6a", 20)

['ide3554d4190fd714e50345bd4906469af',
 'id51f79510005a16836683c69d5ef37bc5',
 'id28be4f9571336d8cbf23d18f7d7548b9',
 'id92d80b2442fdeea5f83948c1bd6f16bf']

In [121]:
from tqdm import tqdm

In [136]:
data_dict = {}
for track_id in tqdm(track_ids):
    data_dict[track_id] = {"positives": get_positives(track_id, 50), "negatives": get_negatives(track_id)}

100%|██████████| 822/822 [00:05<00:00, 139.63it/s]


In [137]:
import json

In [138]:
json.dump(data_dict, open('fairouz_conf/fairouz/positives_negatives.json', 'w'))