<a href="https://colab.research.google.com/github/Viditk07-Bits/AudioAnalytics_S2-24_AIMLCZG527/blob/main/Group8_AudioAnalytics_Assignment2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Audio Analysis Assignment 2
# GROUP 8 - Members

1. VIDIT KUMAR KALE - (2023ac05613)
2. MAYANK GROVER - (2023ac05486)
3. AMIT KUMAR ANAND - (2023ac05670)

# Music Information Retrieval System

## Assignment Objective
This assignment implements a comprehensive Music Information Retrieval (MIR) system using Large Language Models (LLMs) and deep learning techniques. It includes music recommendation, genre classification, and semantic search applications, combining audio analysis with natural language processing.

## Dataset Setup
Using the Free Music Archive (FMA) dataset with audio files, metadata, and synthetic user data.

# 1: 1.3 Dataset Requirements

In [18]:
import subprocess
required_packages = ['faiss-cpu', 'transformers', 'sentence-transformers', 'tqdm', 'imblearn', 'librosa', 'nltk', 'matplotlib', 'seaborn', 'datasets']
for pkg in required_packages:
    try:
        __import__(pkg.replace('-', '_'))
    except ImportError:
        subprocess.run(['pip', 'install', pkg], check=True)
        print(f"Installed {pkg}")

import pandas as pd
import numpy as np
import os
from pathlib import Path
import librosa
import logging
from google.colab import drive
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from tqdm import tqdm
from multiprocessing import Pool
import random

# Download NLTK resources
nltk.download('punkt', quiet=True)
nltk.download('punkt_tab', quiet=True)
nltk.download('stopwords', quiet=True)

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logging.getLogger().setLevel(logging.INFO)

# Mount Google Drive
drive.mount('/content/drive', force_remount=True)

# Constants
DATA_PATH = "/content/fma/data"
METADATA_PATH = os.path.join(DATA_PATH, "metadata/tracks.csv")
ARTISTS_PATH = os.path.join(DATA_PATH, "metadata/artists.csv")
GENRES_PATH = os.path.join(DATA_PATH, "metadata/genres.csv")
LYRICS_PATH = os.path.join(DATA_PATH, "lyrics")
USER_DATA_PATH = os.path.join(DATA_PATH, "user_data/ratings.csv")
TAGS_PATH = os.path.join(DATA_PATH, "descriptions/tags.csv")
OUTPUT_DIR = "/content/outputs"
TEMP_DIR = "/content/fma"
MAX_TRACKS = 100
TIMEOUT_SECONDS = 3600
FMA_BASE_DIR = "/content/drive/MyDrive/fma_small/"
FMA_AUDIO_DIRS = [str(p) for p in Path(FMA_BASE_DIR).glob("*") if p.is_dir()]

if not FMA_AUDIO_DIRS:
    logging.warning("No subdirectories found in %s", FMA_BASE_DIR)
    print(f"Warning: No subdirectories found in {FMA_BASE_DIR}")

# Ensure output directory exists
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Genre-specific keywords for synthetic data
GENRE_KEYWORDS = {
    'Rock': ['electric guitar wails', 'rebellious spirit soars', 'grunge heart pounds', 'classic riffs ignite', 'indie soul rebels', 'punk fire explodes', 'rock anthem roars'],
    'Pop': ['infectious hooks dance', 'neon lights pulse', 'melodic dreams soar', 'upbeat rhythm shines', 'love story sparkles', 'dancefloor beats throb', 'pop fever rises'],
    'Jazz': ['saxophone weaves magic', 'improvised notes flow', 'bluesy soul swings', 'smooth grooves linger', 'jazz night whispers', 'rhythmic scat hums', 'cool vibes drift'],
    'Classical': ['orchestral swells rise', 'violin sings softly', 'piano echoes grace', 'symphonic waves crash', 'baroque harmony soars', 'elegant strings weave', 'timeless beauty unfolds'],
    'Hip-Hop': ['heavy beats drop hard', 'sharp rhymes cut deep', 'street stories unfold', 'flow rides the rhythm', 'urban pulse vibrates', 'mic drops with swagger', 'hip-hop reigns supreme'],
    'Electronic': ['synth pulses glow', 'techno beats surge', 'ambient waves drift', 'EDM sparks the night', 'futuristic sounds hum', 'electro vibes ignite', 'digital dreams pulse'],
    'Folk': ['acoustic chords strum', 'heartfelt tales weave', 'rustic paths wander', 'folk roots run deep', 'gentle melodies soothe', 'campfire stories sing', 'tradition lives on'],
    'Blues': ['guitar wails with soul', 'heartache spills over', 'raw blues cry out', 'delta notes resonate', 'mournful chords linger', 'blues spirit endures', 'emotional strings weep'],
    'Country': ['banjo twangs with pride', 'heartland stories sing', 'cowboy boots stomp', 'rural roads ramble', 'love songs ride free', 'country heart beats strong', 'honky-tonk nights shine'],
    'Reggae': ['rasta riddims sway', 'island vibes chill', 'roots reggae grooves', 'one love unites all', 'skank beat lifts high', 'irie spirit flows', 'dreadlocks dance free'],
    'International': ['world rhythms blend', 'exotic melodies soar', 'cultural beats pulse', 'global sounds unite', 'traditional chants echo', 'fusion vibes transcend', 'earth’s heartbeat sings'],
    'Instrumental': ['ambient chords float', 'strings weave dreams', 'piano paints silence', 'orchestral tides rise', 'melody speaks alone', 'instrumental soul soars', 'soundscapes breathe life'],
    'Experimental': ['avant-garde sounds twist', 'abstract beats morph', 'sonic boundaries break', 'unorthodox rhythms pulse', 'experimental vibes soar', 'sound art redefines', 'future notes unfold']
}

def download_fma_dataset():
    """Download and extract FMA small dataset only if necessary."""
    logging.info("Checking for FMA dataset...")
    mp3_files = set()
    for audio_dir in FMA_AUDIO_DIRS:
        if os.path.exists(audio_dir):
            files = [f for f in os.listdir(audio_dir) if f.endswith('.mp3')]
            mp3_files.update(f.replace('.mp3', '') for f in files)
            logging.info("Found %d MP3 files in %s", len(files), audio_dir)
    if os.path.exists(METADATA_PATH) and mp3_files:
        logging.info("Found metadata at %s and %d audio files in Google Drive", METADATA_PATH, len(mp3_files))
        return True, mp3_files
    logging.info("Attempting to download FMA small dataset...")
    max_retries = 3
    retry_delay = 5
    for attempt in range(1, max_retries + 1):
        try:
            os.makedirs(os.path.dirname(METADATA_PATH), exist_ok=True)
            for zip_file in ["/content/fma_small.zip", "/content/fma_metadata.zip"]:
                if os.path.exists(zip_file):
                    os.remove(zip_file)
            subprocess.run(["curl", "-s", "-o", "/content/fma_small.zip", "https://os.unil.cloud.switch.ch/fma/fma_small.zip"], check=True, timeout=TIMEOUT_SECONDS)
            subprocess.run(["curl", "-s", "-o", "/content/fma_metadata.zip", "https://os.unil.cloud.switch.ch/fma/fma_metadata.zip"], check=True, timeout=TIMEOUT_SECONDS)
            subprocess.run(["unzip", "-q", "/content/fma_small.zip", "-d", DATA_PATH], check=True, timeout=TIMEOUT_SECONDS)
            subprocess.run(["unzip", "-q", "/content/fma_metadata.zip", "-d", os.path.dirname(METADATA_PATH)], check=True, timeout=TIMEOUT_SECONDS)
            mp3_files = set()
            for audio_dir in FMA_AUDIO_DIRS:
                if os.path.exists(audio_dir):
                    files = [f for f in os.listdir(audio_dir) if f.endswith('.mp3')]
                    mp3_files.update(f.replace('.mp3', '') for f in files)
                    logging.info("Found %d MP3 files in %s after download", len(files), audio_dir)
            if mp3_files and os.path.exists(METADATA_PATH):
                logging.info("FMA dataset downloaded successfully. Found %d audio files.", len(mp3_files))
                return True, mp3_files
            else:
                logging.warning("No audio files or metadata found after download. Retrying...")
        except subprocess.CalledProcessError as e:
            logging.error("Download attempt %d failed: %s", attempt, str(e))
            if attempt == max_retries:
                logging.error("Max retries reached. Falling back to synthetic dataset.")
                return False, set()
            time.sleep(retry_delay)
    return False, set()

def generate_lyrics(args):
    """Generate richer synthetic lyrics."""
    track_id, genre, lyrics_path = args
    keywords = random.sample(GENRE_KEYWORDS[genre], min(3, len(GENRE_KEYWORDS[genre])))
    extra_words = random.sample(['freedom', 'journey', 'love', 'dream', 'sky', 'heart'], 2)
    verse_lines = [
        f"Feel the {keywords[0].split()[0]} in your {extra_words[0]}, let it take control.",
        f"{keywords[1].split()[0]} carries you to {extra_words[1]}, into the night and day.",
        f"With every {keywords[2].split()[0]} the heart beats strong, singing {genre}’s song."
    ]
    chorus_lines = [
        f"{keywords[1].split()[0]} vibes, we’re alive, dancing through the {extra_words[0]}.",
        f"{genre} spirit, feel it rise, reaching for the {extra_words[1]}.",
        f"Let the {keywords[0].split()[0]} flow, take us where we go."
    ]
    lyrics = f"{genre} song: {', '.join(keywords + extra_words)}.\n" + \
             f"Verse 1:\n{verse_lines[0]}\n{verse_lines[1]}\n{verse_lines[2]}\n" + \
             f"Chorus:\n{chorus_lines[0]}\n{chorus_lines[1]}\n{chorus_lines[2]}"
    with open(os.path.join(lyrics_path, f"{track_id}.txt"), 'w', encoding='utf-8') as f:
        f.write(lyrics)
    return track_id

