<a href="https://colab.research.google.com/github/Viditk07-Bits/AudioAnalytics_S2-24_AIMLCZG527/blob/main/AA_Assignment2_Final.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Music Information Retrieval System

## Assignment Objective
This assignment implements a comprehensive Music Information Retrieval (MIR) system using Large Language Models (LLMs) and deep learning techniques. It includes music recommendation, genre classification, and semantic search applications, combining audio analysis with natural language processing.

## Dataset Setup
Using the Free Music Archive (FMA) dataset with audio files, metadata, and synthetic user data.

In [9]:
# Task A1, A2, A3, A5, C5: Music Genre Classification, Recommendation, and Retrieval
# Implements audio and text-based genre classification, transformer-based audio classification,
# content-based music discovery, personalized recommendation, and comprehensive evaluation.

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from sentence_transformers import SentenceTransformer
from transformers import pipeline, AutoModelForAudioClassification, AutoFeatureExtractor, Trainer, TrainingArguments
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix, average_precision_score, ndcg_score
from sklearn.decomposition import NMF
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.spatial.distance import cosine
import os
from pathlib import Path
import librosa
import hashlib
import subprocess
import logging
import random
from collections import Counter
from google.colab import drive
from datasets import Dataset  # Added for fine-tuning

# Setup logging for detailed debugging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Mount Google Drive for persistent storage
drive.mount('/content/drive', force_remount=True)

# Constants
DATA_PATH = "/content/fma/data"
AUDIO_PATH = os.path.join(DATA_PATH, "audio_files")
METADATA_PATH = os.path.join(DATA_PATH, "metadata/tracks.csv")
ARTISTS_PATH = os.path.join(DATA_PATH, "metadata/artists.csv")
GENRES_PATH = os.path.join(DATA_PATH, "metadata/genres.csv")
LYRICS_PATH = os.path.join(DATA_PATH, "lyrics")
USER_DATA_PATH = os.path.join(DATA_PATH, "user_data/ratings.csv")
TAGS_PATH = os.path.join(DATA_PATH, "descriptions/tags.csv")
OUTPUT_DIR = "/content/outputs"
TEMP_DIR = "/content/fma"
NUM_EPOCHS_REC = 5
NUM_EPOCHS_CLS = 10
BATCH_SIZE = 32
MAX_TRACKS = 500
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Ensure output directory exists
os.makedirs(OUTPUT_DIR, exist_ok=True)

def setup_fma_dataset():
    """Task A1, A3, A5: Download and set up FMA dataset with tracks, genres, ratings, and synthetic lyrics."""
    print("=== Dataset Setup ===")
    logging.info("Starting dataset setup...")
    os.makedirs(TEMP_DIR, exist_ok=True)
    os.makedirs(DATA_PATH, exist_ok=True)
    os.makedirs(AUDIO_PATH, exist_ok=True)
    os.makedirs(os.path.dirname(METADATA_PATH), exist_ok=True)
    os.makedirs(LYRICS_PATH, exist_ok=True)
    os.makedirs(os.path.dirname(USER_DATA_PATH), exist_ok=True)
    os.makedirs(os.path.dirname(TAGS_PATH), exist_ok=True)

    try:
        os.chdir(TEMP_DIR)
        logging.info("Changed to TEMP_DIR: %s", TEMP_DIR)
        if not os.path.exists(os.path.join(TEMP_DIR, "fma")):
            logging.info("Cloning FMA repository...")
            print("Cloning FMA repository...")
            subprocess.run(["git", "clone", "https://github.com/mdeff/fma.git"], check=True, capture_output=True, text=True)
        os.chdir(os.path.join(TEMP_DIR, "fma"))
        logging.info("Changed to FMA directory")

        # Define paths for FMA dataset files
        fma_small_zip = "fma_small.zip"
        fma_metadata_zip = "fma_metadata.zip"

        # Check if files already exist to avoid redundant downloads
        if os.path.exists(fma_small_zip):
            logging.info("fma_small.zip already exists, skipping download.")
            print("fma_small.zip already exists, skipping download.")
        else:
            logging.info("Downloading fma_small.zip...")
            print("Downloading fma_small.zip...")
            subprocess.run(["wget", "-O", fma_small_zip, "https://os.unil.cloud.switch.ch/fma/fma_small.zip"], check=True, capture_output=True, text=True)

        if os.path.exists(fma_metadata_zip):
            logging.info("fma_metadata.zip already exists, skipping download.")
            print("fma_metadata.zip already exists, skipping download.")
        else:
            logging.info("Downloading fma_metadata.zip...")
            print("Downloading fma_metadata.zip...")
            subprocess.run(["wget", "-O", fma_metadata_zip, "https://os.unil.cloud.switch.ch/fma/fma_metadata.zip"], check=True, capture_output=True, text=True)

        # Verify checksums
        def sha1_checksum(file_path):
            sha1 = hashlib.sha1()
            with open(file_path, 'rb') as f:
                while chunk := f.read(8192):
                    sha1.update(chunk)
            return sha1.hexdigest()

        logging.info("Verifying checksums...")
        print("Verifying checksums...")
        if sha1_checksum(fma_small_zip) != "ade154f733639d52e35e32f5593efe5be76c6d70":
            logging.warning("fma_small.zip checksum failed. Proceeding anyway.")
        if sha1_checksum(fma_metadata_zip) != "f0df49ffe5f2a6008d7dc83c6915b31835dfe733":
            logging.warning("fma_metadata.zip checksum failed. Proceeding anyway.")

        # Unzip datasets
        if not os.path.exists(os.path.join(DATA_PATH, "fma_small")):
            logging.info("Unzipping fma_small.zip...")
            print("Unzipping fma_small.zip...")
            try:
                subprocess.run(["unzip", "-q", fma_small_zip, "-d", DATA_PATH], check=True, capture_output=True, text=True)
            except subprocess.CalledProcessError as e:
                logging.error("Unzip fma_small.zip failed: %s", e.stderr)
                raise
        if not os.path.exists(os.path.join(DATA_PATH, "fma_metadata")):
            logging.info("Unzipping fma_metadata.zip...")
            print("Unzipping fma_metadata.zip...")
            try:
                subprocess.run(["unzip", "-q", fma_metadata_zip, "-d", DATA_PATH], check=True, capture_output=True, text=True)
            except subprocess.CalledProcessError as e:
                logging.error("Unzip fma_metadata.zip failed: %s", e.stderr)
                raise

        # Move MP3 files to AUDIO_PATH if not already present
        logging.info("Checking and moving MP3 files...")
        print("Checking and moving MP3 files...")
        mp3_count = 0
        existing_mp3s = set(f.name for f in Path(AUDIO_PATH).glob("*.mp3"))
        mp3_files = list(Path(DATA_PATH).rglob("*.mp3"))
        logging.info("MP3 files found: %s", [str(f) for f in mp3_files[:10]])
        for mp3_file in mp3_files:
            target = os.path.join(AUDIO_PATH, mp3_file.name)
            if mp3_file.name not in existing_mp3s:
                os.rename(mp3_file, target)
                mp3_count += 1
        logging.info("Moved %d MP3 files.", mp3_count)
        print(f"Moved {mp3_count} MP3 files.")
        logging.info("Audio files after move: %s", os.listdir(AUDIO_PATH)[:10])

        # Process metadata with diverse genres
        logging.info("Processing metadata...")
        print("Processing metadata...")
        tracks = pd.read_csv(os.path.join(DATA_PATH, "fma_metadata", "tracks.csv"), index_col=0, header=[0, 1])
        genres = pd.read_csv(os.path.join(DATA_PATH, "fma_metadata", "genres.csv"))

        df_artists = tracks['artist'][['name']].reset_index().rename(columns={'track_id': 'artist_id', 'name': 'artist_name'})
        df_artists['artist_id'] = df_artists['artist_id'].astype(str).str.zfill(6)
        df_artists.to_csv(ARTISTS_PATH, index=False)
        logging.info("Created artists.csv")

        df_genres = genres[['genre_id', 'title']].rename(columns={'title': 'genre_name'})
        df_genres.to_csv(GENRES_PATH, index=False)
        logging.info("Created genres.csv")

        # Filter tracks to match available audio files
        audio_ids = {f.stem for f in Path(AUDIO_PATH).glob("*.mp3")}
        df_tracks = tracks['track'][['title', 'genre_top']].reset_index()
        df_tracks['track_id'] = df_tracks['track_id'].astype(str).str.zfill(6)
        df_tracks = df_tracks[df_tracks['track_id'].isin(audio_ids)]
        df_tracks['artist_id'] = df_tracks['track_id']
        df_tracks['genre_id'] = df_tracks['genre_top'].map(df_genres.set_index('genre_name')['genre_id'])
        df_tracks = df_tracks.dropna().groupby('genre_top').head(50).head(MAX_TRACKS)
        df_tracks.to_csv(METADATA_PATH, index=False)
        logging.info("Created tracks.csv")

        # Generate synthetic ratings
        ratings = pd.DataFrame({
            'user_id': [f"user_{i%10+1}" for i in range(len(df_tracks))],
            'track_id': df_tracks['track_id'],
            'rating': np.random.uniform(1, 5, len(df_tracks))
        })
        ratings.to_csv(USER_DATA_PATH, index=False)
        logging.info("Created ratings.csv")

        # Generate synthetic tags
        pd.DataFrame({
            'track_id': df_tracks['track_id'],
            'tag': [random.choice(['rock', 'pop', 'jazz', 'hip-hop', 'folk']) for _ in range(len(df_tracks))]
        }).to_csv(TAGS_PATH, index=False)
        logging.info("Created tags.csv")

        # Generate synthetic lyrics
        for track_id in df_tracks['track_id']:
            with open(os.path.join(LYRICS_PATH, f"{track_id}.txt"), 'w', encoding='utf-8') as f:
                f.write(f"Synthetic lyrics for track {track_id} in {df_tracks[df_tracks['track_id'] == track_id]['genre_top'].iloc[0]}")
        logging.info("Created synthetic lyrics")

        print("Dataset setup completed successfully.")
        print(f"Tracks shape: {df_tracks.shape}, Genres shape: {df_genres.shape}, Ratings shape: {ratings.shape}")

    except Exception as e:
        logging.error("Dataset setup failed: %s. Creating synthetic dataset.", str(e))
        print(f"Dataset setup failed: {str(e)}. Creating synthetic dataset.")
        create_synthetic_dataset()

