# FMA (Free Music Archive) Dataset Loading and Preprocessing

This notebook demonstrates how to load and preprocess the FMA dataset (advanced).

**FMA Dataset**: Large-scale dataset with tracks of various durations and 161 genres.
- FMA Small: 8,000 tracks of 30s, 8 balanced genres
- FMA Medium: 25,000 tracks of 30s, 16 unbalanced genres
- FMA Large: 106,574 tracks of 30s, 161 unbalanced genres
- FMA Full: 106,574 untrimmed tracks

In [None]:
import os
import ast
import numpy as np
import pandas as pd
import librosa
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm

## Dataset Configuration

In [None]:
# Configure paths
FMA_AUDIO_PATH = '../data/fma_small'  # or fma_medium, fma_large
FMA_METADATA_PATH = '../data/fma_metadata'
SAMPLE_RATE = 22050
DURATION = 30  # seconds
N_MELS = 128

## Helper Functions for FMA Metadata

In [None]:
def load_fma_metadata(metadata_path):
    """Load FMA metadata files."""
    tracks_file = os.path.join(metadata_path, 'tracks.csv')
    genres_file = os.path.join(metadata_path, 'genres.csv')
    
    if not os.path.exists(tracks_file):
        print(f"Warning: tracks.csv not found at {tracks_file}")
        return None, None
    
    # Load tracks metadata
    tracks = pd.read_csv(tracks_file, index_col=0, header=[0, 1])
    
    # Load genres
    if os.path.exists(genres_file):
        genres = pd.read_csv(genres_file, index_col=0)
    else:
        genres = None
    
    return tracks, genres

def get_audio_path(audio_dir, track_id):
    """Get the file path for a track ID."""
    tid_str = '{:06d}'.format(track_id)
    return os.path.join(audio_dir, tid_str[:3], tid_str + '.mp3')

## FMA Dataset Class

In [None]:
class FMADataset(Dataset):
    """PyTorch Dataset for FMA music genre classification."""
    
    def __init__(self, audio_dir, metadata_path, sample_rate=22050, 
                 duration=30, transform=None, subset='small'):
        """
        Args:
            audio_dir (string): Directory with audio files
            metadata_path (string): Path to metadata directory
            sample_rate (int): Target sample rate
            duration (int): Duration in seconds
            transform (callable, optional): Optional transform
            subset (string): Dataset subset ('small', 'medium', 'large')
        """
        self.audio_dir = audio_dir
        self.sample_rate = sample_rate
        self.duration = duration
        self.transform = transform
        self.subset = subset
        
        # Load metadata
        self.tracks, self.genres = load_fma_metadata(metadata_path)
        
        if self.tracks is not None:
            # Filter by subset
            if subset in ['small', 'medium', 'large']:
                self.tracks = self.tracks[self.tracks['set', 'subset'] <= subset]
            
            # Get genre information (top-level genre)
            self.track_ids = self.tracks.index.tolist()
            self.genre_ids = self.tracks['track', 'genre_top'].values
            
            # Create genre mapping
            unique_genres = sorted(set(self.genre_ids))
            self.genre_to_idx = {genre: idx for idx, genre in enumerate(unique_genres)}
            self.idx_to_genre = {idx: genre for genre, idx in self.genre_to_idx.items()}
            
            print(f"FMA {subset} dataset loaded: {len(self.track_ids)} tracks")
            print(f"Genres: {list(self.genre_to_idx.keys())}")
        else:
            self.track_ids = []
    
    def __len__(self):
        return len(self.track_ids)
    
    def __getitem__(self, idx):
        track_id = self.track_ids[idx]
        audio_path = get_audio_path(self.audio_dir, track_id)
        
        # Get genre label
        genre = self.genre_ids[idx]
        label = self.genre_to_idx[genre]
        
        # Load audio (FMA uses MP3 format)
        try:
            waveform, sr = torchaudio.load(audio_path)
        except:
            # If torchaudio fails with MP3, use librosa as fallback
            waveform_np, sr = librosa.load(audio_path, sr=self.sample_rate, mono=True)
            waveform = torch.from_numpy(waveform_np).unsqueeze(0)
            sr = self.sample_rate
        
        # Resample if necessary
        if sr != self.sample_rate:
            resampler = torchaudio.transforms.Resample(sr, self.sample_rate)
            waveform = resampler(waveform)
        
        # Convert to mono
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)
        
        # Pad or truncate
        target_length = self.sample_rate * self.duration
        if waveform.shape[1] > target_length:
            waveform = waveform[:, :target_length]
        elif waveform.shape[1] < target_length:
            padding = target_length - waveform.shape[1]
            waveform = torch.nn.functional.pad(waveform, (0, padding))
        
        if self.transform:
            waveform = self.transform(waveform)
        
        return waveform, label

## Example Usage

In [None]:
# Create dataset instance
# Note: Uncomment and update paths to use
# dataset = FMADataset(
#     audio_dir=FMA_AUDIO_PATH,
#     metadata_path=FMA_METADATA_PATH,
#     sample_rate=SAMPLE_RATE,
#     duration=DURATION,
#     subset='small'
# )
# print(f"Dataset size: {len(dataset)}")

## Visualize Sample

In [None]:
def visualize_fma_sample(dataset, idx=0):
    """Visualize a sample from FMA dataset."""
    waveform, label = dataset[idx]
    genre = dataset.idx_to_genre[label]
    
    # Create mel-spectrogram
    mel_transform = torchaudio.transforms.MelSpectrogram(
        sample_rate=dataset.sample_rate, n_mels=N_MELS
    )
    mel_spec = mel_transform(waveform)
    mel_spec_db = torchaudio.transforms.AmplitudeToDB()(mel_spec)
    
    fig, axes = plt.subplots(2, 1, figsize=(12, 8))
    
    # Plot waveform
    axes[0].plot(waveform[0].numpy())
    axes[0].set_title(f'Waveform - Genre: {genre}')
    axes[0].set_xlabel('Sample')
    axes[0].set_ylabel('Amplitude')
    
    # Plot mel-spectrogram
    im = axes[1].imshow(mel_spec_db[0].numpy(), aspect='auto', origin='lower')
    axes[1].set_title('Mel-Spectrogram (dB)')
    axes[1].set_xlabel('Time')
    axes[1].set_ylabel('Mel Frequency')
    plt.colorbar(im, ax=axes[1])
    
    plt.tight_layout()
    plt.show()

# Example usage:
# visualize_fma_sample(dataset, idx=0)

## Create DataLoader

In [None]:
def create_dataloaders(dataset, batch_size=16, train_split=0.8):
    """Create train and validation dataloaders."""
    train_size = int(train_split * len(dataset))
    val_size = len(dataset) - train_size
    
    train_dataset, val_dataset = torch.utils.data.random_split(
        dataset, [train_size, val_size]
    )
    
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, num_workers=2
    )
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False, num_workers=2
    )
    
    return train_loader, val_loader

# Example usage:
# train_loader, val_loader = create_dataloaders(dataset, batch_size=16)