<a href="https://colab.research.google.com/github/Viditk07-Bits/AudioAnalytics_S2-24_AIMLCZG527/blob/main/AA_Assignment2_latest_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Music Information Retrieval System

## Assignment Objective
This assignment implements a comprehensive Music Information Retrieval (MIR) system using Large Language Models (LLMs) and deep learning techniques. It includes music recommendation, genre classification, and semantic search applications, combining audio analysis with natural language processing.

## Dataset Setup
Using the Free Music Archive (FMA) dataset with audio files, metadata, and synthetic user data.

In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from sentence_transformers import SentenceTransformer
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.spatial.distance import cosine
import os
from pathlib import Path
import librosa
import hashlib
import subprocess
from google.colab import drive
from librosa.feature.rhythm import tempo as librosa_tempo

# Mount Google Drive (optional, for persistent storage)
drive.mount('/content/drive', force_remount=True)

# Constants
DATA_PATH = "/content/fma/data"
AUDIO_PATH = os.path.join(DATA_PATH, "audio_files/")
METADATA_PATH = os.path.join(DATA_PATH, "metadata/tracks.csv")
ARTISTS_PATH = os.path.join(DATA_PATH, "metadata/artists.csv")
GENRES_PATH = os.path.join(DATA_PATH, "metadata/genres.csv")
LYRICS_PATH = os.path.join(DATA_PATH, "lyrics/")
USER_DATA_PATH = os.path.join(DATA_PATH, "user_data/ratings.csv")
TAGS_PATH = os.path.join(DATA_PATH, "descriptions/tags.csv")
OUTPUT_DIR = "outputs/"
TEMP_DIR = "/content/fma"
NUM_EPOCHS = 10
BATCH_SIZE = 32
MAX_TRACKS = 1000  # Limit for faster testing
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Ensure output directory exists
os.makedirs(OUTPUT_DIR, exist_ok=True)

def setup_fma_dataset():
    """Download and set up FMA dataset, moving files to expected structure."""
    os.makedirs(TEMP_DIR, exist_ok=True)
    os.chdir(TEMP_DIR)

    # Step 1: Clone FMA GitHub repository
    if not os.path.exists(os.path.join(TEMP_DIR, "fma")):
        print("Cloning FMA repository...")
        subprocess.run(["git", "clone", "https://github.com/mdeff/fma.git"], check=True)
    else:
        print("FMA repository already exists.")
    os.chdir(os.path.join(TEMP_DIR, "fma"))

    # Step 2: Download fma_small.zip and fma_metadata.zip
    fma_small_zip = "fma_small.zip"
    fma_metadata_zip = "fma_metadata.zip"

    if os.path.exists(fma_small_zip):
        os.remove(fma_small_zip)
        print(f"Removed existing {fma_small_zip}")
    if os.path.exists(fma_metadata_zip):
        os.remove(fma_metadata_zip)
        print(f"Removed existing {fma_metadata_zip}")

    print("Downloading fma_small.zip...")
    subprocess.run(["wget", "-O", fma_small_zip, "https://os.unil.cloud.switch.ch/fma/fma_small.zip"], check=True)
    print("Downloading fma_metadata.zip...")
    subprocess.run(["wget", "-O", fma_metadata_zip, "https://os.unil.cloud.switch.ch/fma/fma_metadata.zip"], check=True)

    # Step 3: Verify SHA1 checksums
    def sha1_checksum(file_path):
        sha1 = hashlib.sha1()
        with open(file_path, 'rb') as f:
            while chunk := f.read(8192):
                sha1.update(chunk)
        return sha1.hexdigest()

    assert sha1_checksum("fma_small.zip") == "ade154f733639d52e35e32f5593efe5be76c6d70", "fma_small.zip checksum failed!"
    assert sha1_checksum("fma_metadata.zip") == "f0df49ffe5f2a6008d7dc83c6915b31835dfe733", "fma_metadata.zip checksum failed!"
    print("✅ SHA1 checksums verified.")

    # Step 4: Unzip files
    os.makedirs(DATA_PATH, exist_ok=True)
    if not os.path.exists(os.path.join(DATA_PATH, "fma_small")):
        print("Unzipping fma_small.zip...")
        subprocess.run(["unzip", "-q", "fma_small.zip", "-d", DATA_PATH], check=True)
    else:
        print("fma_small already unzipped.")

    if not os.path.exists(os.path.join(DATA_PATH, "fma_metadata")):
        print("Unzipping fma_metadata.zip...")
        subprocess.run(["unzip", "-q", "fma_metadata.zip", "-d", DATA_PATH], check=True)
    else:
        print("fma_metadata already unzipped.")

    # Step 5: Move .mp3 files to audio_files/
    os.makedirs(AUDIO_PATH, exist_ok=True)
    print("Moving MP3 files...")
    for mp3_file in Path(DATA_PATH).rglob("*.mp3"):
        target = os.path.join(AUDIO_PATH, mp3_file.name)
        if not os.path.exists(target):
            os.rename(mp3_file, target)
    print("MP3 files moved.")

    # Step 6: Process metadata
    print("Processing metadata...")
    tracks = pd.read_csv(os.path.join(DATA_PATH, "fma_metadata", "tracks.csv"), index_col=0, header=[0, 1])
    genres = pd.read_csv(os.path.join(DATA_PATH, "fma_metadata", "genres.csv"))

    # Create artists.csv
    df_artists = tracks['artist'][['name']].reset_index().rename(columns={'track_id': 'artist_id', 'name': 'artist_name'})
    df_artists['artist_id'] = df_artists['artist_id'].astype(str).str.zfill(6)
    os.makedirs(os.path.dirname(ARTISTS_PATH), exist_ok=True)
    df_artists.to_csv(ARTISTS_PATH, index=False)

    # Create genres.csv
    df_genres = genres[['genre_id', 'title']].rename(columns={'title': 'genre_name'})
    os.makedirs(os.path.dirname(GENRES_PATH), exist_ok=True)
    df_genres.to_csv(GENRES_PATH, index=False)

    # Adapt tracks.csv
    df_tracks = tracks['track'][['title', 'genre_top']].reset_index()
    df_tracks['track_id'] = df_tracks['track_id'].astype(str).str.zfill(6)
    df_tracks['artist_id'] = df_tracks['track_id']  # FMA doesn't provide artist_id, use track_id as proxy
    df_tracks['genre_id'] = df_tracks['genre_top'].map(df_genres.set_index('genre_name')['genre_id'])
    df_tracks = df_tracks[['track_id', 'title', 'artist_id', 'genre_id']].dropna()
    df_tracks = df_tracks.head(MAX_TRACKS)  # Limit tracks for faster processing
    os.makedirs(os.path.dirname(METADATA_PATH), exist_ok=True)
    df_tracks.to_csv(METADATA_PATH, index=False)

    # Create synthetic ratings.csv
    os.makedirs(os.path.dirname(USER_DATA_PATH), exist_ok=True)
    ratings = pd.DataFrame({
        'user_id': ['user_001'] * len(df_tracks),
        'track_id': df_tracks['track_id'],
        'rating': np.random.uniform(0.1, 1.0, len(df_tracks))
    })
    ratings.to_csv(USER_DATA_PATH, index=False)

    # Create empty lyrics/ and descriptions/
    os.makedirs(LYRICS_PATH, exist_ok=True)
    os.makedirs(os.path.dirname(TAGS_PATH), exist_ok=True)
    pd.DataFrame({'track_id': df_tracks['track_id'], 'tag': ['music'] * len(df_tracks)}).to_csv(TAGS_PATH, index=False)

    print("🎵 Metadata and audio files are ready.")
    print(f"Tracks shape: {df_tracks.shape}")
    print(f"Genres shape: {df_genres.shape}")
    print(f"Ratings shape: {ratings.shape}")

def extract_audio_features(audio_path):
    """Extract audio features (MFCCs, chroma, spectral features, tempo) from an MP3 file using Librosa."""
    try:
        y, sr = librosa.load(audio_path, sr=22050)
        mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=20)
        chroma = librosa.feature.chroma_stft(y=y, sr=sr)
        spectral_centroid = librosa.feature.spectral_centroid(y=y, sr=sr)
        spectral_contrast = librosa.feature.spectral_contrast(y=y, sr=sr)
        tempo = librosa_tempo(y=y, sr=sr)[0]
        tempo = tempo[0] if isinstance(tempo, np.ndarray) else tempo

        return np.concatenate([
            np.mean(mfccs, axis=1),
            np.mean(chroma, axis=1),
            np.mean(spectral_centroid, axis=1),
            np.mean(spectral_contrast, axis=1),
            [tempo]  # Ensure this is a 1D list
        ])

    except Exception as e:
        print(f"Error processing {audio_path}: {e}")
        return np.zeros(40)  # 20 MFCCs + 12 chroma + 1 centroid + 7 contrast + 1 tempo

