<a href="https://colab.research.google.com/github/yourusername/mlsp-cocktail-party-problem/blob/main/notebooks/bilstm_colab_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# BiLSTM Source Separation Training on Google Colab

This notebook trains a BiLSTM-based audio source separation model on the LibriMix dataset using Google Colab's GPU.

## Setup Steps:
1. Clone the project repository
2. Install dependencies
3. Download LibriMix dataset (16kHz, 2-source)
4. Configure model and training hyperparameters
5. Train the model with GPU acceleration
6. Monitor training with logs and checkpoints

## 1. Setup: Clone Repository and Install Dependencies

In [None]:
# check GPU availability
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# clone the repository
import os
os.chdir('/content')

# NOTE: Replace with your actual GitHub repository URL
!git clone https://github.com/yourusername/mlsp-cocktail-party-problem.git
os.chdir('/content/mlsp-cocktail-party-problem')
print(f"Working directory: {os.getcwd()}")

In [None]:
# install dependencies from requirements.txt
!pip install -q -r requirements.txt
print("Dependencies installed successfully!")

## 2. Dataset Setup

### Option A: Download LibriMix (Recommended)
The download takes ~30-45 minutes. LibriMix is ~50GB (compressed), extracting to ~100GB.

### Option B: Use Existing LibriMix
If you've already downloaded LibriMix, mount your Google Drive and point to its location.

In [None]:
# Option A: Download LibriMix from HuggingFace Datasets
# This will download the 16kHz, 2-source dataset (~50GB compressed)

import os
from datasets import load_dataset
from pathlib import Path

data_dir = Path('/content/mlsp-cocktail-party-problem/data/Libri2Mix')
data_dir.mkdir(parents=True, exist_ok=True)

print("Downloading LibriMix dataset from HuggingFace...")
print("This will take 30-45 minutes depending on connection speed.")
print("You can monitor progress in the output below.\n")

# Load the LibriMix dataset (16kHz variant)
# Available splits: train-100, dev, test
try:
    # Download training set
    ds_train = load_dataset(
        'JorisCos/LibriMix',
        name='libri2mix_16k',
        split='train-100',
        cache_dir=str(data_dir)
    )
    print(f"\nTraining set downloaded: {len(ds_train)} samples")
    
    # Download validation set
    ds_val = load_dataset(
        'JorisCos/LibriMix',
        name='libri2mix_16k',
        split='dev',
        cache_dir=str(data_dir)
    )
    print(f"Validation set downloaded: {len(ds_val)} samples")
    
    # Download test set
    ds_test = load_dataset(
        'JorisCos/LibriMix',
        name='libri2mix_16k',
        split='test',
        cache_dir=str(data_dir)
    )
    print(f"Test set downloaded: {len(ds_test)} samples")
    
except Exception as e:
    print(f"Note: If download fails, you can mount Google Drive and use existing LibriMix:")
    print(f"Error: {e}")

In [None]:
# Option B: Mount Google Drive if dataset already exists there
from google.colab import drive

try:
    drive.mount('/content/gdrive')
    print("Google Drive mounted successfully!")
    print("\nIf you have LibriMix in Drive, symlink it:")
    print("/content/gdrive/My Drive/Datasets/Libri2Mix -> /content/mlsp-cocktail-party-problem/data/Libri2Mix")
except:
    print("Google Drive mount failed. Using downloaded dataset instead.")

## 3. Training Configuration

In [None]:
# verify dataset is available
import os
from pathlib import Path

data_root = Path('/content/mlsp-cocktail-party-problem/data/Libri2Mix')
config_data = Path('/content/mlsp-cocktail-party-problem/config/libri2mix_16k_2src.yaml')
config_model = Path('/content/mlsp-cocktail-party-problem/config/bilstm.yaml')

print(f"Data root exists: {data_root.exists()}")
print(f"Data config exists: {config_data.exists()}")
print(f"Model config exists: {config_model.exists()}")

if data_root.exists():
    print(f"\nDataset structure:")
    for item in sorted(data_root.iterdir()):
        print(f"  {item.name}")

In [None]:
# display current configurations
import yaml

print("=" * 50)
print("DATASET CONFIG (libri2mix_16k_2src.yaml)")
print("=" * 50)
with open(config_data) as f:
    dataset_config = yaml.safe_load(f)
    print(yaml.dump(dataset_config, default_flow_style=False))

print("\n" + "=" * 50)
print("MODEL CONFIG (bilstm.yaml)")
print("=" * 50)
with open(config_model) as f:
    model_config = yaml.safe_load(f)
    print(yaml.dump(model_config, default_flow_style=False))

### Optional: Modify Training Configuration

Adjust these settings if needed for your training:

In [None]:
# Optional: Create a custom config file for Colab-specific settings
# You can modify these values before training

import yaml

