<a href="https://colab.research.google.com/github/Viditk07-Bits/AudioAnalytics_S2-24_AIMLCZG527/blob/main/Group8_AudioAnalytics_Assignment2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Music Information Retrieval System

## Assignment Objective
This assignment implements a comprehensive Music Information Retrieval (MIR) system using Large Language Models (LLMs) and deep learning techniques. It includes music recommendation, genre classification, and semantic search applications, combining audio analysis with natural language processing.

## Dataset Setup
Using the Free Music Archive (FMA) dataset with audio files, metadata, and synthetic user data.

In [34]:
# Install required packages at the start to avoid ModuleNotFoundError
import subprocess
required_packages = ['faiss-cpu', 'transformers', 'sentence-transformers', 'tqdm', 'imblearn', 'librosa', 'nltk', 'matplotlib', 'seaborn', 'datasets']
for pkg in required_packages:
    try:
        __import__(pkg.replace('-', '_'))
    except ImportError:
        subprocess.run(['pip', 'install', pkg], check=True)
        print(f"Installed {pkg}")

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from sentence_transformers import SentenceTransformer
from transformers import pipeline, AutoModelForAudioClassification, AutoFeatureExtractor, Trainer, TrainingArguments, BertForSequenceClassification, BertTokenizer, RobertaForSequenceClassification, RobertaTokenizer
from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix, average_precision_score, ndcg_score
from sklearn.decomposition import NMF, LatentDirichletAllocation, TruncatedSVD
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.spatial.distance import cosine
import os
from pathlib import Path
import librosa
import subprocess
import logging
import random
from collections import Counter
from google.colab import drive
from datasets import Dataset
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from tqdm import tqdm
from multiprocessing import Pool
import faiss
from imblearn.over_sampling import SMOTE

import warnings
from tqdm.notebook import tqdm  # Use notebook-friendly progress bar

# Suppress warnings for cleaner output
warnings.filterwarnings("ignore")

# Disable wandb logging
os.environ["WANDB_MODE"] = "disabled"

# Download NLTK resources
nltk.download('punkt', quiet=True)
nltk.download('punkt_tab', quiet=True)
nltk.download('stopwords', quiet=True)

# Setup logging (only show INFO and above, suppress WARNING)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logging.getLogger().setLevel(logging.INFO)

# Suppress warnings for cleaner output
warnings.filterwarnings("ignore")

# Disable wandb logging
os.environ["WANDB_MODE"] = "disabled"

# Mount Google Drive
drive.mount('/content/drive', force_remount=True)

# Constants (updated to use Google Drive)
DATA_PATH = "/content/drive/MyDrive/fma/data"
AUDIO_PATH = os.path.join(DATA_PATH, "audio_files")
METADATA_PATH = os.path.join(DATA_PATH, "metadata/tracks.csv")
ARTISTS_PATH = os.path.join(DATA_PATH, "metadata/artists.csv")
GENRES_PATH = os.path.join(DATA_PATH, "metadata/genres.csv")
LYRICS_PATH = os.path.join(DATA_PATH, "lyrics")
USER_DATA_PATH = os.path.join(DATA_PATH, "user_data/ratings.csv")
TAGS_PATH = os.path.join(DATA_PATH, "descriptions/tags.csv")

# Output and temp directories can remain local
OUTPUT_DIR = "/content/outputs"
TEMP_DIR = "/content/fma"

NUM_EPOCHS_REC = 5
NUM_EPOCHS_CLS = 10
BATCH_SIZE = 16
MAX_TRACKS = 100
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TIMEOUT_SECONDS = 3600

# Ensure output directory exists
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Enhanced genre-specific keyword pools for richer lyrics
GENRE_KEYWORDS = {
    'Rock': ['electric guitar wails', 'rebellious spirit soars', 'grunge heart pounds', 'classic riffs ignite', 'indie soul rebels', 'punk fire explodes', 'rock anthem roars'],
    'Pop': ['infectious hooks dance', 'neon lights pulse', 'melodic dreams soar', 'upbeat rhythm shines', 'love story sparkles', 'dancefloor beats throb', 'pop fever rises'],
    'Jazz': ['saxophone weaves magic', 'improvised notes flow', 'bluesy soul swings', 'smooth grooves linger', 'jazz night whispers', 'rhythmic scat hums', 'cool vibes drift'],
    'Classical': ['orchestral swells rise', 'violin sings softly', 'piano echoes grace', 'symphonic waves crash', 'baroque harmony soars', 'elegant strings weave', 'timeless beauty unfolds'],
    'Hip-Hop': ['heavy beats drop hard', 'sharp rhymes cut deep', 'street stories unfold', 'flow rides the rhythm', 'urban pulse vibrates', 'mic drops with swagger', 'hip-hop reigns supreme'],
    'Electronic': ['synth pulses glow', 'techno beats surge', 'ambient waves drift', 'EDM sparks the night', 'futuristic sounds hum', 'electro vibes ignite', 'digital dreams pulse'],
    'Folk': ['acoustic chords strum', 'heartfelt tales weave', 'rustic paths wander', 'folk roots run deep', 'gentle melodies soothe', 'campfire stories sing', 'tradition lives on'],
    'Blues': ['guitar wails with soul', 'heartache spills over', 'raw blues cry out', 'delta notes resonate', 'mournful chords linger', 'blues spirit endures', 'emotional strings weep'],
    'Country': ['banjo twangs with pride', 'heartland stories sing', 'cowboy boots stomp', 'rural roads ramble', 'love songs ride free', 'country heart beats strong', 'honky-tonk nights shine'],
    'Reggae': ['rasta riddims sway', 'island vibes chill', 'roots reggae grooves', 'one love unites all', 'skank beat lifts high', 'irie spirit flows', 'dreadlocks dance free'],
    'International': ['world rhythms blend', 'exotic melodies soar', 'cultural beats pulse', 'global sounds unite', 'traditional chants echo', 'fusion vibes transcend', 'earth’s heartbeat sings'],
    'Instrumental': ['ambient chords float', 'strings weave dreams', 'piano paints silence', 'orchestral tides rise', 'melody speaks alone', 'instrumental soul soars', 'soundscapes breathe life'],
    'Experimental': ['avant-garde sounds twist', 'abstract beats morph', 'sonic boundaries break', 'unorthodox rhythms pulse', 'experimental vibes soar', 'sound art redefines', 'future notes unfold']
}

def download_fma_dataset():
    """Download and extract FMA small dataset with error handling and retries."""
    logging.info("Attempting to download FMA small dataset...")
    max_retries = 3
    retry_delay = 5  # seconds

    for attempt in range(1, max_retries + 1):
        try:
            # Check if audio files and metadata exist and are valid
            if os.path.exists(METADATA_PATH) and os.path.exists(AUDIO_PATH):
                mp3_files = [f for f in os.listdir(AUDIO_PATH) if f.endswith('.mp3')]
                if len(mp3_files) > 0 and os.path.exists(METADATA_PATH):
                    logging.info("FMA dataset already exists at %s with %d audio files", AUDIO_PATH, len(mp3_files))
                    return True
                else:
                    logging.warning("Audio directory %s exists but contains no .mp3 files or metadata is missing. Triggering redownload.", AUDIO_PATH)

            # Clean up existing zip files to avoid corrupted downloads
            for zip_file in ["/content/fma_small.zip", "/content/fma_metadata.zip"]:
                if os.path.exists(zip_file):
                    logging.info("Removing existing %s for redownload", zip_file)
                    os.remove(zip_file)

            # Create directories
            os.makedirs(os.path.dirname(METADATA_PATH), exist_ok=True)
            os.makedirs(AUDIO_PATH, exist_ok=True)

            # Download and extract audio files
            logging.info("Downloading fma_small.zip (attempt %d/%d)...", attempt, max_retries)
            subprocess.run(["wget", "-q", "https://os.unil.cloud.switch.ch/fma/fma_small.zip", "-O", "/content/fma_small.zip"], check=True, timeout=TIMEOUT_SECONDS)
            logging.info("Extracting fma_small.zip...")
            subprocess.run(["unzip", "-q", "/content/fma_small.zip", "-d", DATA_PATH], check=True, timeout=TIMEOUT_SECONDS)

            # Download and extract metadata
            logging.info("Downloading fma_metadata.zip (attempt %d/%d)...", attempt, max_retries)
            subprocess.run(["wget", "-q", "https://os.unil.cloud.switch.ch/fma/fma_metadata.zip", "-O", "/content/fma_metadata.zip"], check=True, timeout=TIMEOUT_SECONDS)
            logging.info("Extracting fma_metadata.zip to %s...", os.path.dirname(METADATA_PATH))
            subprocess.run(["unzip", "-q", "/content/fma_metadata.zip", "-d", os.path.dirname(METADATA_PATH)], check=True, timeout=TIMEOUT_SECONDS)

            # Verify metadata and audio files
            if not os.path.exists(METADATA_PATH):
                logging.error("FMA metadata file (%s) not found after extraction.", METADATA_PATH)
                if attempt == max_retries:
                    return False
                logging.info("Retrying download after %d seconds...", retry_delay)
                import time
                time.sleep(retry_delay)
                continue

            mp3_files = [f for f in os.listdir(AUDIO_PATH) if f.endswith('.mp3')] if os.path.exists(AUDIO_PATH) else []
            metadata_files = [f for f in os.listdir(os.path.dirname(METADATA_PATH)) if f.endswith('.csv')]
            logging.info("FMA dataset downloaded successfully. Found %d audio files and %d metadata files.", len(mp3_files), len(metadata_files))
            return True
        except (subprocess.CalledProcessError, subprocess.TimeoutExpired, Exception) as e:
            logging.error("Failed to download FMA dataset on attempt %d: %s", attempt, str(e))
            if attempt == max_retries:
                logging.error("Max retries reached. Falling back to synthetic dataset.")
                return False
            logging.info("Retrying download after %d seconds...", retry_delay)
            import time
            time.sleep(retry_delay)