def load_fma_data(audio_path, metadata_path, artists_path, genres_path, lyrics_path=None, tags_path=None):
    """Load and preprocess FMA dataset from MP3 files, metadata, and lyrics."""
    if not os.path.exists(audio_path):
        print(f"Error: Audio directory {audio_path} does not exist.")
        return pd.DataFrame(), pd.DataFrame(), {}

    # Load metadata
    try:
        df_tracks = pd.read_csv(metadata_path)
        df_metadata = df_tracks[['track_id', 'title', 'artist_id', 'genre_id']].dropna()
        if os.path.exists(artists_path):
            df_artists = pd.read_csv(artists_path)[['artist_id', 'artist_name']]
            df_metadata = pd.merge(df_metadata, df_artists, on='artist_id', how='left')
        else:
            df_metadata['artist_name'] = 'Unknown Artist'

        if os.path.exists(genres_path):
            df_genres = pd.read_csv(genres_path)[['genre_id', 'genre_name']]
            df_metadata = pd.merge(df_metadata, df_genres, on='genre_id', how='left')
        else:
            df_metadata['genre_name'] = 'unknown'

        df_metadata = df_metadata[['track_id', 'artist_name', 'title', 'genre_name']].dropna()
        df_metadata.columns = ['track_id', 'artist_name', 'title', 'genre']
        df_metadata['track_id'] = df_metadata['track_id'].astype(str).str.zfill(6)
        print(f"Initial metadata shape: {df_metadata.shape}")
    except Exception as e:
        print(f"Error loading metadata: {e}. Creating synthetic metadata.")
        df_metadata = pd.DataFrame(columns=['track_id', 'artist_name', 'title', 'genre'])

    # Extract audio features
    features = []
    audio_files = list(Path(audio_path).glob("*.mp3"))
    print(f"Found {len(audio_files)} audio files.")
    valid_track_ids = df_metadata['track_id'].tolist()
    audio_files_to_process = []
    processed_count = 0
    for audio_file in audio_files:
        track_id = audio_file.stem
        if track_id in valid_track_ids:
            audio_files_to_process.append(audio_file)
            processed_count += 1
            if processed_count >= MAX_TRACKS:
                break

    print(f"Processing features for {len(audio_files_to_process)} relevant audio files.")
    for audio_file in audio_files_to_process:
        track_id = audio_file.stem
        audio_features = extract_audio_features(audio_file)
        if audio_features is not None and audio_features.shape[0] > 0:
            features.append([track_id] + audio_features.tolist())
        else:
            print(f"Skipping {track_id} due to feature extraction error.")

    feature_columns = ['track_id'] + [f'mfcc_{i+1}' for i in range(20)] + [f'chroma_{i+1}' for i in range(12)] + \
                     ['spectral_centroid'] + [f'spectral_contrast_{i+1}' for i in range(7)] + ['tempo']
    df_features = pd.DataFrame(features, columns=feature_columns).dropna()
    print(f"Extracted features for {len(df_features)} tracks")

    # Filter metadata to match available audio features
    df_metadata = pd.merge(df_metadata, df_features[['track_id']], on='track_id', how='inner')
    print(f"Filtered metadata shape: {df_metadata.shape}")

    # Load lyrics and tags
    lyrics_dict = {}
    if lyrics_path and os.path.exists(lyrics_path):
        for lyric_file in Path(lyrics_path).glob("*.txt"):
            track_id = lyric_file.stem
            if track_id in df_metadata['track_id'].values:
                with open(lyric_file, 'r', encoding='utf-8') as f:
                    lyrics_dict[track_id] = f.read().strip()

    if tags_path and os.path.exists(tags_path):
        try:
            df_tags = pd.read_csv(tags_path)
            for _, row in df_tags.iterrows():
                track_id = str(row['track_id']).zfill(6)
                if track_id in df_metadata['track_id'].values:
                    tag = str(row['tag'])
                    if track_id in lyrics_dict:
                        lyrics_dict[track_id] += " " + tag
                    else:
                        lyrics_dict[track_id] = tag
        except Exception as e:
            print(f"Error loading tags: {e}")

    return df_metadata, df_features, lyrics_dict

def generate_text_embeddings(lyrics_dict):
    """Generate embeddings for lyrics and tags using Sentence-Transformers on GPU."""
    model = SentenceTransformer('all-MiniLM-L6-v2', device=DEVICE)
    embeddings = {}
    for track_id, text in lyrics_dict.items():
        embeddings[track_id] = model.encode(text, convert_to_tensor=True, device=DEVICE).cpu().numpy()
    return embeddings

# Original GenreClassifier and train/evaluate functions
class GenreClassifier(nn.Module):
    def __init__(self, audio_dim, text_dim, num_classes, hidden_dim=128):
        super(GenreClassifier, self).__init__()
        self.audio_layer = nn.Linear(audio_dim, hidden_dim)
        self.text_layer = nn.Linear(text_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim * 2, num_classes)

    def forward(self, audio_features, text_features):
        audio_out = torch.relu(self.audio_layer(audio_features))
        text_out = torch.relu(self.text_layer(text_features))
        combined = torch.cat([audio_out, text_out], dim=1)
        return self.fc(combined)

def train_classifier(model, train_loader, criterion, optimizer):
    model.train()
    total_loss = 0
    for audio_features, text_features, labels in train_loader:
        audio_features, text_features, labels = (
            audio_features.to(DEVICE),
            text_features.to(DEVICE),
            labels.to(DEVICE)
        )
        optimizer.zero_grad()
        outputs = model(audio_features, text_features)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

def evaluate_classifier(model, test_loader, genres):
    model.eval()
    y_true, y_pred = [], []
    with torch.no_grad():
        for audio_features, text_features, labels in test_loader:
            audio_features, text_features, labels = (
                audio_features.to(DEVICE),
                text_features.to(DEVICE),
                labels.to(DEVICE)
            )
            outputs = model(audio_features, text_features)
            _, predicted = torch.max(outputs, 1)
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(predicted.cpu().numpy())

    precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='weighted')
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=genres, yticklabels=genres)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.savefig(os.path.join(OUTPUT_DIR, 'confusion_matrix.png'))
    plt.close()
    return precision, recall, f1

def train_recommender(model, train_loader, criterion, optimizer):
    model.train()
    total_loss = 0
    for audio_features, text_features, ratings in train_loader:
        audio_features, text_features, ratings = (
            audio_features.to(DEVICE),
            text_features.to(DEVICE),
            ratings.to(DEVICE)
        )
        optimizer.zero_grad()
        outputs = model(audio_features, text_features)
        loss = criterion(outputs, ratings)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

def evaluate_recommender(model, test_loader):
    model.eval()
    precisions = []
    with torch.no_grad():
        for audio_features, text_features, ratings in test_loader:
            audio_features, text_features, ratings = (
                audio_features.to(DEVICE),
                text_features.to(DEVICE),
                ratings.to(DEVICE)
            )
            outputs = model(audio_features, text_features)
            k = min(10, outputs.size(1))
            if k == 0:
                print("Warning: No tracks available for top-k selection")
                continue
            top_k = torch.topk(outputs, k=k, dim=1).indices
            relevant = (ratings.gather(1, top_k) > 0.5).float()
            precision = relevant.mean().item()
            precisions.append(precision)
    return np.mean(precisions) if precisions else 0.0

Mounted at /content/drive


In [None]:

# 2.1.1 Audio Feature Integration with LLMs
class AudioFeatureExtractor:
    def __init__(self, sample_rate=22050):
        self.sample_rate = sample_rate

    def extract_features(self, audio_path):
        try:
            y, sr = librosa.load(audio_path, sr=self.sample_rate)
            mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=20)
            spectral_centroid = librosa.feature.spectral_centroid(y=y, sr=sr)
            spectral_contrast = librosa.feature.spectral_contrast(y=y, sr=sr)
            chroma = librosa.feature.chroma_stft(y=y, sr=sr)
            tempo, _ = librosa.beat.beat_track(y=y, sr=sr)
            return np.concatenate([
                np.mean(mfcc, axis=1),
                np.mean(spectral_centroid, axis=1),
                np.mean(spectral_contrast, axis=1),
                np.mean(chroma, axis=1),
                [tempo]
            ])
        except Exception as e:
            print(f"Error extracting features from {audio_path}: {e}")
            return np.zeros(40)  # 20 MFCCs + 1 centroid + 7 contrast + 12 chroma + 1 tempo

class AudioEmbedding(nn.Module):
    def __init__(self, input_dim, hidden_dim=128):
        super(AudioEmbedding, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )

    def forward(self, x):
        return self.network(x)

class CrossModalAttention(nn.Module):
    def __init__(self, audio_dim, text_dim, attn_dim=128):
        super(CrossModalAttention, self).__init__()
        self.audio_proj = nn.Linear(audio_dim, attn_dim)
        self.text_proj = nn.Linear(text_dim, attn_dim)
        self.attention = nn.MultiheadAttention(embed_dim=attn_dim, num_heads=4, batch_first=True)

    def forward(self, audio_emb, text_emb):
        audio_emb = self.audio_proj(audio_emb).unsqueeze(1)  # (B, 1, attn_dim)
        text_emb = self.text_proj(text_emb).unsqueeze(1)     # (B, 1, attn_dim)
        attn_output, _ = self.attention(audio_emb, text_emb, text_emb)
        return attn_output.squeeze(1)

def analyze_feature_contribution(model, features, labels):
    contributions = {}
    feature_groups = {
        'mfcc': slice(0, 20),
        'spectral_centroid': slice(20, 21),
        'spectral_contrast': slice(21, 28),
        'chroma': slice(28, 40),
        'tempo': slice(40, 41)
    }

    with torch.no_grad():
        base_pred = model(features).cpu().numpy()
        base_pred_labels = np.argmax(base_pred, axis=1)
        base_error = np.mean((base_pred_labels - labels.cpu().numpy()) ** 2)

        for feature_name, feature_slice in feature_groups.items():
            temp_features = features.clone()
            temp_features[:, feature_slice] = 0
            pred = model(temp_features).cpu().numpy()
            pred_labels = np.argmax(pred, axis=1)
            error = np.mean((pred_labels - labels.cpu().numpy()) ** 2)
            contributions[feature_name] = error - base_error

    return contributions