def create_synthetic_dataset(mp3_files):
    """Create synthetic dataset only if necessary."""
    logging.info("Creating synthetic dataset...")
    print("Creating synthetic dataset...")
    os.makedirs(os.path.dirname(METADATA_PATH), exist_ok=True)
    os.makedirs(os.path.dirname(ARTISTS_PATH), exist_ok=True)
    os.makedirs(os.path.dirname(GENRES_PATH), exist_ok=True)
    os.makedirs(os.path.dirname(USER_DATA_PATH), exist_ok=True)
    os.makedirs(os.path.dirname(TAGS_PATH), exist_ok=True)
    os.makedirs(LYRICS_PATH, exist_ok=True)

    genres = list(GENRE_KEYWORDS.keys())
    df_tracks = pd.DataFrame({
        'track_id': [str(i).zfill(6) for i in range(1, MAX_TRACKS + 1)],
        'title': [f"Track_{i}" for i in range(1, MAX_TRACKS + 1)],
        'artist_id': [str(i).zfill(6) for i in range(1, MAX_TRACKS + 1)],  # Ensure string type
        'genre_id': [random.randint(1, 13) for _ in range(MAX_TRACKS)],
        'genre_top': [random.choice(genres) for _ in range(MAX_TRACKS)]
    })
    df_tracks.to_csv(METADATA_PATH, index=False)
    df_artists = pd.DataFrame({
        'artist_id': [str(i).zfill(6) for i in range(1, MAX_TRACKS + 1)],  # Ensure string type
        'artist_name': [f"Artist_{i}" for i in range(1, MAX_TRACKS + 1)]
    })
    df_artists.to_csv(ARTISTS_PATH, index=False)
    df_genres = pd.DataFrame({
        'genre_id': range(1, 14),
        'genre_name': genres
    })
    df_genres.to_csv(GENRES_PATH, index=False)

    genre_rating_bias = {
        'Pop': 0.2, 'Rock': 0.1, 'Jazz': -0.1, 'Classical': -0.2, 'Hip-Hop': 0.15,
        'Electronic': 0.1, 'Folk': -0.05, 'Blues': -0.1, 'Country': -0.05, 'Reggae': 0.05,
        'International': 0.05, 'Instrumental': -0.2, 'Experimental': -0.1
    }
    ratings = []
    for track_id, genre in zip(df_tracks['track_id'], df_tracks['genre_top']):
        for user_id in [f"user_{i+1}" for i in range(20)]:
            if np.random.random() < 0.9:
                rating = np.random.normal(3.0 + genre_rating_bias.get(genre, 0), 1.5)
                rating = min(max(rating, 1), 5)
                ratings.append({'user_id': user_id, 'track_id': track_id, 'rating': rating})
    ratings = pd.DataFrame(ratings)
    ratings.to_csv(USER_DATA_PATH, index=False)

    tags = []
    for track_id, genre in zip(df_tracks['track_id'], df_tracks['genre_top']):
        num_tags = random.randint(3, 6)
        track_tags = random.sample(GENRE_KEYWORDS.get(genre, ['generic']), min(num_tags, len(GENRE_KEYWORDS.get(genre, ['generic']))))
        tags.extend([{'track_id': track_id, 'tag': tag} for tag in track_tags])
    pd.DataFrame(tags).to_csv(TAGS_PATH, index=False)

    os.makedirs(LYRICS_PATH, exist_ok=True)
    lyrics_args = [(track_id, genre, LYRICS_PATH) for track_id, genre in zip(df_tracks['track_id'], df_tracks['genre_top'])]
    with Pool(processes=4) as pool:
        list(tqdm(pool.imap(generate_lyrics, lyrics_args), total=len(lyrics_args), desc="Generating synthetic lyrics"))

    print(f"Synthetic dataset created: Tracks shape: {df_tracks.shape}, Ratings shape: {ratings.shape}, Tags: {len(tags)}")
    logging.info("Synthetic dataset created: Tracks shape: %s, Ratings: %s, Tags: %d", df_tracks.shape, ratings.shape, len(tags))
    return df_tracks

def extract_audio_features(audio_path):
    """Extract audio features using Librosa or generate synthetic features."""
    try:
        if not audio_path or not os.path.exists(audio_path) or os.path.getsize(audio_path) < 100:
            return np.random.randn(26) * 0.1
        y, sr = librosa.load(audio_path, sr=22050)
        if len(y) == 0:
            return np.random.randn(26) * 0.1
        mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=12)
        chroma = librosa.feature.chroma_stft(y=y, sr=sr)
        spectral_centroid = librosa.feature.spectral_centroid(y=y, sr=sr)
        tempo = librosa.feature.tempo(y=y, sr=sr)
        tempo = float(tempo[0]) if isinstance(tempo, np.ndarray) else float(tempo)
        features = np.concatenate([
            np.mean(mfccs, axis=1),
            np.mean(chroma, axis=1),
            np.mean(spectral_centroid, axis=1),
            [tempo]
        ])
        return features
    except Exception as e:
        logging.warning("Error processing %s: %s", audio_path, str(e))
        return np.random.randn(26) * 0.1

def load_fma_data(audio_dirs, metadata_path, artists_path, genres_path, lyrics_path, tags_path):
    """Load FMA dataset with audio files from specified directories."""
    logging.info("Loading FMA data...")
    print("Loading FMA data...")

    # Check for metadata and audio files
    mp3_files = set()
    for audio_dir in audio_dirs:
        if os.path.exists(audio_dir):
            files = [f for f in os.listdir(audio_dir) if f.endswith('.mp3')]
            mp3_files.update(f.replace('.mp3', '') for f in files)
            logging.info("Found %d MP3 files in %s", len(files), audio_dir)
        else:
            logging.warning("Audio directory %s not found", audio_dir)
    if not mp3_files:
        logging.warning("No MP3 files found in any audio directories: %s", audio_dirs)
    if not os.path.exists(metadata_path):
        logging.warning("Metadata file not found at %s", metadata_path)

    if os.path.exists(metadata_path) and mp3_files:
        try:
            df_metadata = pd.read_csv(metadata_path, header=[0, 1], low_memory=False)
            df_metadata.columns = ['_'.join(col).strip() if isinstance(col, tuple) else col for col in df_metadata.columns]
            track_id_col = title_col = artist_id_col = genre_top_col = None
            for col in df_metadata.columns:
                col_lower = col.lower()
                if 'track_id' in col_lower or 'track.id' in col_lower:
                    track_id_col = col
                elif 'title' in col_lower:
                    title_col = col
                elif 'artist_id' in col_lower or 'artist.id' in col_lower:
                    artist_id_col = col
                elif 'genre_top' in col_lower or 'genre' in col_lower:
                    genre_top_col = col
            required_cols = [track_id_col, title_col, artist_id_col, genre_top_col]
            if None in required_cols:
                logging.error("Required columns missing in FMA metadata: %s", [c for c in required_cols if c is None])
                raise ValueError("Required columns missing in FMA metadata")
            df_metadata = df_metadata[[track_id_col, title_col, artist_id_col, genre_top_col]].dropna()
            df_metadata = df_metadata.rename(columns={
                track_id_col: 'track_id',
                title_col: 'title',
                artist_id_col: 'artist_id',
                genre_top_col: 'genre_top'
            })
            df_metadata['track_id'] = df_metadata['track_id'].astype(str).str.zfill(6)
            df_metadata['artist_id'] = df_metadata['artist_id'].astype(str).str.zfill(6)  # Ensure string type
            logging.info("Metadata track_ids (first 5): %s", list(df_metadata['track_id'])[:5])
            logging.info("Metadata artist_ids (first 5): %s", list(df_metadata['artist_id'])[:5])
            logging.info("MP3 file track_ids (first 5): %s", list(mp3_files)[:5])
            df_metadata = df_metadata[df_metadata['track_id'].isin(mp3_files)]
            if df_metadata.empty:
                logging.warning("No tracks in metadata match audio files: %s", list(mp3_files)[:5])
                print(f"Warning: No tracks in metadata match audio files: {list(mp3_files)[:5]}")
                df_metadata = create_synthetic_dataset(mp3_files)
            elif len(df_metadata) < MAX_TRACKS:
                logging.info("Found %d tracks, supplementing with synthetic data to reach %d", len(df_metadata), MAX_TRACKS)
                num_missing = MAX_TRACKS - len(df_metadata)
                genres = list(GENRE_KEYWORDS.keys())
                synthetic_tracks = pd.DataFrame({
                    'track_id': [str(i).zfill(6) for i in range(len(df_metadata) + 1, len(df_metadata) + num_missing + 1)],
                    'title': [f"Track_{i}" for i in range(len(df_metadata) + 1, len(df_metadata) + num_missing + 1)],
                    'artist_id': [str(i).zfill(6) for i in range(len(df_metadata) + 1, len(df_metadata) + num_missing + 1)],  # String type
                    'genre_top': [random.choice(genres) for _ in range(num_missing)]
                })
                df_metadata = pd.concat([df_metadata, synthetic_tracks], ignore_index=True)
            df_metadata['genre_id'] = df_metadata['genre_top'].map({g: i+1 for i, g in enumerate(GENRE_KEYWORDS.keys())}).fillna(random.randint(1, 13))
        except Exception as e:
            logging.error("Error loading FMA metadata: %s", str(e))
            df_metadata = create_synthetic_dataset(mp3_files)
    else:
        logging.warning("Metadata or audio files missing. Falling back to synthetic dataset.")
        df_metadata = create_synthetic_dataset(mp3_files)

    if os.path.exists(artists_path):
        df_artists = pd.read_csv(artists_path, on_bad_lines='skip', dtype={'artist_id': str})
        df_artists['artist_id'] = df_artists['artist_id'].astype(str).str.zfill(6)
        logging.info("Artists DataFrame artist_id dtype: %s", df_artists['artist_id'].dtype)
        logging.info("Artists DataFrame artist_ids (first 5): %s", list(df_artists['artist_id'])[:5])
    else:
        df_artists = pd.DataFrame({
            'artist_id': [str(i).zfill(6) for i in range(1, len(df_metadata) + 1)],  # String type
            'artist_name': [f"Artist_{i}" for i in range(1, len(df_metadata) + 1)]
        })
    logging.info("df_metadata artist_id dtype before merge: %s", df_metadata['artist_id'].dtype)
    df_metadata = pd.merge(
        df_metadata,
        df_artists,
        on='artist_id',
        how='left'
    )
    if os.path.exists(genres_path):
        df_genres = pd.read_csv(genres_path, on_bad_lines='skip')
    else:
        df_genres = pd.DataFrame({
            'genre_id': range(1, 14),
            'genre_name': list(GENRE_KEYWORDS.keys())
        })
    df_metadata = pd.merge(
        df_metadata,
        df_genres,
        on='genre_id',
        how='left'
    )
    df_metadata = df_metadata[['track_id', 'artist_name', 'title', 'genre_name']].rename(columns={'genre_name': 'genre'})
    df_metadata['track_id'] = df_metadata['track_id'].astype(str).str.zfill(6)

    valid_track_ids = set(df_metadata['track_id'])
    missing = mp3_files - valid_track_ids
    if missing:
        logging.warning("MP3 files found but not in metadata for %d track_ids: %s", len(missing), list(missing)[:5])
        print(f"Warning: MP3 files found but not in metadata for {len(missing)} track_ids: {list(missing)[:5]}")

    features = []
    for track_id in df_metadata['track_id']:
        audio_file = None
        for audio_dir in audio_dirs:
            potential_path = os.path.join(audio_dir, f"{track_id}.mp3")
            if os.path.exists(potential_path):
                audio_file = potential_path
                break
        audio_features = extract_audio_features(audio_file)
        features.append([track_id] + audio_features.tolist())
    feature_columns = ['track_id'] + [f'mfcc_{i+1}' for i in range(12)] + [f'chroma_{i+1}' for i in range(12)] + ['spectral_centroid', 'tempo']
    df_features = pd.DataFrame(features, columns=feature_columns)

    lyrics_dict = {}
    if os.path.exists(lyrics_path):
        for lyric_file in Path(lyrics_path).glob("*.txt"):
            track_id = lyric_file.stem
            if track_id in valid_track_ids:
                with open(lyric_file, 'r', encoding='utf-8') as f:
                    lyrics_dict[track_id] = f.read().strip() or 'music'
    else:
        logging.info("No lyrics found, generating synthetic lyrics...")
        os.makedirs(lyrics_path, exist_ok=True)
        lyrics_args = [(track_id, genre, lyrics_path) for track_id, genre in zip(df_metadata['track_id'], df_metadata['genre'])]
        with Pool(processes=4) as pool:
            list(tqdm(pool.imap(generate_lyrics, lyrics_args), total=len(lyrics_args), desc="Generating synthetic lyrics"))

    if tags_path and os.path.exists(tags_path):
        df_tags = pd.read_csv(tags_path, on_bad_lines='skip')
        for _, row in df_tags.iterrows():
            track_id = str(row['track_id']).zfill(6)
            if track_id in valid_track_ids:
                lyrics_dict[track_id] = lyrics_dict.get(track_id, '') + " " + str(row['tag'])

    logging.info("Metadata shape: %s, Features shape: %s, Lyrics dict size: %d", df_metadata.shape, df_features.shape, len(lyrics_dict))
    print(f"Metadata shape: {df_metadata.shape}, Features shape: {df_features.shape}, Lyrics dict size: {len(lyrics_dict)}")
    return df_metadata, df_features, lyrics_dict

Installed faiss-cpu
Mounted at /content/drive


