In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import lightning as L
import numpy as np

In [20]:
from model_no_contraction import Encoder
from utils import *

In [36]:
model = Encoder.load_from_checkpoint("/workspace/fairouz/logs/vggish_randne_openclip_mxbai/version_0/checkpoints/epoch=199-step=40000.ckpt")

In [37]:
model.eval()

Encoder(
  (audio_encoder): Sequential(
    (0): Linear(in_features=128, out_features=128, bias=True)
    (1): ReLU()
  )
  (image_encoder): Sequential(
    (0): Linear(in_features=512, out_features=512, bias=True)
    (1): ReLU()
  )
  (text_encoder): Sequential(
    (0): Linear(in_features=1024, out_features=1024, bias=True)
    (1): ReLU()
  )
  (graph_encoder): Sequential(
    (0): Linear(in_features=128, out_features=128, bias=True)
    (1): ReLU()
  )
  (combined_encoder): Sequential(
    (0): Linear(in_features=1792, out_features=3584, bias=True)
    (1): ReLU()
    (2): Linear(in_features=3584, out_features=128, bias=True)
  )
)

In [38]:
model.device

device(type='cuda', index=0)

In [39]:
## Reading data
import json
tracks = json.load(open('fairouz_conf/fairouz/tracks_contextualized.json'))
track_ids = list(tracks.keys())

audio_embeddings = json.load(open('fairouz_conf/fairouz/embeddings/audio/song_audio_vggish_embeddings.json'))
graph_embeddings = json.load(open('fairouz_conf/fairouz/embeddings/graph/karate_club/song_nodes_RandNE_embedding.json'))
image_embeddings = json.load(open('fairouz_conf/fairouz/embeddings/image/album_covers_openclip_embeddings.json'))
text_embeddings = json.load(open('fairouz_conf/fairouz/embeddings/lyrics/song_lyrics_mxbai_embeddings.json'))

audio_embeddings_dict = transform_dict(audio_embeddings)
graph_embeddings_dict = transform_dict(graph_embeddings)
image_embeddings_dict = transform_dict(image_embeddings)
text_embeddings_dict = transform_dict(text_embeddings)

In [40]:
import faiss

In [41]:
load = json.load(open('/workspace/fairouz/fairouz_conf/fairouz/positives_negatives.json'))

In [42]:
embeddings = []
for track_id in track_ids:
    audio, graph, image, text = get_modality_embeddings(
        track_id,
        audio_embeddings_dict,
        image_embeddings_dict,
        text_embeddings_dict,
        graph_embeddings_dict,
    )
    embedding = model.predict_step(
        torch.tensor(audio).unsqueeze(0).to("cuda"),
        torch.tensor(image).unsqueeze(0).to("cuda"),
        torch.tensor(text).unsqueeze(0).to("cuda"),
        torch.tensor(graph).unsqueeze(0).to("cuda"),
    )
    embedding = embedding.numpy(force=True).squeeze()
    embeddings.append({"id": track_id, "embedding": embedding.tolist()})

In [43]:
embeddings_array = []
metatadata_array = []
for emb in embeddings:
    id = emb['id']
    emb = np.array(emb['embedding'])
    metadata = tracks[id]
    md = {
        "id": id,
        "track_title": metadata['track_title'],
        "artist_name": metadata['artist_name'],
        "album_name": metadata['album_name'],
        "context": ", ".join(metadata["lyrics"]['context']),
        "summary": metadata["lyrics"]['summary'],
        "emotional": ", ".join(metadata["lyrics"]['emotional']),
        "genre": metadata['genres'][0] if len(metadata['genres']) > 0 else 'None',
        "image": metadata['image'],
        "preview_url": metadata['preview_url'],
    }
    embeddings_array.append(emb)
    metatadata_array.append(md)

In [44]:
index = faiss.IndexFlatL2(128)
index.add(np.array(embeddings_array).astype('float32'))

In [45]:
def get_fairouz_embedding(track_id):
    audio, graph, image, text = get_modality_embeddings(track_id, audio_embeddings_dict, image_embeddings_dict, text_embeddings_dict, graph_embeddings_dict)
    embedding = model.predict_step(torch.tensor(audio).unsqueeze(0), torch.tensor(image).unsqueeze(0), torch.tensor(text).unsqueeze(0), torch.tensor(graph).unsqueeze(0))
    embedding = embedding.numpy(force=True).squeeze()
    return embedding