def create_synthetic_dataset():
    """Task A1, A3, A5: Create synthetic dataset if FMA download fails."""
    logging.info("Creating synthetic dataset...")
    print("Creating synthetic dataset...")
    df_tracks = pd.DataFrame({
        'track_id': [str(i).zfill(6) for i in range(1, MAX_TRACKS + 1)],
        'title': [f"Track_{i}" for i in range(1, MAX_TRACKS + 1)],
        'artist_id': [str(i).zfill(6) for i in range(1, MAX_TRACKS + 1)],
        'genre_id': [random.randint(1, 10) for _ in range(MAX_TRACKS)],
        'genre_top': [random.choice(['Rock', 'Pop', 'Jazz', 'Classical', 'Hip-Hop', 'Electronic', 'Folk', 'Blues', 'Country', 'Reggae']) for _ in range(MAX_TRACKS)]
    })
    df_tracks.to_csv(METADATA_PATH, index=False)

    df_artists = pd.DataFrame({
        'artist_id': df_tracks['artist_id'],
        'artist_name': [f"Artist_{i}" for i in range(1, MAX_TRACKS + 1)]
    })
    df_artists.to_csv(ARTISTS_PATH, index=False)

    df_genres = pd.DataFrame({
        'genre_id': range(1, 11),
        'genre_name': ['Rock', 'Pop', 'Jazz', 'Classical', 'Hip-Hop', 'Electronic', 'Folk', 'Blues', 'Country', 'Reggae']
    })
    df_genres.to_csv(GENRES_PATH, index=False)

    ratings = pd.DataFrame({
        'user_id': [f"user_{i%10+1}" for i in range(len(df_tracks))],
        'track_id': df_tracks['track_id'],
        'rating': np.random.uniform(1, 5, len(df_tracks))
    })
    ratings.to_csv(USER_DATA_PATH, index=False)

    pd.DataFrame({
        'track_id': df_tracks['track_id'],
        'tag': [random.choice(['rock', 'pop', 'jazz', 'hip-hop', 'folk']) for _ in range(len(df_tracks))]
    }).to_csv(TAGS_PATH, index=False)

    os.makedirs(AUDIO_PATH, exist_ok=True)
    for track_id in df_tracks['track_id']:
        with open(os.path.join(AUDIO_PATH, f"{track_id}.mp3"), 'w') as f:
            f.write("")

    for track_id in df_tracks['track_id']:
        with open(os.path.join(LYRICS_PATH, f"{track_id}.txt"), 'w', encoding='utf-8') as f:
            f.write(f"Synthetic lyrics for track {track_id} in {df_tracks[df_tracks['track_id'] == track_id]['genre_top'].iloc[0]}")

    print(f"Synthetic dataset created: Tracks shape: {df_tracks.shape}")
    logging.info("Synthetic dataset created: Tracks shape: %s", df_tracks.shape)