def generate_lyrics(args):
    """Helper function to generate richer synthetic lyrics."""
    track_id, genre, lyrics_path = args
    keywords = random.sample(GENRE_KEYWORDS[genre], min(3, len(GENRE_KEYWORDS[genre])))
    verse_lines = [
        f"Feel the {keywords[0].split()[0]} in your soul, let it take control.",
        f"{keywords[1].split()[0]} carries you away, into the night and day.",
        f"With every {keywords[2].split()[0]} the heart beats strong, singing {genre}’s song."
    ]
    chorus_lines = [
        f"{keywords[1].split()[0]} vibes, we’re alive, dancing through the night.",
        f"{genre} spirit, feel it rise, reaching for the skies.",
        f"Let the {keywords[0].split()[0]} flow, take us where we go."
    ]
    lyrics = f"{genre} song: {', '.join(keywords)}.\n" + \
             f"Verse 1:\n{verse_lines[0]}\n{verse_lines[1]}\n{verse_lines[2]}\n" + \
             f"Chorus:\n{chorus_lines[0]}\n{chorus_lines[1]}\n{chorus_lines[2]}"
    with open(os.path.join(lyrics_path, f"{track_id}.txt"), 'w', encoding='utf-8') as f:
        f.write(lyrics)
    return track_id

def create_synthetic_dataset():
    """Task A1, A3, A5: Create synthetic dataset with FMA fallback and denser user data."""
    logging.info("Creating synthetic dataset...")
    print("Creating synthetic dataset...")

    # Ensure all output directories exist
    os.makedirs(os.path.dirname(METADATA_PATH), exist_ok=True)
    os.makedirs(os.path.dirname(ARTISTS_PATH), exist_ok=True)
    os.makedirs(os.path.dirname(GENRES_PATH), exist_ok=True)
    os.makedirs(os.path.dirname(USER_DATA_PATH), exist_ok=True)
    os.makedirs(os.path.dirname(TAGS_PATH), exist_ok=True)
    logging.info("Created output directories: %s, %s, %s", os.path.dirname(METADATA_PATH), os.path.dirname(USER_DATA_PATH), os.path.dirname(TAGS_PATH))

    fma_available = download_fma_dataset()

    if fma_available:
        try:
            fma_metadata = pd.read_csv(METADATA_PATH, header=[0, 1], low_memory=False)
            fma_metadata.columns = ['_'.join(col).strip() if isinstance(col, tuple) else col for col in fma_metadata.columns]
            track_id_col = None
            title_col = None
            artist_id_col = None
            genre_top_col = None
            for col in fma_metadata.columns:
                col_lower = col.lower()
                if 'track_id' in col_lower or 'track.id' in col_lower:
                    track_id_col = col
                elif 'title' in col_lower:
                    title_col = col
                elif 'artist_id' in col_lower or 'artist.id' in col_lower:
                    artist_id_col = col
                elif 'genre_top' in col_lower or 'genre' in col_lower:
                    genre_top_col = col
            required_cols = [track_id_col, title_col, artist_id_col, genre_top_col]
            if None in required_cols:
                logging.error("Required columns missing in FMA metadata: %s", required_cols)
                fma_available = False
            else:
                fma_metadata = fma_metadata[[track_id_col, title_col, artist_id_col, genre_top_col]].dropna()
                fma_metadata = fma_metadata.rename(columns={
                    track_id_col: 'track_id',
                    title_col: 'title',
                    artist_id_col: 'artist_id',
                    genre_top_col: 'genre_top'
                })
                fma_metadata['track_id'] = fma_metadata['track_id'].astype(str).str.zfill(6)
                # Filter to tracks with existing audio files
                mp3_files = {f.replace('.mp3', '') for f in os.listdir(AUDIO_PATH) if f.endswith('.mp3')} if os.path.exists(AUDIO_PATH) else set()
                fma_metadata = fma_metadata[fma_metadata['track_id'].isin(mp3_files)]
                # Ensure exactly MAX_TRACKS by sampling or padding
                if len(fma_metadata) > MAX_TRACKS:
                    fma_metadata = fma_metadata.sample(n=MAX_TRACKS, random_state=42)
                elif len(fma_metadata) < MAX_TRACKS:
                    # Pad with synthetic data if FMA has fewer tracks
                    num_missing = MAX_TRACKS - len(fma_metadata)
                    genres = list(GENRE_KEYWORDS.keys())
                    synthetic_tracks = pd.DataFrame({
                        'track_id': [str(i).zfill(6) for i in range(len(fma_metadata) + 1, len(fma_metadata) + num_missing + 1)],
                        'title': [f"Track_{i}" for i in range(len(fma_metadata) + 1, len(fma_metadata) + num_missing + 1)],
                        'artist_id': [str(i).zfill(6) for i in range(len(fma_metadata) + 1, len(fma_metadata) + num_missing + 1)],
                        'genre_top': [random.choice(genres) for _ in range(num_missing)]
                    })
                    fma_metadata = pd.concat([fma_metadata, synthetic_tracks], ignore_index=True)
                fma_metadata['genre_id'] = fma_metadata['genre_top'].map({g: i+1 for i, g in enumerate(GENRE_KEYWORDS.keys())}).fillna(random.randint(1, 13))
        except (FileNotFoundError, Exception) as e:
            logging.error("Error loading FMA metadata: %s", str(e))
            fma_available = False

    if not fma_available:
        logging.warning("Falling back to synthetic dataset due to missing or invalid FMA data...")
        genres = list(GENRE_KEYWORDS.keys())
        df_tracks = pd.DataFrame({
            'track_id': [str(i).zfill(6) for i in range(1, MAX_TRACKS + 1)],
            'title': [f"Track_{i}" for i in range(1, MAX_TRACKS + 1)],
            'artist_id': [str(i).zfill(6) for i in range(1, MAX_TRACKS + 1)],
            'genre_id': [random.randint(1, 13) for _ in range(MAX_TRACKS)],
            'genre_top': [random.choice(genres) for _ in range(MAX_TRACKS)]
        })
    else:
        df_tracks = fma_metadata[['track_id', 'title', 'artist_id', 'genre_id', 'genre_top']]

    df_tracks.to_csv(METADATA_PATH, index=False)
    logging.info("Saved tracks metadata to %s", METADATA_PATH)

    df_artists = pd.DataFrame({
        'artist_id': df_tracks['artist_id'],
        'artist_name': [f"Artist_{i}" for i in range(1, len(df_tracks) + 1)]
    })
    df_artists.to_csv(ARTISTS_PATH, index=False)
    logging.info("Saved artists metadata to %s", ARTISTS_PATH)

    df_genres = pd.DataFrame({
        'genre_id': range(1, 14),
        'genre_name': list(GENRE_KEYWORDS.keys())
    })
    df_genres.to_csv(GENRES_PATH, index=False)
    logging.info("Saved genres metadata to %s", GENRES_PATH)

    genre_rating_bias = {
        'Pop': 0.2, 'Rock': 0.1, 'Jazz': -0.1, 'Classical': -0.2, 'Hip-Hop': 0.15,
        'Electronic': 0.1, 'Folk': -0.05, 'Blues': -0.1, 'Country': -0.05, 'Reggae': 0.05,
        'International': 0.05, 'Instrumental': -0.2, 'Experimental': -0.1
    }
    ratings = []
    for track_id, genre in zip(df_tracks['track_id'], df_tracks['genre_top']):
        for user_id in [f"user_{i+1}" for i in range(20)]:
            if np.random.random() < 0.9:
                rating = min(max(np.random.normal(3.0 + genre_rating_bias.get(genre, 0), 0.5), 1), 5)
                ratings.append({'user_id': user_id, 'track_id': track_id, 'rating': rating})
    ratings = pd.DataFrame(ratings)
    ratings.to_csv(USER_DATA_PATH, index=False)
    logging.info("Saved user ratings to %s", USER_DATA_PATH)

    tags = []
    for track_id, genre in zip(df_tracks['track_id'], df_tracks['genre_top']):
        num_tags = random.randint(3, 6)
        track_tags = random.sample(GENRE_KEYWORDS.get(genre, ['generic']), min(num_tags, len(GENRE_KEYWORDS.get(genre, ['generic']))))
        tags.extend([{'track_id': track_id, 'tag': tag} for tag in track_tags])
    pd.DataFrame(tags).to_csv(TAGS_PATH, index=False)
    logging.info("Saved tags to %s", TAGS_PATH)

    os.makedirs(AUDIO_PATH, exist_ok=True)
    os.makedirs(LYRICS_PATH, exist_ok=True)
    lyrics_args = [(track_id, genre, LYRICS_PATH) for track_id, genre in zip(df_tracks['track_id'], df_tracks['genre_top'])]
    with Pool(processes=4) as pool:
        list(tqdm(pool.imap(generate_lyrics, lyrics_args), total=len(lyrics_args), desc="Generating synthetic lyrics"))

    print(f"Synthetic dataset created: Tracks shape: {df_tracks.shape}, Ratings shape: {ratings.shape}, Tags: {len(tags)}")
    logging.info("Synthetic dataset created: Tracks shape: %s, Ratings: %s, Tags: %d", df_tracks.shape, ratings.shape, len(tags))