# 2: Task A1 - LLM-Based Music Genre Classification

In [19]:
import torch
import torch.nn as nn
from sentence_transformers import SentenceTransformer
from transformers import BertForSequenceClassification, BertTokenizer, RobertaForSequenceClassification, RobertaTokenizer, pipeline, Trainer, TrainingArguments
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.decomposition import LatentDirichletAllocation
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from collections import Counter
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from datasets import Dataset
from tqdm import tqdm

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
OUTPUT_DIR = "/content/outputs"
TEXT_EMBEDDING_DIM = 384  # Dimension of all-MiniLM-L6-v2 embeddings

def generate_text_embeddings(lyrics_dict):
    """Generate semantic embeddings for lyrics."""
    logging.info("Generating text embeddings...")
    print("Generating text embeddings...")
    model = SentenceTransformer('all-MiniLM-L6-v2', device=DEVICE)
    embeddings = {}
    for track_id, text in tqdm(lyrics_dict.items(), desc="Generating text embeddings"):
        embeddings[track_id] = model.encode(text, convert_to_tensor=True, device=DEVICE).cpu().numpy()
    logging.info("Text embeddings generated for %d tracks.", len(embeddings))
    print(f"Text embeddings generated for {len(embeddings)} tracks.")
    return embeddings

def analyze_linguistic_patterns(df_metadata, lyrics_dict):
    """Analyze linguistic patterns and topics."""
    print("\nLinguistic Analysis:")
    stop_words = set(stopwords.words('english'))
    genre_words = {g: [] for g in df_metadata['genre'].unique()}
    for track_id, text in lyrics_dict.items():
        genre = df_metadata[df_metadata['track_id'] == track_id]['genre'].iloc[0]
        tokens = word_tokenize(text.lower())
        tokens = [t for t in tokens if t.isalpha() and t not in stop_words]
        genre_words[genre].extend(tokens)
    for genre, words in genre_words.items():
        top_words = Counter(words).most_common(5)
        print(f"{genre} top words: {top_words}")
        logging.info("%s top words: %s", genre, top_words)

    vectorizer = CountVectorizer(stop_words='english')
    X = vectorizer.fit_transform(lyrics_dict.values())
    lda = LatentDirichletAllocation(n_components=5, random_state=42)
    lda.fit(X)
    for i, topic in enumerate(lda.components_):
        print(f"Topic {i}: {[vectorizer.get_feature_names_out()[j] for j in topic.argsort()[-5:]]}")

class AudioTextBERTClassifier(nn.Module):
    """BERT with audio feature integration."""
    def __init__(self, bert_model, audio_dim, num_classes, hidden_dim=128):
        super().__init__()
        self.bert = bert_model
        self.audio_layer = nn.Linear(audio_dim, hidden_dim)
        self.combined_layer = nn.Linear(bert_model.config.hidden_size + hidden_dim, hidden_dim)
        self.dropout = nn.Dropout(0.3)
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, input_ids, attention_mask, audio_features):
        bert_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        text_out = bert_outputs.pooler_output
        audio_out = torch.relu(self.audio_layer(audio_features))
        combined = torch.cat([text_out, audio_out], dim=1)
        combined = torch.relu(self.combined_layer(combined))
        combined = self.dropout(combined)
        return self.fc(combined)

def bert_classifier(df_metadata, text_embeddings, audio_features, genres):
    """BERT-based genre classification with audio features."""
    print("\nTraining Audio-Text BERT Classifier...")
    logging.info("Initializing BERT classifier...")

    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    bert_model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=len(genres))
    model = AudioTextBERTClassifier(bert_model, audio_dim=audio_features.shape[1], num_classes=len(genres)).to(DEVICE)

    # Validate and clean input data
    valid_rows = df_metadata[['track_id', 'title', 'genre']].dropna(subset=['genre'])
    valid_rows = valid_rows[valid_rows['genre'].isin(genres)]
    if len(valid_rows) < len(df_metadata):
        logging.warning("Dropped %d rows due to missing or invalid genres", len(df_metadata) - len(valid_rows))
        print(f"Dropped {len(df_metadata) - len(valid_rows)} rows due to missing or invalid genres")

    # Align audio features with valid rows
    valid_track_ids = valid_rows['track_id'].values
    audio_features_df = pd.DataFrame(audio_features, columns=['track_id'] + [f'feature_{i}' for i in range(audio_features.shape[1]-1)])
    audio_features_df['track_id'] = df_metadata['track_id'].values  # Ensure track_id alignment
    audio_features_df = audio_features_df[audio_features_df['track_id'].isin(valid_track_ids)]
    audio_features = audio_features_df[[col for col in audio_features_df.columns if col != 'track_id']].values

    # Ensure alignment
    if len(audio_features) != len(valid_rows):
        logging.error("Mismatch after aligning audio features: %d features, %d rows", len(audio_features), len(valid_rows))
        raise ValueError(f"Mismatch after aligning audio features: {len(audio_features)} features, {len(valid_rows)} rows")

    # Prepare texts and labels
    texts = [f"{row['title']} {row['genre']}" for _, row in valid_rows.iterrows()]
    genres_list = genres.tolist()
    try:
        labels = np.array([genres_list.index(row['genre']) for _, row in valid_rows.iterrows()], dtype=np.int64)
    except ValueError as e:
        logging.error("Error mapping genres to indices: %s", str(e))
        raise ValueError(f"Error mapping genres to indices: {str(e)}")

    # Tokenize texts
    inputs = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=128)

    # Validate input lengths
    if len(inputs['input_ids']) != len(audio_features) or len(inputs['input_ids']) != len(labels):
        logging.error("Data mismatch: input_ids=%d, audio_features=%d, labels=%d",
                      len(inputs['input_ids']), len(audio_features), len(labels))
        raise ValueError(f"Data mismatch in bert_classifier inputs: input_ids={len(inputs['input_ids'])}, audio_features={len(audio_features)}, labels={len(labels)}")

    # Create dataset with strict validation
    dataset_list = []
    for i in range(len(valid_rows)):
        if i >= len(labels) or i >= len(audio_features) or i >= len(inputs['input_ids']):
            logging.warning("Skipping index %d due to index out of bounds", i)
            continue
        if not isinstance(labels[i], (int, np.integer)) or np.isnan(labels[i]):
            logging.warning("Invalid label at index %d: %s", i, labels[i])
            continue
        item = {
            'input_ids': inputs['input_ids'][i].numpy(),
            'attention_mask': inputs['attention_mask'][i].numpy(),
            'audio_features': audio_features[i].astype(np.float32),
            'labels': int(labels[i])  # Ensure label is an integer
        }
        dataset_list.append(item)

    # Convert to Dataset and validate
    if not dataset_list:
        logging.error("No valid dataset items created")
        raise ValueError("No valid dataset items created")
    dataset = Dataset.from_list(dataset_list)
    logging.info("Dataset created with %d items", len(dataset))
    print(f"Dataset created with {len(dataset)} items")

    # Validate dataset items
    for i, item in enumerate(dataset):
        required_keys = ['input_ids', 'attention_mask', 'audio_features', 'labels']
        missing_keys = [key for key in required_keys if key not in item]
        if missing_keys:
            logging.error("Item %d in dataset missing keys: %s", i, missing_keys)
            raise ValueError(f"Item {i} in dataset missing keys: {missing_keys}")
        if not isinstance(item['labels'], (int, np.integer)):
            logging.error("Item %d in dataset has non-integer label: %s (type: %s)", i, item['labels'], type(item['labels']))
            raise ValueError(f"Item {i} in dataset has non-integer label: {item['labels']}")

    # Split dataset
    dataset = dataset.train_test_split(test_size=0.2, seed=42)

    # Validate splits
    for split in ['train', 'test']:
        logging.info("Inspecting %s split: %d items", split, len(dataset[split]))
        if len(dataset[split]) == 0:
            logging.error("%s split is empty", split)
            raise ValueError(f"{split} split is empty")
        for i, item in enumerate(dataset[split]):
            required_keys = ['input_ids', 'attention_mask', 'audio_features', 'labels']
            missing_keys = [key for key in required_keys if key not in item]
            if missing_keys:
                logging.error("Item %d in %s split missing keys: %s", i, split, missing_keys)
                raise ValueError(f"Item {i} in %s split missing keys: {missing_keys}" % split)
            if not isinstance(item['labels'], (int, np.integer)):
                logging.error("Item %d in %s split has non-integer label: %s (type: %s)", i, split, item['labels'], type(item['labels']))
                raise ValueError(f"Item {i} in %s split has non-integer label: {item['labels']}" % split)

    def collate_fn(batch):
        try:
            valid_batch = []
            for i, item in enumerate(batch):
                required_keys = ['input_ids', 'attention_mask', 'audio_features', 'labels']
                missing_keys = [key for key in required_keys if key not in item]
                if missing_keys:
                    logging.warning("Skipping batch item %d due to missing keys: %s", i, missing_keys)
                    continue
                if not isinstance(item['labels'], (int, np.integer)):
                    logging.warning("Skipping batch item %d due to non-integer label: %s (type: %s)", i, item['labels'], type(item['labels']))
                    continue
                valid_batch.append(item)
                logging.debug("Batch item %d: keys=%s, label_type=%s, label=%s", i, list(item.keys()), type(item['labels']), item['labels'])

            if not valid_batch:
                logging.warning("No valid items in batch after filtering. Skipping batch.")
                return None  # Return None to indicate empty batch

            return {
                'input_ids': torch.tensor([item['input_ids'] for item in valid_batch], dtype=torch.long).to(DEVICE),
                'attention_mask': torch.tensor([item['attention_mask'] for item in valid_batch], dtype=torch.long).to(DEVICE),
                'audio_features': torch.tensor([item['audio_features'] for item in valid_batch], dtype=torch.float32).to(DEVICE),
                'labels': torch.tensor([item['labels'] for item in valid_batch], dtype=torch.long).to(DEVICE)
            }
        except Exception as e:
            logging.error("Collate function error: %s", str(e))
            raise RuntimeError(f"Collate function error: {str(e)}")

    training_args = TrainingArguments(
        output_dir=os.path.join(OUTPUT_DIR, "bert-finetuned"),
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        num_train_epochs=3,
        eval_strategy="epoch",
        save_strategy="epoch",
        learning_rate=2e-5,
        save_total_limit=2,
        load_best_model_at_end=True,
        metric_for_best_model='f1',
        report_to="none"  # Disable W&B logging
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset["train"],
        eval_dataset=dataset["test"],
        data_collator=collate_fn,
        compute_metrics=lambda pred: {
            'precision': precision_recall_fscore_support(pred.label_ids, pred.predictions.argmax(-1), average='weighted', zero_division=0)[0],
            'recall': precision_recall_fscore_support(pred.label_ids, pred.predictions.argmax(-1), average='weighted', zero_division=0)[1],
            'f1': precision_recall_fscore_support(pred.label_ids, pred.predictions.argmax(-1), average='weighted', zero_division=0)[2]
        }
    )

    try:
        trainer.train()
        eval_results = trainer.evaluate()
        logging.info("BERT Classifier Metrics: %s", eval_results)
        print(f"BERT Classifier Metrics: {eval_results}")
        return {
            'precision': eval_results.get('eval_precision', 0.0),
            'recall': eval_results.get('eval_recall', 0.0),
            'f1': eval_results.get('eval_f1', 0.0)
        }
    except Exception as e:
        logging.error("Training failed: %s", str(e))
        print(f"Training failed: {str(e)}")
        return {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}