def extract_audio_features(audio_path):
    """Task A1.1, A2: Extract audio features (MFCCs, chroma, spectral centroid, tempo) using Librosa."""
    try:
        if not os.path.exists(audio_path):
            logging.warning("Audio file %s does not exist.", audio_path)
            return np.zeros(26)
        if os.path.getsize(audio_path) < 100:
            logging.warning("Skipping %s: File too small or empty.", audio_path)
            return np.zeros(26)
        try:
            y, sr = librosa.load(audio_path, sr=22050)
        except Exception as load_e:
            logging.warning("Librosa load failed for %s: %s", audio_path, str(load_e))
            return np.zeros(26)
        if len(y) == 0:
            logging.warning("Skipping %s: Empty audio data.", audio_path)
            return np.zeros(26)
        mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13)
        chroma = librosa.feature.chroma_stft(y=y, sr=sr)
        spectral_centroid = librosa.feature.spectral_centroid(y=y, sr=sr)
        tempo = librosa.feature.tempo(y=y, sr=sr)
        tempo = float(tempo[0]) if isinstance(tempo, np.ndarray) else float(tempo)
        features = np.concatenate([
            np.mean(mfccs, axis=1),
            np.mean(chroma, axis=1),
            np.mean(spectral_centroid, axis=1),
            [tempo]
        ])
        logging.info("Extracted features for %s: %s", audio_path, features[:5])
        return features
    except Exception as e:
        logging.warning("Error processing %s: %s", audio_path, str(e))
        return np.zeros(26)

def load_fma_data(audio_path, metadata_path, artists_path, genres_path, lyrics_path, tags_path):
    """Task A1, A3, A5: Load FMA dataset, process audio, metadata, lyrics, and tags."""
    logging.info("Loading FMA data...")
    print("Loading FMA data...")
    if not os.path.exists(metadata_path):
        logging.error("Metadata not found. Creating synthetic dataset.")
        create_synthetic_dataset()

    try:
        df_tracks = pd.read_csv(metadata_path)
        df_metadata = df_tracks[['track_id', 'title', 'artist_id', 'genre_id']].dropna()
        df_artists = pd.read_csv(artists_path)[['artist_id', 'artist_name']]
        df_genres = pd.read_csv(genres_path)[['genre_id', 'genre_name']]
        df_metadata = pd.merge(df_metadata, df_artists, on='artist_id', how='left')
        df_metadata = pd.merge(df_metadata, df_genres, on='genre_id', how='left')
        df_metadata = df_metadata[['track_id', 'artist_name', 'title', 'genre_name']].dropna()
        df_metadata.columns = ['track_id', 'artist_name', 'title', 'genre']
        df_metadata['track_id'] = df_metadata['track_id'].astype(str).str.zfill(6)
        logging.info("Metadata shape: %s", df_metadata.shape)
        print(f"Metadata shape: {df_metadata.shape}")
        print(f"Genre distribution:\n{df_metadata['genre'].value_counts()}")
    except Exception as e:
        logging.error("Error loading metadata: %s. Using synthetic metadata.", str(e))
        print(f"Error loading metadata: {str(e)}. Using synthetic metadata.")
        df_metadata = pd.DataFrame({
            'track_id': [str(i).zfill(6) for i in range(1, MAX_TRACKS + 1)],
            'artist_name': [f"Artist_{i}" for i in range(1, MAX_TRACKS + 1)],
            'title': [f"Track_{i}" for i in range(1, MAX_TRACKS + 1)],
            'genre': [random.choice(['Rock', 'Pop', 'Jazz', 'Classical', 'Hip-Hop', 'Electronic', 'Folk', 'Blues', 'Country', 'Reggae']) for _ in range(MAX_TRACKS)]
        })

    features = []
    audio_files = list(Path(audio_path).glob("*.mp3"))
    valid_track_ids = set(df_metadata['track_id'].astype(str).str.zfill(6).tolist())
    logging.info("Found %d audio files.", len(audio_files))
    print(f"Found {len(audio_files)} audio files.")
    print(f"Valid track IDs (first 10): {list(valid_track_ids)[:10]}")
    print(f"Audio file IDs (first 10): {[f.stem for f in audio_files[:10]]}")
    print(f"Matching IDs: {len(set([f.stem for f in audio_files]) & valid_track_ids)}")
    valid_audio_count = 0
    valid_audio_files = []
    for audio_file in audio_files:
        track_id = audio_file.stem
        if track_id in valid_track_ids:
            logging.info("Processing audio: %s", audio_file)
            audio_features = extract_audio_features(audio_file)
            if not np.all(audio_features == 0):
                valid_audio_count += 1
                valid_audio_files.append(str(audio_file))
            features.append([track_id] + audio_features.tolist())
    logging.info("Processed %d valid audio files.", valid_audio_count)
    print(f"Processed {valid_audio_count} valid audio files.")
    if valid_audio_files:
        print(f"Valid audio files (first 5): {valid_audio_files[:5]}")
        logging.info("Valid audio files (first 5): %s", valid_audio_files[:5])

    feature_columns = ['track_id'] + [f'mfcc_{i+1}' for i in range(13)] + [f'chroma_{i+1}' for i in range(12)] + ['spectral_centroid', 'tempo']
    df_features = pd.DataFrame(features, columns=feature_columns).dropna()
    if df_features.empty or valid_audio_count == 0:
        logging.warning("No valid audio features. Using synthetic features.")
        print("No valid audio features. Using synthetic features.")
        df_features = pd.DataFrame({
            'track_id': df_metadata['track_id'],
            **{col: [0.0] * len(df_metadata) for col in feature_columns[1:]}
        })
    else:
        df_features = df_features[df_features['track_id'].isin(valid_track_ids)]
        logging.info("Features shape: %s", df_features.shape)
        print(f"Features shape: {df_features.shape}")

    lyrics_dict = {}
    for lyric_file in Path(lyrics_path).glob("*.txt"):
        track_id = lyric_file.stem
        if track_id in df_metadata['track_id'].values:
            try:
                with open(lyric_file, 'r', encoding='utf-8') as f:
                    lyrics_dict[track_id] = f.read().strip() or 'music'
            except Exception as e:
                logging.warning("Error reading lyrics %s: %s", track_id, str(e))
                lyrics_dict[track_id] = 'music'

    if tags_path and os.path.exists(tags_path):
        try:
            df_tags = pd.read_csv(tags_path)
            for _, row in df_tags.iterrows():
                track_id = str(row['track_id']).zfill(6)
                if track_id in df_metadata['track_id'].values:
                    tag = str(row['tag'])
                    lyrics_dict[track_id] = lyrics_dict.get(track_id, '') + " " + tag
        except Exception as e:
            logging.error("Error loading tags: %s", str(e))

    if not lyrics_dict:
        logging.warning("No lyrics found. Using synthetic lyrics.")
        print("No lyrics found. Using synthetic lyrics.")
        for track_id in df_metadata['track_id']:
            lyrics_dict[track_id] = f"Synthetic lyrics for {track_id} in {df_metadata[df_metadata['track_id'] == track_id]['genre'].iloc[0]}"

    logging.info("Lyrics dict size: %d", len(lyrics_dict))
    print(f"Lyrics dict size: {len(lyrics_dict)}")
    return df_metadata, df_features, lyrics_dict