def extract_audio_features(audio_path):
    """Task A1.1, A2: Extract audio features using Librosa."""
    try:
        if not os.path.exists(audio_path):
            return np.zeros(26)  # Silently return zeros for missing files
        if os.path.getsize(audio_path) < 100:
            return np.zeros(26)  # Silently return zeros for small files
        y, sr = librosa.load(audio_path, sr=22050)
        if len(y) == 0:
            return np.zeros(26)  # Silently return zeros for empty audio
        mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=12)
        chroma = librosa.feature.chroma_stft(y=y, sr=sr)
        spectral_centroid = librosa.feature.spectral_centroid(y=y, sr=sr)
        tempo = librosa.feature.tempo(y=y, sr=sr)
        tempo = float(tempo[0]) if isinstance(tempo, np.ndarray) else float(tempo)
        features = np.concatenate([
            np.mean(mfccs, axis=1),
            np.mean(chroma, axis=1),
            np.mean(spectral_centroid, axis=1),
            [tempo]
        ])
        return features
    except Exception as e:
        logging.warning("Error processing %s: %s", audio_path, str(e))
        return np.zeros(26)

def load_fma_data(audio_path, metadata_path, artists_path, genres_path, lyrics_path, tags_path):
    """Task A1, A3, A5: Load synthetic dataset."""
    logging.info("Loading synthetic data...")
    print("Loading synthetic data...")
    create_synthetic_dataset()

    df_metadata = pd.read_csv(metadata_path)
    df_metadata = pd.merge(df_metadata, pd.read_csv(artists_path), on='artist_id', how='left')
    df_metadata = pd.merge(df_metadata, pd.read_csv(genres_path), on='genre_id', how='left')
    df_metadata = df_metadata[['track_id', 'artist_name', 'title', 'genre_name']].rename(columns={'genre_name': 'genre'})
    df_metadata['track_id'] = df_metadata['track_id'].astype(str).str.zfill(6)
    logging.info("Metadata shape: %s", df_metadata.shape)
    print(f"Metadata shape: {df_metadata.shape}")

    # Ensure all track_ids are valid for feature extraction
    valid_track_ids = df_metadata['track_id'].tolist()
    features = []
    for track_id in valid_track_ids:
        audio_file = os.path.join(audio_path, f"{track_id}.mp3")
        audio_features = extract_audio_features(audio_file)
        features.append([track_id] + audio_features.tolist())
    feature_columns = ['track_id'] + [f'mfcc_{i+1}' for i in range(12)] + [f'chroma_{i+1}' for i in range(12)] + ['spectral_centroid', 'tempo']
    if len(feature_columns) != 27:
        logging.error("Feature columns mismatch: expected 27, got %d", len(feature_columns))
        raise ValueError(f"Feature columns mismatch: expected 27, got {len(feature_columns)}")
    df_features = pd.DataFrame(features, columns=feature_columns)
    logging.info("Features shape: %s", df_features.shape)
    print(f"Features shape: {df_features.shape}")

    lyrics_dict = {}
    for lyric_file in Path(lyrics_path).glob("*.txt"):
        track_id = lyric_file.stem
        if track_id in valid_track_ids:
            with open(lyric_file, 'r', encoding='utf-8') as f:
                lyrics_dict[track_id] = f.read().strip() or 'music'

    if tags_path and os.path.exists(tags_path):
        df_tags = pd.read_csv(tags_path)
        for _, row in df_tags.iterrows():
            track_id = str(row['track_id']).zfill(6)
            if track_id in valid_track_ids:
                lyrics_dict[track_id] = lyrics_dict.get(track_id, '') + " " + str(row['tag'])

    logging.info("Lyrics dict size: %d", len(lyrics_dict))
    print(f"Lyrics dict size: {len(lyrics_dict)}")
    return df_metadata, df_features, lyrics_dict

def generate_text_embeddings(lyrics_dict):
    """Task A1.2, A3: Generate semantic embeddings for lyrics."""
    logging.info("Generating text embeddings...")
    print("Generating text embeddings...")

    model = SentenceTransformer('all-MiniLM-L6-v2', device=DEVICE)
    embeddings = {}

    for track_id, text in tqdm(lyrics_dict.items(), desc="Generating text embeddings"):
        embeddings[track_id] = model.encode(
            text,
            convert_to_tensor=True,
            device=DEVICE,
            show_progress_bar=False  # 👈 disables internal batching progress
        ).cpu().numpy()

    logging.info("Text embeddings generated for %d tracks.", len(embeddings))
    print(f"Text embeddings generated for {len(embeddings)} tracks.")
    return embeddings

def analyze_linguistic_patterns(df_metadata, lyrics_dict):
    """Task A1.2: Analyze linguistic patterns and topics."""
    print("\nLinguistic Analysis:")
    stop_words = set(stopwords.words('english'))
    genre_words = {g: [] for g in df_metadata['genre'].unique()}
    for track_id, text in lyrics_dict.items():
        genre = df_metadata[df_metadata['track_id'] == track_id]['genre'].iloc[0]
        tokens = word_tokenize(text.lower())
        tokens = [t for t in tokens if t.isalpha() and t not in stop_words]
        genre_words[genre].extend(tokens)
    for genre, words in genre_words.items():
        top_words = Counter(words).most_common(5)
        print(f"{genre} top words: {top_words}")
        logging.info("%s top words: %s", genre, top_words)

    vectorizer = CountVectorizer(stop_words='english')
    X = vectorizer.fit_transform(lyrics_dict.values())
    lda = LatentDirichletAllocation(n_components=5, random_state=42)
    lda.fit(X)
    for i, topic in enumerate(lda.components_):
        print(f"Topic {i}: {[vectorizer.get_feature_names_out()[j] for j in topic.argsort()[-5:]]}")

class AudioTextBERTClassifier(nn.Module):
    """BERT with audio feature integration."""
    def __init__(self, bert_model, audio_dim, num_classes, hidden_dim=128):
        super().__init__()
        self.bert = bert_model  # Now using BertModel
        self.audio_layer = nn.Linear(audio_dim, hidden_dim)
        self.combined_layer = nn.Linear(self.bert.config.hidden_size + hidden_dim, hidden_dim)
        self.dropout = nn.Dropout(0.3)
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, input_ids, attention_mask, audio_features):
        bert_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        text_out = bert_outputs.pooler_output  # Safe to use with BertModel
        audio_out = torch.relu(self.audio_layer(audio_features))
        combined = torch.cat([text_out, audio_out], dim=1)
        combined = torch.relu(self.combined_layer(combined))
        combined = self.dropout(combined)
        return self.fc(combined)

class HybridRecommender(nn.Module):
    """Task A5: Content-based recommender."""
    def __init__(self, audio_dim, text_dim, hidden_dim=128):
        super().__init__()
        self.audio_layer = nn.Linear(audio_dim, hidden_dim)
        self.text_layer = nn.Linear(text_dim, hidden_dim)
        self.dropout = nn.Dropout(0.3)
        self.fc = nn.Linear(hidden_dim * 2, 1)

    def forward(self, audio_features, text_features):
        audio_out = torch.relu(self.audio_layer(audio_features))
        text_out = torch.relu(self.text_layer(text_features))
        combined = torch.cat([audio_out, text_out], dim=1)
        combined = self.dropout(combined)
        return torch.sigmoid(self.fc(combined))