def roberta_classifier(df_metadata, text_embeddings, genres):
    """RoBERTa-based genre classification."""
    print("\nTraining RoBERTa Classifier...")
    tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
    model = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=len(genres)).to(DEVICE)
    texts = [f"{row['title']} {row['genre']}" for _, row in df_metadata.iterrows()]
    inputs = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=128).to(DEVICE)
    genres_list = genres.tolist()
    labels = torch.tensor([genres_list.index(row['genre']) for _, row in df_metadata.iterrows()]).to(DEVICE)
    dataset = Dataset.from_dict({
        'input_ids': inputs['input_ids'].cpu().numpy(),
        'attention_mask': inputs['attention_mask'].cpu().numpy(),
        'labels': labels.cpu().numpy()
    })
    dataset = dataset.train_test_split(test_size=0.2, seed=42)
    training_args = TrainingArguments(
        output_dir=os.path.join(OUTPUT_DIR, "roberta-finetuned"),
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        num_train_epochs=3,
        eval_strategy="epoch",
        save_strategy="epoch",
        learning_rate=2e-5,
        save_total_limit=2,
        load_best_model_at_end=True
    )
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset["train"],
        eval_dataset=dataset["test"],
        compute_metrics=lambda pred: {
            'precision': precision_recall_fscore_support(pred.label_ids, pred.predictions.argmax(-1), average='weighted', zero_division=0)[0],
            'recall': precision_recall_fscore_support(pred.label_ids, pred.predictions.argmax(-1), average='weighted', zero_division=0)[1],
            'f1': precision_recall_fscore_support(pred.label_ids, pred.predictions.argmax(-1), average='weighted', zero_division=0)[2]
        }
    )
    trainer.train()
    eval_results = trainer.evaluate()
    return {'precision': eval_results['eval_precision'], 'recall': eval_results['eval_recall'], 'f1': eval_results['eval_f1']}

def zero_shot_classification(df_metadata, lyrics_dict):
    """Zero-shot genre classification."""
    print("\nZero-Shot Genre Classification:")
    classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli", device=0 if torch.cuda.is_available() else -1)
    genres = df_metadata['genre'].unique().tolist()
    results = []
    for _, row in df_metadata.head(5).iterrows():
        text = row['title'] + " " + lyrics_dict.get(row['track_id'], "")
        scores = classifier(text, candidate_labels=genres, multi_label=False)
        results.append((row['track_id'], scores['labels'][0], scores['scores'][0], row['genre']))
        print(f"Track {row['track_id']}: Predicted={scores['labels'][0]} ({scores['scores'][0]:.4f}), True={row['genre']}")
    accuracy = sum(1 for _, pred, _, true in results if pred == true) / len(results)
    print(f"Zero-Shot Accuracy: {accuracy:.4f}")
    return results

def few_shot_classification(df_metadata, lyrics_dict):
    """Few-shot genre classification."""
    print("\nFew-Shot Genre Classification:")
    tokenizer = BertTokenizer.from_pretrained('distilbert-base-uncased')
    model = BertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=len(df_metadata['genre'].unique())).to(DEVICE)
    genres = df_metadata['genre'].unique().tolist()
    few_shot_data = []
    for genre in genres:
        genre_tracks = df_metadata[df_metadata['genre'] == genre].head(5)
        for _, row in genre_tracks.iterrows():
            text = row['title'] + " " + lyrics_dict.get(row['track_id'], "")
            inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128).to(DEVICE)
            few_shot_data.append((inputs, genres.index(genre)))
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
    for epoch in range(5):
        for inputs, label in few_shot_data:
            optimizer.zero_grad()
            outputs = model(**inputs, labels=torch.tensor([label]).to(DEVICE))
            loss = outputs.loss
            loss.backward()
            optimizer.step()
    model.eval()
    results = []
    for _, row in df_metadata.head(5).iterrows():
        text = row['title'] + " " + lyrics_dict.get(row['track_id'], "")
        inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128).to(DEVICE)
        with torch.no_grad():
            outputs = model(**inputs)
            predicted = torch.argmax(outputs.logits, dim=1).item()
        results.append((row['track_id'], genres[predicted], row['genre']))
        print(f"Track {row['track_id']}: Predicted={genres[predicted]}, True={row['genre']}")
    accuracy = sum(1 for _, pred, true in results if pred == true) / len(results)
    print(f"Few-Shot Accuracy: {accuracy:.4f}")
    return results

# 3: Task A2 - Transformer-Based Audio Classification

In [20]:
import torch
from transformers import AutoModelForAudioClassification, AutoFeatureExtractor, Trainer, TrainingArguments
from datasets import Dataset
import librosa
import numpy as np
import logging

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
OUTPUT_DIR = "/content/outputs"

def ast_classifier(audio_files, labels, genres, df_metadata):
    """AST-based audio classification."""
    print("\nSetting up AST Classifier...")
    try:
        feature_extractor = AutoFeatureExtractor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.2")
        model = AutoModelForAudioClassification.from_pretrained(
            "MIT/ast-finetuned-audioset-10-10-0.2",
            num_labels=len(genres),
            label2id={genre: idx for idx, genre in enumerate(genres)},
            id2label={idx: genre for idx, genre in enumerate(genres)}
        ).to(DEVICE)
    except Exception as e:
        logging.error("Failed to load AST model: %s", str(e))
        return {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}
    valid_audio_files = [f for f in audio_files if f and os.path.exists(f) and os.path.getsize(f) > 100]
    valid_labels = [labels[i] for i, f in enumerate(audio_files) if f in valid_audio_files]
    def preprocess_audio(audio_file, label):
        try:
            y, sr = librosa.load(audio_file, sr=16000, duration=5.0)
            inputs = feature_extractor(y, sampling_rate=16000, return_tensors="np", padding="max_length", max_length=80000)
            return {
                "input_values": inputs.input_values[0],
                "labels": label
            }
        except:
            return None
    data = [preprocess_audio(f, l) for f, l in zip(valid_audio_files, valid_labels)]
    data = [d for d in data if d is not None]
    if not data:
        logging.error("No valid audio data.")
        return {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}
    dataset = Dataset.from_dict({
        "input_values": [d["input_values"] for d in data],
        "labels": [d["labels"] for d in data]
    })
    dataset = dataset.train_test_split(test_size=0.2, seed=42)
    training_args = TrainingArguments(
        output_dir=os.path.join(OUTPUT_DIR, "ast-finetuned"),
        per_device_train_batch_size=2,
        per_device_eval_batch_size=2,
        num_train_epochs=3,
        eval_strategy="steps",
        save_strategy="steps",
        save_steps=500,
        logging_steps=100,
        learning_rate=3e-5,
        save_total_limit=2,
        load_best_model_at_end=True
    )
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset["train"],
        eval_dataset=dataset["test"],
        compute_metrics=lambda pred: {
            'precision': precision_recall_fscore_support(pred.label_ids, pred.predictions.argmax(-1), average='weighted', zero_division=0)[0],
            'recall': precision_recall_fscore_support(pred.label_ids, pred.predictions.argmax(-1), average='weighted', zero_division=0)[1],
            'f1': precision_recall_fscore_support(pred.label_ids, pred.predictions.argmax(-1), average='weighted', zero_division=0)[2]
        }
    )
    trainer.train()
    eval_results = trainer.evaluate()
    return {'precision': eval_results['eval_precision'], 'recall': eval_results['eval_recall'], 'f1': eval_results['eval_f1']}

# 4: Task A3 - LLM-Enhanced Music Information Retrieval

In [21]:
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from scipy.spatial.distance import cosine
from sklearn.metrics import average_precision_score, ndcg_score
from sklearn.preprocessing import StandardScaler
import pandas as pd
import logging
import torch

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TEXT_EMBEDDING_DIM = 384  # Dimension of all-MiniLM-L6-v2 embeddings

def music_search(query, df, text_embeddings, audio_features_df, model, scaler, k=10):
    """Content-based music discovery with retrieval metrics."""
    logging.info("Performing music search for query: %s", query)
    print(f"Performing music search for query: {query}")
    query_embedding = model.encode(query, convert_to_tensor=True, device=DEVICE).cpu().numpy()
    query_embedding = query_embedding / np.linalg.norm(query_embedding)
    text_embedding_matrix = np.array([text_embeddings.get(tid, np.zeros(TEXT_EMBEDDING_DIM)) for tid in df['track_id']])
    text_embedding_matrix = text_embedding_matrix / np.linalg.norm(text_embedding_matrix, axis=1, keepdims=True)
    text_index = faiss.IndexFlatIP(text_embedding_matrix.shape[1])
    text_index.add(text_embedding_matrix.astype(np.float32))
    text_scores, text_indices = text_index.search(query_embedding.reshape(1, -1).astype(np.float32), k=20)
    audio_features = audio_features_df[[col for col in audio_features_df.columns if col != 'track_id']].values
    audio_features = scaler.transform(np.nan_to_num(audio_features))
    audio_features = audio_features / np.linalg.norm(audio_features, axis=1, keepdims=True)
    combined_scores = {}
    for idx, score in zip(text_indices[0], text_scores[0]):
        track_id = df['track_id'].iloc[idx]
        genre = df[df['track_id'] == track_id]['genre'].iloc[0]
        genre_weight = 1.5 if genre == 'Rock' else 1.0
        combined_scores[track_id] = 0.7 * score * genre_weight
        audio_idx = audio_features_df.index[audio_features_df['track_id'] == track_id].tolist()
        if audio_idx:
            audio_idx = audio_idx[0]
            genre_tracks = df[df['genre'] == genre]['track_id']
            genre_audio_features = audio_features[audio_features_df['track_id'].isin(genre_tracks)]
            if len(genre_audio_features) > 0:
                avg_genre_audio = np.mean(genre_audio_features, axis=0)
                audio_similarity = 1 - cosine(audio_features[audio_idx], avg_genre_audio)
                combined_scores[track_id] += 0.3 * audio_similarity * genre_weight
    top_tracks = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)[:k]
    result_ids = [track_id for track_id, _ in top_tracks]
    results = df[df['track_id'].isin(result_ids)][['track_id', 'artist_name', 'title', 'genre']]
    relevant_tracks = df[df['genre'] == 'Rock']['track_id'].values
    y_true = [1 if track_id in relevant_tracks else 0 for track_id in result_ids]
    y_scores = [combined_scores[track_id] for track_id in result_ids]
    precision = sum(y_true) / len(y_true) if len(y_true) > 0 else 0.0
    recall = sum(y_true) / len(relevant_tracks) if len(relevant_tracks) > 0 else 0.0
    map_score = average_precision_score(y_true, y_scores) if sum(y_true) > 0 else 0.0
    ndcg = ndcg_score(np.array(y_true).reshape(1, -1), np.array(y_scores).reshape(1, -1), k=k) if sum(y_true) > 0 else 0.0
    diversity = len(set(results['genre'])) / len(df['genre'].unique()) if len(df['genre'].unique()) > 0 else 1.0
    novelty = len(set(result_ids) - set(df['track_id'].iloc[train_indices])) / len(result_ids) if len(result_ids) > 0 else 1.0
    print(f"\nSearch Results for '{query}':")
    print(results.to_string(index=False))
    print(f"Retrieval Metrics: Precision@{k}: {precision:.4f}, Recall@{k}: {recall:.4f}, MAP: {map_score:.4f}, NDCG@{k}: {ndcg:.4f}, Diversity: {diversity:.4f}, Novelty: {novelty:.4f}")
    return results, {'precision@k': precision, 'recall@k': recall, 'map': map_score, 'ndcg@k': ndcg, 'diversity': diversity, 'novelty': novelty}

# 5: Task A3 - LLM-Powered Music Recommendation Systems

In [22]:
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.decomposition import TruncatedSVD
from sklearn.metrics import average_precision_score, ndcg_score
from collections import Counter
from imblearn.over_sampling import SMOTE
import logging

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_EPOCHS_REC = 5
BATCH_SIZE = 16
TEXT_EMBEDDING_DIM = 384