def generate_text_embeddings(lyrics_dict):
    """Task A1.2, A3: Generate semantic embeddings for lyrics using Sentence-Transformers."""
    logging.info("Generating text embeddings...")
    print("Generating text embeddings...")
    model = SentenceTransformer('all-MiniLM-L6-v2', device=DEVICE)
    embeddings = {}
    for track_id, text in lyrics_dict.items():
        embeddings[track_id] = model.encode(text, convert_to_tensor=True, device=DEVICE).cpu().numpy()
    logging.info("Text embeddings generated for %d tracks.", len(embeddings))
    print(f"Text embeddings generated for {len(embeddings)} tracks.")
    return embeddings

def analyze_linguistic_patterns(df_metadata, lyrics_dict):
    """Task A1.2: Analyze linguistic patterns across genres (top words per genre)."""
    print("\nLinguistic Analysis: Top 5 words per genre")
    genre_words = {g: [] for g in df_metadata['genre'].unique()}
    for track_id, text in lyrics_dict.items():
        genre = df_metadata[df_metadata['track_id'] == track_id]['genre'].iloc[0]
        genre_words[genre].extend(text.split())
    for genre, words in genre_words.items():
        top_words = Counter(words).most_common(5)
        print(f"{genre} top words: {top_words}")
        logging.info("%s top words: %s", genre, top_words)

class HybridRecommender(nn.Module):
    """Task A5: Content-based recommender combining audio and text features."""
    def __init__(self, audio_dim, text_dim, hidden_dim=128):
        super().__init__()
        self.audio_layer = nn.Linear(audio_dim, hidden_dim)
        self.text_layer = nn.Linear(text_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim * 2, 1)

    def forward(self, audio_features, text_features):
        audio_out = torch.relu(self.audio_layer(audio_features))
        text_out = torch.relu(self.text_layer(text_features))
        combined = torch.cat([audio_out, text_out], dim=1)
        return torch.sigmoid(self.fc(combined))