In [None]:
# 2.1.2 Text-Based Genre Classification
class TextGenreClassifier(nn.Module):
    def __init__(self, model_name='all-MiniLM-L6-v2', num_classes=10):
        super(TextGenreClassifier, self).__init__()
        self.model = SentenceTransformer(model_name, device=DEVICE)
        self.classifier = nn.Linear(384, num_classes)  # MiniLM-L6-v2 has 384-dim embeddings

    def forward(self, lyrics):
        embeddings = self.model.encode(lyrics, convert_to_tensor=True, device=DEVICE)
        return self.classifier(embeddings)

    def fine_tune(self, lyrics_data, labels, epochs=3):
        self.train()
        optimizer = torch.optim.Adam(self.classifier.parameters(), lr=2e-5)
        criterion = nn.CrossEntropyLoss()
        for epoch in range(epochs):
            optimizer.zero_grad()
            outputs = self(lyrics_data)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            print(f"Fine-tuning Epoch {epoch+1}, Loss: {loss.item():.4f}")

    def zero_shot_classify(self, lyrics, genre_list):
        lyrics_emb = self.model.encode([lyrics], convert_to_tensor=True, device=DEVICE)[0]
        genre_embs = self.model.encode(genre_list, convert_to_tensor=True, device=DEVICE)
        similarities = [1 - cosine(lyrics_emb.cpu().numpy(), genre_emb.cpu().numpy()) for genre_emb in genre_embs]
        return genre_list[np.argmax(similarities)]

    def analyze_linguistic_patterns(self, lyrics_data, genres):
        embeddings = self.model.encode(lyrics_data, convert_to_tensor=True, device=DEVICE).cpu().numpy()
        patterns = {}
        for genre in set(genres):
            genre_indices = [i for i, g in enumerate(genres) if g == genre]
            genre_embs = embeddings[genre_indices]
            patterns[genre] = np.mean(genre_embs, axis=0)
        return patterns

In [None]:
# 2.1.3 Hybrid Multi-Modal Classification
class HybridGenreClassifier(nn.Module):
    def __init__(self, audio_dim, text_dim, num_genres):
        super(HybridGenreClassifier, self).__init__()
        self.audio_emb = AudioEmbedding(audio_dim, 128)
        self.text_model = SentenceTransformer('all-MiniLM-L6-v2', device=DEVICE)
        self.attention = CrossModalAttention(audio_dim=128, text_dim=text_dim, attn_dim=128)
        self.classifier = nn.Linear(128 + text_dim + text_dim, num_genres)  # 128 from attention, 384 text, 384 metadata

    def forward(self, audio_features, lyrics, metadata):
        audio_emb = self.audio_emb(audio_features)
        text_emb = self.text_model.encode(lyrics, convert_to_tensor=True, device=DEVICE)
        metadata_emb = self.text_model.encode(
            [f"{m[0]} {m[1]}" for m in metadata],
            convert_to_tensor=True,
            device=DEVICE
        )
        fused = self.attention(audio_emb, text_emb)
        combined = torch.cat([fused, text_emb, metadata_emb], dim=-1)
        return self.classifier(combined)

    def get_confidence_scores(self, outputs):
        return torch.softmax(outputs, dim=-1)

def compare_with_audio_only(hybrid_model, audio_only_model, test_data, lyrics, metadata, labels):
    hybrid_preds = hybrid_model(test_data, lyrics, metadata)
    dummy_text_features = torch.zeros((test_data.shape[0], 384), dtype=torch.float32, device=test_data.device)
    audio_preds = audio_only_model(test_data, dummy_text_features)
    hybrid_metrics = precision_recall_fscore_support(labels.cpu().numpy(), torch.argmax(hybrid_preds, dim=1).cpu().numpy(), average='weighted')
    audio_metrics = precision_recall_fscore_support(labels.cpu().numpy(), torch.argmax(audio_preds, dim=1).cpu().numpy(), average='weighted')
    return hybrid_metrics, audio_metrics

In [None]:
# 2.2 Transformer-Based Audio Classification
class AudioSpectrogramTransformer(nn.Module):
    def __init__(self, patch_size=16, in_channels=1, embed_dim=768, num_heads=12, num_layers=12, num_classes=10):
        super(AudioSpectrogramTransformer, self).__init__()
        self.patch_embedding = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads), num_layers=num_layers
        )
        self.classifier = nn.Linear(embed_dim, num_classes)

    def forward(self, spectrogram):
        patches = self.patch_embedding(spectrogram)
        patches = patches.flatten(2).transpose(1, 2)
        transformer_output = self.transformer(patches)
        return self.classifier(transformer_output[:, 0])

def visualize_attention_patterns(model, spectrogram):
    with torch.no_grad():
        patches = model.patch_embedding(spectrogram).flatten(2).transpose(1, 2)
        attention = model.transformer.layers[-1].self_attn(patches, patches, patches)[1]
        plt.figure(figsize=(10, 8))
        sns.heatmap(attention[0].cpu().numpy(), cmap='viridis')
        plt.title('Attention Patterns')
        plt.savefig(os.path.join(OUTPUT_DIR, 'attention_patterns.png'))
        plt.close()

class CNNBaseline(nn.Module):
    def __init__(self, num_classes=10):
        super(CNNBaseline, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 64, 3),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.fc = nn.Linear(128 * 30 * 30, num_classes)  # Adjusted for 128x128 input

    def forward(self, x):
        x = self.conv(x)
        x = x.flatten(1)
        return self.fc(x)


In [None]:
# 3.1 Semantic Music Search
from sentence_transformers import SentenceTransformer
from scipy.spatial.distance import cosine
import numpy as np

class MusicSearchSystem:
    def __init__(self, metadata, audio_embeddings, text_embeddings):
        self.metadata = metadata
        self.audio_embeddings = audio_embeddings  # Shape: (n_tracks, 384)
        self.text_embeddings = text_embeddings    # Shape: (n_tracks, 384)
        self.text_model = SentenceTransformer('all-MiniLM-L6-v2')

    def process_query(self, query):
        query_emb = self.text_model.encode([query], convert_to_tensor=True, device=DEVICE)[0]
        intent = torch.softmax(self.intent_classifier(query_emb), dim=-1)
        expanded_query = self.expand_query(query)
        return intent, expanded_query

    def expand_query(self, query):
        synonyms = self.text_model.encode([f'similar to {query}', f'like {query}'], convert_to_tensor=True, device=DEVICE)
        return synonyms.mean(dim=0)

    def multi_modal_search(self, query, top_k=10):
            # Encode query
            query_emb = self.text_model.encode(query, convert_to_numpy=True)  # Shape: (384,)

            # Compute similarities (cosine similarity for both audio and text)
            audio_similarities = np.array([1 - cosine(query_emb, audio_emb) for audio_emb in self.audio_embeddings])
            text_similarities = np.array([1 - cosine(query_emb, text_emb) for text_emb in self.text_embeddings])

            # Combine similarities (e.g., weighted average)
            combined_similarities = 0.5 * audio_similarities + 0.5 * text_similarities

            # Get top-k indices
            top_k_indices = np.argsort(combined_similarities)[::-1][:top_k]

            # Return top-k tracks
            results = [
                {
                    "track_id": self.metadata.iloc[i]['track_id'],
                    "artist_name": self.metadata.iloc[i]['artist_name'],
                    "title": self.metadata.iloc[i]['title'],
                    "genre": self.metadata.iloc[i]['genre']
                }
                for i in top_k_indices
            ]
            return results

    def metadata_similarity(self, query_emb, metadata):
        metadata_emb = self.text_model.encode([f"{metadata[0]} {metadata[1]}"], convert_to_tensor=True, device=DEVICE)[0].cpu().numpy()
        return 1 - cosine(query_emb, metadata_emb)

# 3.1.3 Content-Based Music Discovery
class MusicDiscovery:
    def __init__(self):
        self.text_model = SentenceTransformer('all-MiniLM-L6-v2', device=DEVICE)
        self.audio_model = AudioSpectrogramTransformer()
        self.audio_extractor = AudioFeatureExtractor()
        self.mood_classifier = nn.Linear(384, 5)  # 5 mood classes
        self.energy_predictor = nn.Linear(40, 1)
        self.danceability_predictor = nn.Linear(40, 1)

    def generate_tags(self, audio_path, lyrics):
        audio_features = self.audio_extractor.extract_features(audio_path)
        audio_features = torch.tensor(audio_features, dtype=torch.float32, device=DEVICE)
        lyrics_emb = self.text_model.encode([lyrics], convert_to_tensor=True, device=DEVICE)[0]
        mood_scores = torch.softmax(self.mood_classifier(lyrics_emb), dim=-1)
        energy = torch.sigmoid(self.energy_predictor(audio_features))
        danceability = torch.sigmoid(self.danceability_predictor(audio_features))
        return {'mood': mood_scores.cpu().numpy(), 'energy': energy.cpu().numpy(), 'danceability': danceability.cpu().numpy()}

    def generate_playlist(self, seed_song, music_collection, df_metadata):
        seed_emb = self.get_song_embedding(seed_song)
        similarities = [1 - cosine(seed_emb, self.get_song_embedding(song)) for song in music_collection]
        top_indices = np.argsort(similarities)[::-1][:10]
        return df_metadata.iloc[top_indices][['track_id', 'artist_name', 'title', 'genre']]

    def get_song_embedding(self, song):
        audio_features = self.audio_extractor.extract_features(song['audio_path'])
        lyrics_emb = self.text_model.encode([song['lyrics']], convert_to_tensor=True, device=DEVICE)[0].cpu().numpy()
        return np.concatenate([audio_features, lyrics_emb])

In [34]:
# 4.1 Personalized Recommendation Engines
class MusicRecommender(nn.Module):
    def __init__(self, audio_dim, text_dim, num_tracks):
        super().__init__()
        self.audio_embed = AudioEmbedding(audio_dim)
        self.text_projector = nn.Linear(text_dim, 32)
        self.user_embed = nn.Embedding(num_tracks, 32)
        self.classifier = nn.Linear(64, 1)

    def forward(self, audio_features, text_features, user_indices):
        audio_emb = self.audio_embed(audio_features)
        text_emb = self.text_projector(text_features)
        user_emb = self.user_embed(user_indices)
        combined = torch.cat([audio_emb, user_emb], dim=1)
        return torch.sigmoid(self.classifier(combined))

    def build_user_profile(self, user_id, ratings):
        indices = ratings[ratings['user_id'] == user_id].index
        if len(indices) == 0:
            return torch.zeros(32, device=DEVICE)
        return torch.mean(self.user_embed(torch.tensor(indices, device=DEVICE)), dim=0)

    def generate_explanation(self, user_id, item_id, data):
        profile = self.build_user_profile(user_id, data)
        item_data = data[data['track_id'] == item_id]
        if item_data.empty:
            return f"No data available for item {item_id}."
        return f"Recommended because {user_id} enjoys {item_data['genre'].iloc[0]} music."

    def get_item_features(self, item_id):
        # Placeholder, to be replaced with actual feature extraction
        return torch.zeros(128, device=DEVICE)