class HybridRecommender(nn.Module):
    """Content-based recommender."""
    def __init__(self, audio_dim, text_dim, hidden_dim=128):
        super().__init__()
        self.audio_layer = nn.Linear(audio_dim, hidden_dim)
        self.text_layer = nn.Linear(text_dim, hidden_dim)
        self.dropout = nn.Dropout(0.3)
        self.fc = nn.Linear(hidden_dim * 2, 1)

    def forward(self, audio_features, text_features):
        audio_out = torch.relu(self.audio_layer(audio_features))
        text_out = torch.relu(self.text_layer(text_features))
        combined = torch.cat([audio_out, text_out], dim=1)
        combined = self.dropout(combined)
        return torch.sigmoid(self.fc(combined))

def train_recommender(model, train_loader, criterion, optimizer):
    """Train the recommender."""
    model.train()
    total_loss = 0
    for audio_features, text_features, ratings in train_loader:
        audio_features, text_features, ratings = audio_features.to(DEVICE), text_features.to(DEVICE), ratings.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(audio_features, text_features)
        loss = criterion(outputs, ratings.unsqueeze(1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

def evaluate_recommender(model, test_loader, df_metadata, test_idx_resampled, original_indices_resampled, genres, k=10):
    """Evaluate recommender."""
    model.eval()
    precisions, recalls, maps, ndcgs, diversities = [], [], [], [], []
    global train_indices
    if 'train_indices' not in globals():
        train_indices = []
    train_track_ids = set(df_metadata.iloc[train_indices]['track_id'])
    with torch.no_grad():
        for batch_idx, (audio_features, text_features, ratings) in enumerate(test_loader):
            audio_features, text_features, ratings = audio_features.to(DEVICE), text_features.to(DEVICE), ratings.to(DEVICE)
            outputs = model(audio_features, text_features)
            binary_ratings = (ratings > 0.6).float()
            top_k = torch.topk(outputs, k=min(k, outputs.size(0)), dim=0).indices.flatten()
            batch_start = batch_idx * test_loader.batch_size
            batch_end = batch_start + len(audio_features)
            batch_test_idx_resampled = test_idx_resampled[batch_start:batch_end]
            batch_original_indices = original_indices_resampled[batch_test_idx_resampled]
            top_k_original_indices = batch_original_indices[top_k.cpu()]
            top_k_ids = df_metadata.iloc[top_k_original_indices]['track_id']
            top_k_genres = df_metadata.iloc[top_k_original_indices]['genre']
            relevant = binary_ratings[top_k]
            precision = relevant.mean().item()
            recall = relevant.sum().item() / binary_ratings.sum().item() if binary_ratings.sum() > 0 else 0.0
            precisions.append(precision)
            recalls.append(recall)
            if binary_ratings.sum() > 0:
                map_score = average_precision_score(binary_ratings.cpu().numpy(), outputs.cpu().numpy())
                ndcg = ndcg_score(binary_ratings.cpu().numpy().reshape(1, -1), outputs.cpu().numpy().reshape(1, -1), k=k)
            else:
                map_score, ndcg = 0.0, 0.0
            maps.append(map_score)
            ndcgs.append(ndcg)
            diversity = len(set(top_k_genres)) / len(genres) if len(genres) > 0 else 1.0
            diversities.append(diversity)
            novelty = 1 - len(set(top_k_ids).intersection(train_track_ids)) / len(top_k_ids) if len(top_k_ids) > 0 else 1.0
    return {
        'precision@k': np.mean(precisions) if precisions else 0.0,
        'recall@k': np.mean(recalls) if recalls else 0.0,
        'map': np.mean(maps) if maps else 0.0,
        'ndcg@k': np.mean(ndcgs) if ndcgs else 0.0,
        'diversity': np.mean(diversities) if diversities else 1.0,
        'novelty': novelty
    }

def collaborative_filtering(user_data, df_metadata, k=10):
    """Collaborative filtering using SVD."""
    print("\nTraining Collaborative Filtering Model (SVD)...")
    valid_track_ids = set(df_metadata['track_id'].astype(str).str.zfill(6))
    user_data = user_data[user_data['track_id'].isin(valid_track_ids)]
    if user_data.empty:
        logging.error("No valid user data.")
        return {'precision@k': 0.0}
    train_data, test_data = train_test_split(user_data, test_size=0.2, random_state=42)
    user_item_matrix = train_data.pivot(index='user_id', columns='track_id', values='rating').fillna(0)
    if user_item_matrix.empty:
        logging.error("Empty user-item matrix.")
        return {'precision@k': 0.0}
    svd = TruncatedSVD(n_components=20, random_state=42)
    user_features = svd.fit_transform(user_item_matrix)
    item_features = svd.components_
    predictions = np.dot(user_features, item_features)
    predicted_ratings = pd.DataFrame(predictions, index=user_item_matrix.index, columns=user_item_matrix.columns)
    test_matrix = test_data.pivot(index='user_id', columns='track_id', values='rating').fillna(0)
    shared_tracks = set(test_matrix.columns).intersection(set(predicted_ratings.columns))
    precisions = []
    for user_id in test_matrix.index:
        if user_id in predicted_ratings.index:
            true_ratings = test_matrix.loc[user_id]
            pred_ratings = predicted_ratings.loc[user_id]
            valid_top_k = pred_ratings[pred_ratings.index.isin(shared_tracks)].sort_values(ascending=False).index[:k]
            if len(valid_top_k) == 0:
                precisions.append(0.0)
                continue
            relevant = (true_ratings[valid_top_k] > 3.0).astype(int)
            precision = relevant.mean()
            precisions.append(precision)
    return {'precision@k': np.mean(precisions) if precisions else 0.0}

# 6: Task C5 - Evaluation

In [23]:
import torch
import torch.nn as nn
import numpy as np
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix
from sklearn.model_selection import KFold
from imblearn.over_sampling import SMOTE
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
import logging
import pandas as pd
from sklearn.preprocessing import StandardScaler

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
OUTPUT_DIR = "/content/outputs"
NUM_EPOCHS_CLS = 10
BATCH_SIZE = 16
TEXT_EMBEDDING_DIM = 384

class GenreClassifier(nn.Module):
    """Custom genre classifier."""
    def __init__(self, audio_dim, text_dim, num_classes, hidden_dim=128):
        super().__init__()
        self.audio_layer = nn.Linear(audio_dim, hidden_dim)
        self.text_layer = nn.Linear(text_dim, hidden_dim)
        self.dropout = nn.Dropout(0.3)
        self.fc = nn.Linear(hidden_dim * 2, num_classes)

    def forward(self, audio_features, text_features):
        audio_out = torch.relu(self.audio_layer(audio_features))
        text_out = torch.relu(self.text_layer(text_features))
        combined = torch.cat([audio_out, text_out], dim=1)
        combined = self.dropout(combined)
        return self.fc(combined)

def train_classifier(model, train_loader, criterion, optimizer):
    """Train the classifier."""
    model.train()
    total_loss = 0
    for audio_features, text_features, labels in train_loader:
        audio_features, text_features, labels = audio_features.to(DEVICE), text_features.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(audio_features, text_features)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

def evaluate_classifier(model, test_loader, genres):
    """Evaluate classifier with error analysis."""
    model.eval()
    y_true, y_pred = [], []
    with torch.no_grad():
        for audio_features, text_features, labels in test_loader:
            audio_features, text_features, labels = audio_features.to(DEVICE), text_features.to(DEVICE), labels.to(DEVICE)
            outputs = model(audio_features, text_features)
            _, predicted = torch.max(outputs, 1)
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(predicted.cpu().numpy())
    metrics = precision_recall_fscore_support(y_true, y_pred, average='weighted', zero_division=0)
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=genres, yticklabels=genres)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.savefig(os.path.join(OUTPUT_DIR, 'confusion_matrix.png'))
    plt.close()
    errors = [(i, genres[true], genres[pred]) for i, (true, pred) in enumerate(zip(y_true, y_pred)) if true != pred]
    error_counts = Counter((true, pred) for _, true, pred in errors)
    print("\nMisclassification Patterns:")
    for (true, pred), count in error_counts.most_common():
        print(f"True: {true}, Predicted: {pred}, Count: {count}")
    return {'precision': metrics[0], 'recall': metrics[1], 'f1': metrics[2]}

def analyze_feature_contribution(model, test_loader, feature_columns):
    """Analyze audio feature contributions."""
    print("\nFeature Contribution Analysis:")
    model.eval()
    contributions = {col: [] for col in feature_columns}
    with torch.no_grad():
        for audio_features, text_features, _ in test_loader:
            audio_features, text_features = audio_features.to(DEVICE), text_features.to(DEVICE)
            baseline_output = model(audio_features, text_features)
            for i, col in enumerate(feature_columns):
                modified_features = audio_features.clone()
                modified_features[:, i] = 0
                modified_output = model(modified_features, text_features)
                diff = torch.mean(torch.abs(baseline_output - modified_output)).item()
                contributions[col].append(diff)
    for col in feature_columns:
        print(f"{col}: Mean contribution = {np.mean(contributions[col]):.4f}")
    return contributions

def main_evaluation():
    """Evaluation orchestration for classification and recommendation."""
    logging.info("Starting evaluation...")
    print("Starting evaluation...")
    df_metadata, df_features, lyrics_dict = load_fma_data(FMA_AUDIO_DIRS, METADATA_PATH, ARTISTS_PATH, GENRES_PATH, LYRICS_PATH, TAGS_PATH)
    if df_metadata.empty or df_features.empty or not lyrics_dict:
        logging.error("Data loading failed.")
        return
    text_embeddings = generate_text_embeddings(lyrics_dict)
    valid_track_ids = df_features['track_id'].tolist()
    df_metadata = df_metadata[df_metadata['track_id'].isin(valid_track_ids)]
    text_features = np.array([text_embeddings.get(tid, np.zeros(TEXT_EMBEDDING_DIM)) for tid in df_metadata['track_id']])
    scaler = StandardScaler()
    audio_features = scaler.fit_transform(df_features[[col for col in df_features.columns if col != 'track_id']].values)
    audio_features = np.nan_to_num(audio_features)
    user_data = pd.read_csv(USER_DATA_PATH)
    user_data['track_id'] = user_data['track_id'].astype(str).str.zfill(6)
    user_data = user_data[user_data['track_id'].isin(df_metadata['track_id'])]
    ratings_agg = user_data.groupby('track_id')['rating'].mean().reindex(df_metadata['track_id']).fillna(3.0).values
    ratings = np.clip(ratings_agg / 5.0, 0.0, 1.0)
    X = np.hstack([audio_features, text_features])
    if X.shape[0] != len(ratings):
        logging.error("Mismatch in samples: X has %d samples, ratings has %d", X.shape[0], len(ratings))
        raise ValueError(f"Mismatch in samples: X has {X.shape[0]} samples, ratings has {len(ratings)}")
    original_indices = np.arange(len(X))
    binary_ratings = (ratings > 0.5).astype(int)
    class_counts = Counter(binary_ratings)
    min_class_size = min(class_counts.values())
    logging.info("Class distribution: %s", class_counts)
    print(f"Class distribution: {class_counts}")
    X_resampled, ratings_resampled, original_indices_resampled = X, binary_ratings, original_indices
    if min_class_size >= 2:
        k_neighbors = min(5, min_class_size - 1)
        smote = SMOTE(random_state=42, k_neighbors=k_neighbors)
        try:
            X_resampled, ratings_resampled = smote.fit_resample(X, binary_ratings)
            original_indices_resampled = np.concatenate([
                original_indices[binary_ratings == 0],
                original_indices[binary_ratings == 1],
                np.random.choice(original_indices[binary_ratings == 0], size=len(X_resampled) - len(X), replace=True)
            ])
            shuffle_indices = np.random.permutation(len(X_resampled))
            X_resampled = X_resampled[shuffle_indices]
            ratings_resampled = ratings_resampled[shuffle_indices]
            original_indices_resampled = original_indices_resampled[shuffle_indices]
        except ValueError as e:
            logging.error("SMOTE failed: %s", str(e))
            print(f"SMOTE failed: {str(e)}. Using original data.")
    kf = KFold(n_splits=5, shuffle=True, random_state=42)
    rec_metrics_all = []
    global train_indices
    for fold, (train_idx_resampled, test_idx_resampled) in enumerate(kf.split(X_resampled)):
        print(f"\nFold {fold+1}")
        train_indices = original_indices_resampled[train_idx_resampled]
        X_train_rec, X_test_rec = X_resampled[train_idx_resampled], X_resampled[test_idx_resampled]
        y_train_rec, y_test_rec = ratings_resampled[train_idx_resampled], ratings_resampled[test_idx_resampled]
        train_dataset_rec = torch.utils.data.TensorDataset(
            torch.tensor(X_train_rec[:, :audio_features.shape[1]], dtype=torch.float32),
            torch.tensor(X_train_rec[:, audio_features.shape[1]:], dtype=torch.float32),
            torch.tensor(y_train_rec, dtype=torch.float32)
        )
        train_loader_rec = torch.utils.data.DataLoader(train_dataset_rec, batch_size=BATCH_SIZE, shuffle=True)
        test_dataset_rec = torch.utils.data.TensorDataset(
            torch.tensor(X_test_rec[:, :audio_features.shape[1]], dtype=torch.float32),
            torch.tensor(X_test_rec[:, audio_features.shape[1]:], dtype=torch.float32),
            torch.tensor(y_test_rec, dtype=torch.float32)
        )
        test_loader_rec = torch.utils.data.DataLoader(test_dataset_rec, batch_size=BATCH_SIZE)
        recommender = HybridRecommender(audio_dim=audio_features.shape[1], text_dim=TEXT_EMBEDDING_DIM).to(DEVICE)
        criterion_rec = nn.BCELoss()
        optimizer_rec = torch.optim.Adam(recommender.parameters(), lr=0.001, weight_decay=1e-4)
        print("\nTraining Recommender Model...")
        for epoch in range(NUM_EPOCHS_REC):
            loss = train_recommender(recommender, train_loader_rec, criterion_rec, optimizer_rec)
            print(f"Recommendation Epoch {epoch+1}, Loss: {loss:.4f}")
        rec_metrics = evaluate_recommender(recommender, test_loader_rec, df_metadata, test_idx_resampled, original_indices_resampled, df_metadata['genre'].unique())
        rec_metrics_all.append(rec_metrics)
        print(f"Recommendation Metrics: {rec_metrics}")
    avg_rec_metrics = {k: np.mean([m[k] for m in rec_metrics_all]) for k in rec_metrics_all[0]}
    print(f"Average Recommendation Metrics: {avg_rec_metrics}")
    cf_metrics = collaborative_filtering(user_data, df_metadata)
    print(f"Collaborative Filtering Metrics: {cf_metrics}")
    genres = df_metadata['genre'].unique()
    genre_to_idx = {g: i for i, g in enumerate(genres)}
    labels = df_metadata['genre'].map(genre_to_idx).values
    cls_metrics_all = []
    for fold, (train_idx_resampled, test_idx_resampled) in enumerate(kf.split(X_resampled)):
        print(f"\nFold {fold+1}")
        X_train_cls, X_test_cls = X_resampled[train_idx_resampled], X_resampled[test_idx_resampled]
        y_train_cls = labels[original_indices_resampled[train_idx_resampled]]
        y_test_cls = labels[original_indices_resampled[test_idx_resampled]]
        train_dataset_cls = torch.utils.data.TensorDataset(
            torch.tensor(X_train_cls[:, :audio_features.shape[1]], dtype=torch.float32),
            torch.tensor(X_train_cls[:, audio_features.shape[1]:], dtype=torch.float32),
            torch.tensor(y_train_cls, dtype=torch.long)
        )
        train_loader_cls = torch.utils.data.DataLoader(train_dataset_cls, batch_size=BATCH_SIZE, shuffle=True)
        test_dataset_cls = torch.utils.data.TensorDataset(
            torch.tensor(X_test_cls[:, :audio_features.shape[1]], dtype=torch.float32),
            torch.tensor(X_test_cls[:, audio_features.shape[1]:], dtype=torch.float32),
            torch.tensor(y_test_cls, dtype=torch.long)
        )
        test_loader_cls = torch.utils.data.DataLoader(test_dataset_cls, batch_size=BATCH_SIZE)
        classifier = GenreClassifier(
            audio_dim=audio_features.shape[1],
            text_dim=TEXT_EMBEDDING_DIM,
            num_classes=len(genres)
        ).to(DEVICE)
        criterion_cls = nn.CrossEntropyLoss()
        optimizer_cls = torch.optim.Adam(classifier.parameters(), lr=0.001, weight_decay=1e-4)
        print("\nTraining Genre Classifier Model...")
        for epoch in range(NUM_EPOCHS_CLS):
            loss = train_classifier(classifier, train_loader_cls, criterion_cls, optimizer_cls)
            print(f"Classification Epoch {epoch+1}, Loss: {loss:.4f}")
        cls_metrics = evaluate_classifier(classifier, test_loader_cls, genres)
        cls_metrics_all.append(cls_metrics)
        print(f"Classification Metrics: {cls_metrics}")
    avg_cls_metrics = {k: np.mean([m[k] for m in cls_metrics_all]) for k in cls_metrics_all[0]}
    print(f"Average Classification Metrics: {avg_cls_metrics}")

# Main Execution

In [24]:
def main():
    """Main function orchestrating all tasks."""
    logging.info("Starting main execution...")
    print("Starting main execution...")
    df_metadata, df_features, lyrics_dict = load_fma_data(FMA_AUDIO_DIRS, METADATA_PATH, ARTISTS_PATH, GENRES_PATH, LYRICS_PATH, TAGS_PATH)
    if df_metadata.empty or df_features.empty or not lyrics_dict:
        logging.error("Data loading failed.")
        return
    text_embeddings = generate_text_embeddings(lyrics_dict)
    analyze_linguistic_patterns(df_metadata, lyrics_dict)
    zero_shot_classification(df_metadata, lyrics_dict)
    few_shot_classification(df_metadata, lyrics_dict)
    valid_track_ids = df_features['track_id'].tolist()
    df_metadata = df_metadata[df_metadata['track_id'].isin(valid_track_ids)]
    text_features = np.array([text_embeddings.get(tid, np.zeros(TEXT_EMBEDDING_DIM)) for tid in df_metadata['track_id']])
    scaler = StandardScaler()
    audio_features = scaler.fit_transform(df_features[[col for col in df_features.columns if col != 'track_id']].values)
    audio_features = np.nan_to_num(audio_features)
    user_data = pd.read_csv(USER_DATA_PATH)
    user_data['track_id'] = user_data['track_id'].astype(str).str.zfill(6)
    user_data = user_data[user_data['track_id'].isin(df_metadata['track_id'])]
    ratings_agg = user_data.groupby('track_id')['rating'].mean().reindex(df_metadata['track_id']).fillna(3.0).values
    ratings = np.clip(ratings_agg / 5.0, 0.0, 1.0)
    X = np.hstack([audio_features, text_features])
    if X.shape[0] != len(ratings):
        logging.error("Mismatch in samples: X has %d samples, ratings has %d", X.shape[0], len(ratings))
        raise ValueError(f"Mismatch in samples: X has {X.shape[0]} samples, ratings has {len(ratings)}")
    original_indices = np.arange(len(X))
    binary_ratings = (ratings > 0.5).astype(int)
    class_counts = Counter(binary_ratings)
    min_class_size = min(class_counts.values())
    logging.info("Class distribution: %s", class_counts)
    print(f"Class distribution: {class_counts}")
    X_resampled, ratings_resampled, original_indices_resampled = X, binary_ratings, original_indices
    if min_class_size >= 2:
        k_neighbors = min(5, min_class_size - 1)
        smote = SMOTE(random_state=42, k_neighbors=k_neighbors)
        try:
            X_resampled, ratings_resampled = smote.fit_resample(X, binary_ratings)
            original_indices_resampled = np.concatenate([
                original_indices[binary_ratings == 0],
                original_indices[binary_ratings == 1],
                np.random.choice(original_indices[binary_ratings == 0], size=len(X_resampled) - len(X), replace=True)
            ])
            shuffle_indices = np.random.permutation(len(X_resampled))
            X_resampled = X_resampled[shuffle_indices]
            ratings_resampled = ratings_resampled[shuffle_indices]
            original_indices_resampled = original_indices_resampled[shuffle_indices]
        except ValueError as e:
            logging.error("SMOTE failed: %s", str(e))
            print(f"SMOTE failed: {str(e)}. Using original data.")
    kf = KFold(n_splits=5, shuffle=True, random_state=42)
    rec_metrics_all = []
    global train_indices
    for fold, (train_idx_resampled, test_idx_resampled) in enumerate(kf.split(X_resampled)):
        print(f"\nFold {fold+1}")
        train_indices = original_indices_resampled[train_idx_resampled]
        X_train_rec, X_test_rec = X_resampled[train_idx_resampled], X_resampled[test_idx_resampled]
        y_train_rec, y_test_rec = ratings_resampled[train_idx_resampled], ratings_resampled[test_idx_resampled]
        train_dataset_rec = torch.utils.data.TensorDataset(
            torch.tensor(X_train_rec[:, :audio_features.shape[1]], dtype=torch.float32),
            torch.tensor(X_train_rec[:, audio_features.shape[1]:], dtype=torch.float32),
            torch.tensor(y_train_rec, dtype=torch.float32)
        )
        train_loader_rec = torch.utils.data.DataLoader(train_dataset_rec, batch_size=BATCH_SIZE, shuffle=True)
        test_dataset_rec = torch.utils.data.TensorDataset(
            torch.tensor(X_test_rec[:, :audio_features.shape[1]], dtype=torch.float32),
            torch.tensor(X_test_rec[:, audio_features.shape[1]:], dtype=torch.float32),
            torch.tensor(y_test_rec, dtype=torch.float32)
        )
        test_loader_rec = torch.utils.data.DataLoader(test_dataset_rec, batch_size=BATCH_SIZE)
        recommender = HybridRecommender(audio_dim=audio_features.shape[1], text_dim=TEXT_EMBEDDING_DIM).to(DEVICE)
        criterion_rec = nn.BCELoss()
        optimizer_rec = torch.optim.Adam(recommender.parameters(), lr=0.001, weight_decay=1e-4)
        print("\nTraining Recommender Model...")
        for epoch in range(NUM_EPOCHS_REC):
            loss = train_recommender(recommender, train_loader_rec, criterion_rec, optimizer_rec)
            print(f"Recommendation Epoch {epoch+1}, Loss: {loss:.4f}")
        rec_metrics = evaluate_recommender(recommender, test_loader_rec, df_metadata, test_idx_resampled, original_indices_resampled, df_metadata['genre'].unique())
        rec_metrics_all.append(rec_metrics)
        print(f"Recommendation Metrics: {rec_metrics}")
    avg_rec_metrics = {k: np.mean([m[k] for m in rec_metrics_all]) for k in rec_metrics_all[0]}
    print(f"Average Recommendation Metrics: {avg_rec_metrics}")
    cf_metrics = collaborative_filtering(user_data, df_metadata)
    print(f"Collaborative Filtering Metrics: {cf_metrics}")
    genres = df_metadata['genre'].unique()
    genre_to_idx = {g: i for i, g in enumerate(genres)}
    labels = df_metadata['genre'].map(genre_to_idx).values
    cls_metrics_all = []
    for fold, (train_idx_resampled, test_idx_resampled) in enumerate(kf.split(X_resampled)):
        print(f"\nFold {fold+1}")
        X_train_cls, X_test_cls = X_resampled[train_idx_resampled], X_resampled[test_idx_resampled]
        y_train_cls = labels[original_indices_resampled[train_idx_resampled]]
        y_test_cls = labels[original_indices_resampled[test_idx_resampled]]
        train_dataset_cls = torch.utils.data.TensorDataset(
            torch.tensor(X_train_cls[:, :audio_features.shape[1]], dtype=torch.float32),
            torch.tensor(X_train_cls[:, audio_features.shape[1]:], dtype=torch.float32),
            torch.tensor(y_train_cls, dtype=torch.long)
        )
        train_loader_cls = torch.utils.data.DataLoader(train_dataset_cls, batch_size=BATCH_SIZE, shuffle=True)
        test_dataset_cls = torch.utils.data.TensorDataset(
            torch.tensor(X_test_cls[:, :audio_features.shape[1]], dtype=torch.float32),
            torch.tensor(X_test_cls[:, audio_features.shape[1]:], dtype=torch.float32),
            torch.tensor(y_test_cls, dtype=torch.long)
        )
        test_loader_cls = torch.utils.data.DataLoader(test_dataset_cls, batch_size=BATCH_SIZE)
        classifier = GenreClassifier(
            audio_dim=audio_features.shape[1],
            text_dim=TEXT_EMBEDDING_DIM,
            num_classes=len(genres)
        ).to(DEVICE)
        criterion_cls = nn.CrossEntropyLoss()
        optimizer_cls = torch.optim.Adam(classifier.parameters(), lr=0.001, weight_decay=1e-4)
        print("\nTraining Genre Classifier Model...")
        for epoch in range(NUM_EPOCHS_CLS):
            loss = train_classifier(classifier, train_loader_cls, criterion_cls, optimizer_cls)
            print(f"Classification Epoch {epoch+1}, Loss: {loss:.4f}")
        cls_metrics = evaluate_classifier(classifier, test_loader_cls, genres)
        cls_metrics_all.append(cls_metrics)
        print(f"Classification Metrics: {cls_metrics}")
    avg_cls_metrics = {k: np.mean([m[k] for m in cls_metrics_all]) for k in cls_metrics_all[0]}
    print(f"Average Classification Metrics: {avg_cls_metrics}")
    bert_metrics = bert_classifier(df_metadata, text_embeddings, audio_features, genres)
    print(f"BERT Classifier Metrics: {bert_metrics}")
    roberta_metrics = roberta_classifier(df_metadata, text_embeddings, genres)
    print(f"RoBERTa Classifier Metrics: {roberta_metrics}")
    audio_files = []
    for track_id in df_metadata['track_id']:
        audio_file = None
        for audio_dir in FMA_AUDIO_DIRS:
            potential_path = os.path.join(audio_dir, f"{track_id}.mp3")
            if os.path.exists(potential_path):
                audio_file = potential_path
                break
        audio_files.append(audio_file if audio_file else "")
    ast_metrics = ast_classifier(audio_files, labels, genres, df_metadata)
    print(f"AST Classifier Metrics: {ast_metrics}")
    search_model = SentenceTransformer('all-MiniLM-L6-v2', device=DEVICE)
    query = "upbeat rock songs"
    _, search_metrics = music_search(query, df_metadata, text_embeddings, df_features, search_model, scaler)
    print(f"Search Metrics: {search_metrics}")
    feature_columns = [col for col in df_features.columns if col != 'track_id']
    analyze_feature_contribution(classifier, test_loader_cls, feature_columns)

if __name__ == "__main__":
    main()

INFO:root:Starting main execution...
INFO:root:Loading FMA data...
INFO:root:Found 10 MP3 files in /content/drive/MyDrive/fma_small/008
INFO:root:Found 62 MP3 files in /content/drive/MyDrive/fma_small/000
INFO:root:Found 60 MP3 files in /content/drive/MyDrive/fma_small/001
INFO:root:Found 71 MP3 files in /content/drive/MyDrive/fma_small/006
INFO:root:Found 54 MP3 files in /content/drive/MyDrive/fma_small/003
INFO:root:Found 4 MP3 files in /content/drive/MyDrive/fma_small/002
INFO:root:Found 71 MP3 files in /content/drive/MyDrive/fma_small/004
INFO:root:Found 17 MP3 files in /content/drive/MyDrive/fma_small/005
INFO:root:Found 37 MP3 files in /content/drive/MyDrive/fma_small/007
INFO:root:Found 23 MP3 files in /content/drive/MyDrive/fma_small/009
INFO:root:Found 65 MP3 files in /content/drive/MyDrive/fma_small/010
INFO:root:Found 90 MP3 files in /content/drive/MyDrive/fma_small/011
INFO:root:Found 67 MP3 files in /content/drive/MyDrive/fma_small/012
INFO:root:Found 61 MP3 files in /cont

Starting main execution...
Loading FMA data...


INFO:root:Found 74 MP3 files in /content/drive/MyDrive/fma_small/054
INFO:root:Found 76 MP3 files in /content/drive/MyDrive/fma_small/056
INFO:root:Found 53 MP3 files in /content/drive/MyDrive/fma_small/057
INFO:root:Found 36 MP3 files in /content/drive/MyDrive/fma_small/058
INFO:root:Found 61 MP3 files in /content/drive/MyDrive/fma_small/059
INFO:root:Found 48 MP3 files in /content/drive/MyDrive/fma_small/060
INFO:root:Found 32 MP3 files in /content/drive/MyDrive/fma_small/061
INFO:root:Found 56 MP3 files in /content/drive/MyDrive/fma_small/062
INFO:root:Found 47 MP3 files in /content/drive/MyDrive/fma_small/063
INFO:root:Found 93 MP3 files in /content/drive/MyDrive/fma_small/064
INFO:root:Found 24 MP3 files in /content/drive/MyDrive/fma_small/065
INFO:root:Found 36 MP3 files in /content/drive/MyDrive/fma_small/066
INFO:root:Found 57 MP3 files in /content/drive/MyDrive/fma_small/067
INFO:root:Found 58 MP3 files in /content/drive/MyDrive/fma_small/068
INFO:root:Found 75 MP3 files in /c



INFO:root:Metadata shape: (100, 4), Features shape: (100, 27), Lyrics dict size: 98
INFO:root:Generating text embeddings...
INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: all-MiniLM-L6-v2


Metadata shape: (100, 4), Features shape: (100, 27), Lyrics dict size: 98
Generating text embeddings...


Generating text embeddings:   0%|          | 0/98 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

  return forward_call(*args, **kwargs)
Generating text embeddings:   1%|          | 1/98 [00:00<00:11,  8.25it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:   2%|▏         | 2/98 [00:00<00:11,  8.10it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:   3%|▎         | 3/98 [00:00<00:10,  8.76it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:   4%|▍         | 4/98 [00:00<00:10,  8.67it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:   5%|▌         | 5/98 [00:00<00:10,  8.69it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:   6%|▌         | 6/98 [00:00<00:10,  8.71it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:   7%|▋         | 7/98 [00:00<00:11,  8.10it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:   8%|▊         | 8/98 [00:00<00:10,  8.22it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:   9%|▉         | 9/98 [00:01<00:10,  8.53it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  10%|█         | 10/98 [00:01<00:10,  8.36it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  11%|█         | 11/98 [00:01<00:10,  8.44it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  12%|█▏        | 12/98 [00:01<00:10,  8.35it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  13%|█▎        | 13/98 [00:01<00:10,  7.98it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  14%|█▍        | 14/98 [00:01<00:10,  8.17it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  15%|█▌        | 15/98 [00:01<00:10,  8.28it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  16%|█▋        | 16/98 [00:01<00:10,  7.77it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  17%|█▋        | 17/98 [00:02<00:10,  7.98it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  18%|█▊        | 18/98 [00:02<00:10,  7.95it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  19%|█▉        | 19/98 [00:02<00:09,  8.04it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  20%|██        | 20/98 [00:02<00:09,  8.13it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  21%|██▏       | 21/98 [00:02<00:09,  8.46it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  22%|██▏       | 22/98 [00:02<00:08,  8.62it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  23%|██▎       | 23/98 [00:02<00:08,  8.50it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  24%|██▍       | 24/98 [00:02<00:08,  8.27it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  26%|██▌       | 25/98 [00:03<00:09,  8.02it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  27%|██▋       | 26/98 [00:03<00:08,  8.07it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  28%|██▊       | 27/98 [00:03<00:08,  8.15it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  29%|██▊       | 28/98 [00:03<00:08,  8.32it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  30%|██▉       | 29/98 [00:03<00:08,  8.18it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  31%|███       | 30/98 [00:03<00:08,  8.28it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  32%|███▏      | 31/98 [00:03<00:08,  8.22it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  33%|███▎      | 32/98 [00:03<00:08,  7.89it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  34%|███▎      | 33/98 [00:04<00:08,  7.73it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  35%|███▍      | 34/98 [00:04<00:07,  8.07it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  36%|███▌      | 35/98 [00:04<00:07,  8.05it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  37%|███▋      | 36/98 [00:04<00:07,  8.17it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  38%|███▊      | 37/98 [00:04<00:07,  8.25it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  39%|███▉      | 38/98 [00:04<00:07,  8.19it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  40%|███▉      | 39/98 [00:04<00:07,  8.26it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  41%|████      | 40/98 [00:04<00:07,  8.18it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  42%|████▏     | 41/98 [00:04<00:06,  8.15it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  43%|████▎     | 42/98 [00:05<00:07,  7.86it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  44%|████▍     | 43/98 [00:05<00:06,  7.97it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  45%|████▍     | 44/98 [00:05<00:06,  8.14it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  46%|████▌     | 45/98 [00:05<00:06,  8.09it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  47%|████▋     | 46/98 [00:05<00:06,  8.18it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  48%|████▊     | 47/98 [00:05<00:06,  8.01it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  49%|████▉     | 48/98 [00:05<00:06,  7.88it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  50%|█████     | 49/98 [00:05<00:06,  7.96it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  51%|█████     | 50/98 [00:06<00:06,  7.64it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  52%|█████▏    | 51/98 [00:06<00:06,  7.73it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  53%|█████▎    | 52/98 [00:06<00:05,  7.83it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  54%|█████▍    | 53/98 [00:06<00:05,  7.88it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  55%|█████▌    | 54/98 [00:06<00:05,  8.03it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  56%|█████▌    | 55/98 [00:06<00:05,  7.85it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  57%|█████▋    | 56/98 [00:06<00:05,  7.91it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  58%|█████▊    | 57/98 [00:07<00:05,  7.95it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  59%|█████▉    | 58/98 [00:07<00:05,  7.79it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  60%|██████    | 59/98 [00:07<00:05,  7.45it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  61%|██████    | 60/98 [00:07<00:04,  7.72it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  62%|██████▏   | 61/98 [00:07<00:04,  7.86it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  63%|██████▎   | 62/98 [00:07<00:04,  7.85it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  64%|██████▍   | 63/98 [00:07<00:04,  7.92it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  65%|██████▌   | 64/98 [00:07<00:04,  7.67it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  66%|██████▋   | 65/98 [00:08<00:04,  7.84it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  67%|██████▋   | 66/98 [00:08<00:04,  7.79it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  68%|██████▊   | 67/98 [00:08<00:04,  7.36it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  69%|██████▉   | 68/98 [00:08<00:03,  7.51it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  70%|███████   | 69/98 [00:08<00:03,  7.73it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  71%|███████▏  | 70/98 [00:08<00:03,  7.71it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  72%|███████▏  | 71/98 [00:08<00:03,  7.78it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  73%|███████▎  | 72/98 [00:09<00:03,  7.27it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  74%|███████▍  | 73/98 [00:09<00:03,  7.28it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  76%|███████▌  | 74/98 [00:09<00:03,  6.70it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  77%|███████▋  | 75/98 [00:09<00:03,  6.30it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  78%|███████▊  | 76/98 [00:09<00:03,  6.09it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  79%|███████▊  | 77/98 [00:09<00:03,  6.24it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  80%|███████▉  | 78/98 [00:09<00:03,  6.23it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  81%|████████  | 79/98 [00:10<00:02,  6.40it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  82%|████████▏ | 80/98 [00:10<00:02,  6.70it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  83%|████████▎ | 81/98 [00:10<00:02,  6.37it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  84%|████████▎ | 82/98 [00:10<00:02,  6.54it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  85%|████████▍ | 83/98 [00:10<00:02,  6.77it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  86%|████████▌ | 84/98 [00:10<00:02,  6.92it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  87%|████████▋ | 85/98 [00:10<00:01,  7.06it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  88%|████████▊ | 86/98 [00:11<00:01,  7.04it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  89%|████████▉ | 87/98 [00:11<00:01,  6.78it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  90%|████████▉ | 88/98 [00:11<00:01,  6.56it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  91%|█████████ | 89/98 [00:11<00:01,  6.66it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  92%|█████████▏| 90/98 [00:11<00:01,  6.65it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  93%|█████████▎| 91/98 [00:11<00:01,  6.91it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  94%|█████████▍| 92/98 [00:12<00:00,  6.56it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  95%|█████████▍| 93/98 [00:12<00:00,  6.41it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  96%|█████████▌| 94/98 [00:12<00:00,  6.74it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  97%|█████████▋| 95/98 [00:12<00:00,  6.76it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  98%|█████████▊| 96/98 [00:12<00:00,  6.98it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings:  99%|█████████▉| 97/98 [00:12<00:00,  6.88it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generating text embeddings: 100%|██████████| 98/98 [00:12<00:00,  7.59it/s]
INFO:root:Text embeddings generated for 98 tracks.


Text embeddings generated for 98 tracks.

Linguistic Analysis:


INFO:root:Classical top words: [('heart', 36), ('song', 30), ('feel', 30), ('let', 30), ('take', 30)]
INFO:root:Pop top words: [('love', 26), ('heart', 23), ('song', 22), ('feel', 22), ('let', 22)]
INFO:root:Rock top words: [('heart', 20), ('rock', 16), ('song', 16), ('feel', 16), ('let', 16)]
INFO:root:International top words: [('instrumental', 15), ('song', 12), ('dream', 12), ('feel', 12), ('let', 12)]
INFO:root:Country top words: [('heart', 27), ('song', 20), ('feel', 20), ('let', 20), ('take', 20)]
INFO:root:Reggae top words: [('song', 16), ('feel', 16), ('let', 16), ('take', 16), ('vibes', 15)]
INFO:root:Jazz top words: [('heart', 13), ('love', 13), ('song', 12), ('dream', 12), ('feel', 12)]
INFO:root:Electronic top words: [('heart', 13), ('song', 12), ('feel', 12), ('let', 12), ('take', 12)]
INFO:root:Experimental top words: [('heart', 20), ('song', 14), ('feel', 14), ('let', 14), ('take', 14)]
INFO:root:Blues top words: [('dream', 21), ('heart', 17), ('song', 16), ('feel', 16),

Classical top words: [('heart', 36), ('song', 30), ('feel', 30), ('let', 30), ('take', 30)]
Pop top words: [('love', 26), ('heart', 23), ('song', 22), ('feel', 22), ('let', 22)]
Rock top words: [('heart', 20), ('rock', 16), ('song', 16), ('feel', 16), ('let', 16)]
International top words: [('instrumental', 15), ('song', 12), ('dream', 12), ('feel', 12), ('let', 12)]
Country top words: [('heart', 27), ('song', 20), ('feel', 20), ('let', 20), ('take', 20)]
Reggae top words: [('song', 16), ('feel', 16), ('let', 16), ('take', 16), ('vibes', 15)]
Jazz top words: [('heart', 13), ('love', 13), ('song', 12), ('dream', 12), ('feel', 12)]
Electronic top words: [('heart', 13), ('song', 12), ('feel', 12), ('let', 12), ('take', 12)]
Experimental top words: [('heart', 20), ('song', 14), ('feel', 14), ('let', 14), ('take', 14)]
Blues top words: [('dream', 21), ('heart', 17), ('song', 16), ('feel', 16), ('let', 16)]
Hip-Hop top words: [('heart', 15), ('song', 8), ('feel', 8), ('let', 8), ('take', 8)]


Device set to use cpu


Track 000002: Predicted=Classical (0.9485), True=Classical
Track 000005: Predicted=Pop (0.9346), True=Pop
Track 000010: Predicted=Rock (0.9489), True=Rock
Track 000004: Predicted=Hip-Hop (0.9675), True=International
Track 000005: Predicted=Pop (0.9346), True=Country
Zero-Shot Accuracy: 0.6000

Few-Shot Genre Classification:


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DistilBertTokenizer'. 
The class this function is called from is 'BertTokenizer'.
You are using a model of type distilbert to instantiate a model of type bert. This is not supported for all configurations of models and can yield errors.


model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['bert.embeddings.LayerNorm.bias', 'bert.embeddings.LayerNorm.weight', 'bert.embeddings.position_embeddings.weight', 'bert.embeddings.token_type_embeddings.weight', 'bert.embeddings.word_embeddings.weight', 'bert.encoder.layer.0.attention.output.LayerNorm.bias', 'bert.encoder.layer.0.attention.output.LayerNorm.weight', 'bert.encoder.layer.0.attention.output.dense.bias', 'bert.encoder.layer.0.attention.output.dense.weight', 'bert.encoder.layer.0.attention.self.key.bias', 'bert.encoder.layer.0.attention.self.key.weight', 'bert.encoder.layer.0.attention.self.query.bias', 'bert.encoder.layer.0.attention.self.query.weight', 'bert.encoder.layer.0.attention.self.value.bias', 'bert.encoder.layer.0.attention.self.value.weight', 'bert.encoder.layer.0.intermediate.dense.bias', 'bert.encoder.layer.0.intermediate.dense.weight', 'bert.encoder.layer.0.outp

Track 000002: Predicted=Blues, True=Classical
Track 000005: Predicted=Blues, True=Pop
Track 000010: Predicted=Blues, True=Rock
Track 000004: Predicted=Blues, True=International


INFO:root:Class distribution: Counter({np.int64(1): 93, np.int64(0): 7})


Track 000005: Predicted=Blues, True=Country
Few-Shot Accuracy: 0.0000
Class distribution: Counter({np.int64(1): 93, np.int64(0): 7})

Fold 1

Training Recommender Model...
Recommendation Epoch 1, Loss: 0.6829
Recommendation Epoch 2, Loss: 0.6379
Recommendation Epoch 3, Loss: 0.6004
Recommendation Epoch 4, Loss: 0.5570
Recommendation Epoch 5, Loss: 0.5219
Recommendation Metrics: {'precision@k': np.float64(0.5888888786236445), 'recall@k': np.float64(0.9583333333333334), 'map': np.float64(0.986111111111111), 'ndcg@k': np.float64(0.973401820479184), 'diversity': np.float64(0.4358974358974359), 'novelty': 0.33333333333333337}

Fold 2

Training Recommender Model...
Recommendation Epoch 1, Loss: 0.6848
Recommendation Epoch 2, Loss: 0.6362
Recommendation Epoch 3, Loss: 0.5948
Recommendation Epoch 4, Loss: 0.5451
Recommendation Epoch 5, Loss: 0.4989
Recommendation Metrics: {'precision@k': np.float64(0.6666666567325592), 'recall@k': np.float64(0.9666666666666667), 'map': np.float64(0.95534271284

INFO:root:Initializing BERT classifier...



Misclassification Patterns:
True: Pop, Predicted: Jazz, Count: 6
True: Classical, Predicted: Jazz, Count: 5
True: Reggae, Predicted: Classical, Count: 3
True: Rock, Predicted: Classical, Count: 3
True: Country, Predicted: Classical, Count: 2
True: Rock, Predicted: Jazz, Count: 2
True: Pop, Predicted: Classical, Count: 2
True: Instrumental, Predicted: Jazz, Count: 2
True: Blues, Predicted: Classical, Count: 1
True: Folk, Predicted: Classical, Count: 1
True: Jazz, Predicted: Classical, Count: 1
True: Electronic, Predicted: Jazz, Count: 1
True: International, Predicted: Classical, Count: 1
True: Reggae, Predicted: Jazz, Count: 1
Classification Metrics: {'precision': 0.047665847665847666, 'recall': 0.16216216216216217, 'f1': 0.07335907335907337}
Average Classification Metrics: {'precision': np.float64(0.05676366307785119), 'recall': np.float64(0.16642958748221906), 'f1': np.float64(0.08142766826977353)}

Training Audio-Text BERT Classifier...


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
INFO:root:Dataset created with 100 items
INFO:root:Inspecting train split: 80 items
INFO:root:Inspecting test split: 20 items


Dataset created with 100 items


ERROR:root:Training failed: argument of type 'NoneType' is not iterable


Training failed: argument of type 'NoneType' is not iterable
BERT Classifier Metrics: {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}

Training RoBERTa Classifier...


Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  return forward_call(*args, **kwargs)


Epoch,Training Loss,Validation Loss,Precision,Recall,F1
1,No log,2.558,0.0025,0.05,0.004762
2,No log,2.529307,0.123684,0.2,0.107576
3,No log,2.506878,0.125,0.25,0.142857


  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)


RoBERTa Classifier Metrics: {'precision': 0.125, 'recall': 0.25, 'f1': 0.14285714285714285}


ERROR:root:Failed to load AST model: MIT/ast-finetuned-audioset-10-10-0.2 is not a local folder and is not a valid model identifier listed on 'https://huggingface.co/models'
If this is a private repository, make sure to pass a token having permission to this repo either by logging in with `hf auth login` or by passing `token=<your_token>`
INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: all-MiniLM-L6-v2



Setting up AST Classifier...
AST Classifier Metrics: {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}


INFO:root:Performing music search for query: upbeat rock songs


Performing music search for query: upbeat rock songs


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

  return forward_call(*args, **kwargs)



Search Results for 'upbeat rock songs':
track_id artist_name    title        genre
  000011   Artist_11 Track_11         Jazz
  000016   Artist_16 Track_16       Reggae
  000019   Artist_19 Track_19         Rock
  000036   Artist_36 Track_36         Jazz
  000042   Artist_42 Track_42      Hip-Hop
  000045   Artist_45 Track_45         Folk
  000069   Artist_69 Track_69    Classical
  000079   Artist_79 Track_79         Rock
  000082   Artist_82 Track_82      Country
  000090   Artist_90 Track_90 Instrumental
Retrieval Metrics: Precision@10: 0.2000, Recall@10: 0.2500, MAP: 0.8333, NDCG@10: 0.9197, Diversity: 0.6154, Novelty: 0.2000
Search Metrics: {'precision@k': 0.2, 'recall@k': 0.25, 'map': np.float64(0.8333333333333333), 'ndcg@k': np.float64(0.9197207891481877), 'diversity': 0.6153846153846154, 'novelty': 0.2}

Feature Contribution Analysis:
mfcc_1: Mean contribution = 0.0062
mfcc_2: Mean contribution = 0.0070
mfcc_3: Mean contribution = 0.0035
mfcc_4: Mean contribution = 0.0091
mfcc