def train_recommender(model, train_loader, criterion, optimizer):
    """Task A5: Train the content-based recommender."""
    model.train()
    total_loss = 0
    for audio_features, text_features, ratings in train_loader:
        audio_features, text_features, ratings = audio_features.to(DEVICE), text_features.to(DEVICE), ratings.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(audio_features, text_features)
        loss = criterion(outputs, ratings.unsqueeze(1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

def evaluate_recommender(model, test_loader, df_metadata, test_indices, genres, k=10):
    """Task C5: Evaluate recommender with Precision@K, MAP, NDCG, and diversity."""
    model.eval()
    precisions, maps, ndcgs, diversities = [], [], [], []
    with torch.no_grad():
        for audio_features, text_features, ratings in test_loader:
            audio_features, text_features, ratings = audio_features.to(DEVICE), text_features.to(DEVICE), ratings.to(DEVICE)
            outputs = model(audio_features, text_features)
            binary_ratings = (ratings > 0.5).float()
            top_k = torch.topk(outputs, k=min(k, outputs.size(0)), dim=0).indices.flatten()
            relevant = binary_ratings[top_k]
            precision = relevant.mean().item()
            precisions.append(precision)
            if binary_ratings.sum() > 0:
                map_score = average_precision_score(binary_ratings.cpu().numpy(), outputs.cpu().numpy())
                ndcg = ndcg_score(binary_ratings.cpu().numpy().reshape(1, -1), outputs.cpu().numpy().reshape(1, -1), k=k)
            else:
                map_score, ndcg = 0.0, 0.0
            maps.append(map_score)
            ndcgs.append(ndcg)
            top_k_ids = df_metadata.iloc[test_indices].iloc[top_k.cpu()]['track_id']
            top_k_genres = df_metadata[df_metadata['track_id'].isin(top_k_ids)]['genre']
            diversity = len(set(top_k_genres)) / len(genres) if len(genres) > 0 else 1.0
            diversities.append(diversity)
    return {
        'precision@k': np.mean(precisions) if precisions else 0.0,
        'map': np.mean(maps) if maps else 0.0,
        'ndcg@k': np.mean(ndcgs) if ndcgs else 0.0,
        'diversity': np.mean(diversities) if diversities else 1.0
    }

def collaborative_filtering(user_data, df_metadata, k=10):
    """Task A5: Collaborative filtering using NMF."""
    print("\nTraining Collaborative Filtering Model...")
    # Filter user_data to only include tracks present in df_metadata
    valid_track_ids = set(df_metadata['track_id'].astype(str).str.zfill(6))
    user_data = user_data[user_data['track_id'].isin(valid_track_ids)]

    if user_data.empty:
        logging.error("No valid user data after filtering. Returning zero metrics.")
        print("No valid user data after filtering. Returning zero metrics.")
        return {'precision@k': 0.0}

    # Split user data into train/test
    train_data, test_data = train_test_split(user_data, test_size=0.2, random_state=42)
    logging.info("CF data split: train=%d, test=%d", len(train_data), len(test_data))
    print(f"CF data split: train={len(train_data)}, test={len(test_data)}")

    # Create user-item matrix for training
    user_item_matrix = train_data.pivot(index='user_id', columns='track_id', values='rating').fillna(0)
    if user_item_matrix.empty:
        logging.error("Empty user-item matrix. Returning zero metrics.")
        print("Empty user-item matrix. Returning zero metrics.")
        return {'precision@k': 0.0}

    # Apply NMF
    nmf = NMF(n_components=20, random_state=42)
    user_features = nmf.fit_transform(user_item_matrix)
    item_features = nmf.components_
    predictions = np.dot(user_features, item_features)
    predicted_ratings = pd.DataFrame(predictions, index=user_item_matrix.index, columns=user_item_matrix.columns)

    # Create test user-item matrix
    test_matrix = test_data.pivot(index='user_id', columns='track_id', values='rating').fillna(0)
    logging.info("Test matrix tracks: %d, Predicted tracks: %d", len(test_matrix.columns), len(predicted_ratings.columns))
    print(f"Test matrix tracks: {len(test_matrix.columns)}, Predicted tracks: {len(predicted_ratings.columns)}")

    # Ensure shared tracks
    shared_tracks = set(test_matrix.columns).intersection(set(predicted_ratings.columns))
    logging.info("Shared tracks between test and predicted: %d", len(shared_tracks))
    print(f"Shared tracks between test and predicted: {len(shared_tracks)}")

    if not shared_tracks:
        logging.warning("No shared tracks between test and predicted. Returning zero metrics.")
        print("No shared tracks between test and predicted. Returning zero metrics.")
        return {'precision@k': 0.0}

    # Evaluate on test data
    precisions = []
    for user_id in test_matrix.index:
        if user_id in predicted_ratings.index:
            true_ratings = test_matrix.loc[user_id]
            pred_ratings = predicted_ratings.loc[user_id]
            valid_top_k = pred_ratings[pred_ratings.index.isin(shared_tracks)].sort_values(ascending=False).index[:k]
            if len(valid_top_k) == 0:
                logging.warning("No valid tracks for user %s in test set.", user_id)
                precisions.append(0.0)
                continue
            relevant = (true_ratings[valid_top_k] > 3.5).astype(int)
            precision = relevant.mean()
            precisions.append(precision)
            logging.info("User %s: %d valid tracks, precision=%f", user_id, len(valid_top_k), precision)

    metrics = {'precision@k': np.mean(precisions) if precisions else 0.0}
    logging.info("CF Metrics: %s", metrics)
    return metrics

class GenreClassifier(nn.Module):
    """Task A1.1: Custom genre classifier combining audio and text features."""
    def __init__(self, audio_dim, text_dim, num_classes, hidden_dim=128):
        super().__init__()
        self.audio_layer = nn.Linear(audio_dim, hidden_dim)
        self.text_layer = nn.Linear(text_dim, hidden_dim)
        self.dropout = nn.Dropout(0.3)
        self.fc = nn.Linear(hidden_dim * 2, num_classes)

    def forward(self, audio_features, text_features):
        audio_out = torch.relu(self.audio_layer(audio_features))
        text_out = torch.relu(self.text_layer(text_features))
        combined = torch.cat([audio_out, text_out], dim=1)
        combined = self.dropout(combined)
        return self.fc(combined)

def train_classifier(model, train_loader, criterion, optimizer):
    """Task A1.1: Train the genre classifier."""
    model.train()
    total_loss = 0
    for audio_features, text_features, labels in train_loader:
        audio_features, text_features, labels = audio_features.to(DEVICE), text_features.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(audio_features, text_features)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

def evaluate_classifier(model, test_loader, genres):
    """Task C5: Evaluate classifier with precision, recall, F1, and confusion matrix."""
    model.eval()
    y_true, y_pred = [], []
    with torch.no_grad():
        for audio_features, text_features, labels in test_loader:
            audio_features, text_features, labels = audio_features.to(DEVICE), text_features.to(DEVICE), labels.to(DEVICE)
            outputs = model(audio_features, text_features)
            _, predicted = torch.max(outputs, 1)
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(predicted.cpu().numpy())
    metrics = precision_recall_fscore_support(y_true, y_pred, average='weighted', zero_division=0)
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=genres, yticklabels=genres)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.savefig(os.path.join(OUTPUT_DIR, 'confusion_matrix.png'))
    plt.close()
    errors = [(i, y_true[i], y_pred[i]) for i in range(len(y_true)) if y_true[i] != y_pred[i]]
    print(f"Error Analysis: {len(errors)} misclassifications")
    for idx, true, pred in errors[:5]:
        print(f"Sample {idx}: True={genres[true]}, Predicted={genres[pred]}")
    logging.info("Test set class distribution: %s", Counter(y_true))
    print(f"Test set class distribution: {Counter(y_true)}")
    return {'precision': metrics[0], 'recall': metrics[1], 'f1': metrics[2]}

class AudioDataset(Dataset):
    """Custom Dataset for audio files and genre labels."""
    def __init__(self, audio_files, labels, genre_to_label, feature_extractor, max_length=160000):
        self.audio_files = audio_files
        self.labels = labels
        self.genre_to_label = genre_to_label
        self.feature_extractor = feature_extractor
        self.max_length = max_length

    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, idx):
        try:
            audio_file = self.audio_files[idx]
            y, sr = librosa.load(audio_file, sr=16000, duration=10.0)  # Limit to 10 seconds
            inputs = self.feature_extractor(
                y,
                sampling_rate=16000,
                return_tensors="pt",
                truncation=True,
                padding="max_length",
                max_length=self.max_length,
                return_attention_mask=True
            )
            # Remove batch dimension
            input_values = inputs.input_values.squeeze(0)
            attention_mask = inputs.attention_mask.squeeze(0) if inputs.attention_mask is not None else None
            label = self.genre_to_label[self.labels[idx]]
            return {
                "input_values": input_values,
                "attention_mask": attention_mask,
                "labels": torch.tensor(label, dtype=torch.long)
            }
        except Exception as e:
            logging.error("Failed to process audio file %s: %s", audio_file, str(e))
            # Return a dummy item to avoid breaking the DataLoader
            return {
                "input_values": torch.zeros(self.max_length),
                "attention_mask": torch.zeros(self.max_length),
                "labels": torch.tensor(0, dtype=torch.long)
            }