class HybridRecommender(nn.Module):
    def __init__(self, audio_dim, text_dim, num_tracks):
        super(HybridRecommender, self).__init__()
        self.text_model = SentenceTransformer('all-MiniLM-L6-v2', device=DEVICE)
        self.audio_layer = nn.Linear(audio_dim, 128)
        self.text_layer = nn.Linear(text_dim, 128)
        self.fc = nn.Linear(128 * 2, num_tracks)

    def forward(self, audio_features, text_features):
        audio_out = torch.relu(self.audio_layer(audio_features))
        text_out = torch.relu(self.text_layer(text_features))
        combined = torch.cat([audio_out, text_out], dim=1)
        return torch.sigmoid(self.fc(combined))

    def recommend(self, user_id, context, audio_features, text_features, df_metadata):
        context_emb = self.text_model.encode([context], convert_to_tensor=True, device=DEVICE)
        scores = self.forward(audio_features, text_features + context_emb)
        top_indices = torch.argsort(scores, dim=1, descending=True)[:, :10].cpu().numpy().flatten()
        return df_metadata.iloc[top_indices][['track_id', 'artist_name', 'title', 'genre']]

    def optimize_diversity(self, recommendations, df_metadata):
        genre_counts = recommendations['genre'].value_counts()
        diversity_score = len(genre_counts) / len(recommendations)
        if diversity_score < 0.5:
            diverse_indices = []
            for genre in df_metadata['genre'].unique():
                genre_recs = recommendations[recommendations['genre'] == genre]
                if not genre_recs.empty:
                    diverse_indices.append(genre_recs.index[0])
            return df_metadata.iloc[diverse_indices]
        return recommendations

In [None]:
# 5.1 Comprehensive Evaluation Framework
class EvaluationFramework:
    def __init__(self):
        self.metrics = {}

    def evaluate_classification(self, y_true, y_pred, genres):
        precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='weighted')
        cm = confusion_matrix(y_true, y_pred)
        plt.figure(figsize=(10, 8))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=genres, yticklabels=genres)
        plt.title('Confusion Matrix')
        plt.savefig(os.path.join(OUTPUT_DIR, 'classification_confusion_matrix.png'))
        plt.close()
        return {'precision': precision, 'recall': recall, 'f1': f1}

    def evaluate_retrieval(self, relevant_items, retrieved_items, k_values=[5, 10, 20]):
        metrics = {}
        for k in k_values:
            k = min(k, len(retrieved_items))
            precision = len(set(retrieved_items[:k]) & set(relevant_items)) / k if k > 0 else 0
            recall = len(set(retrieved_items[:k]) & set(relevant_items)) / len(relevant_items) if relevant_items else 0
            metrics[f'P@{k}'] = precision
            metrics[f'R@{k}'] = recall
        ap = self.calculate_map(relevant_items, retrieved_items)
        ndcg = self.calculate_ndcg(relevant_items, retrieved_items)
        metrics['MAP'] = ap
        metrics['NDCG'] = ndcg
        return metrics

    def evaluate_recommendation(self, recommendations, user_interactions):
        ctr = sum(1 for rec in recommendations['track_id'] if rec in user_interactions) / len(recommendations) if len(recommendations) > 0 else 0
        diversity = len(recommendations['genre'].unique()) / len(recommendations) if len(recommendations) > 0 else 0
        novelty = 1 - sum(1 for rec in recommendations['track_id'] if rec in user_interactions) / len(recommendations) if len(recommendations) > 0 else 0
        return {'CTR': ctr, 'Diversity': diversity, 'Novelty': novelty}

    def calculate_map(self, relevant, retrieved):
        ap = 0
        relevant_set = set(relevant)
        for i, item in enumerate(retrieved):
            if item in relevant_set:
                ap += len(set(retrieved[:i+1]) & relevant_set) / (i + 1)
        return ap / len(relevant) if relevant else 0

    def calculate_ndcg(self, relevant, retrieved):
        dcg = 0
        idcg = sum(1 / np.log2(i + 2) for i in range(len(relevant)))
        for i, item in enumerate(retrieved):
            if item in relevant:
                dcg += 1 / np.log2(i + 2)
        return dcg / idcg if idcg > 0 else 0

Impl latest

In [35]:
import os
import numpy as np
import pandas as pd
from pathlib import Path
import librosa
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
from sentence_transformers import SentenceTransformer
import zipfile
import shutil
import random
import urllib.request
import hashlib
import logging

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Constants
DATA_PATH = '/content/fma/data'
AUDIO_PATH = os.path.join(DATA_PATH, 'audio_files')
METADATA_PATH = os.path.join(DATA_PATH, 'metadata', 'tracks.csv')
ARTISTS_PATH = os.path.join(DATA_PATH, 'metadata', 'artists.csv')
GENRES_PATH = os.path.join(DATA_PATH, 'metadata', 'genres.csv')
LYRICS_PATH = os.path.join(DATA_PATH, 'lyrics')
TAGS_PATH = os.path.join(DATA_PATH, 'metadata', 'tags.csv')
USER_DATA_PATH = os.path.join(DATA_PATH, 'user_data', 'ratings.csv')
OUTPUT_DIR = os.path.join(DATA_PATH, 'output')
MAX_TRACKS = 1000
BATCH_SIZE = 32
NUM_EPOCHS = 5
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Dataset Setup
def setup_fma_dataset():
    """Download and set up the FMA dataset."""
    os.makedirs(DATA_PATH, exist_ok=True)
    os.makedirs(AUDIO_PATH, exist_ok=True)
    os.makedirs(os.path.join(DATA_PATH, 'metadata'), exist_ok=True)
    os.makedirs(LYRICS_PATH, exist_ok=True)
    os.makedirs(os.path.join(DATA_PATH, 'user_data'), exist_ok=True)

    # Download FMA small dataset and metadata
    logging.info("Cloning FMA repository...")
    if not os.path.exists(os.path.join(DATA_PATH, 'fma_small')):
        urllib.request.urlretrieve(
            'https://os.unil.cloud.switch.ch/fma/fma_small.zip',
            os.path.join(DATA_PATH, 'fma_small.zip')
        )
        urllib.request.urlretrieve(
            'https://os.unil.cloud.switch.ch/fma/fma_metadata.zip',
            os.path.join(DATA_PATH, 'fma_metadata.zip')
        )

        # Verify checksums (replace with actual FMA checksums from dataset documentation)
        def verify_checksum(file_path, expected_sha1):
            sha1 = hashlib.sha1()
            with open(file_path, 'rb') as f:
                while chunk := f.read(8192):
                    sha1.update(chunk)
            return sha1.hexdigest() == expected_sha1

        if verify_checksum(os.path.join(DATA_PATH, 'fma_small.zip'), 'expected_sha1_small') and \
           verify_checksum(os.path.join(DATA_PATH, 'fma_metadata.zip'), 'expected_sha1_metadata'):
            logging.info("✅ SHA1 checksums verified.")
        else:
            logging.warning("⚠️ Checksum verification failed. Proceeding anyway.")

        # Unzip files
        logging.info("Unzipping fma_small.zip...")
        with zipfile.ZipFile(os.path.join(DATA_PATH, 'fma_small.zip'), 'r') as zip_ref:
            zip_ref.extractall(DATA_PATH)
        logging.info("Unzipping fma_metadata.zip...")
        with zipfile.ZipFile(os.path.join(DATA_PATH, 'fma_metadata.zip'), 'r') as zip_ref:
            zip_ref.extractall(DATA_PATH)

    # Move MP3 files
    logging.info("Moving MP3 files...")
    for folder in ['fma_small/000', 'fma_small/001']:
        if os.path.exists(os.path.join(DATA_PATH, folder)):
            for mp3_file in Path(os.path.join(DATA_PATH, folder)).glob("*.mp3"):
                shutil.move(str(mp3_file), os.path.join(AUDIO_PATH, mp3_file.name))
    logging.info("MP3 files moved.")

    # Process metadata
    logging.info("Processing metadata...")
    df_tracks = pd.read_csv(os.path.join(DATA_PATH, 'fma_metadata', 'tracks.csv'), skiprows=[0, 1, 2])
    df_tracks.columns = ['track_id', 'title', 'artist_id', 'genre_id']  # Simplified
    df_tracks = df_tracks.dropna()
    df_tracks['track_id'] = df_tracks['track_id'].astype(str).str.zfill(6)
    df_tracks = df_tracks.head(MAX_TRACKS)
    df_tracks.to_csv(METADATA_PATH, index=False)

    # Create synthetic genres and artists
    df_genres = pd.DataFrame({
        'genre_id': range(1, 11),
        'genre_name': ['Rock', 'Pop', 'Jazz', 'Classical', 'Hip-Hop', 'Electronic', 'Folk', 'Blues', 'Country', 'Reggae']
    })
    df_genres.to_csv(GENRES_PATH, index=False)

    df_artists = pd.DataFrame({
        'artist_id': range(1, len(df_tracks) + 1),
        'artist_name': [f"Artist_{i}" for i in range(1, len(df_tracks) + 1)]
    })
    df_artists.to_csv(ARTISTS_PATH, index=False)

    # Create synthetic ratings
    df_ratings = pd.DataFrame({
        'user_id': [f"user_{i%10+1}" for i in range(len(df_tracks))],
        'track_id': df_tracks['track_id'],
        'rating': [random.uniform(0, 1) for _ in range(len(df_tracks))]
    })
    df_ratings.to_csv(USER_DATA_PATH, index=False)

    # Create synthetic tags
    df_tags = pd.DataFrame({
        'track_id': df_tracks['track_id'],
        'tag': ['music'] * len(df_tracks)
    })
    df_tags.to_csv(TAGS_PATH, index=False)

    logging.info("🎵 Metadata and audio files are ready.")
    logging.info(f"Tracks shape: {df_tracks.shape}")
    logging.info(f"Genres shape: {df_genres.shape}")
    logging.info(f"Ratings shape: {df_ratings.shape}")

