In [23]:
import os
import sys
import math
import random
import logging
from datetime import datetime
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib
# matplotlib.use('Agg')  # Non-interactive backend for HPC
import matplotlib.pyplot as plt
import librosa
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from torch.utils.data import Dataset
import pandas as pd
import numpy as np
import os
from glob import glob
import gc

In [24]:
def setup_logging(log_dir='logs', experiment_name='smr_seld'):
    """Setup comprehensive logging for HPC environment"""
    os.makedirs(log_dir, exist_ok=True)
    
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    log_file = os.path.join(log_dir, f'{experiment_name}_{timestamp}.log')
    
    # Create logger
    logger = logging.getLogger('SMR_SELD')
    logger.setLevel(logging.INFO)
    
    # Clear existing handlers to prevent duplicates when re-running cells
    if logger.handlers:
        logger.handlers.clear()
    
    # File handler
    fh = logging.FileHandler(log_file)
    fh.setLevel(logging.INFO)
    
    # Console handler
    ch = logging.StreamHandler(sys.stdout)
    ch.setLevel(logging.INFO)
    
    # Formatter
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S'
    )
    fh.setFormatter(formatter)
    ch.setFormatter(formatter)
    
    logger.addHandler(fh)
    logger.addHandler(ch)
    
    return logger, log_file

logger, log_file = setup_logging()