def ast_classifier(audio_files, labels, genres, df_metadata):
    """Task A2: Audio classification with fine-tuned Wav2Vec2 model."""
    print("\nSetting up Audio Classifier (Wav2Vec2)...")

    # Initialize feature extractor
    try:
        feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
    except Exception as e:
        logging.error("Failed to load feature extractor: %s", str(e))
        print(f"Failed to load feature extractor: {str(e)}")
        return {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}

    # Dynamic genre mapping
    genre_to_label = {genre: idx for idx, genre in enumerate(sorted(genres))}
    label_to_genre = {idx: genre for genre, idx in genre_to_label.items()}

    # Filter valid audio files and labels
    valid_audio_files = []
    valid_labels = []
    for audio_file, label in zip(audio_files, labels):
        track_id = Path(audio_file).stem
        if track_id in df_metadata['track_id'].values and os.path.exists(audio_file) and os.path.getsize(audio_file) > 0:
            genre = df_metadata[df_metadata['track_id'] == track_id]['genre'].iloc[0]
            if genre in genre_to_label:
                valid_audio_files.append(audio_file)
                valid_labels.append(genre)
            else:
                logging.warning("Genre %s not in genre_to_label for track %s", genre, track_id)
        else:
            logging.warning("Invalid or missing audio file: %s", audio_file)

    if not valid_audio_files:
        print("No valid audio files for processing.")
        logging.error("No valid audio files for processing.")
        return {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}

    # Split into train and eval sets
    train_files, eval_files, train_labels, eval_labels = train_test_split(
        valid_audio_files, valid_labels, test_size=0.2, random_state=42, stratify=valid_labels
    )

    # Create datasets
    train_dataset = AudioDataset(train_files, train_labels, genre_to_label, feature_extractor)
    eval_dataset = AudioDataset(eval_files, eval_labels, genre_to_label, feature_extractor)

    # Initialize model with correct number of labels
    try:
        model = AutoModelForAudioClassification.from_pretrained(
            "facebook/wav2vec2-base-960h",
            num_labels=len(genre_to_label),
            label2id=genre_to_label,
            id2label=label_to_genre
        ).to(DEVICE)
    except Exception as e:
        logging.error("Failed to load Wav2Vec2 model: %s", str(e))
        print(f"Failed to load Wav2Vec2 model: {str(e)}")
        return {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}

    # Define training arguments
    training_args = TrainingArguments(
        output_dir=os.path.join(OUTPUT_DIR, "wav2vec2-finetuned"),
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4,
        num_train_epochs=3,
        eval_strategy="steps",  # Updated from evaluation_strategy
        save_strategy="steps",
        save_steps=500,
        logging_steps=100,
        learning_rate=3e-5,
        save_total_limit=2,
        load_best_model_at_end=True,
        metric_for_best_model="f1",
        greater_is_better=True,
        logging_dir=os.path.join(OUTPUT_DIR, "logs"),
    )

    # Define compute_metrics function for evaluation
    def compute_metrics(pred):
        labels = pred.label_ids
        preds = pred.predictions.argmax(-1)
        precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted', zero_division=0)
        return {"precision": precision, "recall": recall, "f1": f1}

    # Initialize Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        compute_metrics=compute_metrics,
    )

    # Fine-tune the model
    print("Fine-tuning Wav2Vec2 model...")
    try:
        trainer.train()
        trainer.save_model(os.path.join(OUTPUT_DIR, "wav2vec2-finetuned"))
        print("Model fine-tuning completed and saved.")
    except Exception as e:
        logging.error("Fine-tuning failed: %s", str(e))
        print(f"Fine-tuning failed: {str(e)}")
        return {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}

    # Evaluate on eval set
    model.eval()
    y_true, y_pred = [], []
    for audio_file, genre in zip(eval_files, eval_labels):
        try:
            y, sr = librosa.load(audio_file, sr=16000, duration=10.0)
            inputs = feature_extractor(
                y,
                sampling_rate=16000,
                return_tensors="pt",
                truncation=True,
                padding="max_length",
                max_length=160000,
                return_attention_mask=True
            ).to(DEVICE)
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
            with torch.no_grad():
                outputs = model(**inputs).logits
            predicted = torch.argmax(outputs, dim=1).cpu().numpy()[0]
            y_true.append(genre_to_label[genre])
            y_pred.append(min(predicted, len(genre_to_label) - 1))
        except Exception as e:
            logging.error("Evaluation failed for %s: %s", audio_file, str(e))
            continue

    # Compute and save metrics
    if y_true:
        metrics = precision_recall_fscore_support(y_true, y_pred, average='weighted', zero_division=0)
        print(f"Wav2Vec2 Metrics: Precision: {metrics[0]:.4f}, Recall: {metrics[1]:.4f}, F1: {metrics[2]:.4f}")
        cm = confusion_matrix(y_true, y_pred)
        plt.figure(figsize=(10, 8))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Greens', xticklabels=list(genre_to_label.keys()), yticklabels=list(genre_to_label.keys()))
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.title('Wav2Vec2 Confusion Matrix')
        plt.savefig(os.path.join(OUTPUT_DIR, 'wav2vec2_confusion_matrix.png'))
        plt.close()
        return {'precision': metrics[0], 'recall': metrics[1], 'f1': metrics[2]}
    else:
        print("No valid audio files processed for evaluation.")
        logging.error("No valid audio files processed for evaluation.")
        return {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}

def music_search(query, df, text_embeddings, model):
    """Task A3: Content-based music discovery using LLM embeddings and cosine similarity."""
    logging.info("Performing music search for query: %s", query)
    print(f"Performing music search for query: {query}")
    query_embedding = model.encode(query, convert_to_tensor=True, device=DEVICE).cpu().numpy()
    similarities = {}
    for track_id, embedding in text_embeddings.items():
        genre_weight = 1.5 if df[df['track_id'] == track_id]['genre'].iloc[0] == 'Rock' else 1.0
        similarities[track_id] = (1 - cosine(query_embedding, embedding)) * genre_weight
    top_tracks = sorted(similarities.items(), key=lambda x: x[1], reverse=True)[:10]
    result_ids = [track_id for track_id, _ in top_tracks]
    results = df[df['track_id'].isin(result_ids)][['track_id', 'artist_name', 'title', 'genre']]
    logging.info("Search returned %d results.", len(results))
    return results

def zero_shot_classification(df_metadata, lyrics_dict):
    """Task A1.2: Zero-shot genre classification using BART."""
    print("\nZero-Shot Genre Classification:")
    classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli", device=0 if torch.cuda.is_available() else -1)
    genres = df_metadata['genre'].unique().tolist()
    results = []
    for _, row in df_metadata.head(5).iterrows():
        text = row['title'] + " " + lyrics_dict.get(row['track_id'], "")
        scores = classifier(text, candidate_labels=genres, multi_label=False)
        results.append((row['track_id'], scores['labels'][0], scores['scores'][0]))
        print(f"Track {row['track_id']}: Predicted={scores['labels'][0]} ({scores['scores'][0]:.4f}), True={row['genre']}")
    return results

def analyze_feature_contribution(model, test_loader, feature_columns):
    """Task A1.1: Analyze contribution of audio features to genre classification."""
    print("\nFeature Contribution Analysis:")
    model.eval()
    contributions = {col: [] for col in feature_columns}
    with torch.no_grad():
        for audio_features, text_features, _ in test_loader:
            audio_features, text_features = audio_features.to(DEVICE), text_features.to(DEVICE)
            baseline_output = model(audio_features, text_features)
            for i, col in enumerate(feature_columns):
                modified_features = audio_features.clone()
                modified_features[:, i] = 0
                modified_output = model(modified_features, text_features)
                diff = torch.mean(torch.abs(baseline_output - modified_output)).item()
                contributions[col].append(diff)
    for col in feature_columns:
        print(f"{col}: Mean contribution = {np.mean(contributions[col]):.4f}")
    return contributions

