# MagnaTagATune (MTAT) Dataset Loading and Preprocessing

This notebook demonstrates how to load and preprocess the MagnaTagATune dataset for music tagging.

**MTAT Dataset**: Contains ~25,000 audio clips (29 seconds each) with 188 different tags.

In [None]:
import os
import numpy as np
import pandas as pd
import librosa
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm

## Dataset Configuration

In [None]:
# Configure paths
MTAT_AUDIO_PATH = '../data/mtat/audio'  # Update this path
MTAT_ANNOTATIONS_PATH = '../data/mtat/annotations_final.csv'  # Update this path
SAMPLE_RATE = 22050
DURATION = 29  # seconds
N_MELS = 128

## MTAT Dataset Class

In [None]:
class MTATDataset(Dataset):
    """PyTorch Dataset for MagnaTagATune music tagging."""
    
    def __init__(self, audio_dir, annotations_file, sample_rate=22050, 
                 duration=29, transform=None, top_tags=50):
        """
        Args:
            audio_dir (string): Directory with audio files
            annotations_file (string): Path to annotations CSV file
            sample_rate (int): Target sample rate
            duration (int): Duration in seconds
            transform (callable, optional): Optional transform
            top_tags (int): Number of most frequent tags to use
        """
        self.audio_dir = audio_dir
        self.sample_rate = sample_rate
        self.duration = duration
        self.transform = transform
        
        # Load annotations
        if os.path.exists(annotations_file):
            self.annotations = pd.read_csv(annotations_file, sep='\t')
            
            # Select top tags by frequency
            tag_counts = self.annotations.iloc[:, 1:].sum(axis=0).sort_values(ascending=False)
            self.top_tags = tag_counts.head(top_tags).index.tolist()
            self.n_tags = len(self.top_tags)
            
            print(f"Using top {top_tags} tags: {self.top_tags[:10]}...")
        else:
            print(f"Warning: Annotations file not found at {annotations_file}")
            self.annotations = None
    
    def __len__(self):
        if self.annotations is not None:
            return len(self.annotations)
        return 0
    
    def __getitem__(self, idx):
        if self.annotations is None:
            raise ValueError("Annotations not loaded")
        
        # Get file path and labels
        row = self.annotations.iloc[idx]
        audio_file = row[0]  # First column is filename
        audio_path = os.path.join(self.audio_dir, audio_file)
        
        # Get multi-hot encoded labels for top tags
        labels = torch.zeros(self.n_tags, dtype=torch.float32)
        for i, tag in enumerate(self.top_tags):
            if tag in row.index and row[tag] == 1:
                labels[i] = 1.0
        
        # Load audio
        waveform, sr = torchaudio.load(audio_path)
        
        # Resample if necessary
        if sr != self.sample_rate:
            resampler = torchaudio.transforms.Resample(sr, self.sample_rate)
            waveform = resampler(waveform)
        
        # Convert to mono
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)
        
        # Pad or truncate
        target_length = self.sample_rate * self.duration
        if waveform.shape[1] > target_length:
            waveform = waveform[:, :target_length]
        elif waveform.shape[1] < target_length:
            padding = target_length - waveform.shape[1]
            waveform = torch.nn.functional.pad(waveform, (0, padding))
        
        if self.transform:
            waveform = self.transform(waveform)
        
        return waveform, labels

## Example Usage

In [None]:
# Create dataset instance
# Note: Uncomment and update paths to use
# dataset = MTATDataset(
#     audio_dir=MTAT_AUDIO_PATH,
#     annotations_file=MTAT_ANNOTATIONS_PATH,
#     sample_rate=SAMPLE_RATE,
#     duration=DURATION,
#     top_tags=50
# )
# print(f"Dataset size: {len(dataset)}")
# print(f"Number of tags: {dataset.n_tags}")

## Visualize Sample with Tags

In [None]:
def visualize_mtat_sample(dataset, idx=0):
    """Visualize a sample with its tags."""
    waveform, labels = dataset[idx]
    
    # Get active tags
    active_tags = [dataset.top_tags[i] for i, val in enumerate(labels) if val == 1]
    
    # Create mel-spectrogram
    mel_transform = torchaudio.transforms.MelSpectrogram(
        sample_rate=dataset.sample_rate, n_mels=N_MELS
    )
    mel_spec = mel_transform(waveform)
    mel_spec_db = torchaudio.transforms.AmplitudeToDB()(mel_spec)
    
    fig, axes = plt.subplots(2, 1, figsize=(12, 8))
    
    # Plot waveform
    axes[0].plot(waveform[0].numpy())
    axes[0].set_title(f'Waveform - Tags: {", ".join(active_tags)}')
    axes[0].set_xlabel('Sample')
    axes[0].set_ylabel('Amplitude')
    
    # Plot mel-spectrogram
    im = axes[1].imshow(mel_spec_db[0].numpy(), aspect='auto', origin='lower')
    axes[1].set_title('Mel-Spectrogram (dB)')
    axes[1].set_xlabel('Time')
    axes[1].set_ylabel('Mel Frequency')
    plt.colorbar(im, ax=axes[1])
    
    plt.tight_layout()
    plt.show()
    
    print(f"Active tags ({len(active_tags)}): {active_tags}")

# Example usage:
# visualize_mtat_sample(dataset, idx=0)

## Create DataLoader

In [None]:
def create_dataloaders(dataset, batch_size=16, train_split=0.8):
    """Create train and validation dataloaders."""
    train_size = int(train_split * len(dataset))
    val_size = len(dataset) - train_size
    
    train_dataset, val_dataset = torch.utils.data.random_split(
        dataset, [train_size, val_size]
    )
    
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, num_workers=2
    )
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False, num_workers=2
    )
    
    return train_loader, val_loader

# Example usage:
# train_loader, val_loader = create_dataloaders(dataset, batch_size=16)