In [25]:
def get_device():
    """Get available device with CUDA support"""
    if torch.cuda.is_available():
        device = torch.device('cuda')
        logger.info(f"CUDA is available! Using GPU: {torch.cuda.get_device_name(0)}")
        logger.info(f"CUDA Version: {torch.version.cuda}")
        logger.info(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
    else:
        device = torch.device('cpu')
        logger.warning("CUDA not available. Using CPU. Training will be slower.")
    
    return device

DEVICE = get_device()

2025-11-17 23:10:38 - SMR_SELD - INFO - CUDA is available! Using GPU: NVIDIA GeForce RTX 3050 Laptop GPU
2025-11-17 23:10:38 - SMR_SELD - INFO - CUDA Version: 11.8
2025-11-17 23:10:38 - SMR_SELD - INFO - GPU Memory: 4.00 GB
2025-11-17 23:10:38 - SMR_SELD - INFO - CUDA Version: 11.8
2025-11-17 23:10:38 - SMR_SELD - INFO - GPU Memory: 4.00 GB


In [26]:
class Config:
    """Configuration class for SMR-SELD training"""
    
    # Paths (relative to script location)
    BASE_PATH = Path.cwd()
    AUDIO_PATH = BASE_PATH / "foa_dev"
    METADATA_PATH = BASE_PATH / "metadata_dev"
    OUTPUT_PATH = BASE_PATH / "outputs"
    CHECKPOINT_PATH = BASE_PATH / "checkpoints"
    
    # Dataset - Use full dataset or single file for testing
    USE_FULL_DATASET = True  # Set to False for quick testing with single file
    TRAIN_AUDIO_FILE = "fold3_room21_mix001.wav"  # Used only if USE_FULL_DATASET=False
    TRAIN_META_FILE = "fold3_room21_mix001.csv"
    TEST_AUDIO_FILE = "fold4_room23_mix001.wav"
    TEST_META_FILE = "fold4_room23_mix001.csv"
    
    # STARSS22 Classes
    STARSS22_CLASSES = {
        0: 'Female speech, woman speaking',
        1: 'Male speech, man speaking',
        2: 'Clapping',
        3: 'Telephone',
        4: 'Laughter',
        5: 'Domestic sounds',
        6: 'Walk, footsteps',
        7: 'Door, open or close',
        8: 'Music',
        9: 'Musical instrument',
        10: 'Water tap, faucet',
        11: 'Bell',
        12: 'Knock',
        13: 'Background'
    }
    
    # Model
    NUM_CLASSES = 14  # 13 classes + 1 background
    N_CHANNELS = 4  # FOA channels    
    
    # Training hyperparameters
    NUM_EPOCHS = 25
    BATCH_SIZE = 8
    LEARNING_RATE = 1e-3
    LR_DECAY_FACTOR = 0.5  # Multiply LR by this factor
    LR_DECAY_PATIENCE = 5   # Decay LR after this many epochs without improvement
    WEIGHT_DECAY = 1e-4
    
    # Loss weights
    W_CLASS = 1.0
    W_AIUR = 0.5
    W_CL = 0.5
    
    # Early stopping
    PATIENCE = 10
    MIN_DELTA = 1e-4
    
    # Checkpointing
    SAVE_EVERY_N_EPOCHS = 5
    KEEP_LAST_N_CHECKPOINTS = 3
    
    # Signal Processing (to convert into spectro)
    SPECTROGRAM_N_FFT = int(0.04*24000)
    SPECTROGRAM_HOP_LENGTH = int(0.02*24000) 
    N_MELS = 64
    SR = 24000
    
    # Dataset
    # 5s window (saved as no of frames) to feed into the model
    #1s hop length (converted to no of frames) to feed into the model
    WINDOW_LENGTH = int(5*24000)
    HOP_LENGTH = int(1*24000)
    
    # 3D to 2D Mapping
    I = None
    J = None
    GRID_CELL_DEGREES = 10
    
    
    def __init__(self):
        # Create directories
        self.OUTPUT_PATH.mkdir(exist_ok=True, parents=True)
        self.CHECKPOINT_PATH.mkdir(exist_ok=True, parents=True)
        
        # Build full paths (used only for single file mode)
        self.TRAIN_AUDIO_PATH = self.AUDIO_PATH / "dev-train-sony" / self.TRAIN_AUDIO_FILE
        self.TRAIN_META_PATH = self.METADATA_PATH / "dev-train-sony" / self.TRAIN_META_FILE
        self.TEST_AUDIO_PATH = self.AUDIO_PATH / "dev-test-sony" / self.TEST_AUDIO_FILE
        self.TEST_META_PATH = self.METADATA_PATH / "dev-test-sony" / self.TEST_META_FILE
        
        # Dataset directories
        self.SONY_TRAIN_DIR = self.AUDIO_PATH / "dev-train-sony"
        self.SONY_TEST_DIR = self.AUDIO_PATH / "dev-test-sony"
        self.SONY_TRAIN_META_DIR = self.METADATA_PATH / "dev-train-sony"
        self.SONY_TEST_META_DIR = self.METADATA_PATH / "dev-test-sony"
        self.TAU_TRAIN_DIR = self.AUDIO_PATH / "dev-train-tau"
        self.TAU_TEST_DIR = self.AUDIO_PATH / "dev-test-tau"
        self.TAU_TRAIN_META_DIR = self.METADATA_PATH / "dev-train-tau"
        self.TAU_TEST_META_DIR = self.METADATA_PATH / "dev-test-tau"


config = Config()

In [27]:
def load_files():
    """Load audio and metadata files based on configuration"""
    if config.USE_FULL_DATASET:
        # Load all audio files
        sony_train_audio = sorted(glob(str(config.SONY_TRAIN_DIR / "*.wav")))
        tau_train_audio = sorted(glob(str(config.TAU_TRAIN_DIR / "*.wav")))
        sony_test_audio = sorted(glob(str(config.SONY_TEST_DIR / "*.wav")))
        tau_test_audio = sorted(glob(str(config.TAU_TEST_DIR / "*.wav")))
        
        # Match metadata files to audio files by basename
        def get_matching_metadata(audio_files, meta_dir):
            """Get metadata files matching audio files by basename"""
            meta_files = []
            for audio_file in audio_files:
                # Get basename without extension (e.g., fold3_room21_mix001)
                basename = Path(audio_file).stem
                # Build corresponding metadata path
                meta_file = meta_dir / f"{basename}.csv"
                if meta_file.exists():
                    meta_files.append(str(meta_file))
                else:
                    raise FileNotFoundError(f"Metadata file not found: {meta_file}")
            return meta_files
        
        # Get matching metadata files
        sony_train_meta = get_matching_metadata(sony_train_audio, config.SONY_TRAIN_META_DIR)
        tau_train_meta = get_matching_metadata(tau_train_audio, config.TAU_TRAIN_META_DIR)
        sony_test_meta = get_matching_metadata(sony_test_audio, config.SONY_TEST_META_DIR)
        tau_test_meta = get_matching_metadata(tau_test_audio, config.TAU_TEST_META_DIR)
        
        # Combine training and testing files
        train_audio_files = sony_train_audio + tau_train_audio
        train_meta_files = sony_train_meta + tau_train_meta
        test_audio_files = sony_test_audio + tau_test_audio
        test_meta_files = sony_test_meta + tau_test_meta
    else:
        # Load single training file
        train_audio_files = [str(config.TRAIN_AUDIO_PATH)]
        train_meta_files = [str(config.TRAIN_META_PATH)]
        
        # Load single testing file
        test_audio_files = [str(config.TEST_AUDIO_PATH)]
        test_meta_files = [str(config.TEST_META_PATH)]
    
    return train_audio_files, train_meta_files, test_audio_files, test_meta_files


    

In [28]:
def polar_to_grid(phi, theta, I=None, J=None, cell_size_deg=None):
    """
    Convert polar coordinates (azimuth phi, elevation theta) to grid indices (i, j).

    Parameters
    ----------
    phi : float
        Azimuth in degrees (range [-180, 180]).
    theta : float
        Elevation in degrees (range [-90, 90]).
    I : int, optional
        Number of elevation bins. If None, computed from cell_size_deg.
    J : int, optional
        Number of azimuth bins. If None, computed from cell_size_deg.
    cell_size_deg : float, optional
        Size of each grid cell in degrees. Required if I or J is None.

    Returns
    -------
    i, j : tuple of int
        Grid row (elevation index) and column (azimuth index).
    """
    # If grid dimensions not provided, compute from cell size
    if (I is None or J is None) and cell_size_deg is not None:
        I = int(180 // cell_size_deg)
        J = int(360 // cell_size_deg)
    elif I is None or J is None:
        raise ValueError("Either provide (I, J) or cell_size_deg for polar_to_grid")

    # Normalize azimuth and elevation to [0,1]
    phi_norm = (phi + 180.0) / 360.0
    theta_norm = (theta + 90.0) / 180.0
    j = int(np.clip(phi_norm * J, 0, J - 1))
    i = int(np.clip(theta_norm * I, 0, I - 1))
    return i, j


In [29]:
def load_audio(audio_path):
    """
    Load multi-channel audio file using torchaudio.
    
    Args:
        audio_path: Path to the audio file
        
    Returns:
        waveform: Tensor of shape (channels, samples) - preserves all 4 FOA channels
        sample_rate: Sample rate of the audio
    """
    waveform, sample_rate = torchaudio.load(audio_path)
    
    # Verify we have 4 channels (FOA format)
    if waveform.shape[0] != 4:
        logger.warning(f"Expected 4 channels but got {waveform.shape[0]} channels in {audio_path}")
    
    return waveform, sample_rate

In [30]:
def audio_to_mel_spectrogram(waveform, sample_rate, n_fft=None, hop_length=None, n_mels=None):
    """
    Convert multi-channel audio waveform to mel spectrogram.
    Processes each channel separately and stacks them.
    
    Args:
        waveform: Tensor of shape (channels, samples) - typically (4, num_samples) for FOA
        sample_rate: Sample rate of the audio
        n_fft: FFT window size (default: from config)
        hop_length: Hop length for STFT (default: from config)
        n_mels: Number of mel filterbanks (default: from config)
        
    Returns:
        mel_spec: Tensor of shape (channels, n_mels, time_frames)
    """
    # Use config defaults if not provided
    if n_fft is None:
        n_fft = config.SPECTROGRAM_N_FFT
    if hop_length is None:
        hop_length = config.SPECTROGRAM_HOP_LENGTH
    if n_mels is None:
        n_mels = config.N_MELS
    
    # Create mel spectrogram transform
    mel_transform = torchaudio.transforms.MelSpectrogram(
        sample_rate=sample_rate,
        n_fft=n_fft,
        hop_length=hop_length,
        n_mels=n_mels
    )
    
    # Process each channel separately
    mel_specs = []
    for channel_idx in range(waveform.shape[0]):
        channel_waveform = waveform[channel_idx:channel_idx+1, :]  # Keep dimension (1, samples)
        mel_spec = mel_transform(channel_waveform)  # Shape: (1, n_mels, time_frames)
        mel_specs.append(mel_spec)
    
    # Stack all channels: (channels, n_mels, time_frames)
    mel_spec_multichannel = torch.cat(mel_specs, dim=0)
    
    # Convert to log scale (dB)
    mel_spec_db = torchaudio.transforms.AmplitudeToDB()(mel_spec_multichannel)
    
    return mel_spec_db

In [31]:
def visualize_mel_spectrogram(mel_spec, title="Multi-Channel Mel Spectrogram", figsize=(15, 10)):
    """
    Visualize multi-channel mel spectrogram.
    
    Args:
        mel_spec: Tensor of shape (channels, n_mels, time_frames)
        title: Title for the figure
        figsize: Figure size (width, height)
    """
    # Convert to numpy if it's a torch tensor
    if isinstance(mel_spec, torch.Tensor):
        mel_spec_np = mel_spec.cpu().numpy()
    else:
        mel_spec_np = mel_spec
    
    n_channels = mel_spec_np.shape[0]
    
    # Create subplots for each channel
    fig, axes = plt.subplots(2, 2, figsize=figsize)
    fig.suptitle(title, fontsize=16)
    
    # Flatten axes for easier iteration
    axes_flat = axes.flatten()
    
    # Plot each channel
    for i in range(n_channels):
        ax = axes_flat[i]
        im = ax.imshow(
            mel_spec_np[i], 
            aspect='auto', 
            origin='lower',
            cmap='viridis',
            interpolation='nearest'
        )
        ax.set_title(f'Channel {i+1}')
        ax.set_xlabel('Time Frames')
        ax.set_ylabel('Mel Frequency Bins')
        
        # Add colorbar
        plt.colorbar(im, ax=ax, format='%+2.0f dB')
    
    plt.tight_layout()
    plt.show()
    
    return fig

In [32]:
def metadata_to_labels(metadata_path, audio_duration, sample_rate=24000, I=None, J=None, 
                        cell_size_deg=None, num_classes=14):
    """
    Convert metadata file to target labels for 2D grid representation.
    
    Args:
        metadata_path: Path to the CSV metadata file
        audio_duration: Duration of audio in seconds
        sample_rate: Sample rate of audio (default: 24000 Hz)
        I: Number of elevation bins (height of grid)
        J: Number of azimuth bins (width of grid)
        cell_size_deg: Cell size in degrees (alternative to I, J)
        num_classes: Total number of classes including background (default: 14)
        
    Returns:
        labels: Tensor of shape [T', I*J, M] where:
                T' = number of 20ms frames
                I*J = total grid cells
                M = number of classes (14)
    """
    # Use config default if not provided
    if cell_size_deg is None:
        cell_size_deg = config.GRID_CELL_DEGREES
    
    # Step 1: Calculate total number of frames (20ms per frame)
    frame_duration_ms = 20  # 20ms per frame in final representation
    metadata_frame_duration_ms = 100  # 100ms per frame in metadata
    frames_per_metadata_frame = metadata_frame_duration_ms // frame_duration_ms  # = 5
    
    # Total number of frames for the audio
    total_frames = int((audio_duration * 1000) / frame_duration_ms)
    
    # Calculate grid dimensions if not provided
    if (I is None or J is None) and cell_size_deg is not None:
        I = int(180 // cell_size_deg)
        J = int(360 // cell_size_deg)
    elif I is None or J is None:
        raise ValueError("Either provide (I, J) or cell_size_deg for grid dimensions")
    
    total_cells = I * J
    
    # Initialize labels tensor: [T', I*J, M] with all zeros
    labels = torch.zeros((total_frames, total_cells, num_classes), dtype=torch.float32)
    
    # Read metadata CSV
    df = pd.read_csv(metadata_path, header=None)
    
    # Track which cells have active events for each frame
    # This will help us set background (class 13) for empty cells
    active_cells_per_frame = [set() for _ in range(total_frames)]
    
    # Process each row in metadata
    for _, row in df.iterrows():
        # Parse metadata row: [frame, class, source, azimuth, elevation]
        metadata_frame = int(row.iloc[0])  # Frame number from metadata
        active_class = int(row.iloc[1])     # Active class index
        source_num = int(row.iloc[2])       # Source number (not used in labeling)
        azimuth = int(row.iloc[3])          # Azimuth in degrees
        elevation = int(row.iloc[4])        # Elevation in degrees
        
        # Step 2: Map metadata frame to final representation frames
        # Metadata frame t corresponds to frames t*5 to t*5+4 in final representation
        start_frame = metadata_frame * frames_per_metadata_frame
        end_frame = start_frame + frames_per_metadata_frame
        
        # Ensure we don't exceed total frames
        end_frame = min(end_frame, total_frames)
        
        # Step 3: Convert polar coordinates to grid cell
        i, j = polar_to_grid(azimuth, elevation, I=I, J=J)
        cell_idx = i * J + j  # Flatten 2D grid to 1D index
        
        # Step 4: Set active class for this cell across the time frames
        for t in range(start_frame, end_frame):
            # Set the active class to 1 (one-hot encoding)
            labels[t, cell_idx, active_class] = 1.0
            # Track that this cell has an active event
            active_cells_per_frame[t].add(cell_idx)
    
    # Step 5: Set background class (index 13) for cells with no active events
    for t in range(total_frames):
        for cell_idx in range(total_cells):
            # If this cell has no active events in this frame
            if cell_idx not in active_cells_per_frame[t]:
                # Set background class (index 13) to 1
                labels[t, cell_idx, num_classes - 1] = 1.0
    
    # Reshape to [T' * (I*J), M] if needed for certain loss functions
    # For now, keep as [T', I*J, M] for clarity
    
    return labels, I, J

In [33]:
train_audio_files, train_meta_files, test_audio_files, test_meta_files = load_files()
print(len(train_audio_files), len(train_meta_files), len(test_audio_files), len(test_meta_files))

67 67 54 54


In [12]:
class SELDDataset(Dataset):
    """
    Dataset class for Sound Event Localization and Detection (SELD).
    
    Loads all audio files, concatenates spectrograms and labels, then segments
    into fixed-length windows with overlap for training.
    """
    
    def __init__(self, audio_files, metadata_files, num_classes=14):
        """
        Initialize SELD Dataset with windowing.
        
        This dataset:
        1. Loads all audio files and computes spectrograms + labels
        2. Concatenates all spectrograms and labels into single tensors
        3. Segments concatenated data into windows (5s window, 1s hop)
        4. Pads final window if needed
        
        Args:
            audio_files: List of audio file paths
            metadata_files: List of corresponding metadata CSV file paths
            num_classes: Total number of classes including background (default: 14)
        """
        assert len(audio_files) == len(metadata_files), \
            "Number of audio files must match number of metadata files"
        
        self.audio_files = audio_files
        self.metadata_files = metadata_files
        self.sample_rate = config.SR
        self.n_fft = config.SPECTROGRAM_N_FFT
        self.spectrogram_hop_length = config.SPECTROGRAM_HOP_LENGTH
        self.n_mels = config.N_MELS
        self.cell_size_deg = config.GRID_CELL_DEGREES
        self.num_classes = num_classes
        
        # Calculate grid dimensions
        self.I = int(180 // self.cell_size_deg)
        self.J = int(360 // self.cell_size_deg)
        self.total_cells = self.I * self.J
        
        # Window parameters (in samples)
        self.window_length_samples = config.WINDOW_LENGTH  # 5s in samples
        self.hop_length_samples = config.HOP_LENGTH  # 1s in samples
        
        # Convert to spectrogram frames
        # Each spectrogram frame represents spectrogram_hop_length samples
        self.window_length_frames = int(self.window_length_samples / self.spectrogram_hop_length)
        self.hop_length_frames = int(self.hop_length_samples / self.spectrogram_hop_length)
        
        logger.info(f"SELDDataset initialization started...")
        logger.info(f"  Files: {len(audio_files)} audio files")
        logger.info(f"  Grid: {self.I}x{self.J} = {self.total_cells} cells")
        logger.info(f"  Window: {self.window_length_frames} frames ({self.window_length_samples / self.sample_rate:.1f}s)")
        logger.info(f"  Hop: {self.hop_length_frames} frames ({self.hop_length_samples / self.sample_rate:.1f}s)")
        
        # Step 1 & 2: Load all files and concatenate
        self._load_and_concatenate_all()
        
        # Step 3 & 4: Segment into windows
        self._create_windows()
        
        logger.info(f"SELDDataset initialized with {len(self.windows)} windows")
    
    def _load_and_concatenate_all(self):
        """Load all files, compute spectrograms and labels, then concatenate."""
        all_spectrograms = []
        all_labels = []
        
        logger.info("Loading and processing all audio files...")
        for idx, (audio_path, metadata_path) in enumerate(tqdm(
            zip(self.audio_files, self.metadata_files),
            total=len(self.audio_files),
            desc="Processing files"
        )):
            try:
                # Load audio
                waveform, sr = load_audio(audio_path)
                
                # Compute mel spectrogram
                mel_spec = audio_to_mel_spectrogram(
                    waveform, 
                    sr,
                    n_fft=self.n_fft,
                    hop_length=self.spectrogram_hop_length,
                    n_mels=self.n_mels
                )  # Shape: (4, n_mels, time_frames)
                
                # Calculate audio duration
                audio_duration = waveform.shape[1] / sr
                
                # Generate labels from metadata
                labels, _, _ = metadata_to_labels(
                    metadata_path,
                    audio_duration,
                    sample_rate=sr,
                    I=self.I,
                    J=self.J,
                    cell_size_deg=self.cell_size_deg,
                    num_classes=self.num_classes
                )  # Shape: (time_frames, I*J, num_classes)
                
                # Ensure matching time dimensions
                mel_time_frames = mel_spec.shape[2]
                label_time_frames = labels.shape[0]
                
                if mel_time_frames != label_time_frames:
                    min_frames = min(mel_time_frames, label_time_frames)
                    mel_spec = mel_spec[:, :, :min_frames]
                    labels = labels[:min_frames, :, :]
                
                # Append to lists
                all_spectrograms.append(mel_spec)
                all_labels.append(labels)
                
            except Exception as e:
                logger.error(f"Error processing file {idx} ({audio_path}): {str(e)}")
                raise
        
        # Concatenate along time dimension
        self.concatenated_spectrograms = torch.cat(all_spectrograms, dim=2)  # (4, n_mels, T)
        self.concatenated_labels = torch.cat(all_labels, dim=0)  # (T, I*J, num_classes)
        
        self.total_frames = self.concatenated_spectrograms.shape[2]
        logger.info(f"Concatenated data: {self.total_frames} total frames")
        logger.info(f"  Spectrograms shape: {self.concatenated_spectrograms.shape}")
        logger.info(f"  Labels shape: {self.concatenated_labels.shape}")
    
    def _create_windows(self):
        """Segment concatenated data into windows with overlap."""
        self.windows = []
        
        start_frame = 0
        window_idx = 0
        
        while start_frame < self.total_frames:
            end_frame = start_frame + self.window_length_frames
            
            # Extract window
            if end_frame <= self.total_frames:
                # Normal window - no padding needed
                window_spec = self.concatenated_spectrograms[:, :, start_frame:end_frame]
                window_labels = self.concatenated_labels[start_frame:end_frame, :, :]
            else:
                # Last window - needs padding
                actual_frames = self.total_frames - start_frame
                
                # Extract what we have
                window_spec = self.concatenated_spectrograms[:, :, start_frame:]
                window_labels = self.concatenated_labels[start_frame:, :, :]
                
                # Pad to window_length_frames
                pad_frames = self.window_length_frames - actual_frames
                
                # Pad spectrograms: (4, n_mels, time) -> pad time dimension
                spec_pad = torch.zeros((4, self.n_mels, pad_frames), dtype=window_spec.dtype)
                window_spec = torch.cat([window_spec, spec_pad], dim=2)
                
                # Pad labels: (time, I*J, num_classes) -> pad time dimension
                # Set background class (index 13) for padded frames
                label_pad = torch.zeros((pad_frames, self.total_cells, self.num_classes), dtype=window_labels.dtype)
                label_pad[:, :, self.num_classes - 1] = 1.0  # Set background class
                window_labels = torch.cat([window_labels, label_pad], dim=0)
            
            # Transpose spectrogram from [C, F, T] to [T, C, F]
            window_spec = window_spec.permute(2, 0, 1)  # (T, C, F)
            
            # Store window
            self.windows.append({
                'spectrogram': window_spec,
                'labels': window_labels,
                'window_idx': window_idx,
                'start_frame': start_frame,
                'end_frame': min(end_frame, self.total_frames)
            })
            
            # Move to next window
            start_frame += self.hop_length_frames
            window_idx += 1
        
        logger.info(f"Created {len(self.windows)} windows")
    
    def __len__(self):
        """Return the number of windows in the dataset."""
        return len(self.windows)
    
    def __getitem__(self, idx):
        """
        Get a single window from the dataset.
        
        Args:
            idx: Window index
            
        Returns:
            spectrogram: Mel spectrogram tensor of shape (window_length_frames, 4, n_mels) - [T, C, F]
            labels: Target labels tensor of shape (window_length_frames, I*J, num_classes) - [T, I*J, M]
        """
        window = self.windows[idx]
        return window['spectrogram'], window['labels']

## SELDDataset with Windowing

The `SELDDataset` class implements a windowing approach for SELD training:

### Workflow:
1. **Load all files**: Each audio file is loaded and processed to extract:
   - Mel spectrogram: shape `(4, 64, T_i)` for file i
   - Labels: shape `(T_i, 648, 14)` for file i

2. **Concatenate**: All spectrograms and labels are concatenated along the time dimension:
   - Concatenated spectrogram: `(4, 64, T)` where T = sum of all T_i
   - Concatenated labels: `(T, 648, 14)`

3. **Segment into windows**: The concatenated data is segmented into fixed-length windows:
   - Window length: 5 seconds = 250 frames (at 20ms per frame)
   - Hop length: 1 second = 50 frames
   - Overlap: 4 seconds (200 frames) between consecutive windows

4. **Padding**: The final window is padded if it has fewer than 250 frames:
   - Spectrograms: padded with zeros
   - Labels: padded with background class (index 13 = 1.0)

5. **Transpose**: Spectrograms are transposed from `[C, F, T]` to `[T, C, F]` format

### Output:
Each window provides:
- Spectrogram: `(250, 4, 64)` - **[T, C, F]** format: 250 time frames, 4 channels, 64 mel bins
- Labels: `(250, 648, 14)` - **[T, I×J, M]** format: 250 time frames, 648 spatial cells, 14 classes

This ensures all training samples have consistent shapes and can be batched efficiently.

This ensures all training samples have consistent shapes and can be batched efficiently.

### Model


In [13]:
class Conv(nn.Module):
    """Standard convolution with BN and SiLU activation"""
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.act = nn.SiLU(inplace=True)
    
    def forward(self, x):
        return self.act(self.bn(self.conv(x)))

class Bottleneck(nn.Module):
    """Standard bottleneck block with residual connection"""
    def __init__(self, in_channels, out_channels, shortcut=True):
        super().__init__()
        self.cv1 = Conv(in_channels, out_channels, 1, 1, 0)
        self.cv2 = Conv(out_channels, out_channels, 3, 1, 1)
        self.add = shortcut and in_channels == out_channels
    
    def forward(self, x):
        return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))

class C3(nn.Module):
    """CSP Bottleneck with 3 convolutions"""
    def __init__(self, in_channels, out_channels, n_blocks=1, shortcut=True):
        super().__init__()
        hidden_channels = out_channels // 2
        self.cv1 = Conv(in_channels, hidden_channels, 1, 1, 0)
        self.cv2 = Conv(in_channels, hidden_channels, 1, 1, 0)
        self.cv3 = Conv(2 * hidden_channels, out_channels, 1, 1, 0)
        self.m = nn.Sequential(
            *[Bottleneck(hidden_channels, hidden_channels, shortcut) for _ in range(n_blocks)]
        )
    
    def forward(self, x):
        return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))

class SPPF(nn.Module):
    """Spatial Pyramid Pooling - Fast"""
    def __init__(self, in_channels, out_channels, kernel_size=5):
        super().__init__()
        hidden_channels = in_channels // 2
        self.cv1 = Conv(in_channels, hidden_channels, 1, 1, 0)
        self.cv2 = Conv(hidden_channels * 4, out_channels, 1, 1, 0)
        self.m = nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=kernel_size // 2)
    
    def forward(self, x):
        x = self.cv1(x)
        y1 = self.m(x)
        y2 = self.m(y1)
        y3 = self.m(y2)
        return self.cv2(torch.cat([x, y1, y2, y3], dim=1))

class CSPDarkNet53(nn.Module):
    """CSPDarkNet53 backbone for audio SELD"""
    def __init__(self, in_channels=4, base_channels=64, depth_multiple=1.0, width_multiple=1.0):
        super().__init__()
        
        def get_channels(c):
            return max(round(c * width_multiple), 1)
        
        def get_depth(n):
            return max(round(n * depth_multiple), 1)
        
        # 3x3 stem for audio
        self.stem = Conv(in_channels, get_channels(base_channels), 3, 1, 1)
        
        self.stage1 = nn.Sequential(
            Conv(get_channels(64), get_channels(128), 3, 2, 1),
            C3(get_channels(128), get_channels(128), n_blocks=get_depth(3))
        )
        
        self.stage2 = nn.Sequential(
            Conv(get_channels(128), get_channels(256), 3, 2, 1),
            C3(get_channels(256), get_channels(256), n_blocks=get_depth(6))
        )
        
        self.stage3 = nn.Sequential(
            Conv(get_channels(256), get_channels(512), 3, 2, 1),
            C3(get_channels(512), get_channels(512), n_blocks=get_depth(9))
        )
        
        self.stage4 = nn.Sequential(
            Conv(get_channels(512), get_channels(1024), 3, 2, 1),
            C3(get_channels(1024), get_channels(1024), n_blocks=get_depth(3)),
            SPPF(get_channels(1024), get_channels(1024))
        )
        
        self.out_channels = [
            get_channels(128),   # P2
            get_channels(256),   # P3
            get_channels(512),   # P4
            get_channels(1024)   # P5
        ]
    
    def forward(self, x):
        x = self.stem(x)
        p2 = self.stage1(x)
        p3 = self.stage2(p2)
        p4 = self.stage3(p3)
        p5 = self.stage4(p4)
        return [p2, p3, p4, p5]

class SMRSELDWithCSPDarkNet(nn.Module):
    """
    SMR‑SELD model with a CSPDarkNet53 backbone.

    Input shape: [batch_size, 250, 4, 64] - [B, T, C, F]
    Output shape: [batch_size, 250, I*J, num_classes] - [B, T, grid_cells, M]

    Parameters
    ----------
    n_channels : int
        Number of input channels (4 for FOA).
    grid_size : tuple of int
        (I, J) specifying number of elevation and azimuth bins (e.g., (18, 36) for 10° resolution).
    num_classes : int
        Number of event classes including background (14).
    use_small : bool
        If True, use a reduced backbone (depth and width multipliers).
    """
    def __init__(self, n_channels=4, grid_size=(18, 36), num_classes=14, use_small=True):
        super().__init__()
        self.I, self.J = grid_size
        self.grid_cells = self.I * self.J
        self.num_classes = num_classes

        # CSPDarkNet53 backbone
        if use_small:
            self.backbone = CSPDarkNet53(in_channels=n_channels, depth_multiple=0.33, width_multiple=0.5)
        else:
            self.backbone = CSPDarkNet53(in_channels=n_channels)

        # Multi‑scale fusion
        self.fusion = nn.ModuleList([
            nn.Conv2d(self.backbone.out_channels[1], 256, 1),  # P3
            nn.Conv2d(self.backbone.out_channels[2], 256, 1),  # P4
            nn.Conv2d(self.backbone.out_channels[3], 256, 1),  # P5
        ])

        self.conv_fuse = nn.Sequential(
            nn.Conv2d(256 * 3, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.SiLU(),
            nn.Conv2d(512, 256, 1),
            nn.BatchNorm2d(256),
            nn.SiLU()
        )

        # Pooling layer that resamples feature maps to exactly `grid_cells` rows.
        # We treat the frequency dimension as the vertical axis and collapse the
        # singleton width dimension.  AdaptiveAvgPool2d with output size
        # `(grid_cells, 1)` produces a feature vector for each grid cell.
        self.grid_pool = nn.AdaptiveAvgPool2d((self.grid_cells, 1))

        # Classifier head applied to each grid cell independently.  Input
        # dimension matches the number of channels after fusion (256).
        self.classifier = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        """
        Forward pass of SMR-SELD model.

        Args:
            x: Input tensor of shape [B, T, C, F] = [batch_size, 250, 4, 64]

        Returns:
            Output tensor of shape [B, T, grid_cells, num_classes] = [batch_size, 250, 648, 14]
        """
        B, T, C, F_dim = x.shape  # [batch_size, 250, 4, 64]

        # Reshape to [B*T, C, F, 1] for 2D CNN processing
        # Treat each time frame independently, with frequency as spatial dimension
        x = x.reshape(B * T, C, F_dim, 1)  # [B*T, 4, 64, 1]

        # Pass through backbone - extracts multi-scale features
        features = self.backbone(x)
        p3, p4, p5 = features[1], features[2], features[3]

        # Apply 1x1 convolutions for channel reduction
        p3 = self.fusion[0](p3)  # [B*T, 256, H3, W3]
        p4 = self.fusion[1](p4)  # [B*T, 256, H4, W4]
        p5 = self.fusion[2](p5)  # [B*T, 256, H5, W5]

        # Upsample all to same spatial size (align to p3)
        target_size = p3.shape[2:]
        p4 = F.interpolate(p4, size=target_size, mode='bilinear', align_corners=False)
        p5 = F.interpolate(p5, size=target_size, mode='bilinear', align_corners=False)

        # Concatenate multi-scale features
        fused = torch.cat([p3, p4, p5], dim=1)  # [B*T, 768, H, W]
        x = self.conv_fuse(fused)  # [B*T, 256, H, W]

        # Resample feature maps to the number of grid cells.  Treat the
        # frequency axis as the vertical dimension and the singleton width
        # axis as horizontal.  This yields a tensor of shape
        # [B*T, 256, grid_cells, 1].
        x = self.grid_pool(x)  # [B*T, 256, grid_cells, 1]

        # Remove the trailing width dimension and transpose to
        # [B*T, grid_cells, 256] so each cell has its own feature vector.
        x = x.squeeze(-1).permute(0, 2, 1)  # [B*T, G, 256]

        # Apply the classifier independently to each cell.  The linear layer
        # will broadcast over the grid dimension.  Output shape:
        # [B*T, G, num_classes].
        x = self.classifier(x)  # [B*T, G, M]

        # Reshape back to [B, T, G, M]
        x = x.view(B, T, self.grid_cells, self.num_classes)

        # Apply softmax along the class dimension
        return F.softmax(x, dim=-1)

In [14]:
class SMRSELDLoss(nn.Module):
    """
    Complete SMR‑SELD loss function with three components:
    class MSE, AIUR and converging localization.  The grid dimensions (I, J)
    are needed for the localization loss; they are derived from the
    configuration at initialization.
    """

    def __init__(self, w_class=1.0, w_aiur=0.5, w_cl=0.5, grid_size=None):
        super().__init__()
        self.w_class = w_class
        self.w_aiur = w_aiur
        self.w_cl = w_cl
        # grid_size should be a tuple (I, J).  If None, we assume square grid.
        if grid_size is not None:
            self.I, self.J = grid_size
        else:
            self.I = self.J = None
    
    def class_mse_loss(self, y_pred, y_true):
        """Class-wise Mean Squared Error loss with class weighting
        
        Apply higher weight to event classes (first 13) vs background (last).
        This handles the severe class imbalance (~99.8% background).
        
        Args:
            y_pred: Predicted output (B, T, G, M) where M = num_classes
            y_true: Ground truth labels (B, T, G, M)
        
        Returns:
            Weighted MSE loss (scalar tensor)
        """
        # # Get dimensions
        # B, T, G, M = y_pred.shape
        
        # # Background class is the last index (index 13 for 14 classes)
        # background_idx = M - 1
        
        # # Create masks for event and background cells
        # # A cell has an event if any non-background class has value 1
        # # We check the true labels to determine which cells have events
        # y_true_class = torch.argmax(y_true, dim=-1)  # (B, T, G) - get class index
        
        # # Mask for cells with events (non-background)
        # event_mask = (y_true_class != background_idx)  # (B, T, G)
        
        # # Mask for background cells
        # background_mask = (y_true_class == background_idx)  # (B, T, G)
        
        # # Compute squared error for all cells and all classes
        # squared_error = (y_pred - y_true) ** 2  # (B, T, G, M)
        
        # # Sum across class dimension to get per-cell error
        # squared_error_per_cell = squared_error.sum(dim=-1)  # (B, T, G)
        
        # # Separate losses for event cells and background cells
        # if event_mask.sum() > 0:
        #     event_loss = squared_error_per_cell[event_mask].mean()
        # else:
        #     event_loss = torch.tensor(0.0, device=y_pred.device)
        
        # if background_mask.sum() > 0:
        #     background_loss = squared_error_per_cell[background_mask].mean()
        # else:
        #     background_loss = torch.tensor(0.0, device=y_pred.device)
        
        # # Apply weighted loss: higher weight for event classes to handle imbalance
        # # Using 10:1 ratio as suggested in the original implementation
        # weighted_loss = 10.0 * event_loss + 1.0 * background_loss
        
        # return weighted_loss
        mse_loss = F.mse_loss(y_pred, y_true, reduction='mean')
        return mse_loss
        

    
    def aiur_loss(self, y_pred, y_true):
        """Area Intersection Union Ratio (AIUR) loss computed per frame and batch.

        We compute the IoU for each frame in each sequence and average the results.
        This encourages the model to improve IoU locally rather than aggregating
        everything into a single global statistic, which can remain near 1 when
        the model predicts only background.
        
        Args:
            y_pred: Predicted output (B, T, G, M) where M = num_classes
            y_true: Ground truth labels (B, T, G, M)
        
        Returns:
            AIUR loss (scalar tensor): 1 - average_IoU across all frames and batches
        """
        # Get dimensions
        B, T, G, M = y_pred.shape
        
        # Background class is the last index (index 13 for 14 classes)
        background_idx = M - 1
        
        # Get predicted and true class indices
        y_pred_class = torch.argmax(y_pred, dim=-1)  # (B, T, G)
        y_true_class = torch.argmax(y_true, dim=-1)  # (B, T, G)
        
        # Create binary masks: 1 for event cells, 0 for background
        pred_event_mask = (y_pred_class != background_idx).float()  # (B, T, G)
        true_event_mask = (y_true_class != background_idx).float()  # (B, T, G)
        
        # Compute IoU for each frame in each batch
        # IoU = intersection / union
        # intersection = cells that have events in BOTH pred and true
        # union = cells that have events in EITHER pred or true
        
        # Intersection: element-wise multiplication
        intersection = (pred_event_mask * true_event_mask).sum(dim=-1)  # (B, T) - sum over grid cells
        
        # Union: sum of both masks minus intersection
        # Union = |A| + |B| - |A ∩ B|
        pred_count = pred_event_mask.sum(dim=-1)  # (B, T)
        true_count = true_event_mask.sum(dim=-1)  # (B, T)
        union = pred_count + true_count - intersection  # (B, T)
        
        # Compute IoU for each frame
        # Add small epsilon to avoid division by zero
        epsilon = 1e-8
        iou = intersection / (union + epsilon)  # (B, T)
        
        # Handle edge case where both pred and true have no events (union = 0)
        # In this case, IoU should be 1.0 (perfect match of empty sets)
        iou = torch.where(union > 0, iou, torch.ones_like(iou))
        
        # Average IoU across all frames and batches
        avg_iou = iou.mean()
        
        # AIUR loss = 1 - IoU
        aiur_loss_value = 1.0 - avg_iou
        
        return aiur_loss_value

    
    def converging_localization_loss(self, y_pred, y_true):
        """Converging Localization loss
        
        This loss encourages predictions to converge towards dense non-background areas.
        It guides the prediction results from surrounding areas to the target location,
        reducing extreme category imbalance and strengthening regression positioning ability.
        
        From the paper:
        1. Transform targets: y'_ij = 1 for background cells, -N_bac/N_non_bac for event cells
        2. Calculate attention: yat_ij = y'_ij + AVG(surroundings - y'_ij)
        3. Loss: LCL = Σ(ŷ_ij × yat_ij) / total_cells
        
        Args:
            y_pred: Predicted output (B, T, G, M) where M = num_classes
            y_true: Ground truth labels (B, T, G, M)
        
        Returns:
            Converging localization loss (scalar tensor)
        """
        # Get dimensions
        B, T, G, M = y_pred.shape
        # Determine grid dimensions
        if self.I is not None and self.J is not None:
            I, J = self.I, self.J
        else:
            I = J = int(math.sqrt(G))
        # Reshape to (B, T, I, J)
        true_nonbg = y_true[..., :-1].sum(dim=-1).view(B, T, I, J)
        pred_nonbg = y_pred[..., :-1].sum(dim=-1).view(B, T, I, J)
        # Count per frame
        N_bac = (true_nonbg < 0.01).sum(dim=(2, 3), keepdim=True).float()
        N_non_bac = (true_nonbg > 0.01).sum(dim=(2, 3), keepdim=True).float()
        # Initialise y_prime: 1 for background, negative ratio for events
        y_prime = torch.ones_like(true_nonbg)
        ratio = -(N_bac / (N_non_bac + 1e-10))
        y_prime = torch.where(true_nonbg > 0.01, ratio.expand_as(true_nonbg), y_prime)
        # Compute neighbourhood average using circular padding
        y_prime_padded = F.pad(y_prime, (1, 1, 1, 1), mode='circular')  # pad (left,right,top,bottom)
        diff_sum = torch.zeros_like(y_prime)
        for di in (-1, 0, 1):
            for dj in (-1, 0, 1):
                if di == 0 and dj == 0:
                    continue
                neighbor = y_prime_padded[:, :, 1+di:I+1+di, 1+dj:J+1+dj]
                diff_sum += (neighbor - y_prime)
        avg_diff = diff_sum / 8.0
        y_at = y_prime + avg_diff
        # Mask frames without events
        has_events_mask = (N_non_bac > 0).float()
        loss = ((pred_nonbg * y_at) * has_events_mask).sum() / (B * T + 1e-10)
        return loss


    def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor):
        """
        Compute the weighted sum of the three loss components and return
        both the total loss and a breakdown of individual components.  This
        method is called when the loss module is invoked like a function,
        e.g., `loss, breakdown = criterion(pred, target)`.

        Parameters
        ----------
        y_pred : torch.Tensor
            Predicted output of shape (B, T, G, M), where B is batch size,
            T is the number of frames in a sequence, G is the number of
            grid cells and M is the number of classes.
        y_true : torch.Tensor
            Ground‑truth targets with the same shape as ``y_pred``.

        Returns
        -------
        total_loss : torch.Tensor
            Scalar tensor containing the weighted sum of the three loss terms.
        breakdown : dict
            Dictionary with keys ``'class_mse'``, ``'aiur'`` and ``'cl'`` mapping
            to the individual component losses (Python floats).
        """
        # Compute individual losses
        loss_class = self.class_mse_loss(y_pred, y_true)
        loss_aiur = self.aiur_loss(y_pred, y_true)
        loss_cl = self.converging_localization_loss(y_pred, y_true)
        # Weighted sum
        total_loss = (self.w_class * loss_class +
                      self.w_aiur * loss_aiur +
                      self.w_cl * loss_cl)
        # Prepare breakdown for logging
        breakdown = {
            'class_mse': float(loss_class.item()),
            'aiur': float(loss_aiur.item()),
            'cl': float(loss_cl.item())
        }
        return total_loss, breakdown

In [16]:
# Create train and test datasets
print("Creating train dataset...")
train_dataset = SELDDataset(
    audio_files=train_audio_files[:1],
    metadata_files=train_meta_files[:1],
    num_classes=14
)

print("\nCreating test dataset...")
test_dataset = SELDDataset(
    audio_files=test_audio_files[:1],
    metadata_files=test_meta_files[:1],
    num_classes=14
)

print(f"\n{'='*60}")
print(f"Dataset Summary:")
print(f"{'='*60}")
print(f"Train dataset: {len(train_dataset)} windows")
print(f"Test dataset: {len(test_dataset)} windows")
print(f"Grid dimensions: {train_dataset.I}x{train_dataset.J} = {train_dataset.total_cells} cells")
print(f"Window length: {train_dataset.window_length_frames} frames ({train_dataset.window_length_samples / train_dataset.sample_rate:.1f}s)")
print(f"Hop length: {train_dataset.hop_length_frames} frames ({train_dataset.hop_length_samples / train_dataset.sample_rate:.1f}s)") 

Creating train dataset...
2025-11-16 23:43:43 - SMR_SELD - INFO - SELDDataset initialization started...
2025-11-16 23:43:43 - SMR_SELD - INFO -   Files: 1 audio files
2025-11-16 23:43:43 - SMR_SELD - INFO -   Grid: 18x36 = 648 cells
2025-11-16 23:43:43 - SMR_SELD - INFO -   Window: 250 frames (5.0s)
2025-11-16 23:43:43 - SMR_SELD - INFO -   Hop: 50 frames (1.0s)
2025-11-16 23:43:43 - SMR_SELD - INFO - Loading and processing all audio files...
2025-11-16 23:43:43 - SMR_SELD - INFO -   Files: 1 audio files
2025-11-16 23:43:43 - SMR_SELD - INFO -   Grid: 18x36 = 648 cells
2025-11-16 23:43:43 - SMR_SELD - INFO -   Window: 250 frames (5.0s)
2025-11-16 23:43:43 - SMR_SELD - INFO -   Hop: 50 frames (1.0s)
2025-11-16 23:43:43 - SMR_SELD - INFO - Loading and processing all audio files...


Processing files: 100%|██████████| 1/1 [00:18<00:00, 18.21s/it]

2025-11-16 23:44:01 - SMR_SELD - INFO - Concatenated data: 4470 total frames
2025-11-16 23:44:01 - SMR_SELD - INFO -   Spectrograms shape: torch.Size([4, 64, 4470])
2025-11-16 23:44:01 - SMR_SELD - INFO -   Labels shape: torch.Size([4470, 648, 14])
2025-11-16 23:44:01 - SMR_SELD - INFO -   Spectrograms shape: torch.Size([4, 64, 4470])
2025-11-16 23:44:01 - SMR_SELD - INFO -   Labels shape: torch.Size([4470, 648, 14])
2025-11-16 23:44:01 - SMR_SELD - INFO - Created 90 windows
2025-11-16 23:44:01 - SMR_SELD - INFO - SELDDataset initialized with 90 windows

Creating test dataset...
2025-11-16 23:44:01 - SMR_SELD - INFO - SELDDataset initialization started...
2025-11-16 23:44:01 - SMR_SELD - INFO -   Files: 1 audio files
2025-11-16 23:44:01 - SMR_SELD - INFO -   Grid: 18x36 = 648 cells
2025-11-16 23:44:01 - SMR_SELD - INFO -   Window: 250 frames (5.0s)
2025-11-16 23:44:01 - SMR_SELD - INFO -   Hop: 50 frames (1.0s)
2025-11-16 23:44:01 - SMR_SELD - INFO - Loading and processing all audio fi


Processing files: 100%|██████████| 1/1 [00:12<00:00, 12.29s/it]

2025-11-16 23:44:14 - SMR_SELD - INFO - Concatenated data: 3035 total frames
2025-11-16 23:44:14 - SMR_SELD - INFO -   Spectrograms shape: torch.Size([4, 64, 3035])
2025-11-16 23:44:14 - SMR_SELD - INFO -   Labels shape: torch.Size([3035, 648, 14])
2025-11-16 23:44:14 - SMR_SELD - INFO -   Spectrograms shape: torch.Size([4, 64, 3035])
2025-11-16 23:44:14 - SMR_SELD - INFO -   Labels shape: torch.Size([3035, 648, 14])
2025-11-16 23:44:14 - SMR_SELD - INFO - Created 61 windows
2025-11-16 23:44:14 - SMR_SELD - INFO - SELDDataset initialized with 61 windows

Dataset Summary:
Train dataset: 90 windows
Test dataset: 61 windows
Grid dimensions: 18x36 = 648 cells
Window length: 250 frames (5.0s)
Hop length: 50 frames (1.0s)
2025-11-16 23:44:14 - SMR_SELD - INFO - Created 61 windows
2025-11-16 23:44:14 - SMR_SELD - INFO - SELDDataset initialized with 61 windows

Dataset Summary:
Train dataset: 90 windows
Test dataset: 61 windows
Grid dimensions: 18x36 = 648 cells
Window length: 250 frames (5.0s




### Model Testing - Single Forward Pass

In [17]:
# Test model with a single window
print("="*60)
print("MODEL FORWARD PASS TEST")
print("="*60)

# Initialize the model
model = SMRSELDWithCSPDarkNet(
    n_channels=4,
    grid_size=(18, 36),  # 18x36 = 648 cells
    num_classes=14,
    use_small=False
)

# Move model to device
model = model.to(DEVICE)
model.eval()  # Set to evaluation mode

print(f"\nModel initialized and moved to {DEVICE}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# Get a single window from test dataset
print("\n" + "-"*60)
print("Getting a single window from dataset...")
print("-"*60)

spec, labels = test_dataset[0]
print(f"Spectrogram shape: {spec.shape}")  # Expected: (250, 4, 64)
print(f"Labels shape: {labels.shape}")      # Expected: (250, 648, 14)

# Add batch dimension and move to device
spec_batch = spec.unsqueeze(0).to(DEVICE)      # Shape: (1, 250, 4, 64)
labels_batch = labels.unsqueeze(0).to(DEVICE)  # Shape: (1, 250, 648, 14)

print(f"\nBatch spectrogram shape: {spec_batch.shape}")  # (1, 250, 4, 64)
print(f"Batch labels shape: {labels_batch.shape}")      # (1, 250, 648, 14)

# Forward pass
print("\n" + "-"*60)
print("Running forward pass...")
print("-"*60)

with torch.no_grad():
    predictions = model(spec_batch)

print(f"Predictions shape: {predictions.shape}")  # Expected: (1, 250, 648, 14)

# Verify dimensions match
print("\n" + "-"*60)
print("DIMENSION VERIFICATION")
print("-"*60)
print(f"✓ Input shape:       {spec_batch.shape} -> [batch_size, T, C, F]")
print(f"✓ Output shape:      {predictions.shape} -> [batch_size, T, grid_cells, num_classes]")
print(f"✓ Ground truth shape: {labels_batch.shape} -> [batch_size, T, grid_cells, num_classes]")
print(f"✓ Shapes match: {predictions.shape == labels_batch.shape}")

# Compare predictions with ground truth
print("\n" + "-"*60)
print("PREDICTIONS vs GROUND TRUTH COMPARISON")
print("-"*60)

# Get predictions for first time frame, first 5 cells
time_frame = 0
num_cells_to_show = 5

print(f"\nTime frame {time_frame}, showing first {num_cells_to_show} cells:")
print(f"{'Cell':<6} {'Pred Class':<12} {'True Class':<12} {'Pred Prob':<12} {'Match':<6}")
print("-" * 60)

for cell_idx in range(num_cells_to_show):
    pred_class = torch.argmax(predictions[0, time_frame, cell_idx]).item()
    true_class = torch.argmax(labels_batch[0, time_frame, cell_idx]).item()
    pred_prob = predictions[0, time_frame, cell_idx, pred_class].item()
    match = "✓" if pred_class == true_class else "✗"
    
    print(f"{cell_idx:<6} {pred_class:<12} {true_class:<12} {pred_prob:<12.4f} {match:<6}")

# Summary statistics
print("\n" + "-"*60)
print("SUMMARY STATISTICS")
print("-"*60)

# Get predicted and true classes for all frames and cells
pred_classes = torch.argmax(predictions, dim=-1)  # Shape: (1, 250, 648)
true_classes = torch.argmax(labels_batch, dim=-1)  # Shape: (1, 250, 648)

# Calculate accuracy
accuracy = (pred_classes == true_classes).float().mean().item()
print(f"Overall accuracy: {accuracy*100:.2f}%")

# Count predictions per class
pred_class_counts = torch.bincount(pred_classes.flatten(), minlength=14)
true_class_counts = torch.bincount(true_classes.flatten(), minlength=14)

print(f"\nClass distribution:")
print(f"{'Class':<20} {'Predicted':<12} {'Ground Truth':<12}")
print("-" * 50)
for class_idx in range(14):
    class_name = config.STARSS22_CLASSES.get(class_idx, f"Class {class_idx}")
    # Truncate long names
    class_name = class_name[:18] + ".." if len(class_name) > 20 else class_name
    print(f"{class_name:<20} {pred_class_counts[class_idx].item():<12} {true_class_counts[class_idx].item():<12}")

print("\n" + "="*60)
print("MODEL TEST COMPLETED SUCCESSFULLY!")
print("="*60)
print("\nNote: Low accuracy is expected for untrained model (random initialization).")

MODEL FORWARD PASS TEST

Model initialized and moved to cuda
Model parameters: 30,744,846

------------------------------------------------------------
Getting a single window from dataset...
------------------------------------------------------------
Spectrogram shape: torch.Size([250, 4, 64])
Labels shape: torch.Size([250, 648, 14])

Batch spectrogram shape: torch.Size([1, 250, 4, 64])
Batch labels shape: torch.Size([1, 250, 648, 14])

------------------------------------------------------------
Running forward pass...
------------------------------------------------------------

Model initialized and moved to cuda
Model parameters: 30,744,846

------------------------------------------------------------
Getting a single window from dataset...
------------------------------------------------------------
Spectrogram shape: torch.Size([250, 4, 64])
Labels shape: torch.Size([250, 648, 14])

Batch spectrogram shape: torch.Size([1, 250, 4, 64])
Batch labels shape: torch.Size([1, 250, 648

In [34]:
def plot_loss_curves(train_losses, test_losses, save_path=None):
    """
    Plot and save training and test loss curves.
    
    Args:
        train_losses: List of training losses per epoch
        test_losses: List of test losses per epoch
        save_path: Path to save the plot (optional)
    """
    epochs = range(1, len(train_losses) + 1)
    
    plt.figure(figsize=(12, 6))
    
    plt.plot(epochs, train_losses, 'b-o', label='Training Loss', linewidth=2, markersize=4)
    plt.plot(epochs, test_losses, 'r-s', label='Test Loss', linewidth=2, markersize=4)
    
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('Loss', fontsize=12)
    plt.title('Training and Test Loss Curves', fontsize=14, fontweight='bold')
    plt.legend(fontsize=11)
    plt.grid(True, alpha=0.3)
    
    # Add minimum loss markers
    min_train_idx = train_losses.index(min(train_losses))
    min_test_idx = test_losses.index(min(test_losses))
    
    plt.plot(min_train_idx + 1, train_losses[min_train_idx], 'b*', 
             markersize=15, label=f'Best Train: {train_losses[min_train_idx]:.4f}')
    plt.plot(min_test_idx + 1, test_losses[min_test_idx], 'r*', 
             markersize=15, label=f'Best Test: {test_losses[min_test_idx]:.4f}')
    
    plt.legend(fontsize=10)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        logger.info(f"Loss curve saved to {save_path}")
    
    plt.show()
    
    return plt.gcf()

In [None]:
def train_model(
    train_audio_files,
    train_meta_files,
    test_audio_files,
    test_meta_files,
    num_epochs=None,
    batch_size=None,
    learning_rate=None,
    device=None,
    use_small_model=True
):
    """
    Complete training function for SMR-SELD model.
    
    Args:
        train_audio_files: List of training audio file paths
        train_meta_files: List of training metadata file paths
        test_audio_files: List of test audio file paths
        test_meta_files: List of test metadata file paths
        num_epochs: Number of training epochs (default: from config)
        batch_size: Batch size (default: from config)
        learning_rate: Initial learning rate (default: from config)
        device: Device to train on (default: DEVICE global)
        use_small_model: Whether to use small backbone (default: True)
    
    Returns:
        model: Trained model
        history: Dictionary with training history
    """
    # Use config defaults if not provided
    num_epochs = config.NUM_EPOCHS
    batch_size = config.BATCH_SIZE
    learning_rate = config.LEARNING_RATE
    device = DEVICE
    
    logger.info("="*80)
    logger.info("STARTING TRAINING")
    logger.info("="*80)
    
    # ========================================================================
    # STEP 1: Load datasets
    # ========================================================================
    logger.info("\nStep 1: Loading datasets...")
    
    train_dataset = SELDDataset(
        audio_files=train_audio_files,
        metadata_files=train_meta_files,
        num_classes=config.NUM_CLASSES
    )
    
    test_dataset = SELDDataset(
        audio_files=test_audio_files,
        metadata_files=test_meta_files,
        num_classes=config.NUM_CLASSES
    )
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,  # Set to 0 for Windows compatibility
        pin_memory=True if torch.cuda.is_available() else False
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True if torch.cuda.is_available() else False
    )
    
    logger.info(f"Train dataset: {len(train_dataset)} windows ({len(train_loader)} batches)")
    logger.info(f"Test dataset: {len(test_dataset)} windows ({len(test_loader)} batches)")
    logger.info(f"Grid dimensions: {train_dataset.I}x{train_dataset.J} = {train_dataset.total_cells} cells")
    
    # ========================================================================
    # STEP 2: Initialize model, loss, optimizer
    # ========================================================================
    logger.info("\nStep 2: Initializing model, loss, and optimizer...")
    
    model = SMRSELDWithCSPDarkNet(
        n_channels=config.N_CHANNELS,
        grid_size=(train_dataset.I, train_dataset.J),
        num_classes=config.NUM_CLASSES,
        use_small=use_small_model
    ).to(device)
    
    criterion = SMRSELDLoss(
        w_class=config.W_CLASS,
        w_aiur=config.W_AIUR,
        w_cl=config.W_CL,
        grid_size=(train_dataset.I, train_dataset.J)
    )
    
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=learning_rate,
        weight_decay=config.WEIGHT_DECAY
    )
    
    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=config.LR_DECAY_FACTOR,
        patience=config.LR_DECAY_PATIENCE,
        verbose=True
    )
    
    logger.info(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    logger.info(f"Optimizer: Adam (lr={learning_rate}, weight_decay={config.WEIGHT_DECAY})")
    logger.info(f"Scheduler: ReduceLROnPlateau (factor={config.LR_DECAY_FACTOR}, patience={config.LR_DECAY_PATIENCE})")
    
    # ========================================================================
    # STEP 3: Training setup
    # ========================================================================
    train_losses = []
    test_losses = []
    best_test_loss = float('inf')
    best_epoch = 0
    epochs_without_improvement = 0
    checkpoint_files = []
    
    logger.info(f"\nTraining configuration:")
    logger.info(f"  Epochs: {num_epochs}")
    logger.info(f"  Batch size: {batch_size}")
    logger.info(f"  Learning rate: {learning_rate}")
    logger.info(f"  Early stopping patience: {config.PATIENCE}")
    logger.info(f"  Min delta: {config.MIN_DELTA}")
    logger.info(f"  Save every N epochs: {config.SAVE_EVERY_N_EPOCHS}")
    logger.info(f"  Device: {device}")
    
    # ========================================================================
    # STEP 4: Training loop
    # ========================================================================
    logger.info("\n" + "="*80)
    logger.info("STARTING TRAINING LOOP")
    logger.info("="*80 + "\n")
    
    for epoch in range(1, num_epochs + 1):
        epoch_start_time = datetime.now()
        
        # ====================================================================
        # Training phase
        # ====================================================================
        model.train()
        train_loss_accum = 0.0
        train_class_mse_accum = 0.0
        train_aiur_accum = 0.0
        train_cl_accum = 0.0
        
        train_progress = tqdm(
            train_loader,
            desc=f"Epoch {epoch}/{num_epochs} [Train]",
            leave=False
        )
        
        for batch_idx, (spectrograms, labels) in enumerate(train_progress):
            # Move to device
            spectrograms = spectrograms.to(device)
            labels = labels.to(device)
            
            # Forward pass
            optimizer.zero_grad()
            predictions = model(spectrograms)
            
            # Compute loss
            loss, breakdown = criterion(predictions, labels)
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            # Accumulate losses
            train_loss_accum += loss.item()
            train_class_mse_accum += breakdown['class_mse']
            train_aiur_accum += breakdown['aiur']
            train_cl_accum += breakdown['cl']
            
            # Update progress bar
            train_progress.set_postfix({
                'loss': f"{loss.item():.4f}",
                'lr': f"{optimizer.param_groups[0]['lr']:.6f}"
            })
        
        # Average training losses
        avg_train_loss = train_loss_accum / len(train_loader)
        avg_train_class_mse = train_class_mse_accum / len(train_loader)
        avg_train_aiur = train_aiur_accum / len(train_loader)
        avg_train_cl = train_cl_accum / len(train_loader)
        
        # ====================================================================
        # Validation phase
        # ====================================================================
        model.eval()
        test_loss_accum = 0.0
        test_class_mse_accum = 0.0
        test_aiur_accum = 0.0
        test_cl_accum = 0.0
        
        test_progress = tqdm(
            test_loader,
            desc=f"Epoch {epoch}/{num_epochs} [Test]",
            leave=False
        )
        
        with torch.no_grad():
            for spectrograms, labels in test_progress:
                # Move to device
                spectrograms = spectrograms.to(device)
                labels = labels.to(device)
                
                # Forward pass
                predictions = model(spectrograms)
                
                # Compute loss
                loss, breakdown = criterion(predictions, labels)
                
                # Accumulate losses
                test_loss_accum += loss.item()
                test_class_mse_accum += breakdown['class_mse']
                test_aiur_accum += breakdown['aiur']
                test_cl_accum += breakdown['cl']
                
                # Update progress bar
                test_progress.set_postfix({'loss': f"{loss.item():.4f}"})
        
        # Average test losses
        avg_test_loss = test_loss_accum / len(test_loader)
        avg_test_class_mse = test_class_mse_accum / len(test_loader)
        avg_test_aiur = test_aiur_accum / len(test_loader)
        avg_test_cl = test_cl_accum / len(test_loader)
        
        # Store losses
        train_losses.append(avg_train_loss)
        test_losses.append(avg_test_loss)
        
        # Update learning rate scheduler
        scheduler.step(avg_test_loss)
        
        # Calculate epoch duration
        epoch_duration = (datetime.now() - epoch_start_time).total_seconds()
        
        # ====================================================================
        # Logging
        # ====================================================================
        logger.info(f"\nEpoch {epoch}/{num_epochs} - Duration: {epoch_duration:.1f}s")
        logger.info(f"  Train Loss: {avg_train_loss:.6f} (MSE: {avg_train_class_mse:.6f}, AIUR: {avg_train_aiur:.6f}, CL: {avg_train_cl:.6f})")
        logger.info(f"  Test Loss:  {avg_test_loss:.6f} (MSE: {avg_test_class_mse:.6f}, AIUR: {avg_test_aiur:.6f}, CL: {avg_test_cl:.6f})")
        logger.info(f"  Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")
        
        # ====================================================================
        # Save best model
        # ====================================================================
        if avg_test_loss < best_test_loss - config.MIN_DELTA:
            improvement = best_test_loss - avg_test_loss
            best_test_loss = avg_test_loss
            best_epoch = epoch
            epochs_without_improvement = 0
            
            best_model_path = config.CHECKPOINT_PATH / "best_model.pth"
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': avg_train_loss,
                'test_loss': avg_test_loss,
                'config': config
            }, best_model_path)
            
            logger.info(f"  New best model saved! (improvement: {improvement:.6f})")
        else:
            epochs_without_improvement += 1
            logger.info(f"  No improvement for {epochs_without_improvement} epoch(s)")
        
        # ====================================================================
        # Save periodic checkpoints
        # ====================================================================
        if epoch % config.SAVE_EVERY_N_EPOCHS == 0:
            checkpoint_path = config.CHECKPOINT_PATH / f"checkpoint_epoch_{epoch}.pth"
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': avg_train_loss,
                'test_loss': avg_test_loss,
                'config': config
            }, checkpoint_path)
            
            checkpoint_files.append(checkpoint_path)
            logger.info(f"  Checkpoint saved: {checkpoint_path.name}")
            
            # Keep only last N checkpoints
            if len(checkpoint_files) > config.KEEP_LAST_N_CHECKPOINTS:
                old_checkpoint = checkpoint_files.pop(0)
                if old_checkpoint.exists():
                    old_checkpoint.unlink()
                    logger.info(f"  Removed old checkpoint: {old_checkpoint.name}")
        
        # ====================================================================
        # Early stopping check
        # ====================================================================
        if epochs_without_improvement >= config.PATIENCE:
            logger.info(f"\n{'='*80}")
            logger.info(f"EARLY STOPPING at epoch {epoch}")
            logger.info(f"No improvement for {config.PATIENCE} consecutive epochs")
            logger.info(f"Best test loss: {best_test_loss:.6f} at epoch {best_epoch}")
            logger.info(f"{'='*80}\n")
            break
    
    # ========================================================================
    # STEP 5: Training complete
    # ========================================================================
    logger.info("\n" + "="*80)
    logger.info("TRAINING COMPLETE")
    logger.info("="*80)
    logger.info(f"Total epochs trained: {epoch}")
    logger.info(f"Best test loss: {best_test_loss:.6f} at epoch {best_epoch}")
    logger.info(f"Final train loss: {train_losses[-1]:.6f}")
    logger.info(f"Final test loss: {test_losses[-1]:.6f}")
    
    # ========================================================================
    # STEP 6: Plot and save loss curves
    # ========================================================================
    logger.info("\nGenerating loss curves...")
    loss_curve_path = config.OUTPUT_PATH / f"loss_curves_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png"
    plot_loss_curves(train_losses, test_losses, save_path=loss_curve_path)
    
    # ========================================================================
    # STEP 7: Load best model weights
    # ========================================================================
    logger.info("\nLoading best model weights...")
    best_checkpoint = torch.load(config.CHECKPOINT_PATH / "best_model.pth")
    model.load_state_dict(best_checkpoint['model_state_dict'])
    logger.info(f"Best model loaded from epoch {best_checkpoint['epoch']}")
    
    # ========================================================================
    # Save training history
    # ========================================================================
    history = {
        'train_losses': train_losses,
        'test_losses': test_losses,
        'best_test_loss': best_test_loss,
        'best_epoch': best_epoch,
        'total_epochs': epoch,
        'config': {
            'num_epochs': num_epochs,
            'batch_size': batch_size,
            'learning_rate': learning_rate,
            'grid_size': (train_dataset.I, train_dataset.J)
        }
    }
    
    # Save history to file
    history_path = config.OUTPUT_PATH / f"training_history_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pth"
    torch.save(history, history_path)
    logger.info(f"Training history saved to {history_path}")
    
    logger.info("\n" + "="*80)
    logger.info("ALL DONE!")
    logger.info("="*80 + "\n")
    
    return model, history

## Training Pipeline Summary

The complete training pipeline has been implemented with the following features:

### 1. **Loss Function (`SMRSELDLoss`)**
   - **Class MSE Loss**: Weighted MSE with 10:1 ratio for event:background classes to handle class imbalance
   - **AIUR Loss**: Area Intersection Union Ratio computed per sample to encourage better spatial predictions
   - **Converging Localization Loss**: Encourages predictions to converge towards dense non-background areas

### 2. **Training Function (`train_model`)**
   - ✅ Loads train and test datasets automatically
   - ✅ Trains model with hyperparameters from Config class
   - ✅ Computes both train and test losses each epoch
   - ✅ Saves best model weights based on test loss
   - ✅ Saves checkpoints every N epochs (configurable)
   - ✅ Tracks losses for every epoch
   - ✅ Plots and saves loss curves as images
   - ✅ Implements early stopping with configurable patience
   - ✅ Uses ReduceLROnPlateau scheduler for learning rate decay

### 3. **Key Features**
   - Progress bars with `tqdm` for training and validation
   - Detailed logging for every epoch
   - Automatic cleanup of old checkpoints (keeps last N)
   - Saves training history to disk
   - Returns trained model and history dictionary

### 4. **Configuration Parameters (in Config class)**
```python
NUM_EPOCHS = 100
BATCH_SIZE = 8
LEARNING_RATE = 1e-3
LR_DECAY_FACTOR = 0.5
LR_DECAY_PATIENCE = 5
WEIGHT_DECAY = 1e-4
W_CLASS = 1.0  # Class MSE weight
W_AIUR = 0.5   # AIUR loss weight
W_CL = 0.5     # Converging localization weight
PATIENCE = 10  # Early stopping patience
MIN_DELTA = 1e-4  # Minimum improvement threshold
SAVE_EVERY_N_EPOCHS = 5
KEEP_LAST_N_CHECKPOINTS = 3
```

### 5. **Outputs**
   - `best_model.pth`: Best performing model weights
   - `checkpoint_epoch_N.pth`: Periodic checkpoints
   - `loss_curves_TIMESTAMP.png`: Training/test loss plot
   - `training_history_TIMESTAMP.pth`: Complete training history

### 6. **Usage**
Simply uncomment and run the training cell below to start training with all default parameters!

### Model Testing and Visualization

In [35]:
def visualize_grid_predictions(
    ground_truth,
    predictions,
    time_frame,
    grid_size,
    title_prefix="",
    save_path=None
):
    """
    Visualize ground truth and predictions on a 2D grid for a specific time frame.
    
    Args:
        ground_truth: Ground truth tensor of shape (grid_cells, num_classes)
        predictions: Predictions tensor of shape (grid_cells, num_classes)
        time_frame: Time frame index being visualized
        grid_size: Tuple (I, J) for grid dimensions
        title_prefix: Prefix for plot titles
        save_path: Path to save the figure (optional)
    
    Returns:
        fig: Matplotlib figure object
    """
    I, J = grid_size
    
    # Get class predictions (argmax)
    gt_classes = torch.argmax(ground_truth, dim=-1).cpu().numpy()  # (grid_cells,)
    pred_classes = torch.argmax(predictions, dim=-1).cpu().numpy()  # (grid_cells,)
    
    # Reshape to 2D grid
    gt_grid = gt_classes.reshape(I, J)
    pred_grid = pred_classes.reshape(I, J)
    
    # Create figure with two subplots
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    # Define colormap (background = 13 will be white/light)
    cmap = plt.cm.get_cmap('tab20', 14)
    
    # Plot ground truth
    im1 = axes[0].imshow(gt_grid, cmap=cmap, vmin=0, vmax=13, aspect='auto')
    axes[0].set_title(f'{title_prefix}Ground Truth\nFrame {time_frame}', fontsize=14, fontweight='bold')
    axes[0].set_xlabel('Azimuth bins (J)', fontsize=11)
    axes[0].set_ylabel('Elevation bins (I)', fontsize=11)
    axes[0].grid(True, alpha=0.3, color='gray', linewidth=0.5)
    
    # Plot predictions
    im2 = axes[1].imshow(pred_grid, cmap=cmap, vmin=0, vmax=13, aspect='auto')
    axes[1].set_title(f'{title_prefix}Predictions\nFrame {time_frame}', fontsize=14, fontweight='bold')
    axes[1].set_xlabel('Azimuth bins (J)', fontsize=11)
    axes[1].set_ylabel('Elevation bins (I)', fontsize=11)
    axes[1].grid(True, alpha=0.3, color='gray', linewidth=0.5)
    
    # Plot difference (correct=green, incorrect=red, background=white)
    difference = (gt_classes == pred_classes).astype(int)
    # Mask background cells
    is_background = (gt_classes == 13)
    difference[is_background] = 2  # Special value for background
    
    diff_grid = difference.reshape(I, J)
    diff_cmap = plt.matplotlib.colors.ListedColormap(['red', 'green', 'lightgray'])
    im3 = axes[2].imshow(diff_grid, cmap=diff_cmap, vmin=0, vmax=2, aspect='auto')
    axes[2].set_title(f'{title_prefix}Comparison\nFrame {time_frame}\n(Green=Correct, Red=Wrong, Gray=Background)', 
                      fontsize=14, fontweight='bold')
    axes[2].set_xlabel('Azimuth bins (J)', fontsize=11)
    axes[2].set_ylabel('Elevation bins (I)', fontsize=11)
    axes[2].grid(True, alpha=0.3, color='gray', linewidth=0.5)
    
    # Add colorbars
    cbar1 = plt.colorbar(im1, ax=axes[0], fraction=0.046, pad=0.04)
    cbar1.set_label('Class ID', fontsize=10)
    
    cbar2 = plt.colorbar(im2, ax=axes[1], fraction=0.046, pad=0.04)
    cbar2.set_label('Class ID', fontsize=10)
    
    # Calculate accuracy (excluding background)
    non_bg_mask = ~is_background
    if non_bg_mask.sum() > 0:
        accuracy = (gt_classes[non_bg_mask] == pred_classes[non_bg_mask]).mean() * 100
        bg_accuracy = (gt_classes[is_background] == pred_classes[is_background]).mean() * 100
    else:
        accuracy = 0.0
        bg_accuracy = (gt_classes == pred_classes).mean() * 100
    
    # Add statistics text
    stats_text = f"Non-BG Accuracy: {accuracy:.1f}%\nBG Accuracy: {bg_accuracy:.1f}%\n"
    stats_text += f"Active Events: {non_bg_mask.sum()}/{len(gt_classes)}"
    fig.text(0.5, 0.02, stats_text, ha='center', fontsize=12, 
             bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    
    plt.tight_layout(rect=[0, 0.08, 1, 1])
    
    if save_path:
        plt.savefig(save_path, dpi=200, bbox_inches='tight')
        logger.info(f"Visualization saved to {save_path}")
    
    plt.show()
    
    return fig

In [None]:
def test_model(
    test_audio_files,
    test_meta_files,
    model_path=None,
    batch_size=None,
    device=None,
    num_visualizations=5,
    save_visualizations=True
):
    """
    Test a trained SMR-SELD model and visualize predictions.
    
    This function:
    1. Loads the best model weights
    2. Creates predictions and computes loss on test set
    3. Visualizes ground truth vs predictions for random frames with active events
    
    Args:
        test_audio_files: List of test audio file paths
        test_meta_files: List of test metadata file paths
        model_path: Path to model checkpoint (default: best_model.pth)
        batch_size: Batch size for testing (default: from config)
        device: Device to test on (default: DEVICE global)
        num_visualizations: Number of frames to visualize (default: 5)
        save_visualizations: Whether to save visualization images (default: True)
    
    Returns:
        results: Dictionary containing test metrics and visualizations
    """
    batch_size = batch_size or config.BATCH_SIZE
    device = device or DEVICE
    model_path = model_path or (config.CHECKPOINT_PATH / "best_model.pth")
    
    logger.info("="*80)
    logger.info("STARTING MODEL TESTING")
    logger.info("="*80)
    
    # ========================================================================
    # STEP 1: Load test dataset
    # ========================================================================
    logger.info("\nStep 1: Loading test dataset...")
    
    test_dataset = SELDDataset(
        audio_files=test_audio_files,
        metadata_files=test_meta_files,
        num_classes=config.NUM_CLASSES
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True if torch.cuda.is_available() else False
    )
    
    logger.info(f"Test dataset: {len(test_dataset)} windows ({len(test_loader)} batches)")
    logger.info(f"Grid dimensions: {test_dataset.I}x{test_dataset.J} = {test_dataset.total_cells} cells")
    
    # ========================================================================
    # STEP 2: Load model
    # ========================================================================
    logger.info(f"\nStep 2: Loading model from {model_path}...")
    
    # Check if checkpoint exists
    if not Path(model_path).exists():
        raise FileNotFoundError(f"Model checkpoint not found: {model_path}")
    
    # Load checkpoint
    checkpoint = torch.load(model_path, map_location=device)
    
    # Initialize model
    model = SMRSELDWithCSPDarkNet(
        n_channels=config.N_CHANNELS,
        grid_size=(test_dataset.I, test_dataset.J),
        num_classes=config.NUM_CLASSES,
        use_small=True  # Adjust if needed
    ).to(device)
    
    # Load weights
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    logger.info(f"Model loaded successfully!")
    logger.info(f"  Checkpoint epoch: {checkpoint['epoch']}")
    logger.info(f"  Checkpoint test loss: {checkpoint['test_loss']:.6f}")
    logger.info(f"  Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # ========================================================================
    # STEP 3: Initialize loss function
    # ========================================================================
    criterion = SMRSELDLoss(
        w_class=config.W_CLASS,
        w_aiur=config.W_AIUR,
        w_cl=config.W_CL,
        grid_size=(test_dataset.I, test_dataset.J)
    )
    
    # ========================================================================
    # STEP 4: Run inference on test set
    # ========================================================================
    logger.info("\nStep 3: Running inference on test set...")
    
    test_loss_accum = 0.0
    test_class_mse_accum = 0.0
    test_aiur_accum = 0.0
    test_cl_accum = 0.0
    
    all_predictions = []
    all_labels = []
    all_window_indices = []
    
    test_progress = tqdm(test_loader, desc="Testing")
    
    with torch.no_grad():
        for batch_idx, (spectrograms, labels) in enumerate(test_progress):
            # Move to device
            spectrograms = spectrograms.to(device)
            labels = labels.to(device)
            
            # Forward pass
            predictions = model(spectrograms)
            
            # Compute loss
            loss, breakdown = criterion(predictions, labels)
            
            # Accumulate losses
            test_loss_accum += loss.item()
            test_class_mse_accum += breakdown['class_mse']
            test_aiur_accum += breakdown['aiur']
            test_cl_accum += breakdown['cl']
            
            # Store predictions and labels for visualization
            all_predictions.append(predictions.cpu())
            all_labels.append(labels.cpu())
            all_window_indices.extend(range(batch_idx * batch_size, 
                                           batch_idx * batch_size + spectrograms.shape[0]))
            
            # Update progress bar
            test_progress.set_postfix({'loss': f"{loss.item():.4f}"})
    
    # Average test losses
    avg_test_loss = test_loss_accum / len(test_loader)
    avg_test_class_mse = test_class_mse_accum / len(test_loader)
    avg_test_aiur = test_aiur_accum / len(test_loader)
    avg_test_cl = test_cl_accum / len(test_loader)
    
    # Concatenate all predictions and labels
    all_predictions = torch.cat(all_predictions, dim=0)  # (N, T, G, M)
    all_labels = torch.cat(all_labels, dim=0)  # (N, T, G, M)
    
    logger.info("\n" + "="*80)
    logger.info("TEST RESULTS")
    logger.info("="*80)
    logger.info(f"Total Loss: {avg_test_loss:.6f}")
    logger.info(f"  Class MSE:  {avg_test_class_mse:.6f}")
    logger.info(f"  AIUR Loss:  {avg_test_aiur:.6f}")
    logger.info(f"  CL Loss:    {avg_test_cl:.6f}")
    
    # Calculate overall accuracy
    pred_classes = torch.argmax(all_predictions, dim=-1)
    true_classes = torch.argmax(all_labels, dim=-1)
    overall_accuracy = (pred_classes == true_classes).float().mean().item() * 100
    
    # Calculate non-background accuracy
    is_background = (true_classes == config.NUM_CLASSES - 1)
    non_bg_mask = ~is_background
    if non_bg_mask.sum() > 0:
        non_bg_accuracy = (pred_classes[non_bg_mask] == true_classes[non_bg_mask]).float().mean().item() * 100
    else:
        non_bg_accuracy = 0.0
    
    logger.info(f"\nOverall Accuracy: {overall_accuracy:.2f}%")
    logger.info(f"Non-Background Accuracy: {non_bg_accuracy:.2f}%")
    logger.info(f"Active Events: {non_bg_mask.sum().item()} / {non_bg_mask.numel()}")
    
    # ========================================================================
    # STEP 5: Find frames with active events for visualization
    # ========================================================================
    logger.info(f"\nStep 4: Finding frames with active events for visualization...")
    
    # Find frames with non-background events
    N, T, G, M = all_predictions.shape
    frames_with_events = []
    
    for window_idx in range(N):
        for time_idx in range(T):
            # Check if this frame has any non-background events
            frame_labels = all_labels[window_idx, time_idx, :, :]  # (G, M)
            frame_classes = torch.argmax(frame_labels, dim=-1)  # (G,)
            
            # Count non-background cells
            num_active = (frame_classes != config.NUM_CLASSES - 1).sum().item()
            
            if num_active > 0:
                frames_with_events.append({
                    'window_idx': window_idx,
                    'time_idx': time_idx,
                    'num_active': num_active
                })
    
    logger.info(f"Found {len(frames_with_events)} frames with active events")
    
    if len(frames_with_events) == 0:
        logger.warning("No frames with active events found! Cannot create visualizations.")
        return {
            'test_loss': avg_test_loss,
            'class_mse': avg_test_class_mse,
            'aiur': avg_test_aiur,
            'cl': avg_test_cl,
            'overall_accuracy': overall_accuracy,
            'non_bg_accuracy': non_bg_accuracy,
            'visualizations': []
        }
    
    # ========================================================================
    # STEP 6: Visualize random frames with active events
    # ========================================================================
    logger.info(f"\nStep 5: Creating {num_visualizations} visualizations...")
    
    # Randomly select frames for visualization
    num_to_visualize = min(num_visualizations, len(frames_with_events))
    selected_frames = random.sample(frames_with_events, num_to_visualize)
    
    # Sort by number of active events (most interesting first)
    selected_frames = sorted(selected_frames, key=lambda x: x['num_active'], reverse=True)
    
    visualizations = []
    
    for viz_idx, frame_info in enumerate(selected_frames):
        window_idx = frame_info['window_idx']
        time_idx = frame_info['time_idx']
        num_active = frame_info['num_active']
        
        logger.info(f"\n  Visualization {viz_idx + 1}/{num_to_visualize}:")
        logger.info(f"    Window: {window_idx}, Time Frame: {time_idx}")
        logger.info(f"    Active Events: {num_active}")
        
        # Get predictions and labels for this frame
        frame_predictions = all_predictions[window_idx, time_idx, :, :]  # (G, M)
        frame_labels = all_labels[window_idx, time_idx, :, :]  # (G, M)
        
        # Create visualization
        save_path = None
        if save_visualizations:
            viz_dir = config.OUTPUT_PATH / "test_visualizations"
            viz_dir.mkdir(exist_ok=True)
            save_path = viz_dir / f"test_viz_{viz_idx + 1}_window{window_idx}_frame{time_idx}.png"
        
        fig = visualize_grid_predictions(
            ground_truth=frame_labels,
            predictions=frame_predictions,
            time_frame=time_idx,
            grid_size=(test_dataset.I, test_dataset.J),
            title_prefix=f"Window {window_idx}, ",
            save_path=save_path
        )
        
        visualizations.append({
            'window_idx': window_idx,
            'time_idx': time_idx,
            'num_active': num_active,
            'figure': fig,
            'save_path': save_path
        })
    
    logger.info("\n" + "="*80)
    logger.info("TESTING COMPLETE")
    logger.info("="*80)
    
    # ========================================================================
    # Return results
    # ========================================================================
    results = {
        'test_loss': avg_test_loss,
        'class_mse': avg_test_class_mse,
        'aiur': avg_test_aiur,
        'cl': avg_test_cl,
        'overall_accuracy': overall_accuracy,
        'non_bg_accuracy': non_bg_accuracy,
        'num_frames_with_events': len(frames_with_events),
        'visualizations': visualizations,
        'checkpoint_epoch': checkpoint['epoch']
    }
    
    return results

---

## 🚀 Full Training Pipeline

Run the complete end-to-end pipeline: load all data, train model, test model, and save weights.

In [None]:
"""
FULL TRAINING AND TESTING PIPELINE
====================================

This cell runs the complete end-to-end pipeline:
1. Loads all training and test audio/metadata files
2. Creates train and test datasets
3. Trains the model with full configuration
4. Tests the model and creates visualizations
5. Saves the best model weights for future inference

Prerequisites:
- All data files should be loaded (train_audio_files, train_meta_files, etc.)
- Config class should be properly configured
- All functions (train_model, test_model) should be defined

WARNING: This will take significant time depending on:
   - Dataset size (number of files)
   - Number of epochs (config.NUM_EPOCHS)
   - Hardware (GPU vs CPU)
   
Estimated time: Several hours on GPU, much longer on CPU
"""

print("="*80)
print("STARTING FULL TRAINING AND TESTING PIPELINE")
print("="*80)

# ============================================================================
# Configuration Check
# ============================================================================
print("\n[Configuration]")
print(f"  Device: {DEVICE}")
print(f"  Number of epochs: {config.NUM_EPOCHS}")
print(f"  Batch size: {config.BATCH_SIZE}")
print(f"  Learning rate: {config.LEARNING_RATE}")
print(f"  Early stopping patience: {config.PATIENCE}")
print(f"  Loss weights: Class={config.W_CLASS}, AIUR={config.W_AIUR}, CL={config.W_CL}")
print(f"  Checkpoint path: {config.CHECKPOINT_PATH}")
print(f"  Output path: {config.OUTPUT_PATH}")

# ============================================================================
# Step 1: Verify data files are loaded
# ============================================================================
print("\n[Step 1] Verifying data files...")

if 'train_audio_files' not in locals() or 'train_meta_files' not in locals():
    raise RuntimeError("Training files not loaded! Please run the data loading cells first.")

if 'test_audio_files' not in locals() or 'test_meta_files' not in locals():
    raise RuntimeError("Test files not loaded! Please run the data loading cells first.")

print(f"Training files: {len(train_audio_files)} audio, {len(train_meta_files)} metadata")
print(f"Test files: {len(test_audio_files)} audio, {len(test_meta_files)} metadata")

# Verify files exist
print("\nVerifying file existence (sampling first 3 files)...")
for i, (audio, meta) in enumerate(zip(train_audio_files[:3], train_meta_files[:3])):
    audio_exists = Path(audio).exists()
    meta_exists = Path(meta).exists()
    status = "✓" if audio_exists and meta_exists else "✗"
    print(f"  {status} Train {i+1}: Audio={audio_exists}, Meta={meta_exists}")

for i, (audio, meta) in enumerate(zip(test_audio_files[:3], test_meta_files[:3])):
    audio_exists = Path(audio).exists()
    meta_exists = Path(meta).exists()
    status = "✓" if audio_exists and meta_exists else "✗"
    print(f"  {status} Test {i+1}: Audio={audio_exists}, Meta={meta_exists}")

# ============================================================================
# Step 2: Train the model
# ============================================================================
print("\n" + "="*80)
print("STEP 2: TRAINING THE MODEL")
print("="*80)
print("\n This may take several hours depending on dataset size and hardware...")
print("Progress will be displayed with loss values for each epoch")
print("Model checkpoints will be saved periodically\n")

# Confirm before starting (optional - comment out if running in batch mode)
# response = input("Ready to start training? (yes/no): ")
# if response.lower() != 'yes':
#     print("Training cancelled.")
#     raise RuntimeError("Training cancelled by user")

try:
    # Start training
    trained_model, training_history = train_model(
        train_audio_files=train_audio_files,
        train_meta_files=train_meta_files,
        test_audio_files=test_audio_files,
        test_meta_files=test_meta_files,
        num_epochs=config.NUM_EPOCHS,
        batch_size=config.BATCH_SIZE,
        learning_rate=config.LEARNING_RATE,
        device=DEVICE,
        use_small_model=True  # Set to False for full CSPDarkNet53 model
    )
    
    print("\n" + "="*80)
    print("TRAINING COMPLETED SUCCESSFULLY!")
    print("="*80)
    
    # Display training summary
    print("\n[Training Summary]")
    print(f"  Total epochs trained: {training_history['total_epochs']}")
    print(f"  Best epoch: {training_history['best_epoch']}")
    print(f"  Best test loss: {training_history['best_test_loss']:.6f}")
    print(f"  Final train loss: {training_history['train_losses'][-1]:.6f}")
    print(f"  Final test loss: {training_history['test_losses'][-1]:.6f}")
    
    # Check learning progress
    train_improvement = training_history['train_losses'][0] - training_history['train_losses'][-1]
    test_improvement = training_history['test_losses'][0] - training_history['test_losses'][-1]
    
    print(f"\n[Learning Progress]")
    print(f"  Train loss improvement: {train_improvement:.6f} ({train_improvement/training_history['train_losses'][0]*100:.2f}%)")
    print(f"  Test loss improvement: {test_improvement:.6f} ({test_improvement/training_history['test_losses'][0]*100:.2f}%)")
    
    if train_improvement > 0:
        print("  Model successfully learned from training data")
    else:
        print("  Training loss did not improve - consider adjusting hyperparameters")
    
    if test_improvement > 0:
        print("  Model generalizes well to test data")
    else:
        print("  Test loss did not improve - possible overfitting or insufficient training")
    
    # Model checkpoint info
    best_model_path = config.CHECKPOINT_PATH / "best_model.pth"
    print(f"\n[Model Checkpoints]")
    print(f"  Best model saved at: {best_model_path}")
    print(f"  File size: {best_model_path.stat().st_size / (1024*1024):.2f} MB")
    print(f"  Loss curves saved in: {config.OUTPUT_PATH}")
    
except Exception as e:
    print("\n" + "="*80)
    print("TRAINING FAILED")
    print("="*80)
    print(f"\nError: {e}")
    import traceback
    traceback.print_exc()
    print("\nPlease check the error message above and fix any issues.")
    raise

# ============================================================================
# Step 3: Test the trained model
# ============================================================================
print("\n" + "="*80)
print("STEP 3: TESTING THE MODEL")
print("="*80)
print("\nRunning inference on test set...")
print("Creating visualizations for frames with active events...\n")

try:
    # Run testing with visualizations
    test_results = test_model(
        test_audio_files=test_audio_files,
        test_meta_files=test_meta_files,
        model_path=config.CHECKPOINT_PATH / "best_model.pth",
        batch_size=config.BATCH_SIZE,
        device=DEVICE,
        num_visualizations=10,  # Create 10 visualizations
        save_visualizations=True
    )
    
    print("\n" + "="*80)
    print("✅ TESTING COMPLETED SUCCESSFULLY!")
    print("="*80)
    
    # Display test summary
    print("\n[Test Performance]")
    print(f"  Test Loss: {test_results['test_loss']:.6f}")
    print(f"    - Class MSE:  {test_results['class_mse']:.6f}")
    print(f"    - AIUR Loss:  {test_results['aiur']:.6f}")
    print(f"    - CL Loss:    {test_results['cl']:.6f}")
    
    print(f"\n[Accuracy Metrics]")
    print(f"  Overall Accuracy: {test_results['overall_accuracy']:.2f}%")
    print(f"  Non-Background Accuracy: {test_results['non_bg_accuracy']:.2f}%")
    print(f"  Frames with events: {test_results['num_frames_with_events']}")
    
    print(f"\n[Visualizations]")
    print(f"  Number created: {len(test_results['visualizations'])}")
    if len(test_results['visualizations']) > 0:
        viz_dir = test_results['visualizations'][0]['save_path'].parent
        print(f"  Saved to: {viz_dir}")
        print(f"  Files:")
        for viz in test_results['visualizations']:
            print(f"    - {viz['save_path'].name} (Window {viz['window_idx']}, Frame {viz['time_idx']}, {viz['num_active']} events)")
    
except Exception as e:
    print("\n" + "="*80)
    print("TESTING FAILED")
    print("="*80)
    print(f"\nError: {e}")
    import traceback
    traceback.print_exc()
    print("\nNote: Training may have succeeded even if testing failed.")
    raise

# ============================================================================
# Final Summary
# ============================================================================
print("\n" + "="*80)
print("PIPELINE COMPLETED SUCCESSFULLY!")
print("="*80)

print("\n[Summary]")
print("Data Loading: PASSED")
print("Model Training: PASSED")
print("Model Testing: PASSED")
print("Visualizations: CREATED")
print("Model Weights: SAVED")

print("\n[Output Files]")
print(f"Checkpoints: {config.CHECKPOINT_PATH}")
print(f"   - best_model.pth (use this for inference)")
print(f"   - checkpoint_epoch_*.pth (periodic checkpoints)")
print(f"Results: {config.OUTPUT_PATH}")
print(f"   - loss_curves_*.png (training progress)")
print(f"   - training_history_*.pth (complete training history)")
print(f"   - test_visualizations/ (prediction visualizations)")

print("\n[Next Steps]")
print("1. Review loss curves to understand training dynamics")
print("2. Examine test visualizations to see model predictions")
print("3. Use best_model.pth for inference on new data")
print("4. If results are unsatisfactory, adjust hyperparameters and retrain")

print("\n[Model Info]")
print(f"  Architecture: SMRSELDWithCSPDarkNet")
print(f"  Parameters: {sum(p.numel() for p in trained_model.parameters()):,}")
print(f"  Best epoch: {training_history['best_epoch']}/{training_history['total_epochs']}")
print(f"  Best test loss: {training_history['best_test_loss']:.6f}")
print(f"  Test accuracy: {test_results['non_bg_accuracy']:.2f}%")

print("\n" + "="*80)
print("Training pipeline completed! Your model is ready for inference.")
print("="*80)

In [None]:
def augment_with_gaussian_noise(metadata_path, audio_duration, sample_rate=24000, I=None, J=None, 
                                cell_size_deg=None, num_classes=14, sigma_azimuth=5.0, sigma_elevation=5.0):
    """
    Convert metadata file to target labels with Gaussian spatial augmentation for each source.
    Instead of treating sources as points, treat them as 3D regions with Gaussian distribution.
    """
    if cell_size_deg is None:
        cell_size_deg = config.GRID_CELL_DEGREES
    
    frame_duration_ms = 20
    metadata_frame_duration_ms = 100
    frames_per_metadata_frame = metadata_frame_duration_ms // frame_duration_ms  # = 5

    total_frames = int((audio_duration * 1000) / frame_duration_ms)

    if (I is None or J is None) and cell_size_deg is not None:
        I = int(180 // cell_size_deg)
        J = int(360 // cell_size_deg)
    elif I is None or J is None:
        raise ValueError("Either provide (I, J) or cell_size_deg for grid dimensions")
    
    total_cells = I * J

    labels = torch.zeros((total_frames, total_cells, num_classes), dtype=torch.float32)
    
    df = pd.read_csv(metadata_path, header=None)
    active_cells_per_frame = [set() for _ in range(total_frames)]
    
    # Step 1: Identify unique sources in the file
    # Source is identified by (class_idx, source_num) tuple
    unique_sources = df.groupby([1, 2]).first().reset_index()  # Columns 1=class, 2=source_num
    
    # Step 2: Generate fixed Gaussian noise for each unique source
    # This ensures the same source has consistent spatial augmentation across all frames
    source_noise = {}
    for _, source_row in unique_sources.iterrows():
        class_idx = int(source_row.iloc[0])
        source_num = int(source_row.iloc[1])
        source_key = (class_idx, source_num)

        azimuth_noise = np.random.normal(0, sigma_azimuth)
        elevation_noise = np.random.normal(0, sigma_elevation)
        
        source_noise[source_key] = (azimuth_noise, elevation_noise)
    
    # Step 3: Process each row in metadata with augmentation
    for _, row in df.iterrows():
        metadata_frame = int(row.iloc[0])
        active_class = int(row.iloc[1])
        source_num = int(row.iloc[2])
        azimuth = int(row.iloc[3])
        elevation = int(row.iloc[4])
        
        source_key = (active_class, source_num)
        
        # Get the fixed Gaussian noise for this source
        azimuth_noise, elevation_noise = source_noise[source_key]
        
        # Map metadata frame to final representation frames
        start_frame = metadata_frame * frames_per_metadata_frame
        end_frame = start_frame + frames_per_metadata_frame
        end_frame = min(end_frame, total_frames)
        
        # Step 4: Create Gaussian region around the source location
        # Instead of sampling points, consider the entire bounded region
        # Use 2-sigma range (covers ~95% of Gaussian distribution)
        # This creates a bounded rectangular region in azimuth-elevation space
        
        affected_cells = set()
        
        # Define the bounded region using the Gaussian noise and 2-sigma bounds
        # The region center is offset by the fixed noise for this source
        center_azimuth = azimuth + azimuth_noise
        center_elevation = elevation + elevation_noise
        
        # Define bounds: use 2*sigma to cover ~95% of the Gaussian
        # This creates a rectangular region in the angular space
        azimuth_min = center_azimuth - 2 * sigma_azimuth
        azimuth_max = center_azimuth + 2 * sigma_azimuth
        elevation_min = center_elevation - 2 * sigma_elevation
        elevation_max = center_elevation + 2 * sigma_elevation
        
        # Clip elevation to valid range
        elevation_min = max(elevation_min, -90)
        elevation_max = min(elevation_max, 90)
        
        # Find all grid cells that fall within this bounded region
        # We need to check each grid cell to see if it falls within the bounds
        for grid_i in range(I):
            for grid_j in range(J):
                # Get the center coordinates of this grid cell
                # Grid cell (i, j) corresponds to:
                # elevation: -90 + (i + 0.5) * cell_size_elevation
                # azimuth: -180 + (j + 0.5) * cell_size_azimuth
                cell_size_elevation = 180.0 / I
                cell_size_azimuth = 360.0 / J
                
                cell_elevation = -90 + (grid_i + 0.5) * cell_size_elevation
                cell_azimuth = -180 + (grid_j + 0.5) * cell_size_azimuth
                
                # Check if this cell's center falls within the Gaussian region bounds
                # For azimuth, need to handle wraparound at -180/180
                # Normalize all azimuths to be in the same reference frame
                def normalize_azimuth_diff(az1, az2):
                    """Calculate shortest angular distance between two azimuths"""
                    diff = az1 - az2
                    while diff > 180:
                        diff -= 360
                    while diff < -180:
                        diff += 360
                    return diff
                
                # Check if cell is within azimuth bounds (considering wraparound)
                azimuth_dist = abs(normalize_azimuth_diff(cell_azimuth, center_azimuth))
                azimuth_in_bounds = azimuth_dist <= 2 * sigma_azimuth
                
                # Check if cell is within elevation bounds (simple range check)
                elevation_in_bounds = elevation_min <= cell_elevation <= elevation_max
                
                if azimuth_in_bounds and elevation_in_bounds:
                    cell_idx = grid_i * J + grid_j
                    affected_cells.add(cell_idx)
        
        # Step 5: Set active class for all affected cells in the region
        for cell_idx in affected_cells:
            for t in range(start_frame, end_frame):
                # Use softer probability for cells in the region
                # Central cells (from original position) get higher probability
                labels[t, cell_idx, active_class] = 1.0
                active_cells_per_frame[t].add(cell_idx)
    
    # Step 6: Set background class for cells with no active events
    for t in range(total_frames):
        for cell_idx in range(total_cells):
            if cell_idx not in active_cells_per_frame[t]:
                labels[t, cell_idx, num_classes - 1] = 1.0
    
    return labels, I, J

In [37]:
def compare_augmentation_methods(metadata_path, audio_duration):
    """
    Compare the original point-wise labeling with Gaussian region augmentation.
    
    Args:
        metadata_path: Path to metadata CSV file
        audio_duration: Duration of audio in seconds
    """
    # Original point-wise labeling
    labels_original, I, J = metadata_to_labels(
        metadata_path, 
        audio_duration,
        sample_rate=config.SR,
        cell_size_deg=config.GRID_CELL_DEGREES,
        num_classes=config.NUM_CLASSES
    )
    
    # Augmented region-based labeling
    labels_augmented, _, _ = augment_with_gaussian_noise(
        metadata_path,
        audio_duration,
        sample_rate=config.SR,
        cell_size_deg=config.GRID_CELL_DEGREES,
        num_classes=config.NUM_CLASSES,
        sigma_azimuth=5.0,  # 5 degrees standard deviation
        sigma_elevation=5.0
    )
    
    # Compare statistics
    print("Comparison of Original vs Augmented Labeling:")
    print("=" * 60)
    
    # Count non-background cells
    original_nonbg = (labels_original[:, :, :-1].sum(dim=-1) > 0).sum().item()
    augmented_nonbg = (labels_augmented[:, :, :-1].sum(dim=-1) > 0).sum().item()
    
    print(f"Original - Active cells (non-background): {original_nonbg}")
    print(f"Augmented - Active cells (non-background): {augmented_nonbg}")
    print(f"Increase in active cells: {augmented_nonbg - original_nonbg} ({(augmented_nonbg/original_nonbg - 1)*100:.1f}%)")
    
    # Find frames with different characteristics
    print("\nFinding representative frames...")
    
    # Calculate activity per frame (non-background cells)
    activity_per_frame = (labels_original[:, :, :-1].sum(dim=-1) > 0).sum(dim=-1)
    
    # Find 1 frame with no events
    no_event_frames = (activity_per_frame == 0).nonzero(as_tuple=True)[0]
    if len(no_event_frames) > 0:
        frame_no_event = no_event_frames[len(no_event_frames) // 2].item()  # Pick middle one
    else:
        frame_no_event = 0
        print("  Warning: No frames without events found, using frame 0")
    
    # Find 2 frames with events (low and high activity)
    event_frames = (activity_per_frame > 0).nonzero(as_tuple=True)[0]
    if len(event_frames) >= 2:
        # Sort by activity level
        event_frames_sorted = event_frames[torch.argsort(activity_per_frame[event_frames])]
        frame_low_event = event_frames_sorted[len(event_frames_sorted) // 3].item()  # Low activity
        frame_high_event = event_frames_sorted[2 * len(event_frames_sorted) // 3].item()  # High activity
    elif len(event_frames) == 1:
        frame_low_event = event_frames[0].item()
        frame_high_event = event_frames[0].item()
    else:
        frame_low_event = 10
        frame_high_event = 20
        print("  Warning: No frames with events found, using default frames")
    
    selected_frames = [
        (frame_no_event, "No Events", int(activity_per_frame[frame_no_event].item())),
        (frame_low_event, "Low Activity", int(activity_per_frame[frame_low_event].item())),
        (frame_high_event, "High Activity", int(activity_per_frame[frame_high_event].item()))
    ]
    
    print(f"\nSelected frames:")
    for frame_idx, desc, activity in selected_frames:
        print(f"  {desc}: Frame {frame_idx} (Activity: {activity} cells)")
    
    # Create output directory
    import os
    output_dir = "gaussian_visualizations"
    os.makedirs(output_dir, exist_ok=True)
    print(f"\nSaving visualizations to: {output_dir}/")
    
    # Visualize each selected frame
    for frame_idx, desc, activity in selected_frames:
        original_frame = labels_original[frame_idx].reshape(I, J, -1)
        augmented_frame = labels_augmented[frame_idx].reshape(I, J, -1)
        
        fig, axes = plt.subplots(1, 2, figsize=(16, 6))
        
        # Original
        original_activity = original_frame[:, :, :-1].sum(dim=-1).numpy()
        im1 = axes[0].imshow(original_activity, cmap='hot', aspect='auto', origin='lower', vmin=0, vmax=1)
        axes[0].set_title(f'Original Point-wise\nFrame {frame_idx} - {desc} ({activity} cells)', fontsize=12, fontweight='bold')
        axes[0].set_xlabel('Azimuth bins (J)')
        axes[0].set_ylabel('Elevation bins (I)')
        axes[0].grid(True, alpha=0.3)
        plt.colorbar(im1, ax=axes[0], label='Activity')
        
        # Augmented
        augmented_activity = augmented_frame[:, :, :-1].sum(dim=-1).numpy()
        augmented_cells = (augmented_activity > 0).sum()
        im2 = axes[1].imshow(augmented_activity, cmap='hot', aspect='auto', origin='lower', vmin=0, vmax=1)
        axes[1].set_title(f'Augmented Gaussian Region\nFrame {frame_idx} - {desc} ({augmented_cells} cells)', fontsize=12, fontweight='bold')
        axes[1].set_xlabel('Azimuth bins (J)')
        axes[1].set_ylabel('Elevation bins (I)')
        axes[1].grid(True, alpha=0.3)
        plt.colorbar(im2, ax=axes[1], label='Activity')
        
        # Overall title
        increase_pct = ((augmented_cells / max(activity, 1)) - 1) * 100 if activity > 0 else 0
        fig.suptitle(f'Gaussian Augmentation Comparison - {desc} Frame\n'
                     f'Increase: {augmented_cells - activity} cells (+{increase_pct:.1f}%)',
                     fontsize=14, fontweight='bold', y=1.02)
        
        plt.tight_layout()
        
        # Save figure
        safe_desc = desc.lower().replace(" ", "_")
        save_path = os.path.join(output_dir, f'augmentation_frame_{frame_idx}_{safe_desc}.png')
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"  Saved: {save_path}")
        
        plt.close(fig)
    
    print("\nVisualization complete!")
    
    return labels_original, labels_augmented

In [38]:
# Test Gaussian Augmentation on a Single File
print("Testing Gaussian Augmentation")
print("=" * 80)

# Load a single file from the training set
if len(train_meta_files) > 0:
    # Pick a file from the middle of the training set
    test_file_idx = len(train_meta_files) // 2
    test_metadata_path = train_meta_files[test_file_idx]
    test_audio_path = train_audio_files[test_file_idx]
    
    print(f"\nSelected file for testing:")
    print(f"  Metadata: {Path(test_metadata_path).name}")
    print(f"  Audio: {Path(test_audio_path).name}")
    
    # Get audio duration
    import torchaudio
    audio_info = torchaudio.info(test_audio_path)
    audio_duration = audio_info.num_frames / audio_info.sample_rate
    
    print(f"  Duration: {audio_duration:.2f} seconds")
    print(f"  Sample Rate: {audio_info.sample_rate} Hz")
    
    # Run comparison and generate visualizations
    print("\nRunning Gaussian augmentation comparison...")
    print("-" * 80)
    
    labels_original, labels_augmented = compare_augmentation_methods(
        test_metadata_path, 
        audio_duration
    )
    
    print("\n" + "=" * 80)
    print("Gaussian augmentation test complete!")
    print(f"Check the 'gaussian_visualizations/' directory for results.")
    print("=" * 80)
    
else:
    print("No training files found. Please load the dataset first.")

Testing Gaussian Augmentation

Selected file for testing:
  Metadata: fold3_room22_mix005.csv
  Audio: fold3_room22_mix005.wav
  Duration: 133.50 seconds
  Sample Rate: 24000 Hz

Running Gaussian augmentation comparison...
--------------------------------------------------------------------------------
Comparison of Original vs Augmented Labeling:
Original - Active cells (non-background): 15415
Augmented - Active cells (non-background): 61465
Increase in active cells: 46050 (298.7%)

Finding representative frames...
Comparison of Original vs Augmented Labeling:
Original - Active cells (non-background): 15415
Augmented - Active cells (non-background): 61465
Increase in active cells: 46050 (298.7%)

Finding representative frames...

Selected frames:
  No Events: Frame 2 (Activity: 0 cells)
  Low Activity: Frame 1438 (Activity: 2 cells)
  High Activity: Frame 1931 (Activity: 3 cells)

Saving visualizations to: gaussian_visualizations/

Selected frames:
  No Events: Frame 2 (Activity: 0 ce