In [84]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import lightning as L
import numpy as np

In [85]:
from torchmetrics.functional.retrieval import retrieval_hit_rate
from torchmetrics.functional.retrieval import retrieval_precision
from torchmetrics.functional.retrieval import retrieval_average_precision
from torchmetrics.functional.retrieval import retrieval_reciprocal_rank
from torchmetrics.functional.retrieval import retrieval_normalized_dcg
from torchmetrics.functional.retrieval import retrieval_recall

In [86]:
from model import Encoder
from utils import *

In [87]:
model = Encoder.load_from_checkpoint(
    "../vggish_randne_openclip_mxbai_contraction_2_expansion4_51k_datapoints/version_0/checkpoints/epoch=199-step=322800.ckpt")

In [88]:
model.eval()

Encoder(
  (audio_encoder): Sequential(
    (0): Linear(in_features=128, out_features=64, bias=True)
    (1): ReLU()
  )
  (image_encoder): Sequential(
    (0): Linear(in_features=512, out_features=256, bias=True)
    (1): ReLU()
  )
  (text_encoder): Sequential(
    (0): Linear(in_features=1024, out_features=512, bias=True)
    (1): ReLU()
  )
  (graph_encoder): Sequential(
    (0): Linear(in_features=128, out_features=64, bias=True)
    (1): ReLU()
  )
  (combined_encoder): Sequential(
    (0): Linear(in_features=896, out_features=3584, bias=True)
    (1): ReLU()
    (2): Linear(in_features=3584, out_features=128, bias=True)
  )
)

In [89]:
model.device

device(type='cpu')

In [90]:
DEVICE = "cpu"

In [91]:
# Reading data
import json
tracks = json.load(open('../fairouz_conf/fairouz/tracks_contextualized.json'))
track_ids = list(tracks.keys())

audio_embeddings = json.load(open(
    '../fairouz_conf/fairouz/embeddings/audio/song_audio_vggish_embeddings.json'))
graph_embeddings = json.load(open(
    '../fairouz_conf/fairouz/embeddings/graph/karate_club/song_nodes_RandNE_embedding.json'))
image_embeddings = json.load(open(
    '../fairouz_conf/fairouz/embeddings/image/album_covers_openclip_embeddings.json'))
text_embeddings = json.load(open(
    '../fairouz_conf/fairouz/embeddings/lyrics/song_lyrics_mxbai_embeddings.json'))

audio_embeddings_dict = transform_dict(audio_embeddings)
graph_embeddings_dict = transform_dict(graph_embeddings)
image_embeddings_dict = transform_dict(image_embeddings)
text_embeddings_dict = transform_dict(text_embeddings)

In [92]:
import faiss

In [93]:
load = json.load(open("../fairouz_conf/fairouz/positives_negatives_more_negatives.json"))

In [94]:
embeddings = []
for track_id in track_ids:
    audio, graph, image, text = get_modality_embeddings(
        track_id,
        audio_embeddings_dict,
        image_embeddings_dict,
        text_embeddings_dict,
        graph_embeddings_dict,
    )
    embedding = model.predict_step(
        torch.tensor(audio).unsqueeze(0).to(DEVICE),
        torch.tensor(image).unsqueeze(0).to(DEVICE),
        torch.tensor(text).unsqueeze(0).to(DEVICE),
        torch.tensor(graph).unsqueeze(0).to(DEVICE),
    )
    embedding = embedding.numpy(force=True).squeeze()
    embeddings.append({"id": track_id, "embedding": embedding.tolist()})

In [95]:
embeddings_array = []
metatadata_array = []
for emb in embeddings:
    id = emb["id"]
    emb = np.array(emb["embedding"])
    metadata = tracks[id]
    md = {
        "id": id,
        "track_title": metadata["track_title"],
        "artist_name": metadata["artist_name"],
        "album_name": metadata["album_name"],
        "context": ", ".join(metadata["lyrics"]["context"]),
        "summary": metadata["lyrics"]["summary"],
        "emotional": ", ".join(metadata["lyrics"]["emotional"]),
        "genre": metadata["genres"][0] if len(metadata["genres"]) > 0 else "None",
        "image": metadata["image"],
        "preview_url": metadata["preview_url"],
    }
    embeddings_array.append(emb)
    metatadata_array.append(md)