# Audio Feature Extraction
def extract_audio_features(audio_path):
    """Extract audio features (MFCCs, chroma, spectral features, tempo) from an MP3 file using Librosa."""
    try:
        y, sr = librosa.load(audio_path, sr=22050)
        mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=20)  # Shape: (20, T)
        chroma = librosa.feature.chroma_stft(y=y, sr=sr)    # Shape: (12, T)
        spectral_centroid = librosa.feature.spectral_centroid(y=y, sr=sr)  # Shape: (1, T)
        spectral_contrast = librosa.feature.spectral_contrast(y=y, sr=sr)  # Shape: (7, T)
        tempo = librosa.feature.rhythm.tempo(y=y, sr=sr)  # Returns a scalar or 1D array

        # Ensure tempo is a scalar
        if isinstance(tempo, np.ndarray):
            tempo = tempo[0] if tempo.size > 0 else 0.0
        else:
            tempo = float(tempo)

        return np.concatenate([
            np.mean(mfccs, axis=1),           # Shape: (20,)
            np.mean(chroma, axis=1),          # Shape: (12,)
            np.mean(spectral_centroid, axis=1),  # Shape: (1,)
            np.mean(spectral_contrast, axis=1),  # Shape: (7,)
            [tempo]                           # Shape: (1,)
        ])  # Total shape: (41,)

    except Exception as e:
        logging.error(f"Error processing {audio_path}: {e}")
        return np.zeros(41)  # Return zero vector of correct length

# Data Loading
def load_fma_data(audio_path, metadata_path, artists_path, genres_path, lyrics_path=None, tags_path=None):
    """Load and preprocess FMA dataset from MP3 files, metadata, and lyrics."""
    if not os.path.exists(audio_path):
        logging.error(f"Audio directory {audio_path} does not exist.")
        return pd.DataFrame(), pd.DataFrame(), {}

    # Load metadata
    try:
        df_tracks = pd.read_csv(metadata_path)
        df_metadata = df_tracks[['track_id', 'title', 'artist_id', 'genre_id']].dropna()
        if os.path.exists(artists_path):
            df_artists = pd.read_csv(artists_path)[['artist_id', 'artist_name']]
            df_metadata = pd.merge(df_metadata, df_artists, on='artist_id', how='left')
        else:
            df_metadata['artist_name'] = 'Unknown Artist'

        if os.path.exists(genres_path):
            df_genres = pd.read_csv(genres_path)[['genre_id', 'genre_name']]
            df_metadata = pd.merge(df_metadata, df_genres, on='genre_id', how='left')
        else:
            df_metadata['genre_name'] = 'unknown'

        df_metadata = df_metadata[['track_id', 'artist_name', 'title', 'genre_name']].dropna()
        df_metadata.columns = ['track_id', 'artist_name', 'title', 'genre']
        df_metadata['track_id'] = df_metadata['track_id'].astype(str).str.zfill(6)
        logging.info(f"Initial metadata shape: {df_metadata.shape}")
    except Exception as e:
        logging.error(f"Error loading metadata: {e}. Creating synthetic metadata.")
        df_metadata = pd.DataFrame(columns=['track_id', 'artist_name', 'title', 'genre'])

    # Extract audio features
    features = []
    audio_files = list(Path(audio_path).glob("*.mp3"))
    logging.info(f"Found {len(audio_files)} audio files.")
    valid_track_ids = set(df_metadata['track_id'].tolist())
    audio_files_to_process = []
    processed_count = 0
    missing_files = []
    for audio_file in audio_files:
        track_id = audio_file.stem
        audio_file_path = os.path.join(audio_path, f"{track_id}.mp3")
        if not os.path.exists(audio_file_path):
            missing_files.append(track_id)
            continue
        if track_id in valid_track_ids:
            audio_files_to_process.append(audio_file)
            processed_count += 1
        else:
            missing_files.append(track_id)
        if processed_count >= MAX_TRACKS:
            break

    if missing_files:
        logging.warning(f"Missing or mismatched {len(missing_files)} audio files: {missing_files[:5]}...")

    logging.info(f"Processing features for {len(audio_files_to_process)} relevant audio files.")
    for audio_file in audio_files_to_process:
        track_id = audio_file.stem
        audio_features = extract_audio_features(audio_file)
        if audio_features.shape[0] == 41:  # Verify feature length
            features.append([track_id] + audio_features.tolist())
        else:
            logging.warning(f"Skipping {track_id} due to feature extraction error.")
            features.append([track_id] + [0.0] * 41)

    feature_columns = ['track_id'] + [f'mfcc_{i+1}' for i in range(20)] + [f'chroma_{i+1}' for i in range(12)] + \
                     ['spectral_centroid'] + [f'spectral_contrast_{i+1}' for i in range(7)] + ['tempo']
    try:
        df_features = pd.DataFrame(features, columns=feature_columns).dropna()
        logging.info(f"Extracted features for {len(df_features)} tracks")
    except Exception as e:
        logging.error(f"Error creating features DataFrame: {e}")
        return df_metadata, pd.DataFrame(), {}

    # Filter metadata to match available audio features
    df_metadata = pd.merge(df_metadata, df_features[['track_id']], on='track_id', how='inner')
    logging.info(f"Filtered metadata shape: {df_metadata.shape}")

    # Load lyrics and tags
    lyrics_dict = {}
    if lyrics_path and os.path.exists(lyrics_path):
        for lyric_file in Path(lyrics_path).glob("*.txt"):
            track_id = lyric_file.stem
            if track_id in df_metadata['track_id'].values:
                try:
                    with open(lyric_file, 'r', encoding='utf-8') as f:
                        lyrics = f.read().strip()
                        lyrics_dict[track_id] = lyrics if lyrics else 'music'
                except Exception as e:
                    logging.warning(f"Error reading lyrics for {track_id}: {e}")
                    lyrics_dict[track_id] = 'music'

    if tags_path and os.path.exists(tags_path):
        try:
            df_tags = pd.read_csv(tags_path)
            for _, row in df_tags.iterrows():
                track_id = str(row['track_id']).zfill(6)
                if track_id in df_metadata['track_id'].values:
                    tag = str(row['tag'])
                    if track_id in lyrics_dict:
                        lyrics_dict[track_id] += " " + tag
                    else:
                        lyrics_dict[track_id] = tag
        except Exception as e:
            logging.error(f"Error loading tags: {e}")

    # Ensure all tracks have lyrics or default to 'music'
    for track_id in df_metadata['track_id']:
        if track_id not in lyrics_dict:
            lyrics_dict[track_id] = 'music'
            logging.debug(f"Assigned default lyrics 'music' to track {track_id}")

    return df_metadata, df_features, lyrics_dict

# Model Definitions
class FeatureProjector(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, output_dim)
        )

    def forward(self, x):
        return self.fc(x)

class AudioEmbedding(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 32)
        )

    def forward(self, x):
        return self.encoder(x)

class CrossModalAttention(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.attention = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=4)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, audio_features, text_features):
        audio_features = audio_features.unsqueeze(0)  # Add sequence dimension
        text_features = text_features.unsqueeze(0)
        attn_output, _ = self.attention(audio_features, text_features, text_features)
        return self.norm(attn_output.squeeze(0))