def train_recommender(model, train_loader, criterion, optimizer):
    """Task A5: Train the recommender."""
    model.train()
    total_loss = 0
    for audio_features, text_features, ratings in train_loader:
        audio_features, text_features, ratings = audio_features.to(DEVICE), text_features.to(DEVICE), ratings.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(audio_features, text_features)
        loss = criterion(outputs, ratings.unsqueeze(1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

def evaluate_recommender(model, test_loader, df_metadata, test_idx_resampled, original_indices_resampled, genres, k=10):
    """Task C5: Evaluate recommender."""
    model.eval()
    precisions, recalls, maps, ndcgs, diversities = [], [], [], [], []
    global train_indices
    if 'train_indices' not in globals():
        train_indices = []
    train_track_ids = set(df_metadata.iloc[train_indices]['track_id'])
    with torch.no_grad():
        for batch_idx, (audio_features, text_features, ratings) in enumerate(test_loader):
            audio_features, text_features, ratings = audio_features.to(DEVICE), text_features.to(DEVICE), ratings.to(DEVICE)
            outputs = model(audio_features, text_features)
            binary_ratings = (ratings > 0.6).float()
            top_k = torch.topk(outputs, k=min(k, outputs.size(0)), dim=0).indices.flatten()
            batch_start = batch_idx * test_loader.batch_size
            batch_end = batch_start + len(audio_features)
            batch_test_idx_resampled = test_idx_resampled[batch_start:batch_end]
            batch_original_indices = original_indices_resampled[batch_test_idx_resampled]
            top_k_original_indices = batch_original_indices[top_k.cpu()]
            top_k_ids = df_metadata.iloc[top_k_original_indices]['track_id']
            top_k_genres = df_metadata.iloc[top_k_original_indices]['genre']
            relevant = binary_ratings[top_k]
            precision = relevant.mean().item()
            recall = relevant.sum().item() / binary_ratings.sum().item() if binary_ratings.sum() > 0 else 0.0
            precisions.append(precision)
            recalls.append(recall)
            if binary_ratings.sum() > 0:
                map_score = average_precision_score(binary_ratings.cpu().numpy(), outputs.cpu().numpy())
                ndcg = ndcg_score(binary_ratings.cpu().numpy().reshape(1, -1), outputs.cpu().numpy().reshape(1, -1), k=k)
            else:
                map_score, ndcg = 0.0, 0.0
            maps.append(map_score)
            ndcgs.append(ndcg)
            diversity = len(set(top_k_genres)) / len(genres) if len(genres) > 0 else 1.0
            diversities.append(diversity)
            novelty = 1 - len(set(top_k_ids).intersection(train_track_ids)) / len(top_k_ids) if len(top_k_ids) > 0 else 1.0
    return {
        'precision@k': np.mean(precisions) if precisions else 0.0,
        'recall@k': np.mean(recalls) if recalls else 0.0,
        'map': np.mean(maps) if maps else 0.0,
        'ndcg@k': np.mean(ndcgs) if ndcgs else 0.0,
        'diversity': np.mean(diversities) if diversities else 1.0,
        'novelty': novelty
    }

def collaborative_filtering(user_data, df_metadata, k=10):
    """Task A5: Collaborative filtering using SVD."""
    print("\nTraining Collaborative Filtering Model (SVD)...")
    valid_track_ids = set(df_metadata['track_id'].astype(str).str.zfill(6))
    user_data = user_data[user_data['track_id'].isin(valid_track_ids)]

    if user_data.empty:
        logging.error("No valid user data.")
        return {'precision@k': 0.0}

    train_data, test_data = train_test_split(user_data, test_size=0.2, random_state=42)

    # Deduplicate both train and test sets
    train_data = train_data.groupby(['user_id', 'track_id'], as_index=False)['rating'].mean()
    test_data = test_data.groupby(['user_id', 'track_id'], as_index=False)['rating'].mean()

    user_item_matrix = train_data.pivot(index='user_id', columns='track_id', values='rating').fillna(0)
    if user_item_matrix.empty:
        logging.error("Empty user-item matrix.")
        return {'precision@k': 0.0}

    svd = TruncatedSVD(n_components=20, random_state=42)
    user_features = svd.fit_transform(user_item_matrix)
    item_features = svd.components_
    predictions = np.dot(user_features, item_features)
    predicted_ratings = pd.DataFrame(predictions, index=user_item_matrix.index, columns=user_item_matrix.columns)

    test_matrix = test_data.pivot(index='user_id', columns='track_id', values='rating').fillna(0)
    shared_tracks = set(test_matrix.columns).intersection(set(predicted_ratings.columns))
    precisions = []
    for user_id in test_matrix.index:
        if user_id in predicted_ratings.index:
            true_ratings = test_matrix.loc[user_id]
            pred_ratings = predicted_ratings.loc[user_id]
            valid_top_k = pred_ratings[pred_ratings.index.isin(shared_tracks)].sort_values(ascending=False).index[:k]
            if len(valid_top_k) == 0:
                precisions.append(0.0)
                continue
            relevant = (true_ratings[valid_top_k] > 3.0).astype(int)
            precision = relevant.mean()
            precisions.append(precision)
    return {'precision@k': np.mean(precisions) if precisions else 0.0}

class GenreClassifier(nn.Module):
    """Task A1.1: Custom genre classifier."""
    def __init__(self, audio_dim, text_dim, num_classes, hidden_dim=128):
        super().__init__()
        self.audio_layer = nn.Linear(audio_dim, hidden_dim)
        self.text_layer = nn.Linear(text_dim, hidden_dim)
        self.dropout = nn.Dropout(0.3)
        self.fc = nn.Linear(hidden_dim * 2, num_classes)

    def forward(self, audio_features, text_features):
        audio_out = torch.relu(self.audio_layer(audio_features))
        text_out = torch.relu(self.text_layer(text_features))
        combined = torch.cat([audio_out, text_out], dim=1)
        combined = self.dropout(combined)
        return self.fc(combined)

def train_classifier(model, train_loader, criterion, optimizer):
    """Task A1.1: Train the classifier."""
    model.train()
    total_loss = 0
    for audio_features, text_features, labels in train_loader:
        audio_features, text_features, labels = audio_features.to(DEVICE), text_features.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(audio_features, text_features)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

def evaluate_classifier(model, test_loader, genres):
    """Task C5: Evaluate classifier with error analysis."""
    model.eval()
    y_true, y_pred = [], []
    with torch.no_grad():
        for audio_features, text_features, labels in test_loader:
            audio_features, text_features, labels = audio_features.to(DEVICE), text_features.to(DEVICE), labels.to(DEVICE)
            outputs = model(audio_features, text_features)
            _, predicted = torch.max(outputs, 1)
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(predicted.cpu().numpy())
    metrics = precision_recall_fscore_support(y_true, y_pred, average='weighted', zero_division=0)
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=genres, yticklabels=genres)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.savefig(os.path.join(OUTPUT_DIR, 'confusion_matrix.png'))
    plt.close()

    errors = [(i, genres[true], genres[pred]) for i, (true, pred) in enumerate(zip(y_true, y_pred)) if true != pred]
    error_counts = Counter((true, pred) for _, true, pred in errors)
    print("\nMisclassification Patterns:")
    for (true, pred), count in error_counts.most_common():
        print(f"True: {true}, Predicted: {pred}, Count: {count}")

    return {'precision': metrics[0], 'recall': metrics[1], 'f1': metrics[2]}

def ast_classifier(audio_files, labels, genres, df_metadata):
    """Task A2: AST-based audio classification."""
    print("\nSetting up AST Classifier...")
    try:
        feature_extractor = AutoFeatureExtractor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.2")
        model = AutoModelForAudioClassification.from_pretrained(
            "MIT/ast-finetuned-audioset-10-10-0.2",
            num_labels=len(genres),
            label2id={genre: idx for idx, genre in enumerate(genres)},
            id2label={idx: genre for idx, genre in enumerate(genres)}
        ).to(DEVICE)
    except Exception as e:
        logging.error("Failed to load AST model: %s", str(e))
        return {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}

    valid_audio_files = [f for f in audio_files if os.path.exists(f) and os.path.getsize(f) > 100]
    valid_labels = [labels[i] for i, f in enumerate(audio_files) if f in valid_audio_files]

    def preprocess_audio(audio_file, label):
        try:
            y, sr = librosa.load(audio_file, sr=16000, duration=5.0)
            inputs = feature_extractor(y, sampling_rate=16000, return_tensors="np", padding="max_length", max_length=80000)
            return {
                "input_values": inputs.input_values[0],
                "labels": label
            }
        except:
            return None

    data = [preprocess_audio(f, l) for f, l in zip(valid_audio_files, valid_labels)]
    data = [d for d in data if d is not None]
    if not data:
        logging.error("No valid audio data.")
        return {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}

    dataset = Dataset.from_dict({
        "input_values": [d["input_values"] for d in data],
        "labels": [d["labels"] for d in data]
    })
    dataset = dataset.train_test_split(test_size=0.2, seed=42)

    training_args = TrainingArguments(
        output_dir=os.path.join(OUTPUT_DIR, "ast-finetuned"),
        per_device_train_batch_size=2,
        per_device_eval_batch_size=2,
        num_train_epochs=3,
        eval_strategy="steps",
        save_strategy="steps",
        save_steps=500,
        logging_steps=100,
        learning_rate=3e-5,
        save_total_limit=2,
        load_best_model_at_end=True
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset["train"],
        eval_dataset=dataset["test"],
        compute_metrics=lambda pred: {'precision': precision_recall_fscore_support(pred.label_ids, pred.predictions.argmax(-1), average='weighted', zero_division=0)[0],
                                     'recall': precision_recall_fscore_support(pred.label_ids, pred.predictions.argmax(-1), average='weighted', zero_division=0)[1],
                                     'f1': precision_recall_fscore_support(pred.label_ids, pred.predictions.argmax(-1), average='weighted', zero_division=0)[2]}
    )

    trainer.train()
    eval_results = trainer.evaluate()
    return {'precision': eval_results['eval_precision'], 'recall': eval_results['eval_recall'], 'f1': eval_results['eval_f1']}

def music_search(query, df, text_embeddings, audio_features_df, model, scaler, k=10):
    """Task A3: Content-based music discovery with retrieval metrics."""
    logging.info("Performing music search for query: %s", query)
    print(f"Performing music search for query: {query}")

    query_embedding = model.encode(query, convert_to_tensor=True, device=DEVICE).cpu().numpy()
    query_embedding = query_embedding / np.linalg.norm(query_embedding)

    text_embedding_matrix = np.array([text_embeddings.get(tid, np.zeros(384)) for tid in df['track_id']])
    text_embedding_matrix = text_embedding_matrix / np.linalg.norm(text_embedding_matrix, axis=1, keepdims=True)
    text_index = faiss.IndexFlatIP(text_embedding_matrix.shape[1])
    text_index.add(text_embedding_matrix.astype(np.float32))

    text_scores, text_indices = text_index.search(query_embedding.reshape(1, -1).astype(np.float32), k=20)

    audio_features = audio_features_df[[col for col in audio_features_df.columns if col != 'track_id']].values
    audio_features = scaler.transform(np.nan_to_num(audio_features))
    audio_features = audio_features / np.linalg.norm(audio_features, axis=1, keepdims=True)

    combined_scores = {}
    for idx, score in zip(text_indices[0], text_scores[0]):
        track_id = df['track_id'].iloc[idx]
        genre = df[df['track_id'] == track_id]['genre'].iloc[0]
        genre_weight = 1.5 if genre == 'Rock' else 1.0
        combined_scores[track_id] = 0.7 * score * genre_weight
        audio_idx = audio_features_df.index[audio_features_df['track_id'] == track_id].tolist()
        if audio_idx:
            audio_idx = audio_idx[0]
            genre_tracks = df[df['genre'] == genre]['track_id']
            genre_audio_features = audio_features[audio_features_df['track_id'].isin(genre_tracks)]
            if len(genre_audio_features) > 0:
                avg_genre_audio = np.mean(genre_audio_features, axis=0)
                audio_similarity = 1 - cosine(audio_features[audio_idx], avg_genre_audio)
                combined_scores[track_id] += 0.3 * audio_similarity * genre_weight

    top_tracks = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)[:k]
    result_ids = [track_id for track_id, _ in top_tracks]
    results = df[df['track_id'].isin(result_ids)][['track_id', 'artist_name', 'title', 'genre']]

    # Calculate retrieval metrics
    relevant_tracks = df[df['genre'] == 'Rock']['track_id'].values  # Assuming query targets Rock
    y_true = [1 if track_id in relevant_tracks else 0 for track_id in result_ids]
    y_scores = [combined_scores[track_id] for track_id in result_ids]
    precision = sum(y_true) / len(y_true) if len(y_true) > 0 else 0.0
    recall = sum(y_true) / len(relevant_tracks) if len(relevant_tracks) > 0 else 0.0
    map_score = average_precision_score(y_true, y_scores) if sum(y_true) > 0 else 0.0
    ndcg = ndcg_score(np.array(y_true).reshape(1, -1), np.array(y_scores).reshape(1, -1), k=k) if sum(y_true) > 0 else 0.0
    diversity = len(set(results['genre'])) / len(df['genre'].unique()) if len(df['genre'].unique()) > 0 else 1.0
    novelty = len(set(result_ids) - set(df['track_id'].iloc[train_indices])) / len(result_ids) if len(result_ids) > 0 else 1.0

    print(f"\nSearch Results for '{query}':")
    print(results.to_string(index=False))
    print(f"Retrieval Metrics: Precision@{k}: {precision:.4f}, Recall@{k}: {recall:.4f}, MAP: {map_score:.4f}, NDCG@{k}: {ndcg:.4f}, Diversity: {diversity:.4f}, Novelty: {novelty:.4f}")
    return results, {'precision@k': precision, 'recall@k': recall, 'map': map_score, 'ndcg@k': ndcg, 'diversity': diversity, 'novelty': novelty}