In [96]:
index = faiss.IndexFlatL2(128)
index.add(np.array(embeddings_array).astype("float32"))

In [97]:
def get_fairouz_embedding(track_id):
    audio, graph, image, text = get_modality_embeddings(
        track_id, audio_embeddings_dict, image_embeddings_dict, text_embeddings_dict, graph_embeddings_dict)
    embedding = model.predict_step(torch.tensor(audio).unsqueeze(0), torch.tensor(
        image).unsqueeze(0), torch.tensor(text).unsqueeze(0), torch.tensor(graph).unsqueeze(0))
    embedding = embedding.numpy(force=True).squeeze()
    return embedding

In [98]:
def get_positives(track_id):
    return load[track_id]["positives"]


def get_negatives(track_id):
    return load[track_id]["negatives"]

In [99]:
metrics = {
    "retrieval_precision": retrieval_precision,
    "retrieval_recall": retrieval_recall,
    "retrieval_hit_rate": retrieval_hit_rate,
    "retrieval_average_precision": retrieval_average_precision,
    "retrieval_reciprocal_rank": retrieval_reciprocal_rank,
    "retrieval_normalized_dcg": retrieval_normalized_dcg,
}

In [100]:
def evaluate(track_id, k=10):
    audio, graph, image, text = get_modality_embeddings(
        track_id,
        audio_embeddings_dict,
        image_embeddings_dict,
        text_embeddings_dict,
        graph_embeddings_dict,
    )
    embedding = model.predict_step(
        torch.tensor(audio).unsqueeze(0).to(DEVICE),
        torch.tensor(image).unsqueeze(0).to(DEVICE),
        torch.tensor(text).unsqueeze(0).to(DEVICE),
        torch.tensor(graph).unsqueeze(0).to(DEVICE),
    )
    embedding = embedding.numpy(force=True)
    D, I = index.search(embedding, k)
    distances = D[0]
    normalized_distances = (distances - np.min(distances)) / (
        np.max(distances) - np.min(distances)
    )
    m = nn.Softmax(dim=0)
    similarity = m(torch.tensor([1 - d for d in normalized_distances]))
    positives = get_positives(track_id)
    ids = [metatadata_array[i]["id"] for i in I[0]]
    target = [1 if id in positives else 0 for id in ids]
    return similarity, torch.tensor(target)

In [101]:
for metric in metrics:
    for k in [10, 15, 20, 25]:
        results = {f"{metric}_k@{k}": []}
        for track_id in track_ids:
            similarity, target = evaluate(track_id, 50)
            results[f"{metric}_k@{k}"].append(metrics[metric]
                                              (similarity, target, k))
        print(f"{metric}_k@{k}", np.mean(results[f"{metric}_k@{k}"]))

retrieval_precision_k@10 0.3281022
retrieval_precision_k@15 0.29878345
retrieval_precision_k@20 0.27670315
retrieval_precision_k@25 0.2596107
retrieval_recall_k@10 0.31484777
retrieval_recall_k@15 0.42805424
retrieval_recall_k@20 0.5288602
retrieval_recall_k@25 0.62133265
retrieval_hit_rate_k@10 0.9306569
retrieval_hit_rate_k@15 0.96836984
retrieval_hit_rate_k@20 0.9829684
retrieval_hit_rate_k@25 0.99026763
retrieval_average_precision_k@10 0.41307124
retrieval_average_precision_k@15 0.4041722
retrieval_average_precision_k@20 0.3920268
retrieval_average_precision_k@25 0.38168234
retrieval_reciprocal_rank_k@10 0.361479
retrieval_reciprocal_rank_k@15 0.36450228
retrieval_reciprocal_rank_k@20 0.3652504
retrieval_reciprocal_rank_k@25 0.3655893
retrieval_normalized_dcg_k@10 0.32183042
retrieval_normalized_dcg_k@15 0.35890967
retrieval_normalized_dcg_k@20 0.40776482
retrieval_normalized_dcg_k@25 0.4517281