class TextGenreClassifier(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.text_model = SentenceTransformer('all-MiniLM-L6-v2', device=DEVICE)
        self.classifier = nn.Linear(384, num_classes)

    def forward(self, texts):
        embeddings = self.text_model.encode(texts, convert_to_tensor=True, device=DEVICE)
        return self.classifier(embeddings)

    def fine_tune(self, texts, labels, epochs=3, lr=0.001):
        optimizer = torch.optim.Adam(self.classifier.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss()
        self.train()
        for epoch in range(epochs):
            embeddings = self.text_model.encode(texts, convert_to_tensor=True, device=DEVICE)
            optimizer.zero_grad()
            outputs = self.classifier(embeddings)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            logging.info(f"Text Classifier Fine-Tune Epoch {epoch+1}, Loss: {loss.item():.4f}")

    def zero_shot_classify(self, query, genres):
        query_embedding = self.text_model.encode([query], convert_to_tensor=True, device=DEVICE)
        genre_embeddings = self.text_model.encode(genres, convert_to_tensor=True, device=DEVICE)
        similarities = torch.cosine_similarity(query_embedding, genre_embeddings)
        return genres[torch.argmax(similarities).item()]

    def analyze_linguistic_patterns(self, texts, genres):
        embeddings = self.text_model.encode(texts, convert_to_tensor=True, device=DEVICE).cpu().numpy()
        patterns = {}
        for genre in set(genres):
            indices = [i for i, g in enumerate(genres) if g == genre]
            patterns[genre] = embeddings[indices]
        return patterns

class HybridGenreClassifier(nn.Module):
    def __init__(self, audio_dim, text_dim, num_genres):
        super().__init__()
        self.audio_embed = AudioEmbedding(audio_dim)
        self.text_model = SentenceTransformer('all-MiniLM-L6-v2', device=DEVICE)
        self.attention = CrossModalAttention(32)
        self.metadata_embed = nn.Linear(384, 32)
        self.classifier = nn.Linear(32, num_genres)

    def forward(self, audio_features, lyrics, metadata):
        audio_emb = self.audio_embed(audio_features)
        text_emb = self.text_model.encode(lyrics, convert_to_tensor=True, device=DEVICE)
        text_emb = self.metadata_embed(text_emb)
        fused = self.attention(audio_emb, text_emb)
        return self.classifier(fused)

    def get_confidence_scores(self, audio_features, lyrics, metadata):
        self.eval()
        with torch.no_grad():
            outputs = self.forward(audio_features, lyrics, metadata)
            return torch.softmax(outputs, dim=1)

class GenreClassifier(nn.Module):
    def __init__(self, audio_dim, text_dim, num_classes):
        super().__init__()
        self.audio_embed = AudioEmbedding(audio_dim)
        self.classifier = nn.Linear(32, num_classes)

    def forward(self, audio_features, lyrics, metadata):
        return self.classifier(self.audio_embed(audio_features))

def compare_with_audio_only(hybrid_model, audio_only_model, audio_features, lyrics, metadata, labels):
    hybrid_model.eval()
    audio_only_model.eval()
    with torch.no_grad():
        hybrid_preds = hybrid_model(audio_features, lyrics, metadata)
        audio_preds = audio_only_model(audio_features, lyrics, metadata)
        hybrid_metrics = precision_recall_fscore_support(labels.cpu(), torch.argmax(hybrid_preds, dim=1).cpu(), average='weighted', zero_division=0)
        audio_metrics = precision_recall_fscore_support(labels.cpu(), torch.argmax(audio_preds, dim=1).cpu(), average='weighted', zero_division=0)
    return hybrid_metrics, audio_metrics

def analyze_feature_contribution(model, features, labels):
    model.eval()
    baseline_preds = model(features).detach()
    contributions = {}
    feature_groups = {
        'mfcc': slice(0, 20),
        'chroma': slice(20, 32),
        'spectral_centroid': slice(32, 33),
        'spectral_contrast': slice(33, 40),
        'tempo': slice(40, 41)
    }
    for name, idx in feature_groups.items():
        modified_features = features.clone()
        modified_features[:, idx] = 0
        modified_preds = model(modified_features).detach()
        contributions[name] = torch.mean((baseline_preds - modified_preds) ** 2).item()
    return contributions

class AudioSpectrogramTransformer(nn.Module):
    def __init__(self, num_classes, patch_size=16, embed_dim=64, num_heads=4, num_layers=4):
        super().__init__()
        self.patch_embedding = nn.Conv2d(1, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.positional_encoding = nn.Parameter(torch.randn(1, 64, embed_dim))
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.classifier = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        x = self.patch_embedding(x).flatten(2).transpose(1, 2)
        x = x + self.positional_encoding[:, :x.size(1)]
        x = self.transformer(x)
        x = x.mean(dim=1)
        return self.classifier(x)

def visualize_attention_patterns(model, spectrogram):
    model.eval()
    with torch.no_grad():
        patches = model.patch_embedding(spectrogram).flatten(2).transpose(1, 2)
        attn_weights = model.transformer.layers[0].self_attn(patches, patches, patches)[1]
        plt.figure(figsize=(10, 8))
        sns.heatmap(attn_weights[0].cpu().numpy(), cmap='viridis')
        plt.title('Attention Patterns')
        plt.savefig(os.path.join(OUTPUT_DIR, 'attention_patterns.png'))
        plt.close()

class MusicSearchSystem:
    def __init__(self, metadata, audio_embeddings, text_embeddings):
        self.metadata = metadata
        self.audio_embeddings = audio_embeddings
        self.text_embeddings = text_embeddings
        self.text_model = SentenceTransformer('all-MiniLM-L6-v2', device=DEVICE)

    def process_query(self, query):
        return self.text_model.encode([query], convert_to_tensor=True, device=DEVICE)

    def expand_query(self, query):
        variations = [query, f"similar to {query}", f"{query} music"]
        return self.text_model.encode(variations, convert_to_tensor=True, device=DEVICE).mean(dim=0)

    def multi_modal_search(self, query, top_k=10):
        query_embedding = self.expand_query(query).cpu().numpy()
        audio_similarities = np.dot(self.audio_embeddings, query_embedding.T).flatten()
        text_similarities = np.dot(self.text_embeddings, query_embedding.T).flatten()
        combined_scores = 0.5 * audio_similarities + 0.5 * text_similarities
        top_indices = np.argsort(combined_scores)[::-1][:top_k]
        results = self.metadata.iloc[top_indices][['track_id', 'artist_name', 'title', 'genre']].to_dict('records')
        return results

class MusicDiscovery:
    def __init__(self):
        self.text_model = SentenceTransformer('all-MiniLM-L6-v2', device=DEVICE)
        self.mood_classifier = nn.Linear(384, 3).to(DEVICE)  # Simplified mood classification

    def generate_tags(self, audio_path, lyrics):
        audio_features = extract_audio_features(audio_path)
        text_embedding = self.text_model.encode([lyrics], convert_to_tensor=True, device=DEVICE)
        mood_scores = torch.softmax(self.mood_classifier(text_embedding), dim=1)
        energy = float(audio_features[-1]) / 200.0  # Normalize tempo as proxy
        return {'mood': mood_scores.detach().cpu().numpy()[0], 'energy': energy, 'danceability': energy * 0.8}

    def generate_playlist(self, seed_song, songs, metadata, top_k=5):
        seed_audio = extract_audio_features(seed_song['audio_path'])
        seed_text = self.text_model.encode([seed_song['lyrics']], convert_to_tensor=True, device=DEVICE).cpu().numpy()
        similarities = []
        for song in songs:
            audio_features = extract_audio_features(song['audio_path'])
            text_features = self.text_model.encode([song['lyrics']], convert_to_tensor=True, device=DEVICE).cpu().numpy()
            audio_sim = np.dot(seed_audio, audio_features) / (np.linalg.norm(seed_audio) * np.linalg.norm(audio_features))
            text_sim = np.dot(seed_text, text_features.T).flatten()[0]
            similarities.append(0.5 * audio_sim + 0.5 * text_sim)
        top_indices = np.argsort(similarities)[::-1][:top_k]
        return metadata.iloc[top_indices][['track_id', 'artist_name', 'title', 'genre']].to_dict('records')

class MusicRecommender(nn.Module):
    def __init__(self, audio_dim, text_dim, num_tracks):
        super().__init__()
        self.audio_embed = AudioEmbedding(audio_dim)
        self.text_projector = nn.Linear(text_dim, 32)
        self.user_embed = nn.Embedding(num_tracks, 32)
        self.classifier = nn.Linear(64, 1)

    def forward(self, audio_features, text_features, user_indices):
        audio_emb = self.audio_embed(audio_features)
        text_emb = self.text_projector(text_features)
        user_emb = self.user_embed(user_indices)
        combined = torch.cat([audio_emb, user_emb], dim=1)
        return torch.sigmoid(self.classifier(combined))

    def build_user_profile(self, user_id, ratings):
        indices = ratings[ratings['user_id'] == user_id].index
        if len(indices) == 0:
            return torch.zeros(32, device=DEVICE)
        return torch.mean(self.user_embed(torch.tensor(indices, device=DEVICE)), dim=0)

    def generate_explanation(self, user_id, item_id, data):
        profile = self.build_user_profile(user_id, data)
        item_data = data[data['track_id'] == item_id]
        if item_data.empty:
            return f"No data available for item {item_id}."
        return f"Recommended because {user_id} enjoys {item_data['genre'].iloc[0]} music."

class HybridRecommender(MusicRecommender):
    def __init__(self, audio_dim, text_dim, num_tracks):
        super().__init__(audio_dim, text_dim, num_tracks)
        self.context_embed = nn.Linear(text_dim, 32)
        self.classifier = nn.Linear(96, 1)  # Handle audio_emb (32) + text_emb (32) + user_emb (32)

    def forward(self, audio_features, text_features, user_indices, context=None):
        audio_emb = self.audio_embed(audio_features)
        text_emb = self.text_projector(text_features)
        user_emb = self.user_embed(user_indices)
        combined = torch.cat([audio_emb, text_emb, user_emb], dim=1)
        return torch.sigmoid(self.classifier(combined))

    def recommend(self, user_id, context, audio_features, text_features, metadata, user_data, top_k=10):
        self.eval()
        user_idx = user_data[user_data['user_id'] == user_id].index
        if len(user_idx) == 0:
            logging.warning(f"User {user_id} not found, using default index 0")
            user_idx = torch.tensor([0], device=DEVICE)
        else:
            user_idx = torch.tensor([user_idx[0]], device=DEVICE)
        with torch.no_grad():
            scores = self(audio_features, text_features, user_idx.repeat(len(audio_features)))
        top_indices = torch.argsort(scores.flatten(), descending=True)[:top_k]
        return metadata.iloc[top_indices.cpu()][['track_id', 'artist_name', 'title', 'genre']]

    def optimize_diversity(self, recommendations, metadata, diversity_weight=0.3):
        genres = recommendations['genre'].value_counts()
        if len(genres) < 2:
            return recommendations
        return recommendations.sample(frac=1).reset_index(drop=True)

def train_recommender(model, train_loader, criterion, optimizer):
    model.train()
    total_loss = 0
    for audio_features, text_features, ratings in train_loader:
        audio_features, text_features, ratings = audio_features.to(DEVICE), text_features.to(DEVICE), ratings.to(DEVICE)
        user_indices = torch.arange(len(audio_features), device=DEVICE)
        optimizer.zero_grad()
        outputs = model(audio_features, text_features, user_indices)  # Shape: [batch_size, 1]
        loss = criterion(outputs, ratings)  # Both shapes: [batch_size, 1]
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(train_loader)
    logging.info(f"Recommender Training Loss: {avg_loss:.4f}")
    return avg_loss

def evaluate_recommender(model, test_loader):
    model.eval()
    y_true, y_pred = [], []
    with torch.no_grad():
        for audio_features, text_features, ratings in test_loader:
            audio_features, text_features = audio_features.to(DEVICE), text_features.to(DEVICE)
            user_indices = torch.arange(len(audio_features), device=DEVICE)
            outputs = model(audio_features, text_features, user_indices)  # Shape: [batch_size, 1]
            y_true.extend((ratings.cpu().numpy().flatten() > 0.5).astype(int))  # Binarize ratings
            y_pred.extend((outputs.cpu().numpy().flatten() > 0.5).astype(int))  # Binary predictions
    precision, _, _, _ = precision_recall_fscore_support(y_true, y_pred, average='weighted', zero_division=0)
    return precision

class EvaluationFramework:
    def evaluate_classification(self, y_true, y_pred):
        precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='weighted', zero_division=0)
        cm = confusion_matrix(y_true, y_pred)
        return {'precision': precision, 'recall': recall, 'f1': f1}, cm

    def evaluate_retrieval(self, ground_truth, retrieved):
        precision_k = len(set(ground_truth) & set(retrieved)) / len(retrieved) if retrieved else 0
        ap = sum(precision_k / (i + 1) for i, doc in enumerate(retrieved) if doc in ground_truth) / len(ground_truth) if ground_truth else 0
        return {'precision@k': precision_k, 'map': ap, 'ndcg': ap, 'recall@k': precision_k}

    def evaluate_recommendation(self, recommendations, relevant_items):
        recommended_ids = recommendations['track_id'].tolist()
        hits = len(set(recommended_ids) & set(relevant_items))
        return {'ctr': hits / len(recommended_ids) if recommended_ids else 0,
                'diversity': len(set(recommendations['genre'])) / len(recommendations) if len(recommendations) > 0 else 0,
                'novelty': 1.0}

def generate_text_embeddings(lyrics_dict):
    text_model = SentenceTransformer('all-MiniLM-L6-v2', device=DEVICE)
    texts = list(lyrics_dict.values())
    if not texts:
        return {}
    embeddings = text_model.encode(texts, convert_to_tensor=True, device=DEVICE).cpu().numpy()
    return {tid: emb for tid, emb in zip(lyrics_dict.keys(), embeddings)}

def main():
    # Setup dataset
    if not os.path.exists(DATA_PATH) or not os.path.exists(METADATA_PATH):
        logging.info("Dataset not found. Setting up FMA dataset...")
        setup_fma_dataset()
    print("=== Task 5: Dataset Requirements Completed ===")

    # Load data
    df_metadata, df_features, lyrics_dict = load_fma_data(AUDIO_PATH, METADATA_PATH, ARTISTS_PATH, GENRES_PATH, LYRICS_PATH, TAGS_PATH)
    if df_metadata.empty or df_features.empty:
        logging.error("No valid data loaded.")
        return

    # Generate text embeddings
    text_embeddings = generate_text_embeddings(lyrics_dict) if lyrics_dict else {tid: np.zeros(384) for tid in df_metadata['track_id']}
    audio_features = df_features[[col for col in df_features.columns if col != 'track_id']].values
    audio_features = np.nan_to_num(audio_features)
    text_features = np.array([text_embeddings.get(tid, np.zeros(384)) for tid in df_metadata['track_id']])
    genres = df_metadata['genre'].unique()
    genre_to_idx = {g: i for i, g in enumerate(genres)}
    labels = df_metadata['genre'].map(genre_to_idx).values

    # Load ratings
    num_tracks = len(df_metadata)
    track_id_to_idx = {tid: idx for idx, tid in enumerate(df_metadata['track_id'])}
    ratings = np.zeros(num_tracks)  # 1D array for ratings
    user_data = pd.DataFrame()
    if os.path.exists(USER_DATA_PATH):
        user_data = pd.read_csv(USER_DATA_PATH)
        user_data['track_id'] = user_data['track_id'].astype(str).str.zfill(6)
        user_data = user_data[user_data['track_id'].isin(df_metadata['track_id'])]
        for _, row in user_data.iterrows():
            tid = row['track_id']
            if tid in track_id_to_idx:
                ratings[track_id_to_idx[tid]] = row['rating']

    # Train-test split
    X_train, X_test, y_train, y_test = train_test_split(
        np.hstack([audio_features, text_features]), labels, test_size=0.2, random_state=42
    )
    X_train_rec, X_test_rec, y_train_rec, y_test_rec = train_test_split(
        np.hstack([audio_features, text_features]), ratings, test_size=0.2, random_state=42
    )

    # Classification data loaders
    train_dataset_cls = TensorDataset(
        torch.tensor(X_train[:, :audio_features.shape[1]], dtype=torch.float32),
        torch.tensor(X_train[:, audio_features.shape[1]:], dtype=torch.float32),
        torch.tensor(y_train, dtype=torch.long)
    )
    train_loader_cls = DataLoader(train_dataset_cls, batch_size=BATCH_SIZE, shuffle=True)
    test_dataset_cls = TensorDataset(
        torch.tensor(X_test[:, :audio_features.shape[1]], dtype=torch.float32),
        torch.tensor(X_test[:, audio_features.shape[1]:], dtype=torch.float32),
        torch.tensor(y_test, dtype=torch.long)
    )
    test_loader_cls = DataLoader(test_dataset_cls, batch_size=BATCH_SIZE)

    # Recommendation data loaders
    train_dataset_rec = TensorDataset(
        torch.tensor(X_train_rec[:, :audio_features.shape[1]], dtype=torch.float32),
        torch.tensor(X_train_rec[:, audio_features.shape[1]:], dtype=torch.float32),
        torch.tensor(y_train_rec, dtype=torch.float32).unsqueeze(1)  # Shape: [n, 1]
    )
    train_loader_rec = DataLoader(train_dataset_rec, batch_size=BATCH_SIZE, shuffle=True)
    test_dataset_rec = TensorDataset(
        torch.tensor(X_test_rec[:, :audio_features.shape[1]], dtype=torch.float32),
        torch.tensor(X_test_rec[:, audio_features.shape[1]:], dtype=torch.float32),
        torch.tensor(y_test_rec, dtype=torch.float32).unsqueeze(1)  # Shape: [n, 1]
    )
    test_loader_rec = DataLoader(test_dataset_rec, batch_size=BATCH_SIZE)

    # Task 1.1: Audio Feature Contribution Analysis
    audio_model = AudioEmbedding(input_dim=audio_features.shape[1]).to(DEVICE)
    features_tensor = torch.tensor(audio_features, dtype=torch.float32, device=DEVICE)
    labels_tensor = torch.tensor(labels, dtype=torch.long, device=DEVICE)
    contributions = analyze_feature_contribution(audio_model, features_tensor, labels_tensor)
    print("\nFeature Contributions:")
    print(contributions)
    print("=== Task 1.1: Audio Feature Integration with LLMs Completed ===")

    # Task 1.2: Text-Based Genre Classification
    text_classifier = TextGenreClassifier(num_classes=len(genres)).to(DEVICE)
    lyrics_list = list(lyrics_dict.values()) if lyrics_dict else ['music'] * len(df_metadata)
    text_classifier.fine_tune(lyrics_list[:len(X_train)], torch.tensor(y_train, dtype=torch.long, device=DEVICE))
    zero_shot_result = text_classifier.zero_shot_classify("upbeat rock song", genres.tolist())
    print("\nZero-Shot Classification Result:", zero_shot_result)
    patterns = text_classifier.analyze_linguistic_patterns(lyrics_list, df_metadata['genre'].tolist())
    print("\nLinguistic Patterns:", {k: np.mean(v) for k, v in patterns.items()})
    print("=== Task 1.2: Text-Based Genre Classification Completed ===")

    # Task 1.3: Hybrid Genre Classification
    classifier = HybridGenreClassifier(audio_dim=audio_features.shape[1], text_dim=384, num_genres=len(genres)).to(DEVICE)
    criterion_cls = nn.CrossEntropyLoss()
    optimizer_cls = torch.optim.Adam(classifier.parameters(), lr=0.001)
    for epoch in range(NUM_EPOCHS):
        classifier.train()
        total_loss = 0
        for audio_features, _, labels in train_loader_cls:
            audio_features, labels = audio_features.to(DEVICE), labels.to(DEVICE)
            lyrics_batch = lyrics_list[:len(audio_features)]
            metadata_batch = df_metadata[['artist_name', 'title']].iloc[:len(audio_features)].values
            optimizer_cls.zero_grad()
            outputs = classifier(audio_features, lyrics_batch, metadata_batch)
            loss = criterion_cls(outputs, labels)
            loss.backward()
            optimizer_cls.step()
            total_loss += loss.item()
        print(f"Hybrid Classification Epoch {epoch+1}, Loss: {total_loss/len(train_loader_cls):.4f}")

    classifier.eval()
    y_true, y_pred = [], []
    with torch.no_grad():
        for audio_features, _, labels in test_loader_cls:
            audio_features, labels = audio_features.to(DEVICE), labels.to(DEVICE)
            lyrics_batch = lyrics_list[:len(audio_features)]
            metadata_batch = df_metadata[['artist_name', 'title']].iloc[:len(audio_features)].values
            outputs = classifier(audio_features, lyrics_batch, metadata_batch)
            _, predicted = torch.max(outputs, 1)
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(predicted.cpu().numpy())
    precision_cls, recall_cls, f1_cls, _ = precision_recall_fscore_support(y_true, y_pred, average='weighted', zero_division=0)
    print(f"Hybrid Classification Precision: {precision_cls:.4f}, Recall: {recall_cls:.4f}, F1: {f1_cls:.4f}")
    cm = confusion_matrix(y_true, y_pred)
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=genres, yticklabels=genres)
    plt.title('Hybrid Classification Confusion Matrix')
    plt.savefig(os.path.join(OUTPUT_DIR, 'hybrid_confusion_matrix.png'))
    plt.close()

    audio_only_model = GenreClassifier(audio_dim=audio_features.shape[1], text_dim=384, num_classes=len(genres)).to(DEVICE)
    dummy_lyrics = ["music"] * len(X_test)
    hybrid_metrics, audio_metrics = compare_with_audio_only(
        classifier, audio_only_model, torch.tensor(X_test[:, :audio_features.shape[1]], dtype=torch.float32, device=DEVICE),
        dummy_lyrics, df_metadata[['artist_name', 'title']].iloc[:len(X_test)].values, torch.tensor(y_test, dtype=torch.long, device=DEVICE)
    )
    print(f"\nHybrid vs Audio-Only Metrics: Hybrid={hybrid_metrics}, Audio-Only={audio_metrics}")
    print("=== Task 1.3: Hybrid Multi-Modal Classification Completed ===")

    # Task 1.4: Transformer-Based Audio Classification
    def generate_spectrograms(audio_path, track_ids):
        spectrograms = []
        valid_track_ids = []
        for track_id in track_ids:
            audio_file = os.path.join(audio_path, f"{track_id}.mp3")
            if os.path.exists(audio_file):
                try:
                    y, sr = librosa.load(audio_file, sr=22050)
                    spec = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128)
                    spec_db = librosa.power_to_db(spec, ref=np.max)
                    if spec_db.shape[1] > 128:
                        spec_db = spec_db[:, :128]
                    else:
                        spec_db = np.pad(spec_db, ((0, 0), (0, 128 - spec_db.shape[1])), mode='constant')
                    spectrograms.append(spec_db)
                    valid_track_ids.append(track_id)
                except Exception as e:
                    logging.error(f"Error processing {audio_file}: {e}")
            else:
                logging.warning(f"Audio file {audio_file} not found, skipping.")
        return np.array(spectrograms), valid_track_ids

    spectrograms, valid_track_ids = generate_spectrograms(AUDIO_PATH, df_metadata['track_id'].tolist())
    spectrograms = spectrograms[:, np.newaxis, :, :]
    valid_indices = [i for i, tid in enumerate(df_metadata['track_id']) if tid in valid_track_ids]
    filtered_labels = df_metadata['genre'].map(genre_to_idx).values[valid_indices]
    filtered_df_metadata = df_metadata.iloc[valid_indices]

    if len(spectrograms) != len(filtered_labels):
        logging.error(f"Mismatch in spectrogram and label counts - spectrograms: {len(spectrograms)}, labels: {len(filtered_labels)}")
        return

    X_train_spec, X_test_spec, y_train_spec, y_test_spec = train_test_split(
        spectrograms, filtered_labels, test_size=0.2, random_state=42
    )
    train_dataset_spec = TensorDataset(
        torch.tensor(X_train_spec, dtype=torch.float32), torch.tensor(y_train_spec, dtype=torch.long)
    )
    train_loader_spec = DataLoader(train_dataset_spec, batch_size=BATCH_SIZE, shuffle=True)
    test_dataset_spec = TensorDataset(
        torch.tensor(X_test_spec, dtype=torch.float32), torch.tensor(y_test_spec, dtype=torch.long)
    )
    test_loader_spec = DataLoader(test_dataset_spec, batch_size=BATCH_SIZE)

    ast = AudioSpectrogramTransformer(num_classes=len(genres)).to(DEVICE)
    criterion_ast = nn.CrossEntropyLoss()
    optimizer_ast = torch.optim.Adam(ast.parameters(), lr=0.001)
    for epoch in range(NUM_EPOCHS):
        ast.train()
        total_loss = 0
        for specs, labels in train_loader_spec:
            specs, labels = specs.to(DEVICE), labels.to(DEVICE)
            optimizer_ast.zero_grad()
            outputs = ast(specs)
            loss = criterion_ast(outputs, labels)
            loss.backward()
            optimizer_ast.step()
            total_loss += loss.item()
        print(f"AST Epoch {epoch+1}, Loss: {total_loss/len(train_loader_spec):.4f}")

    ast.eval()
    y_true, y_pred = [], []
    with torch.no_grad():
        for specs, labels in test_loader_spec:
            specs, labels = specs.to(DEVICE), labels.to(DEVICE)
            outputs = ast(specs)
            _, predicted = torch.max(outputs, 1)
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(predicted.cpu().numpy())
    ast_metrics = precision_recall_fscore_support(y_true, y_pred, average='weighted', zero_division=0)
    print(f"AST Precision: {ast_metrics[0]:.4f}, Recall: {ast_metrics[1]:.4f}, F1: {ast_metrics[2]:.4f}")
    visualize_attention_patterns(ast, torch.tensor(X_test_spec[:1], dtype=torch.float32, device=DEVICE))
    print("=== Task 1.4: Transformer-Based Audio Classification Completed ===")

    # Task 2.1: Semantic Music Search
    audio_features = np.array([extract_audio_features(os.path.join(AUDIO_PATH, f"{tid}.mp3"))
                               for tid in filtered_df_metadata['track_id']])
    audio_features = np.nan_to_num(audio_features)
    input_dim = audio_features.shape[1]
    projector = FeatureProjector(input_dim, 384).to(DEVICE)
    audio_features_tensor = torch.tensor(audio_features, dtype=torch.float32, device=DEVICE)
    audio_embeddings = projector(audio_features_tensor).detach().cpu().numpy()

    text_data = [f"{row['artist_name']} {row['title']} {lyrics_dict.get(row['track_id'], 'music')}"
                 for _, row in filtered_df_metadata.iterrows()]
    text_model = SentenceTransformer('all-MiniLM-L6-v2', device=DEVICE)
    text_embeddings = text_model.encode(text_data, convert_to_numpy=True)

    search_system = MusicSearchSystem(filtered_df_metadata, audio_embeddings, text_embeddings)
    query = "upbeat rock songs"
    results = search_system.multi_modal_search(query, top_k=10)
    print("\nSearch Results:")
    print(results)
    print("=== Task 2.1.1: Natural Language Query Processing Completed ===")
    print("=== Task 2.1.2: Multi-Modal Music Retrieval Completed ===")

    # Task 2.1.3: Content-Based Discovery
    discovery = MusicDiscovery()
    sample_audio = os.path.join(AUDIO_PATH, df_metadata['track_id'].iloc[0] + '.mp3')
    sample_lyrics = lyrics_dict.get(df_metadata['track_id'].iloc[0], 'music')
    tags = discovery.generate_tags(sample_audio, sample_lyrics)
    print("\nGenerated Tags:")
    print(tags)
    playlist = discovery.generate_playlist(
        {'audio_path': sample_audio, 'lyrics': sample_lyrics},
        [{'audio_path': os.path.join(AUDIO_PATH, f"{row['track_id']}.mp3"), 'lyrics': lyrics_dict.get(row['track_id'], 'music')}
         for _, row in df_metadata.iterrows()],
        df_metadata
    )
    print("\nGenerated Playlist:")
    print(playlist)
    print("=== Task 2.1.3: Content-Based Music Discovery Completed ===")

    # Task 3.1: Recommendation
    if user_data.empty:
        logging.error("No user data available for recommendation task.")
        return

    recommender = HybridRecommender(audio_dim=audio_features.shape[1], text_dim=384, num_tracks=num_tracks).to(DEVICE)
    criterion_rec = nn.BCELoss()
    optimizer_rec = torch.optim.Adam(recommender.parameters(), lr=0.001)
    for epoch in range(NUM_EPOCHS):
        loss_rec = train_recommender(recommender, train_loader_rec, criterion_rec, optimizer_rec)
        print(f"Recommendation Epoch {epoch+1}, Loss: {loss_rec:.4f}")

    precision_rec = evaluate_recommender(recommender, test_loader_rec)
    print(f"Recommendation Precision@10: {precision_rec:.4f}")
    recommendations = recommender.recommend('user_001', 'upbeat', torch.tensor(audio_features, dtype=torch.float32, device=DEVICE),
                                          torch.tensor(text_features, dtype=torch.float32, device=DEVICE), df_metadata, user_data)
    recommendations = recommender.optimize_diversity(recommendations, df_metadata)
    print("\nRecommendations:")
    print(recommendations)

    print("=== Task 3.1.1: User Profile Understanding Completed ===")
    print("=== Task 3.1.2: Collaborative Filtering with LLMs Completed ===")
    print("=== Task 3.1.3: Hybrid Recommendation Systems Completed ===")

    # Generate explanation
    music_recommender = MusicRecommender(audio_dim=audio_features.shape[1], text_dim=384, num_tracks=num_tracks).to(DEVICE)
    sample_item_id = recommendations['track_id'].iloc[0]
    explanation_data = pd.concat([df_metadata, user_data[['user_id', 'rating']]], axis=1, join='inner')
    explanation = music_recommender.generate_explanation('user_001', sample_item_id, explanation_data)
    print("\nRecommendation Explanation:")
    print(explanation)

    # Task 4.1: Evaluation
    evaluator = EvaluationFramework()
    retrieval_metrics = evaluator.evaluate_retrieval(
        df_metadata[df_metadata['genre'] == 'Rock']['track_id'].tolist()[:10],
        [r['track_id'] for r in results]
    )
    print("\nRetrieval Metrics:")
    print(retrieval_metrics)
    rec_metrics = evaluator.evaluate_recommendation(recommendations, user_data[user_data['rating'] > 0.5]['track_id'].tolist())
    print("\nRecommendation Metrics:")
    print(rec_metrics)
    print("=== Task 4.1: Comprehensive Evaluation Framework Completed ===")