def zero_shot_classification(df_metadata, lyrics_dict):
    """Task A1.2: Zero-shot genre classification."""
    print("\nZero-Shot Genre Classification:")
    classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli", device=0 if torch.cuda.is_available() else -1)
    genres = df_metadata['genre'].unique().tolist()
    results = []
    for _, row in df_metadata.head(5).iterrows():
        text = row['title'] + " " + lyrics_dict.get(row['track_id'], "")
        scores = classifier(text, candidate_labels=genres, multi_label=False)
        results.append((row['track_id'], scores['labels'][0], scores['scores'][0], row['genre']))
        print(f"Track {row['track_id']}: Predicted={scores['labels'][0]} ({scores['scores'][0]:.4f}), True={row['genre']}")
    accuracy = sum(1 for _, pred, _, true in results if pred == true) / len(results)
    print(f"Zero-Shot Accuracy: {accuracy:.4f}")
    return results

def few_shot_classification(df_metadata, lyrics_dict):
    """Task A1.2: Few-shot genre classification."""
    print("\nFew-Shot Genre Classification:")
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=len(df_metadata['genre'].unique())).to(DEVICE)

    genres = df_metadata['genre'].unique().tolist()
    few_shot_data = []
    for genre in genres:
        genre_tracks = df_metadata[df_metadata['genre'] == genre].head(5)
        for _, row in genre_tracks.iterrows():
            text = row['title'] + " " + lyrics_dict.get(row['track_id'], "")
            inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128).to(DEVICE)
            few_shot_data.append((inputs, genres.index(genre)))

    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
    for epoch in range(5):
        for inputs, label in few_shot_data:
            optimizer.zero_grad()
            outputs = model(**inputs, labels=torch.tensor([label]).to(DEVICE))
            loss = outputs.loss
            loss.backward()
            optimizer.step()

    model.eval()
    results = []
    for _, row in df_metadata.head(5).iterrows():
        text = row['title'] + " " + lyrics_dict.get(row['track_id'], "")
        inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128).to(DEVICE)
        with torch.no_grad():
            outputs = model(**inputs)
            predicted = torch.argmax(outputs.logits, dim=1).item()
        results.append((row['track_id'], genres[predicted], row['genre']))
        print(f"Track {row['track_id']}: Predicted={genres[predicted]}, True={row['genre']}")
    accuracy = sum(1 for _, pred, true in results if pred == true) / len(results)
    print(f"Few-Shot Accuracy: {accuracy:.4f}")
    return results

def analyze_feature_contribution(model, test_loader, feature_columns):
    """Task A1.1: Analyze audio feature contributions."""
    print("\nFeature Contribution Analysis:")
    model.eval()
    contributions = {col: [] for col in feature_columns}
    with torch.no_grad():
        for audio_features, text_features, _ in test_loader:
            audio_features, text_features = audio_features.to(DEVICE), text_features.to(DEVICE)
            baseline_output = model(audio_features, text_features)
            for i, col in enumerate(feature_columns):
                modified_features = audio_features.clone()
                modified_features[:, i] = 0
                modified_output = model(modified_features, text_features)
                diff = torch.mean(torch.abs(baseline_output - modified_output)).item()
                contributions[col].append(diff)
    for col in feature_columns:
        print(f"{col}: Mean contribution = {np.mean(contributions[col]):.4f}")
    return contributions

# def main():
#     """Main function orchestrating all tasks."""
#     logging.info("Starting main execution...")
#     print("Starting main execution...")

#     df_metadata, df_features, lyrics_dict = load_fma_data(AUDIO_PATH, METADATA_PATH, ARTISTS_PATH, GENRES_PATH, LYRICS_PATH, TAGS_PATH)

#     if df_metadata.empty or df_features.empty or not lyrics_dict:
#         logging.error("Data loading failed.")
#         return

#     text_embeddings = generate_text_embeddings(lyrics_dict)
#     analyze_linguistic_patterns(df_metadata, lyrics_dict)
#     zero_shot_classification(df_metadata, lyrics_dict)
#     few_shot_classification(df_metadata, lyrics_dict)

#     valid_track_ids = df_features['track_id'].tolist()
#     df_metadata = df_metadata[df_metadata['track_id'].isin(valid_track_ids)]
#     text_features = np.array([text_embeddings.get(tid, np.zeros(384)) for tid in df_metadata['track_id']])
#     scaler = StandardScaler()
#     audio_features = scaler.fit_transform(df_features[[col for col in df_features.columns if col != 'track_id']].values)
#     audio_features = np.nan_to_num(audio_features)

#     user_data = pd.read_csv(USER_DATA_PATH)
#     user_data['track_id'] = user_data['track_id'].astype(str).str.zfill(6)
#     user_data = user_data[user_data['track_id'].isin(df_metadata['track_id'])]
#     ratings_agg = user_data.groupby('track_id')['rating'].mean().reindex(df_metadata['track_id']).fillna(3.0).values
#     ratings = np.clip(ratings_agg / 5.0, 0.0, 1.0)