def main():
    """Main function orchestrating all tasks."""
    logging.info("Starting main execution...")
    print("Starting main execution...")

    # Task A1, A3, A5: Setup dataset
    setup_fma_dataset()
    df_metadata, df_features, lyrics_dict = load_fma_data(AUDIO_PATH, METADATA_PATH, ARTISTS_PATH, GENRES_PATH, LYRICS_PATH, TAGS_PATH)

    if df_metadata.empty or df_features.empty or not lyrics_dict:
        logging.error("Data loading failed. Exiting.")
        print("Data loading failed. Exiting.")
        return

    # Task A1.2: Generate text embeddings
    text_embeddings = generate_text_embeddings(lyrics_dict)

    # Task A1.2: Linguistic analysis
    analyze_linguistic_patterns(df_metadata, lyrics_dict)

    # Task A1.2: Zero-shot classification
    zero_shot_classification(df_metadata, lyrics_dict)

    # Prepare features, aligning with valid audio tracks
    valid_track_ids = df_features['track_id'].tolist()
    if len(valid_track_ids) < 50:
        logging.warning("Only %d valid audio tracks found. Consider increasing MAX_TRACKS or checking audio files.", len(valid_track_ids))
        print(f"Warning: Only {len(valid_track_ids)} valid audio tracks found. Using available tracks.")

    # Filter metadata and lyrics to valid tracks
    df_metadata = df_metadata[df_metadata['track_id'].isin(valid_track_ids)]
    text_features = np.array([text_embeddings.get(tid, np.zeros(384)) for tid in df_metadata['track_id']])

    # Normalize audio features
    scaler = StandardScaler()
    audio_features = scaler.fit_transform(df_features[[col for col in df_features.columns if col != 'track_id']].values)
    audio_features = np.nan_to_num(audio_features)
    logging.info("Audio feature stats: mean=%s, std=%s", audio_features.mean(), audio_features.std())
    print(f"Audio feature stats: mean={audio_features.mean():.4f}, std={audio_features.std():.4f}")

    if audio_features.shape[0] == 0 or text_features.shape[0] == 0:
        logging.error("No valid features after filtering. Exiting.")
        print("No valid features after filtering. Exiting.")
        return

    # Task A5: Load user ratings
    try:
        user_data = pd.read_csv(USER_DATA_PATH)
        user_data['track_id'] = user_data['track_id'].astype(str).str.zfill(6)
        user_data = user_data[user_data['track_id'].isin(df_metadata['track_id'])]
        ratings = np.clip(user_data['rating'].values / 5.0, 0.0, 1.0)
        logging.info("Ratings stats: min=%s, max=%s, mean=%s", ratings.min(), ratings.max(), ratings.mean())
        print(f"Ratings stats: min={ratings.min():.4f}, max={ratings.max():.4f}, mean={ratings.mean():.4f}")
    except Exception as e:
        logging.error("Error loading ratings: %s. Using synthetic ratings.", str(e))
        print(f"Error loading ratings: {str(e)}. Using synthetic ratings.")
        ratings = np.random.uniform(0.2, 1.0, len(df_metadata))
        ratings = np.clip(ratings, 0.0, 1.0)
        logging.info("Synthetic ratings stats: min=%s, max=%s, mean=%s", ratings.min(), ratings.max(), ratings.mean())
        print(f"Synthetic ratings stats: min={ratings.min():.4f}, max={ratings.max():.4f}, mean={ratings.mean():.4f}")

    if len(ratings) != len(df_metadata):
        logging.warning("Ratings length mismatch. Truncating to match metadata.")
        print("Ratings length mismatch. Truncating to match metadata.")
        ratings = ratings[:len(df_metadata)]

    # Task A5: Train-test split for recommender
    indices = np.arange(len(df_metadata))
    X_train_rec, X_test_rec, y_train_rec, y_test_rec, train_idx, test_idx = train_test_split(
        np.hstack([audio_features, text_features]), ratings, indices, test_size=0.2, random_state=42
    )
    logging.info("Train-test split for recommender: train=%d, test=%d", len(X_train_rec), len(X_test_rec))
    print(f"Train-test split for recommender: train={len(X_train_rec)}, test={len(X_test_rec)}")

    train_dataset_rec = torch.utils.data.TensorDataset(
        torch.tensor(X_train_rec[:, :audio_features.shape[1]], dtype=torch.float32),
        torch.tensor(X_train_rec[:, audio_features.shape[1]:], dtype=torch.float32),
        torch.tensor(y_train_rec, dtype=torch.float32)
    )
    train_loader_rec = torch.utils.data.DataLoader(train_dataset_rec, batch_size=BATCH_SIZE, shuffle=True)

    test_dataset_rec = torch.utils.data.TensorDataset(
        torch.tensor(X_test_rec[:, :audio_features.shape[1]], dtype=torch.float32),
        torch.tensor(X_test_rec[:, audio_features.shape[1]:], dtype=torch.float32),
        torch.tensor(y_test_rec, dtype=torch.float32)
    )
    test_loader_rec = torch.utils.data.DataLoader(test_dataset_rec, batch_size=BATCH_SIZE)

    # Task A5: Train content-based recommender
    recommender = HybridRecommender(audio_dim=audio_features.shape[1], text_dim=384).to(DEVICE)
    criterion_rec = nn.BCELoss()
    optimizer_rec = torch.optim.Adam(recommender.parameters(), lr=0.001)

    print("\nTraining Recommender Model...")
    for epoch in range(NUM_EPOCHS_REC):
        loss = train_recommender(recommender, train_loader_rec, criterion_rec, optimizer_rec)
        logging.info("Recommendation Epoch %d, Loss: %.4f", epoch+1, loss)
        print(f"Recommendation Epoch {epoch+1}, Loss: {loss:.4f}")

    # Task C5: Evaluate recommender
    rec_metrics = evaluate_recommender(recommender, test_loader_rec, df_metadata, test_idx, df_metadata['genre'].unique())
    logging.info("Recommendation Metrics: %s", rec_metrics)
    print(f"Recommendation Metrics: {rec_metrics}")

    # Task A5: Collaborative filtering
    cf_metrics = collaborative_filtering(user_data, df_metadata)
    print(f"Collaborative Filtering Metrics: {cf_metrics}")
    print(f"Comparison: Content-Based Precision@K={rec_metrics['precision@k']:.4f}, CF Precision@K={cf_metrics['precision@k']:.4f}")

    # Task A1.1: Prepare classifier data
    genres = df_metadata['genre'].unique()
    genre_to_idx = {g: i for i, g in enumerate(genres)}
    labels = df_metadata['genre'].map(genre_to_idx).values

    X_train_cls, X_test_cls, y_train_cls, y_test_cls = train_test_split(
        np.hstack([audio_features, text_features]), labels, test_size=0.2, random_state=42
    )
    logging.info("Train-test split for classifier: train=%d, test=%d", len(X_train_cls), len(X_test_cls))
    print(f"Train-test split for classifier: train={len(X_train_cls)}, test={len(X_test_cls)}")

    train_dataset_cls = torch.utils.data.TensorDataset(
        torch.tensor(X_train_cls[:, :audio_features.shape[1]], dtype=torch.float32),
        torch.tensor(X_train_cls[:, audio_features.shape[1]:], dtype=torch.float32),
        torch.tensor(y_train_cls, dtype=torch.long)
    )
    train_loader_cls = torch.utils.data.DataLoader(train_dataset_cls, batch_size=BATCH_SIZE, shuffle=True)

    test_dataset_cls = torch.utils.data.TensorDataset(
        torch.tensor(X_test_cls[:, :audio_features.shape[1]], dtype=torch.float32),
        torch.tensor(X_test_cls[:, audio_features.shape[1]:], dtype=torch.float32),
        torch.tensor(y_test_cls, dtype=torch.long)
    )
    test_loader_cls = torch.utils.data.DataLoader(test_dataset_cls, batch_size=BATCH_SIZE)

    # Task A1.1: Train custom classifier
    classifier = GenreClassifier(audio_dim=audio_features.shape[1], text_dim=384, num_classes=len(genres)).to(DEVICE)
    criterion_cls = nn.CrossEntropyLoss()
    optimizer_cls = torch.optim.Adam(classifier.parameters(), lr=0.001, weight_decay=1e-4)

    print("\nTraining Genre Classifier Model...")
    for epoch in range(NUM_EPOCHS_CLS):
        loss = train_classifier(classifier, train_loader_cls, criterion_cls, optimizer_cls)
        logging.info("Classification Epoch %d, Loss: %.4f", epoch+1, loss)
        print(f"Classification Epoch {epoch+1}, Loss: {loss:.4f}")

    # Task C5: Evaluate classifier
    cls_metrics = evaluate_classifier(classifier, test_loader_cls, genres)
    logging.info("Classification Metrics: %s", cls_metrics)
    print(f"Classification Metrics: Precision: {cls_metrics['precision']:.4f}, Recall: {cls_metrics['recall']:.4f}, F1: {cls_metrics['f1']:.4f}")

    # Task A1.1: Analyze feature contributions
    analyze_feature_contribution(classifier, test_loader_cls, df_features.columns[1:])

    # Task A2: Audio classification
    audio_files = [os.path.join(AUDIO_PATH, f"{tid}.mp3") for tid in df_metadata['track_id'] if os.path.exists(os.path.join(AUDIO_PATH, f"{tid}.mp3"))]
    ast_metrics = ast_classifier(audio_files, labels[:len(audio_files)], genres, df_metadata)
    print(f"Wav2Vec2 vs Custom Classifier: Wav2Vec2 F1={ast_metrics['f1']:.4f}, Custom F1={cls_metrics['f1']:.4f}")

    # Task A3: Perform music search
    search_model = SentenceTransformer('all-MiniLM-L6-v2', device=DEVICE)
    query = "upbeat rock songs"
    results = music_search(query, df_metadata, text_embeddings, search_model)
    print("\nSearch Results:")
    print(results)

