# Task 2 : Extend your framework with audio-based retrieval systems and with evaluation metrics

### Team E

**Don't forget to update the version number after making changes** 

Version 2.0.2
Date: 11/12/2023

## Import Libraries  

In [None]:
import numpy as np
import pandas as pd
import json 

In [None]:
# import utility functions 
from ret_mmsr import read_data # utility func to load data
from ret_mmsr  import get_id_from_info # utility func to return id by entering song's info
from ret_mmsr  import display_res # utility func to display results 
from ret_mmsr  import get_genre #utility func to get the genres from a list of id´s
from ret_mmsr  import get_genre_from_query #utility func to get the id and genre from the query
from ret_mmsr  import get_genre_from_ids #utility func to get the id and genre from the retrieved results

# import the retrieval systems 
from ret_mmsr  import random_baseline # baseline retrieval system that returns random results 
from ret_mmsr  import text_based # modularized text based retrieval system
from ret_mmsr  import audio_based # modularized audio based retrieval system

# import wrapper function to calculate cosine similarity
from ret import cos_sim 

# import evaluation functions 
from ret_mmsr  import gen_cov_10
from ret_mmsr  import ndcg_score
from ret_mmsr  import gen_div_10
from ret_mmsr  import calculate_precision_at_k
from ret_mmsr  import calculate_recall_at_k
from ret_mmsr  import plot_precision_recall_curve
from ret_mmsr  import get_avg_recall_at_k
from ret_mmsr  import get_avg_precision_at_k

## Load Data

In [None]:
def load_data():
    """
    Please put data files in ""./data/" before use 
    """

    info = read_data("information")
    genres = read_data("genres")

    # text embeddings 
    tfidf = read_data("lyrics_tf-idf")
    word2vec = read_data("lyrics_word2vec")
    bert = read_data("lyrics_bert")

    # audio embeddings
    blf_correlation = read_data("blf_correlation")
    ivec256 = read_data("ivec256")
    mfcc_stats = read_data("mfcc_stats")
    musicnn = read_data("musicnn")

In [None]:


def initialize_results(tracks):
    """Initialize the 3D dictionary for storing evaluation results"""
    systems = ["base_line", "tfidf", "word2vec", "bert", 
              "blf_correlation", "ivec256", "mfcc_stats", "musicnn"]
    metrics = ["tracks", "precision@10", "recall@10", 
              "genre_diversity@10", "genre_coverage@10", "ndcg"]
    
    return {
        track: {
            system: {
                metric: None for metric in metrics
            } for system in systems
        } for track in tracks
    }

def retrieve_tracks(res, track_ids, info,representations):
    """Retrieve tracks using various retrieval systems"""
   
    
    for track, track_id in track_ids.items():
        # Baseline retrieval
        res[track]["base_line"]["tracks"] = random_baseline(id=track_id, info=info, N=10)
        
        # Text-based retrieval
        for system in ["tfidf", "word2vec", "bert"]:
            if system in representations:
                res[track][system]["tracks"] = text_based(
                    id=track_id, 
                    repr= representations[system], 
                    N=10, 
                    sim_func=cos_sim
                )

        # Audio-based retrieval
        for system in ["blf_correlation", "ivec256", "mfcc_stats", "musicnn"]:
            if system in representations:
                res[track][system]["tracks"] = audio_based(
                    id=track_id, 
                    repr=representations[system], 
                    N=10, 
                    sim_func=cos_sim
                )
                print(f"{track} {system} Results:")
                display_res(res[track][system]["tracks"], info)

def calculate_metrics(res, query_genres, genres):
    """Calculate evaluation metrics for all retrieval systems"""
    systems = ["base_line", "tfidf", "word2vec", "bert", 
              "blf_correlation", "ivec256", "mfcc_stats", "musicnn"]
    genres_list = genres.values.tolist()
    
    for track, query_genre in query_genres.items():
        # Get genres for retrieved tracks
        genres_dict = {
            system: get_genre_from_ids(res[track][system]["tracks"], genres)
            for system in systems
        }
        
        # Calculate metrics for each system
        for system in systems:
            # Precision@10
            print(query_genre)
            res[track][system]["precision@10"] = calculate_precision_at_k(
                query_genre, genres_dict[system], 10)
            print(f"{track} {system} precision@10: {res[track][system]['precision@10']}")

            # Recall@10
            res[track][system]["recall@10"] = calculate_recall_at_k(
                query_genre, genres_dict[system], genres_list, 10)
            print(f"{track} {system} recall@10: {res[track][system]['recall@10']}")

            # Genre Coverage@10
            res[track][system]["genre_coverage@10"] = gen_cov_10(
                res[track][system]["tracks"], genres)
            print(f"{track} {system} genre coverage@10: {res[track][system]['genre_coverage@10']}")

            # NDCG
            res[track][system]["ndcg"] = ndcg_score(
                track_ids[track], res[track][system]["tracks"], genres)
            print(f"{track} {system} ndcg: {res[track][system]['ndcg']}")