# Create a Colab-optimized config (feel free to adjust)
colab_config = {
    'dataset': {
        'sample_rate': 16000,
        'n_src': 2,
        'mode': 'min',
        'mixture_type': 'mix_clean'
    },
    'run': {
        'seed': 42,
        'device': 'cuda'  # Use GPU on Colab
    },
    'model': {
        'num_layers': 2,
        'hidden_size': 512,
        'dropout': 0.3,
        'n_fft': 1024,
        'hop_length': 256
    },
    'training': {
        'epochs': 50,  # Reduce if short on time (set to 10-20 for quick testing)
        'early_stopping_patience': 10,
        'learning_rate': 3e-3,
        'weight_decay': 1e-4,
        'gradient_clip_norm': 5.0,
        'scheduler': 'cosine',
        'scheduler_params': {
            'step': {'step_size': 10, 'gamma': 0.1},
            'reduce_on_plateau': {'factor': 0.5, 'patience': 5},
            'cosine': None
        }
    }
}

# MODIFY THESE VALUES FOR YOUR TRAINING:
EPOCHS = 50  # Set to 10-20 for quick test
LEARNING_RATE = 3e-3
HIDDEN_SIZE = 512  # Reduce to 256 if OOM
NUM_LAYERS = 2

colab_config['training']['epochs'] = EPOCHS
colab_config['training']['learning_rate'] = LEARNING_RATE
colab_config['model']['hidden_size'] = HIDDEN_SIZE
colab_config['model']['num_layers'] = NUM_LAYERS

# Save custom config
config_path = '/content/mlsp-cocktail-party-problem/config/bilstm_colab.yaml'
with open(config_path, 'w') as f:
    yaml.dump(colab_config, f, default_flow_style=False)

print(f"Colab config saved to: {config_path}")
print(f"\nTraining configuration:")
print(f"  Epochs: {EPOCHS}")
print(f"  Learning Rate: {LEARNING_RATE}")
print(f"  Hidden Size: {HIDDEN_SIZE}")
print(f"  Number of Layers: {NUM_LAYERS}")

## 4. Train the Model

In [None]:
import subprocess
import os
from pathlib import Path

os.chdir('/content/mlsp-cocktail-party-problem')

# Training command
# NOTE: Update these paths if you've created a custom config
train_cmd = [
    'python', '-m', 'src.train.bilstm_train',
    '--root-dir-data', 'data/Libri2Mix',
    '--config-data', 'config/libri2mix_16k_2src.yaml',
    '--config-model', 'config/bilstm_colab.yaml',  # Use custom Colab config
    '--save-dir', 'output/models/bilstm_colab',
    '--save-checkpoints',
    '--save-every', '5',
    '--log-level', 'INFO'
]

print(f"Starting training...")
print(f"Command: {' '.join(train_cmd)}\n")

# Run training
result = subprocess.run(train_cmd, cwd='/content/mlsp-cocktail-party-problem')
print(f"\nTraining completed with exit code: {result.returncode}")

## 5. Monitor Training

In [None]:
# List saved checkpoints
from pathlib import Path
import json

output_dir = Path('/content/mlsp-cocktail-party-problem/output/models/bilstm_colab')

if output_dir.exists():
    print(f"Output directory: {output_dir}\n")
    
    # List subdirectories (run folders)
    runs = sorted([d for d in output_dir.iterdir() if d.is_dir()])
    
    if runs:
        latest_run = runs[-1]
        print(f"Latest run: {latest_run.name}")
        print(f"\nCheckpoints:")
        
        checkpoints = sorted(latest_run.glob('*.pth'))
        for ckpt in checkpoints:
            size_mb = ckpt.stat().st_size / (1024 ** 2)
            print(f"  {ckpt.name} ({size_mb:.1f} MB)")
        
        # Display config
        config_file = latest_run / 'config.json'
        if config_file.exists():
            print(f"\nConfig file: {config_file.name}")
            with open(config_file) as f:
                config = json.load(f)
                print(f"Model parameters: {config.get('model', {})}")
                print(f"Training parameters: {config.get('training', {})}")
else:
    print(f"Output directory not found: {output_dir}")

In [None]:
# View training logs (last 50 lines)
from pathlib import Path

output_dir = Path('/content/mlsp-cocktail-party-problem/output/models/bilstm_colab')
runs = sorted([d for d in output_dir.iterdir() if d.is_dir()], key=lambda x: x.stat().st_mtime)

if runs:
    latest_run = runs[-1]
    log_file = latest_run / 'training.log'
    
    if log_file.exists():
        print(f"Training log (last 50 lines):\n")
        with open(log_file) as f:
            lines = f.readlines()
            for line in lines[-50:]:
                print(line, end='')
    else:
        print(f"No log file found at: {log_file}")
else:
    print("No training runs found yet.")

## 6. Download Model Checkpoint

In [None]:
# Download trained model from Colab to local machine
# This creates a zip file you can download

import shutil
from pathlib import Path

output_dir = Path('/content/mlsp-cocktail-party-problem/output/models/bilstm_colab')
runs = sorted([d for d in output_dir.iterdir() if d.is_dir()], key=lambda x: x.stat().st_mtime)