#     X = np.hstack([audio_features, text_features])
#     if X.shape[0] != len(ratings):
#         logging.error("Mismatch in samples: X has %d samples, ratings has %d", X.shape[0], len(ratings))
#         raise ValueError(f"Mismatch in samples: X has {X.shape[0]} samples, ratings has {len(ratings)}")

#     original_indices = np.arange(len(X))
#     binary_ratings = (ratings > 0.5).astype(int)
#     class_counts = Counter(binary_ratings)
#     min_class_size = min(class_counts.values())
#     logging.info("Class distribution: %s", class_counts)
#     print(f"Class distribution: {class_counts}")

#     X_resampled, ratings_resampled, original_indices_resampled = X, binary_ratings, original_indices
#     if min_class_size >= 2:
#         k_neighbors = min(5, min_class_size - 1)
#         smote = SMOTE(random_state=42, k_neighbors=k_neighbors)
#         try:
#             X_resampled, ratings_resampled = smote.fit_resample(X, binary_ratings)
#             original_indices_resampled = np.concatenate([
#                 original_indices[ratings == 0],
#                 original_indices[ratings == 1],
#                 np.random.choice(original_indices[ratings == 0], size=len(X_resampled) - len(X), replace=True)
#             ])
#             shuffle_indices = np.random.permutation(len(X_resampled))
#             X_resampled = X_resampled[shuffle_indices]
#             ratings_resampled = ratings_resampled[shuffle_indices]
#             original_indices_resampled = original_indices_resampled[shuffle_indices]
#         except ValueError as e:
#             logging.error("SMOTE failed: %s", str(e))
#             print(f"SMOTE failed: {str(e)}. Using original data.")

#     kf = KFold(n_splits=5, shuffle=True, random_state=42)
#     rec_metrics_all = []
#     global train_indices

#     for fold, (train_idx_resampled, test_idx_resampled) in enumerate(kf.split(X_resampled)):
#         print(f"\nFold {fold+1}")
#         train_indices = original_indices_resampled[train_idx_resampled]
#         X_train_rec, X_test_rec = X_resampled[train_idx_resampled], X_resampled[test_idx_resampled]
#         y_train_rec, y_test_rec = ratings_resampled[train_idx_resampled], ratings_resampled[test_idx_resampled]

#         train_dataset_rec = torch.utils.data.TensorDataset(
#             torch.tensor(X_train_rec[:, :audio_features.shape[1]], dtype=torch.float32),
#             torch.tensor(X_train_rec[:, audio_features.shape[1]:], dtype=torch.float32),
#             torch.tensor(y_train_rec, dtype=torch.float32)
#         )
#         train_loader_rec = torch.utils.data.DataLoader(train_dataset_rec, batch_size=BATCH_SIZE, shuffle=True)
#         test_dataset_rec = torch.utils.data.TensorDataset(
#             torch.tensor(X_test_rec[:, :audio_features.shape[1]], dtype=torch.float32),
#             torch.tensor(X_test_rec[:, audio_features.shape[1]:], dtype=torch.float32),
#             torch.tensor(y_test_rec, dtype=torch.float32)
#         )
#         test_loader_rec = torch.utils.data.DataLoader(test_dataset_rec, batch_size=BATCH_SIZE)

#         recommender = HybridRecommender(audio_dim=audio_features.shape[1], text_dim=384).to(DEVICE)
#         criterion_rec = nn.BCELoss()
#         optimizer_rec = torch.optim.Adam(recommender.parameters(), lr=0.001, weight_decay=1e-4)

#         print("\nTraining Recommender Model...")
#         for epoch in range(NUM_EPOCHS_REC):
#             loss = train_recommender(recommender, train_loader_rec, criterion_rec, optimizer_rec)
#             print(f"Recommendation Epoch {epoch+1}, Loss: {loss:.4f}")

#         rec_metrics = evaluate_recommender(recommender, test_loader_rec, df_metadata, test_idx_resampled, original_indices_resampled, df_metadata['genre'].unique())
#         rec_metrics_all.append(rec_metrics)
#         print(f"Recommendation Metrics: {rec_metrics}")

#     avg_rec_metrics = {k: np.mean([m[k] for m in rec_metrics_all]) for k in rec_metrics_all[0]}
#     print(f"Average Recommendation Metrics: {avg_rec_metrics}")

#     cf_metrics = collaborative_filtering(user_data, df_metadata)
#     print(f"Collaborative Filtering Metrics: {cf_metrics}")

#     genres = df_metadata['genre'].unique()
#     genre_to_idx = {g: i for i, g in enumerate(genres)}
#     labels = df_metadata['genre'].map(genre_to_idx).values

#     cls_metrics_all = []
#     for fold, (train_idx_resampled, test_idx_resampled) in enumerate(kf.split(X_resampled)):
#         print(f"\nFold {fold+1}")
#         X_train_cls, X_test_cls = X_resampled[train_idx_resampled], X_resampled[test_idx_resampled]
#         y_train_cls = labels[original_indices_resampled[train_idx_resampled]]
#         y_test_cls = labels[original_indices_resampled[test_idx_resampled]]

#         train_dataset_cls = torch.utils.data.TensorDataset(
#             torch.tensor(X_train_cls[:, :audio_features.shape[1]], dtype=torch.float32),
#             torch.tensor(X_train_cls[:, audio_features.shape[1]:], dtype=torch.float32),
#             torch.tensor(y_train_cls, dtype=torch.long)
#         )
#         train_loader_cls = torch.utils.data.DataLoader(train_dataset_cls, batch_size=BATCH_SIZE, shuffle=True)
#         test_dataset_cls = torch.utils.data.TensorDataset(
#             torch.tensor(X_test_cls[:, :audio_features.shape[1]], dtype=torch.float32),
#             torch.tensor(X_test_cls[:, audio_features.shape[1]:], dtype=torch.float32),
#             torch.tensor(y_test_cls, dtype=torch.long)
#         )
#         test_loader_cls = torch.utils.data.DataLoader(test_dataset_cls, batch_size=BATCH_SIZE)

#         classifier = GenreClassifier(audio_dim=audio_features.shape[1], text_dim=384, num_classes=len(genres)).to(DEVICE)
#         criterion_cls = nn.CrossEntropyLoss()
#         optimizer_cls = torch.optim.Adam(classifier.parameters(), lr=0.001, weight_decay=1e-4)

#         print("\nTraining Genre Classifier Model...")
#         for epoch in range(NUM_EPOCHS_CLS):
#             loss = train_classifier(classifier, train_loader_cls, criterion_cls, optimizer_cls)
#             print(f"Classification Epoch {epoch+1}, Loss: {loss:.4f}")

#         cls_metrics = evaluate_classifier(classifier, test_loader_cls, genres)
#         cls_metrics_all.append(cls_metrics)
#         print(f"Classification Metrics: {cls_metrics}")

#     avg_cls_metrics = {k: np.mean([m[k] for m in cls_metrics_all]) for k in cls_metrics_all[0]}
#     print(f"Average Classification Metrics: {avg_cls_metrics}")

#     bert_metrics = bert_classifier(df_metadata, text_embeddings, audio_features, genres)
#     print(f"BERT Classifier Metrics: {bert_metrics}")

#     roberta_metrics = roberta_classifier(df_metadata, text_embeddings, genres)
#     print(f"RoBERTa Classifier Metrics: {roberta_metrics}")

#     audio_files = [os.path.join(AUDIO_PATH, f"{tid}.mp3") for tid in df_metadata['track_id']]
#     ast_metrics = ast_classifier(audio_files, labels, genres, df_metadata)
#     print(f"AST vs Custom Classifier: AST F1={ast_metrics['f1']:.4f}, Custom F1={avg_cls_metrics['f1']:.4f}")

#     search_model = SentenceTransformer('all-MiniLM-L6-v2', device=DEVICE)
#     query = "upbeat rock songs"
#     _, search_metrics = music_search(query, df_metadata, text_embeddings, df_features, search_model, scaler)
#     print(f"Search Metrics: {search_metrics}")

# if __name__ == "__main__":
#     main()

Installed faiss-cpu
Mounted at /content/drive


In [None]:
logging.info("Starting main execution...")
print("Starting main execution...")

df_metadata, df_features, lyrics_dict = load_fma_data(AUDIO_PATH, METADATA_PATH, ARTISTS_PATH, GENRES_PATH, LYRICS_PATH, TAGS_PATH)

if df_metadata.empty or df_features.empty or not lyrics_dict:
    logging.error("Data loading failed.")
    # return

text_embeddings = generate_text_embeddings(lyrics_dict)
analyze_linguistic_patterns(df_metadata, lyrics_dict)
zero_shot_classification(df_metadata, lyrics_dict)
few_shot_classification(df_metadata, lyrics_dict)