In [46]:
def get_positives(track_id):
    return load[track_id]["positives"]

def get_negatives(track_id):
    return load[track_id]["negatives"]

In [47]:
from torchmetrics.functional.retrieval import retrieval_precision
from torchmetrics.functional.retrieval import retrieval_recall
from torchmetrics.functional.retrieval import retrieval_hit_rate
from torchmetrics.functional.retrieval import retrieval_average_precision
from torchmetrics.functional.retrieval import retrieval_reciprocal_rank
from torchmetrics.functional.retrieval import retrieval_normalized_dcg


In [48]:
metrics = {
    "retrieval_precision": retrieval_precision, # top-k
    "retrieval_recall": retrieval_recall, # top-k
    "retrieval_hit_rate": retrieval_hit_rate, # top-k
    "retrieval_average_precision": retrieval_average_precision, # top-k
    "retrieval_reciprocal_rank": retrieval_reciprocal_rank, # top-k
    "retrieval_normalized_dcg": retrieval_normalized_dcg, # top-k

}

In [49]:
def retrieve(track_id, k=10):
    audio, graph, image, text = get_modality_embeddings(
        track_id,
        audio_embeddings_dict,
        image_embeddings_dict,
        text_embeddings_dict,
        graph_embeddings_dict,
    )
    embedding = model.predict_step(
        torch.tensor(audio).unsqueeze(0).to("cuda"),
        torch.tensor(image).unsqueeze(0).to("cuda"),
        torch.tensor(text).unsqueeze(0).to("cuda"),
        torch.tensor(graph).unsqueeze(0).to("cuda"),
    )
    embedding = embedding.numpy(force=True)
    D, I = index.search(embedding, k)
    distances = D[0]
    normalized_distances = (distances - np.min(distances)) / (
        np.max(distances) - np.min(distances)
    )
    m = nn.Softmax(dim=0)
    similarity = m(torch.tensor([1 - d for d in normalized_distances]))
    positives = get_positives(track_id)
    ids = [metatadata_array[i]["id"] for i in I[0]]
    target = [1 if id in positives else 0 for id in ids]
    return similarity, torch.tensor(target)

In [50]:
for metric in metrics:
    k = {10: [], 15: [], 20: [], 25: [],}
    for track_id in track_ids:
        sim, target = retrieve(track_id, 10)
        k[10].append(metrics[metric](sim, target, 10))
        sim, target = retrieve(track_id, 15)
        k[15].append(metrics[metric](sim, target, 15))
        sim, target = retrieve(track_id, 20)
        k[20].append(metrics[metric](sim, target, 20))
        sim, target = retrieve(track_id, 25)
        k[25].append(metrics[metric](sim, target, 25))
    print(f"{metric}_k@10", np.mean(k[10]))
    print(f"{metric}_k@15", np.mean(k[15]))
    print(f"{metric}_k@20", np.mean(k[20]))
    print(f"{metric}_k@25", np.mean(k[25]))

retrieval_precision_k@10 0.07445256
retrieval_precision_k@15 0.06326034
retrieval_precision_k@20 0.055778593
retrieval_precision_k@25 0.050170314
retrieval_recall_k@10 0.5231144
retrieval_recall_k@15 0.62895375
retrieval_recall_k@20 0.69343066
retrieval_recall_k@25 0.7457421
retrieval_hit_rate_k@10 0.5231144
retrieval_hit_rate_k@15 0.62895375
retrieval_hit_rate_k@20 0.69343066
retrieval_hit_rate_k@25 0.7457421
retrieval_average_precision_k@10 0.14876954
retrieval_average_precision_k@15 0.15344614
retrieval_average_precision_k@20 0.15399164
retrieval_average_precision_k@25 0.15365708
retrieval_reciprocal_rank_k@10 0.14886263
retrieval_reciprocal_rank_k@15 0.15715034
retrieval_reciprocal_rank_k@20 0.16080293
retrieval_reciprocal_rank_k@25 0.16311963
retrieval_normalized_dcg_k@10 0.24258032
retrieval_normalized_dcg_k@15 0.2707632
retrieval_normalized_dcg_k@20 0.28666085
retrieval_normalized_dcg_k@25 0.29858273


In [None]:
print(model)

In [None]:
sum(p.numel() for p in model.parameters())
# print as (X.X M)
print(f"{sum(p.numel() for p in model.parameters()) / 1e6:.1f} M")