def calculate_pr_data(track_ids, info, query_genres, genres):
    """Calculate data for precision-recall curve"""
   
    
    system_data = {}
    genres_list = genres.values.tolist()
    
    def calculate_pr_data(track_ids, info, query_genres, genres):
        """Calculate data for precision-recall curve"""

        system_data = {}
        genres_list = genres.values.tolist()

        for track, track_id in track_ids.items():
            # Retrieve 100 results
            results = {
                "base_line": random_baseline(id=track_id, info=info, N=100)
            }

            # Text-based retrieval
            text_results = {
                system: text_based(
                    id=track_id, 
                    repr=representations[system], 
                    N=100, 
                    sim_func=cos_sim
                )
                for system in text_systems
                if system in representations
            }
            results.update(text_results)

            # Audio-based retrieval
            audio_results = {
                system: audio_based(
                    id=track_id, 
                    repr=representations[system], 
                    N=100, 
                    sim_func=cos_sim
                )
                for system in audio_systems
                if system in representations
            }
            results.update(audio_results)

       
        
        
      # Convert to genres and store
        for system, retrieved in results.items():
            system_type = "Base Line"
            if system in text_systems:
                system_type = "Text"
            elif system in audio_systems:
                system_type = "Audio"
                
            system_name = f"{system_type} {system.replace('_', ' ').title()}"
            system_data[f"{track}_{system}"] = {
                "system_name": f"{track} {system_name}",
                "query_genre": query_genres[track],
                "retrieved_genres": get_genre_from_ids(retrieved, genres),
                "dataset_genres": genres_list
            }
    
    return system_data
    

def main():
    
    """Main function to run the music retrieval evaluation"""
    # Track information (expandable)
    tracks_info_dict = {
        "track1": {"name": "Love Me", "artist": "The 1975"},
        # Add more tracks here, e.g.:
        # "track2": {"name": "Song Name", "artist": "Artist Name"}
    }
    
    # Load initial data
    
    """
    Please put data files in ""./data/" before use 
    """

    info = read_data("information")
    genres = read_data("genres")
    
    # create dictionary of represantations
    representations = {
        # Text embeddings
        "tfidf": read_data("lyrics_tf-idf"),
        "word2vec": read_data("lyrics_word2vec"),
        "bert": read_data("lyrics_bert"),
        
        # Audio embeddings
        "blf_correlation": read_data("blf_correlation"),
        "ivec256": read_data("ivec256"),
        "mfcc_stats": read_data("mfcc_stats"),
        "musicnn": read_data("musicnn"),
    }
    with open('result_mod.json', 'r') as f:
        res = json.load(f)
    
    # Initialize results dictionary if not present
    if not res:
        res = initialize_results(tracks_info.keys())
    print(info.columns)
  
    # Get track IDs and query genres
    track_ids = {track: get_id_from_info(info_dict["name"], info_dict["artist"], info) 
                for track, info_dict in tracks_info_dict.items()}
    query_genres = {track: get_genre_from_ids([track_id], genres)[0] 
                   for track, track_id in track_ids.items()}
    
    # Perform retrieval and evaluation
    retrieve_tracks(res, track_ids, info,representations)
    calculate_metrics(res, query_genres, genres)
    
    # Generate PR curve data and plot
    system_data = calculate_pr_data(track_ids, info, query_genres, genres, representations)
    plot_precision_recall_curve(system_data)

if __name__ == "__main__":
    main()

In [None]:
name = "Love Me"
artist = "The 1975"
id_track1 = get_id_from_info(name, artist, info)
name = 'One'
artist = 'U2'
id_track2 = get_id_from_info(name, artist, info)
name = 'Every Christmas'
artist = 'Kelly Clarkson'
id_track3 = get_id_from_info(name, artist, info)