if runs:
    latest_run = runs[-1]
    
    # Create zip file
    zip_path = '/content/bilstm_model_checkpoint.zip'
    shutil.make_archive('/content/bilstm_model_checkpoint', 'zip', latest_run.parent, latest_run.name)
    
    # Get file size
    size_gb = Path(zip_path).stat().st_size / (1024 ** 3)
    print(f"Model checkpoint saved: {zip_path}")
    print(f"Size: {size_gb:.2f} GB")
    print(f"\nYou can download this file from Colab's file browser (left sidebar)")
else:
    print("No trained model found.")

## 7. Test Inference with Trained Model

In [None]:
# Test inference with trained model
import torch
import json
from pathlib import Path
from src.models.bilstm import BiLSTMSeparator

output_dir = Path('/content/mlsp-cocktail-party-problem/output/models/bilstm_colab')
runs = sorted([d for d in output_dir.iterdir() if d.is_dir()], key=lambda x: x.stat().st_mtime)

if runs:
    latest_run = runs[-1]
    checkpoint_path = latest_run / 'best_model.pth'
    config_path = latest_run / 'config.json'
    
    if checkpoint_path.exists() and config_path.exists():
        # Load config
        with open(config_path) as f:
            config = json.load(f)
        
        model_config = config['model']
        dataset_config = config['dataset']
        
        # Create model
        model = BiLSTMSeparator(
            num_sources=dataset_config['n_src'],
            num_layers=model_config['num_layers'],
            hidden_size=model_config['hidden_size'],
            dropout=model_config['dropout'],
            n_fft=model_config['n_fft'],
            hop_length=model_config['hop_length']
        )
        
        # Load checkpoint
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval()
        
        print(f"Model loaded successfully!")
        print(f"Checkpoint: {checkpoint_path.name}")
        print(f"\nTraining info:")
        print(f"  Epoch: {checkpoint['epoch']}")
        print(f"  Best SI-SNR: {checkpoint['best_si_snr']:.4f} dB")
        
        # Test inference
        with torch.no_grad():
            # Create dummy mixture (batch_size=1, 16000 samples = 1 second at 16kHz)
            mixture = torch.randn(1, 16000)
            separated = model(mixture)
            
            print(f"\nInference test successful!")
            print(f"  Input shape: {mixture.shape}")
            print(f"  Output shape: {separated.shape}")
            print(f"  Expected: [batch=1, sources=2, time=16000]")
    else:
        print(f"Checkpoint not found: {checkpoint_path}")
else:
    print("No training runs found.")

## 8. Resume Training from Checkpoint (Optional)

If your training was interrupted, you can resume from the best checkpoint.

In [None]:
import subprocess
import os
from pathlib import Path

os.chdir('/content/mlsp-cocktail-party-problem')

# Find latest checkpoint
output_dir = Path('output/models/bilstm_colab')
runs = sorted([d for d in output_dir.iterdir() if d.is_dir()], key=lambda x: x.stat().st_mtime)

if runs:
    latest_run = runs[-1]
    checkpoint_path = latest_run / 'best_model.pth'
    
    if checkpoint_path.exists():
        # Resume training
        train_cmd = [
            'python', '-m', 'src.train.bilstm_train',
            '--root-dir-data', 'data/Libri2Mix',
            '--config-data', 'config/libri2mix_16k_2src.yaml',
            '--config-model', 'config/bilstm_colab.yaml',
            '--resume', str(checkpoint_path),
            '--save-checkpoints',
            '--save-every', '5',
            '--log-level', 'INFO'
        ]
        
        print(f"Resuming training from: {checkpoint_path}")
        print(f"Command: {' '.join(train_cmd)}\n")
        
        result = subprocess.run(train_cmd)
        print(f"\nResumed training completed with exit code: {result.returncode}")
    else:
        print(f"Checkpoint not found: {checkpoint_path}")
else:
    print("No training runs found.")

## Notes

### GPU Memory Considerations
- Colab T4 GPU: ~16GB VRAM
- BiLSTM with hidden_size=512: Uses ~8-10GB with batch_size=16
- If you get OOM errors, reduce:
  - `hidden_size` from 512 to 256
  - `num_layers` from 2 to 1
  - Check dataloader batch sizes in config

### Training Time
- Each epoch: ~2-3 minutes (13,900 train samples)
- 50 epochs: ~2-2.5 hours
- Early stopping may trigger earlier

### Resuming Training
- Checkpoints are saved every 5 epochs + best model
- Use cell 8 to resume from the best checkpoint
- All training metadata is saved in `config.json`

### Saving Results
- Model checkpoints are in `output/models/bilstm_colab/<run_id>/`
- Download the zip file from Colab's file browser
- Or mount Google Drive to save directly to Drive

### Troubleshooting
- **LibriMix download fails**: Use existing dataset or mount Google Drive
- **OOM errors**: Reduce hidden_size or batch size in config
- **Slow training**: May indicate I/O bottleneck, consider caching to local SSD
- **Session timeout**: Save checkpoints frequently, resume from latest