valid_track_ids = df_features['track_id'].tolist()
df_metadata = df_metadata[df_metadata['track_id'].isin(valid_track_ids)]
text_features = np.array([text_embeddings.get(tid, np.zeros(384)) for tid in df_metadata['track_id']])
scaler = StandardScaler()
audio_features = scaler.fit_transform(df_features[[col for col in df_features.columns if col != 'track_id']].values)
audio_features = np.nan_to_num(audio_features)

user_data = pd.read_csv(USER_DATA_PATH)
user_data['track_id'] = user_data['track_id'].astype(str).str.zfill(6)
user_data = user_data[user_data['track_id'].isin(df_metadata['track_id'])]
ratings_agg = user_data.groupby('track_id')['rating'].mean().reindex(df_metadata['track_id']).fillna(3.0).values
ratings = np.clip(ratings_agg / 5.0, 0.0, 1.0)

X = np.hstack([audio_features, text_features])
if X.shape[0] != len(ratings):
    logging.error("Mismatch in samples: X has %d samples, ratings has %d", X.shape[0], len(ratings))
    raise ValueError(f"Mismatch in samples: X has {X.shape[0]} samples, ratings has {len(ratings)}")

original_indices = np.arange(len(X))
binary_ratings = (ratings > 0.5).astype(int)
class_counts = Counter(binary_ratings)
min_class_size = min(class_counts.values())
logging.info("Class distribution: %s", class_counts)
print(f"Class distribution: {class_counts}")

X_resampled, ratings_resampled, original_indices_resampled = X, binary_ratings, original_indices
if min_class_size >= 2:
    k_neighbors = min(5, min_class_size - 1)
    smote = SMOTE(random_state=42, k_neighbors=k_neighbors)
    try:
        X_resampled, ratings_resampled = smote.fit_resample(X, binary_ratings)
        original_indices_resampled = np.concatenate([
            original_indices[ratings == 0],
            original_indices[ratings == 1],
            np.random.choice(original_indices[ratings == 0], size=len(X_resampled) - len(X), replace=True)
        ])
        shuffle_indices = np.random.permutation(len(X_resampled))
        X_resampled = X_resampled[shuffle_indices]
        ratings_resampled = ratings_resampled[shuffle_indices]
        original_indices_resampled = original_indices_resampled[shuffle_indices]
    except ValueError as e:
        logging.error("SMOTE failed: %s", str(e))
        print(f"SMOTE failed: {str(e)}. Using original data.")

kf = KFold(n_splits=5, shuffle=True, random_state=42)
rec_metrics_all = []
global train_indices

for fold, (train_idx_resampled, test_idx_resampled) in enumerate(kf.split(X_resampled)):
    print(f"\nFold {fold+1}")
    train_indices = original_indices_resampled[train_idx_resampled]
    X_train_rec, X_test_rec = X_resampled[train_idx_resampled], X_resampled[test_idx_resampled]
    y_train_rec, y_test_rec = ratings_resampled[train_idx_resampled], ratings_resampled[test_idx_resampled]

    train_dataset_rec = torch.utils.data.TensorDataset(
        torch.tensor(X_train_rec[:, :audio_features.shape[1]], dtype=torch.float32),
        torch.tensor(X_train_rec[:, audio_features.shape[1]:], dtype=torch.float32),
        torch.tensor(y_train_rec, dtype=torch.float32)
    )
    train_loader_rec = torch.utils.data.DataLoader(train_dataset_rec, batch_size=BATCH_SIZE, shuffle=True)
    test_dataset_rec = torch.utils.data.TensorDataset(
        torch.tensor(X_test_rec[:, :audio_features.shape[1]], dtype=torch.float32),
        torch.tensor(X_test_rec[:, audio_features.shape[1]:], dtype=torch.float32),
        torch.tensor(y_test_rec, dtype=torch.float32)
    )
    test_loader_rec = torch.utils.data.DataLoader(test_dataset_rec, batch_size=BATCH_SIZE)

    recommender = HybridRecommender(audio_dim=audio_features.shape[1], text_dim=384).to(DEVICE)
    criterion_rec = nn.BCELoss()
    optimizer_rec = torch.optim.Adam(recommender.parameters(), lr=0.001, weight_decay=1e-4)

    print("\nTraining Recommender Model...")
    for epoch in range(NUM_EPOCHS_REC):
        loss = train_recommender(recommender, train_loader_rec, criterion_rec, optimizer_rec)
        print(f"Recommendation Epoch {epoch+1}, Loss: {loss:.4f}")

    rec_metrics = evaluate_recommender(recommender, test_loader_rec, df_metadata, test_idx_resampled, original_indices_resampled, df_metadata['genre'].unique())
    rec_metrics_all.append(rec_metrics)
    print(f"Recommendation Metrics: {rec_metrics}")

avg_rec_metrics = {k: np.mean([m[k] for m in rec_metrics_all]) for k in rec_metrics_all[0]}
print(f"Average Recommendation Metrics: {avg_rec_metrics}")




INFO:root:Starting main execution...
INFO:root:Loading synthetic data...
INFO:root:Creating synthetic dataset...
INFO:root:Created output directories: /content/drive/MyDrive/fma/data/metadata, /content/drive/MyDrive/fma/data/user_data, /content/drive/MyDrive/fma/data/descriptions
INFO:root:Attempting to download FMA small dataset...
INFO:root:FMA dataset already exists at /content/drive/MyDrive/fma/data/audio_files with 10 audio files


Starting main execution...
Loading synthetic data...
Creating synthetic dataset...


INFO:root:Saved tracks metadata to /content/drive/MyDrive/fma/data/metadata/tracks.csv
INFO:root:Saved artists metadata to /content/drive/MyDrive/fma/data/metadata/artists.csv
INFO:root:Saved genres metadata to /content/drive/MyDrive/fma/data/metadata/genres.csv
INFO:root:Saved user ratings to /content/drive/MyDrive/fma/data/user_data/ratings.csv
INFO:root:Saved tags to /content/drive/MyDrive/fma/data/descriptions/tags.csv


Generating synthetic lyrics:   0%|          | 0/100 [00:00<?, ?it/s]

INFO:root:Synthetic dataset created: Tracks shape: (100, 5), Ratings: (1826, 3), Tags: 447
INFO:root:Metadata shape: (130, 4)


Synthetic dataset created: Tracks shape: (100, 5), Ratings shape: (1826, 3), Tags: 447
Metadata shape: (130, 4)


INFO:root:Features shape: (130, 27)


Features shape: (130, 27)


INFO:root:Lyrics dict size: 95
INFO:root:Generating text embeddings...
INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: all-MiniLM-L6-v2


Lyrics dict size: 95
Generating text embeddings...


Generating text embeddings:   0%|          | 0/95 [00:00<?, ?it/s]

INFO:root:Text embeddings generated for 95 tracks.


Text embeddings generated for 95 tracks.

Linguistic Analysis:


INFO:root:Jazz top words: [('jazz', 24), ('night', 15), ('flow', 13), ('saxophone', 12), ('improvised', 12)]
INFO:root:Blues top words: [('blues', 29), ('delta', 12), ('emotional', 11), ('guitar', 11), ('song', 10)]
INFO:root:Country top words: [('country', 39), ('banjo', 17), ('heart', 16), ('beats', 16), ('strong', 16)]
INFO:root:Electronic top words: [('night', 30), ('beats', 24), ('techno', 24), ('electronic', 24), ('edm', 22)]
INFO:root:Folk top words: [('folk', 26), ('acoustic', 17), ('tradition', 14), ('song', 12), ('heartfelt', 12)]
INFO:root:Experimental top words: [('experimental', 33), ('vibes', 15), ('abstract', 15), ('beats', 15), ('song', 12)]
INFO:root:International top words: [('international', 18), ('fusion', 15), ('vibes', 15), ('song', 12), ('feel', 12)]
INFO:root:Pop top words: [('pop', 54), ('neon', 28), ('love', 27), ('song', 24), ('beats', 24)]
INFO:root:Rock top words: [('rock', 33), ('indie', 16), ('soul', 15), ('heart', 15), ('song', 14)]
INFO:root:Instrumenta

Jazz top words: [('jazz', 24), ('night', 15), ('flow', 13), ('saxophone', 12), ('improvised', 12)]
Blues top words: [('blues', 29), ('delta', 12), ('emotional', 11), ('guitar', 11), ('song', 10)]
Country top words: [('country', 39), ('banjo', 17), ('heart', 16), ('beats', 16), ('strong', 16)]
Electronic top words: [('night', 30), ('beats', 24), ('techno', 24), ('electronic', 24), ('edm', 22)]
Folk top words: [('folk', 26), ('acoustic', 17), ('tradition', 14), ('song', 12), ('heartfelt', 12)]
Experimental top words: [('experimental', 33), ('vibes', 15), ('abstract', 15), ('beats', 15), ('song', 12)]
International top words: [('international', 18), ('fusion', 15), ('vibes', 15), ('song', 12), ('feel', 12)]
Pop top words: [('pop', 54), ('neon', 28), ('love', 27), ('song', 24), ('beats', 24)]
Rock top words: [('rock', 33), ('indie', 16), ('soul', 15), ('heart', 15), ('song', 14)]
Instrumental top words: [('instrumental', 45), ('soul', 19), ('orchestral', 19), ('rise', 18), ('song', 16)]
Hi