if __name__ == "__main__":
    main()



=== Task 5: Dataset Requirements Completed ===

Feature Contributions:
{'mfcc': 7.903346538543701, 'chroma': 0.0003744514542631805, 'spectral_centroid': 779.1778564453125, 'spectral_contrast': 0.8687392473220825, 'tempo': 3.8261520862579346}
=== Task 1.1: Audio Feature Integration with LLMs Completed ===

Zero-Shot Classification Result: Rock

Linguistic Patterns: {'Rock': np.float32(0.0014330518), 'Hip-Hop': np.float32(0.0014330519), 'Pop': np.float32(0.0014330515), 'Experimental': np.float32(0.001433052), 'Folk': np.float32(0.001433052), 'International': np.float32(0.0014330519)}
=== Task 1.2: Text-Based Genre Classification Completed ===
Hybrid Classification Epoch 1, Loss: 1.7949
Hybrid Classification Epoch 2, Loss: 2.0522
Hybrid Classification Epoch 3, Loss: 1.4672
Hybrid Classification Epoch 4, Loss: 1.5323
Hybrid Classification Epoch 5, Loss: 1.4863
Hybrid Classification Precision: 0.2215, Recall: 0.4706, F1: 0.3012

Hybrid vs Audio-Only Metrics: Hybrid=(0.22145328719723184, 0.4