# 🎯 Quick Start: Training BEATs Models

Learn how to fine-tune BEATs on your own audio data in just a few steps!

## What You'll Learn
- Organize your audio data
- Train a custom BEATs model
- Evaluate model performance

**Estimated time**: 15-20 minutes (+ training time)

## 1. Install and Import


In [None]:
# Install beats-trainer (uncomment if needed)
# !pip install git+https://github.com/ninanor/beats-trainer.git

from beats_trainer import BEATsTrainer
from pathlib import Path

print("✅ Ready to train!")

## 2. Prepare Your Data

### Option A: Directory Structure (Easiest)
Organize your audio files like this:
```
your_dataset/
├── class1/
│   ├── audio1.wav
│   ├── audio2.wav
│   └── audio3.wav
├── class2/
│   ├── audio4.wav
│   └── audio5.wav
└── class3/
    └── audio6.wav
```

### Option B: CSV File
Create a CSV with audio paths and labels:
```csv
filename,category
audio/sample1.wav,bird
audio/sample2.wav,dog
audio/sample3.wav,cat
```

In [None]:
# Replace with your actual dataset path
dataset_path = "path/to/your/audio/dataset"

# Check if dataset exists
if Path(dataset_path).exists():
    print(f"✅ Found dataset at: {dataset_path}")

    # Create trainer from directory structure
    trainer = BEATsTrainer.from_directory(dataset_path)
    print(f"🎯 Dataset loaded: {len(trainer.dataset)} samples")
    print(f"📊 Classes: {trainer.dataset_stats['num_classes']}")

else:
    print("📝 Please update 'dataset_path' with your actual dataset path")
    print("\n💡 For demo purposes, here's how you'd load different data types:")

    print("\n# Method 1: From directory")
    print('trainer = BEATsTrainer.from_directory("/path/to/audio/classes")')

    print("\n# Method 2: From CSV")
    print('trainer = BEATsTrainer.from_csv("labels.csv", data_dir="/path/to/audio")')

    print("\n# Method 3: From preset dataset (ESC-50, UrbanSound8K)")
    print('trainer = BEATsTrainer.from_preset("esc50", "/path/to/ESC-50-master")')

## 3. Configure Training (Optional)

The default settings work well for most cases, but you can customize:

In [None]:
from beats_trainer.config import Config, TrainingConfig

# Optional: Custom training configuration
config = Config()
config.training = TrainingConfig(
    learning_rate=5e-5,  # Lower learning rate for fine-tuning
    max_epochs=20,  # Number of training epochs
    batch_size=16,  # Adjust based on your GPU memory
    patience=5,  # Early stopping patience
)

print("⚙️  Training configuration:")
print(f"   Learning rate: {config.training.learning_rate}")
print(f"   Max epochs: {config.training.max_epochs}")
print(f"   Batch size: {config.training.batch_size}")

## 4. Train the Model

This will fine-tune BEATs on your data:

In [None]:
# Start training (this will take some time!)
if "trainer" in locals():
    print("🚀 Starting training...")
    print("⏰ This will take several minutes depending on your dataset size and GPU")

    # Train the model
    results = trainer.train()

    print("🎉 Training completed!")
    print(f"📈 Best validation accuracy: {results['best_score']:.3f}")
    print(f"💾 Model saved to: {results['best_checkpoint']}")

else:
    print("⚠️  Skipping training - no dataset loaded")
    print("   Update the dataset path in step 2 to run training")

## 5. Use Your Trained Model

Extract features with your custom-trained model:

In [None]:
if "trainer" in locals() and "results" in locals():
    # Get feature extractor with trained model
    custom_extractor = trainer.get_feature_extractor()

    print("✅ Custom model feature extractor ready!")
    print("🎯 You can now use this for feature extraction:")
    print("   features = custom_extractor.extract_from_file('new_audio.wav')")

    # Example prediction (if predict method is implemented)
    try:
        # This would predict on new audio files
        # predictions = trainer.predict(['new_audio1.wav', 'new_audio2.wav'])
        print("🔮 Model ready for predictions on new audio!")
    except NotImplementedError:
        print("🔮 Prediction feature coming soon!")

else:
    print("⚠️  No trained model available")
    print("   Complete the training step to get a custom model")

## 6. Compare Models

Let's compare your trained model with the pretrained one:

In [None]:
from beats_trainer import BEATsFeatureExtractor

# Create both extractors for comparison
pretrained_extractor = BEATsFeatureExtractor()  # Pretrained model

if "custom_extractor" in locals():
    print("🔍 Comparing pretrained vs custom-trained models:")
    print("   Pretrained model: General audio understanding")
    print(
        f"   Custom model: Specialized for your {trainer.dataset_stats['num_classes']} classes"
    )

    # You could extract features from the same audio with both models
    # and compare their similarity/clustering performance

else:
    print("📊 Only pretrained model available for now")
    print("   Complete training to get both models for comparison")

## 7. Next Steps

🎉 **Great job!** You've learned how to train custom BEATs models.

### What's Next?
- **Evaluate**: Test your model on held-out data
- **Deploy**: Use your model in production applications  
- **Iterate**: Experiment with different hyperparameters
- **Scale**: Try larger datasets for better performance

### Tips for Better Results
- **More data**: 50+ samples per class minimum
- **Balanced classes**: Similar number of samples per class
- **Clean audio**: Remove background noise when possible
- **Augmentation**: Built-in data augmentation helps with small datasets

### Need Help?
- 📖 [Documentation](https://github.com/ninanor/beats-trainer)
- 🐛 [Report Issues](https://github.com/ninanor/beats-trainer/issues)
- 💬 Ask questions in the repository discussions