if __name__ == "__main__":
    main()

Mounted at /content/drive
Starting main execution...
=== Dataset Setup ===
fma_small.zip already exists, skipping download.
fma_metadata.zip already exists, skipping download.
Verifying checksums...
Checking and moving MP3 files...
Moved 0 MP3 files.
Processing metadata...
Dataset setup completed successfully.
Tracks shape: (400, 5), Genres shape: (163, 2), Ratings shape: (400, 3)
Loading FMA data...
Metadata shape: (400, 4)
Genre distribution:
genre
Hip-Hop          50
Pop              50
Folk             50
Experimental     50
Rock             50
International    50
Electronic       50
Instrumental     50
Name: count, dtype: int64
Found 2584 audio files.
Valid track IDs (first 10): ['001276', '001075', '007527', '013571', '001102', '003778', '003763', '001069', '013197', '000853']
Audio file IDs (first 10): ['030043', '004022', '020364', '048463', '024426', '048042', '011947', '018037', '021891', '052120']
Matching IDs: 400
Processed 400 valid audio files.
Valid audio files (first 5)

Device set to use cpu


Track 000002: Predicted=Hip-Hop (0.9402), True=Hip-Hop
Track 000005: Predicted=Hip-Hop (0.5383), True=Hip-Hop
Track 000010: Predicted=Pop (0.9424), True=Pop
Track 000140: Predicted=Folk (0.9474), True=Folk
Track 000141: Predicted=Folk (0.7373), True=Folk
Audio feature stats: mean=-0.0000, std=1.0000
Ratings stats: min=0.2003, max=0.9976, mean=0.6068
Train-test split for recommender: train=320, test=80

Training Recommender Model...
Recommendation Epoch 1, Loss: 0.6716
Recommendation Epoch 2, Loss: 0.6648
Recommendation Epoch 3, Loss: 0.6589
Recommendation Epoch 4, Loss: 0.6571




Recommendation Epoch 5, Loss: 0.6542
Recommendation Metrics: {'precision@k': np.float64(0.6666666766007742), 'map': np.float64(0.6812424954752659), 'ndcg@k': np.float64(0.6622623010305043), 'diversity': np.float64(0.75)}

Training Collaborative Filtering Model...
CF data split: train=320, test=80
Test matrix tracks: 80, Predicted tracks: 320
Shared tracks between test and predicted: 0
No shared tracks between test and predicted. Returning zero metrics.
Collaborative Filtering Metrics: {'precision@k': 0.0}
Comparison: Content-Based Precision@K=0.6667, CF Precision@K=0.0000
Train-test split for classifier: train=320, test=80

Training Genre Classifier Model...
Classification Epoch 1, Loss: 2.1093
Classification Epoch 2, Loss: 2.0307
Classification Epoch 3, Loss: 1.9911
Classification Epoch 4, Loss: 1.9341
Classification Epoch 5, Loss: 1.8681
Classification Epoch 6, Loss: 1.7903
Classification Epoch 7, Loss: 1.7210
Classification Epoch 8, Loss: 1.6257
Classification Epoch 9, Loss: 1.5313


Some weights of Wav2Vec2ForSequenceClassification were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['classifier.bias', 'classifier.weight', 'projector.bias', 'projector.weight', 'wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
ERROR:root:Fine-tuning failed: 'AudioDataset' object has no attribute '_data'


Fine-tuning Wav2Vec2 model...
Fine-tuning failed: 'AudioDataset' object has no attribute '_data'
Wav2Vec2 vs Custom Classifier: Wav2Vec2 F1=0.0000, Custom F1=0.7298
Performing music search for query: upbeat rock songs

Search Results:
    track_id                     artist_name                      title genre
21    000368                  Blah Blah Blah                   Vampires  Rock
27    000574                    Clockcleaner             Caliente Queen  Rock
53    000825  Here Comes A Big Black Cloud!!                Death March  Rock
58    000993      Jad Fair and Jason Willett       Or So I've Been Told  Rock
70    001087                        Mahjongg  Tell The Police The Truth  Rock
109   001706            Strapping Fieldhands              In the Pineys  Rock
110   001720                        Sun Araw            Harken Sunshine  Rock
116   001891                    Thee Oh Sees               Kids In Cars  Rock
145   003720                  Indian Jewelry       Walking on t