Device set to use cpu


In [None]:
cf_metrics = collaborative_filtering(user_data, df_metadata)
print(f"Collaborative Filtering Metrics: {cf_metrics}")

genres = df_metadata['genre'].unique()
genre_to_idx = {g: i for i, g in enumerate(genres)}
labels = df_metadata['genre'].map(genre_to_idx).values

cls_metrics_all = []
for fold, (train_idx_resampled, test_idx_resampled) in enumerate(kf.split(X_resampled)):
    print(f"\nFold {fold+1}")
    X_train_cls, X_test_cls = X_resampled[train_idx_resampled], X_resampled[test_idx_resampled]
    y_train_cls = labels[original_indices_resampled[train_idx_resampled]]
    y_test_cls = labels[original_indices_resampled[test_idx_resampled]]

    train_dataset_cls = torch.utils.data.TensorDataset(
        torch.tensor(X_train_cls[:, :audio_features.shape[1]], dtype=torch.float32),
        torch.tensor(X_train_cls[:, audio_features.shape[1]:], dtype=torch.float32),
        torch.tensor(y_train_cls, dtype=torch.long)
    )
    train_loader_cls = torch.utils.data.DataLoader(train_dataset_cls, batch_size=BATCH_SIZE, shuffle=True)
    test_dataset_cls = torch.utils.data.TensorDataset(
        torch.tensor(X_test_cls[:, :audio_features.shape[1]], dtype=torch.float32),
        torch.tensor(X_test_cls[:, audio_features.shape[1]:], dtype=torch.float32),
        torch.tensor(y_test_cls, dtype=torch.long)
    )
    test_loader_cls = torch.utils.data.DataLoader(test_dataset_cls, batch_size=BATCH_SIZE)

    classifier = GenreClassifier(audio_dim=audio_features.shape[1], text_dim=384, num_classes=len(genres)).to(DEVICE)
    criterion_cls = nn.CrossEntropyLoss()
    optimizer_cls = torch.optim.Adam(classifier.parameters(), lr=0.001, weight_decay=1e-4)

    print("\nTraining Genre Classifier Model...")
    for epoch in range(NUM_EPOCHS_CLS):
        loss = train_classifier(classifier, train_loader_cls, criterion_cls, optimizer_cls)
        print(f"Classification Epoch {epoch+1}, Loss: {loss:.4f}")

    cls_metrics = evaluate_classifier(classifier, test_loader_cls, genres)
    cls_metrics_all.append(cls_metrics)
    print(f"Classification Metrics: {cls_metrics}")

avg_cls_metrics = {k: np.mean([m[k] for m in cls_metrics_all]) for k in cls_metrics_all[0]}
print(f"Average Classification Metrics: {avg_cls_metrics}")

In [None]:
from transformers import BertModel

def bert_classifier(df_metadata, text_embeddings, audio_features, genres):
    print("\nTraining Audio-Text BERT Classifier...")

    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    bert_model = BertModel.from_pretrained('bert-base-uncased')  # Use BertModel here
    model = AudioTextBERTClassifier(bert_model, audio_dim=audio_features.shape[1], num_classes=len(genres)).to(DEVICE)

    texts = [f"{row['title']} {row['genre']}" for _, row in df_metadata.iterrows()]
    inputs = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=128)
    genres_list = genres.tolist()
    labels = [genres_list.index(row['genre']) for _, row in df_metadata.iterrows()]

    if len(inputs['input_ids']) != len(audio_features) or len(inputs['input_ids']) != len(labels):
        logging.error("Data mismatch: input_ids=%d, audio_features=%d, labels=%d",
                      len(inputs['input_ids']), len(audio_features), len(labels))
        raise ValueError("Data mismatch in bert_classifier inputs")

    dataset_dict = {
        'input_ids': inputs['input_ids'].numpy(),
        'attention_mask': inputs['attention_mask'].numpy(),
        'audio_features': audio_features,
        'labels': np.array(labels, dtype=np.int64)
    }
    dataset = Dataset.from_dict(dataset_dict)
    split_dataset = dataset.train_test_split(test_size=0.2, seed=42)
    dataset = DatasetDict({
        "train": split_dataset["train"],
        "test": split_dataset["test"]
    })

    expected_keys = {'input_ids', 'attention_mask', 'audio_features', 'labels'}
    actual_keys = set(dataset['train'].features.keys())
    if not expected_keys.issubset(actual_keys):
        logging.error("Dataset missing required keys: %s", actual_keys)
        raise ValueError("Dataset missing required keys")

    def collate_fn(batch):
        return {
            'input_ids': torch.tensor([item['input_ids'] for item in batch], dtype=torch.long).to(DEVICE),
            'attention_mask': torch.tensor([item['attention_mask'] for item in batch], dtype=torch.long).to(DEVICE),
            'audio_features': torch.tensor([item['audio_features'] for item in batch], dtype=torch.float32).to(DEVICE),
            'labels': torch.tensor([item['labels'] for item in batch], dtype=torch.long).to(DEVICE)
        }

    training_args = TrainingArguments(
        output_dir=os.path.join(OUTPUT_DIR, "bert-finetuned"),
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        num_train_epochs=3,
        eval_strategy="epoch",
        save_strategy="epoch",
        learning_rate=2e-5,
        save_total_limit=2,
        load_best_model_at_end=True
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset["train"],
        eval_dataset=dataset["test"],
        data_collator=collate_fn,
        compute_metrics=lambda pred: {
            'precision': precision_recall_fscore_support(pred.label_ids, pred.predictions.argmax(-1), average='weighted', zero_division=0)[0],
            'recall': precision_recall_fscore_support(pred.label_ids, pred.predictions.argmax(-1), average='weighted', zero_division=0)[1],
            'f1': precision_recall_fscore_support(pred.label_ids, pred.predictions.argmax(-1), average='weighted', zero_division=0)[2]
        }
    )

    trainer.train()
    eval_results = trainer.evaluate()
    return {
        'precision': eval_results.get('eval_precision', 0.0),
        'recall': eval_results.get('eval_recall', 0.0),
        'f1': eval_results.get('eval_f1', 0.0)
    }



def roberta_classifier(df_metadata, text_embeddings, genres):
    """Task A1.2: RoBERTa-based genre classification."""
    print("\nTraining RoBERTa Classifier...")
    tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
    model = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=len(genres)).to(DEVICE)

    texts = [f"{row['title']} {row['genre']}" for _, row in df_metadata.iterrows()]
    inputs = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=128).to(DEVICE)
    genres_list = genres.tolist()
    labels = torch.tensor([genres_list.index(row['genre']) for _, row in df_metadata.iterrows()]).to(DEVICE)

    dataset = Dataset.from_dict({
        'input_ids': inputs['input_ids'].cpu().numpy(),
        'attention_mask': inputs['attention_mask'].cpu().numpy(),
        'labels': labels.cpu().numpy()
    })
    dataset = dataset.train_test_split(test_size=0.2, seed=42)

    training_args = TrainingArguments(
        output_dir=os.path.join(OUTPUT_DIR, "roberta-finetuned"),
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        num_train_epochs=3,
        eval_strategy="epoch",
        save_strategy="epoch",
        learning_rate=2e-5,
        save_total_limit=2,
        load_best_model_at_end=True
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset["train"],
        eval_dataset=dataset["test"],
        compute_metrics=lambda pred: {'precision': precision_recall_fscore_support(pred.label_ids, pred.predictions.argmax(-1), average='weighted', zero_division=0)[0],
                                     'recall': precision_recall_fscore_support(pred.label_ids, pred.predictions.argmax(-1), average='weighted', zero_division=0)[1],
                                     'f1': precision_recall_fscore_support(pred.label_ids, pred.predictions.argmax(-1), average='weighted', zero_division=0)[2]}
    )
    trainer.train()
    eval_results = trainer.evaluate()
    return {'precision': eval_results['eval_precision'], 'recall': eval_results['eval_recall'], 'f1': eval_results['eval_f1']}


In [None]:
bert_metrics = bert_classifier(df_metadata, text_embeddings, audio_features, genres)
print(f"BERT Classifier Metrics: {bert_metrics}")

roberta_metrics = roberta_classifier(df_metadata, text_embeddings, genres)
print(f"RoBERTa Classifier Metrics: {roberta_metrics}")

audio_files = [os.path.join(AUDIO_PATH, f"{tid}.mp3") for tid in df_metadata['track_id']]
ast_metrics = ast_classifier(audio_files, labels, genres, df_metadata)
print(f"AST vs Custom Classifier: AST F1={ast_metrics['f1']:.4f}, Custom F1={avg_cls_metrics['f1']:.4f}")

search_model = SentenceTransformer('all-MiniLM-L6-v2', device=DEVICE)
query = "upbeat rock songs"
_, search_metrics = music_search(query, df_metadata, text_embeddings, df_features, search_model, scaler)
print(f"Search Metrics: {search_metrics}")