In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Suppress TensorFlow logging messages (e.g., warnings about GPU setup)

import sys
import numpy as np
import pandas as pd
import librosa # Library for audio analysis
import random
import matplotlib.pyplot as plt # For plotting
from IPython.display import Audio # For playing audio in Jupyter/IPython environments

# Import necessary Keras components from TensorFlow
from tensorflow.keras.utils import Sequence # Base class for Keras data generators
from tensorflow.keras import Input, Model # For defining functional API models
from tensorflow.keras import layers, callbacks # Core layers and callback functions
from tensorflow.keras.layers import (
    Cropping2D, Lambda, Conv2D, Dropout, MaxPooling2D, Conv2DTranspose,
    Concatenate, GlobalAveragePooling2D, Activation, BatchNormalization, Dense
) # Specific layers used in the model
from tensorflow.keras.optimizers import Adam # Optimizer for model training
from tensorflow.keras.callbacks import ReduceLROnPlateau # Learning rate scheduler callback

# Import metrics and visualization tools from scikit-learn
from sklearn.metrics import accuracy_score, confusion_matrix, ConfusionMatrixDisplay


# --- Directory Setup ---
# Define a list of directories to be created in the Kaggle working environment.
# These directories are used for storing data, source code, models, and documentation.
dirs = [
    "data/UrbanSound8K", # Placeholder for processed data or temporary files
    "data/masks",       # Placeholder for mask-related data
    "notebooks",        # For Jupyter notebooks (though this is already a notebook)
    "src",              # For Python source files (like datagen.py)
    "models",           # For saving trained Keras models
    "docs"              # For documentation or reports
]
# Create each directory if it does not already exist.
for d in dirs:
    os.makedirs(f"/kaggle/working/{d}", exist_ok=True)
print("Working directories created/ensured:", dirs)

# --- Data Loading and Initial Exploration ---
# List contents of the Kaggle input directory and the specific UrbanSound8K dataset directory
# to verify that the dataset is mounted correctly.
print("\n--- Listing Dataset Contents ---")
print("Kaggle input directory:", os.listdir("/kaggle/input/"))
print("UrbanSound8K dataset directory:", os.listdir("/kaggle/input/urbansound8k/"))

print("\n--- Loading Metadata ---")
# Define the path to the UrbanSound8K metadata CSV file.
metadata_path = "/kaggle/input/urbansound8k/UrbanSound8K.csv"
# Load the metadata into a pandas DataFrame.
df = pd.read_csv(metadata_path)
print("Metadata head:\n", df.head()) # Display the first few rows of the DataFrame

print("\n--- Class Distribution ---")
# Print the unique class labels present in the dataset.
print("Unique classes:", df['class'].unique())
# Print the count of samples for each class to check for class imbalance.
print("Class counts:\n", df['class'].value_counts())

print("\n--- Fold Distribution ---")
# Print the count of samples per fold to understand the data distribution across folds.
print("Fold counts:\n", df['fold'].value_counts())

print("\n--- Missing Values Check ---")
# Check for any missing values across the entire DataFrame.
print("Any missing values in metadata:", df.isna().sum().sum())


# --- Audio Sample Loading and Visualization ---
print("\n--- Audio Sample Demonstration ---")
# Select a random audio sample's metadata from the DataFrame for demonstration.
sample_meta = df.sample(1, random_state=7).iloc[0]
# Construct the full file path for the selected audio sample.
file_path = f"/kaggle/input/urbansound8k/fold{sample_meta['fold']}/{sample_meta['slice_file_name']}"

# Load the audio file using librosa.
# `sr=22050` resamples the audio to 22050 Hz.
# `duration=4` loads only the first 4 seconds of the audio.
y, sr = librosa.load(file_path, sr=22050, duration=4)
print(f"Loaded '{sample_meta['slice_file_name']}'; Class: {sample_meta['class']} | Sample Rate: {sr} | Samples: {len(y)}")

# Play audio directly in the Jupyter/IPython environment.
print("Playing audio sample:")
Audio(y, rate=sr) # Using IPython.display.Audio directly

# Display & plot audio waveform using matplotlib and librosa.display.
plt.figure(figsize=(12, 3))
librosa.display.waveshow(y, sr=sr) # Plot the waveform
plt.title(f"Waveform | {sample_meta['class']} ({sample_meta['slice_file_name']})")
plt.xlabel("Time (s)")
plt.ylabel("Amplitude")
plt.show()

# --- Spectrogram Visualization Function ---
def show_urbansound8k_samples(
    df,
    audio_dir="/kaggle/input/urbansound8k",
    n_samples=6,
    random_state=None
):
    """
    Displays mel spectrograms for a given number of random UrbanSound8K samples.
    Includes min-max normalization for consistent visualization.

    Args:
        df (pd.DataFrame): The metadata DataFrame.
        audio_dir (str): Base directory where audio files are stored.
        n_samples (int): Number of random samples to display.
        random_state (int, optional): Seed for reproducibility. Defaults to None.
    """
    # Sample rows from the DataFrame to visualize.
    sample_df = df.sample(n_samples, random_state=random_state)
    plt.figure(figsize=(6 * n_samples, 6)) # Adjust figure size based on number of samples
    for i, row in enumerate(sample_df.itertuples()): # Iterate through sampled rows
        file_path = f"{audio_dir}/fold{row.fold}/{row.slice_file_name}"
        try:
            # Load audio, resample, and trim/pad to a fixed duration.
            y, sr = librosa.load(file_path, sr=22050, duration=4)
            if len(y) < 4 * 22050:
                y = np.pad(y, (0, 4 * 22050 - len(y)), mode='constant')

            # Compute mel spectrogram and convert to decibels.
            mel = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128)
            mel_db = librosa.power_to_db(mel, ref=np.max)

            # --- IMPROVEMENT: Min-Max Normalization for Visualization ---
            # Normalize spectrograms to a [0, 1] range for consistent color mapping
            # across different spectrograms. Add a small epsilon to prevent division by zero.
            mel_db_normalized = (mel_db - mel_db.min()) / (mel_db.max() - mel_db.min() + 1e-8)

            # Plot the normalized mel spectrogram.
            plt.subplot(1, n_samples, i+1) # Create subplot for each sample
            plt.imshow(mel_db_normalized, origin='lower', aspect='auto', cmap='magma') # Use normalized data
            # Robustly get the class label for the title.
            label = getattr(row, 'class', getattr(row, 'classID', 'Unknown'))
            plt.title(f"Class: {label}")
            plt.axis('off') # Turn off axes for cleaner spectrogram display
        except Exception as e:
            # Handle errors during loading or plotting for individual files.
            print(f"Error loading or plotting {file_path}: {e}")
            plt.subplot(1, n_samples, i+1)
            plt.axis('off')
            plt.title("Error")
    plt.suptitle(f"Random {n_samples} UrbanSound8K Spectrograms", fontsize=18) # Main title for the figure
    plt.tight_layout() # Adjust subplot parameters for a tight layout.
    plt.show()

print("\n--- Visualizing Sample Spectrograms ---")
show_urbansound8k_samples(df, audio_dir="/kaggle/input/urbansound8k", n_samples=6, random_state=42)


# --- CustomDataGen Class Definition (Improved) ---
# Create the /kaggle/working/src/ Directory to store the custom data generator.
os.makedirs("/kaggle/working/src", exist_ok=True)

# Write the CustomDataGen class code to a Python file.
# This allows it to be imported as a module.
with open("/kaggle/working/src/datagen.py", "w") as f:
    f.write("""
import numpy as np
import librosa
import tensorflow as tf # Required for tf.keras.utils.Sequence

class CustomDataGen(tf.keras.utils.Sequence):
    \"\"\"
    A custom Keras data generator for the UrbanSound8K dataset.
    It loads audio, converts to mel spectrograms, applies padding,
    and supports optional mask overlay augmentation.
    Spectrograms are min-max normalized before being returned.
    \"\"\"
    def __init__(self, df, audio_dir, batch_size=8, shuffle=True, n_mels=128, duration=4, sr=22050, mask_overlay_df=None):
        self.df = df.reset_index(drop=True) # Reset index to ensure consistent indexing
        self.audio_dir = audio_dir # Base directory for audio files
        self.batch_size = batch_size # Number of samples per batch
        self.shuffle = shuffle # Whether to shuffle data at the end of each epoch
        self.n_mels = n_mels # Number of mel bands for spectrogram
        self.duration = duration # Duration of audio clips in seconds
        self.sr = sr # Target sample rate for audio
        self.mask_overlay_df = mask_overlay_df # DataFrame for overlay augmentation samples
        self.indexes = np.arange(len(self.df)) # Array of indices for the dataset
        self.on_epoch_end() # Initial shuffle of indexes

    def __len__(self):
        # Returns the number of batches per epoch.
        # np.floor ensures we only return full batches.
        return int(np.floor(len(self.df) / self.batch_size))

    def __getitem__(self, index):
        # Generates one batch of data given an index.
        # Selects indices for the current batch.
        batch_indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
        # Retrieves the corresponding rows from the DataFrame.
        batch = self.df.iloc[batch_indexes]
        # Generates the actual data (mel spectrograms and labels) for the batch.
        X, y = self.__data_generation(batch)
        return X, y

    def on_epoch_end(self):
        # Updates indexes after each epoch. This method is called by Keras after each epoch.
        if self.shuffle:
            np.random.shuffle(self.indexes) # Shuffle indices for the next epoch

    def __load_audio(self, file_path):
        # Loads an audio file, resamples it, and pads/trims to a fixed duration.
        y, _ = librosa.load(file_path, sr=self.sr, duration=self.duration)
        # Pad audio if its length is less than the target duration.
        if len(y) < int(self.sr * self.duration):
            y = np.pad(y, (0, int(self.sr * self.duration - len(y))), "constant")
        return y

    def __mel_spectrogram(self, y):
        # Converts an audio waveform into a mel spectrogram.
        mel = librosa.feature.melspectrogram(y=y, sr=self.sr, n_mels=self.n_mels)
        # Convert power spectrogram to decibel (dB) scale.
        mel_db = librosa.power_to_db(mel, ref=np.max)
        return mel_db

    def __overlay_augment(self, base_audio, overlay_audio, snr_db=10):
        # Overlays a `base_audio` with an `overlay_audio` at a specified Signal-to-Noise Ratio (SNR).
        rms_base = np.sqrt(np.mean(base_audio ** 2)) # RMS of base audio
        rms_overlay = np.sqrt(np.mean(overlay_audio ** 2)) # RMS of overlay audio

        if rms_overlay == 0: # Avoid division by zero if overlay is silent
            return base_audio

        # Calculate desired RMS of overlay based on SNR.
        desired_rms_overlay = rms_base / (10**(snr_db / 20))
        # Scale the overlay audio to achieve the desired RMS.
        scaled_overlay = overlay_audio * (desired_rms_overlay / (rms_overlay + 1e-8))
        # Mix the base and scaled overlay audio.
        mixed = base_audio + scaled_overlay
        mixed = np.clip(mixed, -1.0, 1.0) # Clip to ensure valid audio amplitude range [-1.0, 1.0]
        return mixed

    def __data_generation(self, batch_df):
        # Generates preprocessed data (mel spectrograms) and labels for a given batch DataFrame.
        X = [] # List to store input features (spectrograms)
        y = [] # List to store target labels
        for _, row in batch_df.iterrows():
            file_path = f"{self.audio_dir}/fold{row['fold']}/{row['slice_file_name']}"
            base_audio = self.__load_audio(file_path)

            # Apply mask overlay augmentation with a 50% probability if mask_overlay_df is provided.
            if self.mask_overlay_df is not None and np.random.rand() < 0.5:
                overlay_sample = self.mask_overlay_df.sample(1).iloc[0] # Pick a random overlay sample
                overlay_path = f"{self.audio_dir}/fold{overlay_sample['fold']}/{overlay_sample['slice_file_name']}"
                overlay_audio = self.__load_audio(overlay_path)
                base_audio = self.__overlay_augment(base_audio, overlay_audio, snr_db=10)

            mel_spec = self.__mel_spectrogram(base_audio)

            # --- IMPROVEMENT: Min-Max Normalization for Mel Spectrograms ---
            # Normalize spectrograms to a [0, 1] range. This is crucial for neural network performance
            # as it helps with gradient stability and faster convergence.
            # Add a small epsilon (1e-8) to the denominator to prevent division by zero for flat spectrograms.
            mel_spec_normalized = (mel_spec - np.min(mel_spec)) / (np.max(mel_spec) - np.min(mel_spec) + 1e-8)

            # Add a channel dimension for Keras (expected format: batch, height, width, channels).
            # For grayscale images/spectrograms, channels=1.
            mel_spec_final = np.expand_dims(mel_spec_normalized, axis=-1)
            X.append(mel_spec_final)
            y.append(row['classID']) # Append the class ID as the label

        # Convert lists of features and labels to numpy arrays with appropriate dtypes.
        # float32 for input features, int64 for labels (for sparse_categorical_crossentropy).
        return np.array(X, dtype=np.float32), np.array(y, dtype=np.int64)
""")

# Re-import CustomDataGen to ensure the latest version (with improvements) is used.
sys.path.append("/kaggle/working/src/")
from datagen import CustomDataGen

# --- Test the Data Generator ---
print("\n--- Testing CustomDataGen ---")
# Create a small DataFrame for mask overlay samples (optional augmentation).
mask_overlay_df = df[df['fold'] == 1].sample(20, random_state=42)
# Instantiate the CustomDataGen for demonstration.
datagen = CustomDataGen(
    df=df,
    audio_dir="/kaggle/input/urbansound8k",
    batch_size=8,
    n_mels=128,
    duration=4,
    sr=22050,
    mask_overlay_df=mask_overlay_df # Pass the mask overlay DataFrame
)

# Retrieve a single batch of data from the generator.
X_batch, y_batch = datagen[0]
print("X_batch shape:", X_batch.shape) # Expected: (batch_size, n_mels, time_steps, 1)
print("y_batch:", y_batch) # Expected: (batch_size,) of integer class IDs

# Plot the first spectrogram from the batch to visualize the normalized input.
plt.figure(figsize=(10, 4))
# Transpose the spectrogram for correct orientation (time on x-axis, mel bands on y-axis).
plt.imshow(X_batch[0][:, :, 0].T, aspect='auto', origin='lower', cmap='magma')
plt.title(f"Normalized Spectrogram for ClassID: {y_batch[0]}")
plt.xlabel("Time bins")
plt.ylabel("Mel bands")
# Note: Colorbar might not accurately reflect dB after normalization to [0,1].
plt.colorbar(format="%+2.0f dB")
plt.tight_layout()
plt.show()

# Check the original mel_db shape for a single sample (for comparison with model input shape).
y_test, sr_test = librosa.load(file_path, sr=22050, duration=4)
mel_test = librosa.feature.melspectrogram(y=y_test, sr=sr_test, n_mels=128)
mel_db_test = librosa.power_to_db(mel_test)
print("Original mel_db shape for a single sample:", mel_db_test.shape)


# --- Double UNet/Hybrid Mask-aware Model Definition (Improved) ---
def crop_to_match(encoder_tensor, decoder_tensor):
    """
    Crops the encoder tensor to match the spatial dimensions of the decoder tensor.
    This is essential for skip connections in UNet architectures, where feature maps
    from the encoder are concatenated with upsampled feature maps from the decoder.

    Args:
        encoder_tensor (tf.Tensor): Feature map from the encoder path.
        decoder_tensor (tf.Tensor): Feature map from the decoder path (upsampled).

    Returns:
        tf.Tensor: Cropped encoder tensor.
    """
    # Calculate the amount to crop from height and width.
    # max(0, ...) ensures no negative cropping (i.e., no padding).
    crop_height = encoder_tensor.shape[1] - decoder_tensor.shape[1]
    crop_width = encoder_tensor.shape[2] - decoder_tensor.shape[2]

    # Define the cropping dimensions for Cropping2D layer.
    cropping = ((0, max(0, crop_height)), (0, max(0, crop_width)))

    # Apply Cropping2D only if cropping is necessary.
    if crop_height != 0 or crop_width != 0:
        encoder_tensor = Cropping2D(cropping=cropping)(encoder_tensor)
    return encoder_tensor

def unet_block(inputs, filters, kernel_size=(3, 3), dropout=0.3):
    """
    Defines a single UNet encoder-decoder block with improved architecture.
    Each convolutional block now includes two Conv2D layers followed by BatchNormalization
    and ReLU activation, which generally improves training stability and performance.

    Args:
        inputs (tf.Tensor): The input tensor to the UNet block.
        filters (int): The base number of filters for the convolutional layers.
        kernel_size (tuple): Size of the convolutional kernels.
        dropout (float): Dropout rate for regularization.

    Returns:
        tf.Tensor: The output tensor of the UNet block.
    """
    # --- Encoder Path ---
    # Consists of two convolutional layers, each followed by BatchNormalization and ReLU,
    # then Dropout, and finally MaxPooling for downsampling.

    # First convolutional block in the encoder
    c1 = Conv2D(filters, kernel_size, padding='same')(inputs)
    c1 = BatchNormalization()(c1) # IMPROVEMENT: Added BatchNormalization for stable training
    c1 = Activation('relu')(c1)

    c1 = Conv2D(filters, kernel_size, padding='same')(c1)
    c1 = BatchNormalization()(c1) # IMPROVEMENT: Added BatchNormalization
    c1 = Activation('relu')(c1)
    c1 = Dropout(dropout)(c1) # Dropout for regularization
    p1 = MaxPooling2D((2, 2))(c1) # Downsample by 2x2

    # Second convolutional block in the encoder
    c2 = Conv2D(filters*2, kernel_size, padding='same')(p1)
    c2 = BatchNormalization()(c2) # IMPROVEMENT: Added BatchNormalization
    c2 = Activation('relu')(c2)

    c2 = Conv2D(filters*2, kernel_size, padding='same')(c2)
    c2 = BatchNormalization()(c2) # IMPROVEMENT: Added BatchNormalization
    c2 = Activation('relu')(c2)
    c2 = Dropout(dropout)(c2) # Dropout for regularization
    p2 = MaxPooling2D((2, 2))(c2) # Downsample by 2x2

    # --- Bottleneck ---
    # The deepest part of the UNet, capturing the most abstract features.
    b1 = Conv2D(filters*4, kernel_size, padding='same')(p2)
    b1 = BatchNormalization()(b1) # IMPROVEMENT: Added BatchNormalization
    b1 = Activation('relu')(b1)

    b1 = Conv2D(filters*4, kernel_size, padding='same')(b1)
    b1 = BatchNormalization()(b1) # IMPROVEMENT: Added BatchNormalization
    b1 = Activation('relu')(b1)
    b1 = Dropout(dropout)(b1) # Dropout for regularization

    # --- Decoder Path ---
    # Upsamples feature maps and concatenates them with corresponding encoder feature maps (skip connections).
    # Each upsampling block consists of Conv2DTranspose, Concatenation, and two Conv2D layers
    # with BatchNormalization and ReLU.

    # First upsampling block (from bottleneck to c2 level)
    u1 = Conv2DTranspose(filters*2, (2, 2), strides=(2, 2), padding='same')(b1) # Upsample by 2x2
    c2_cropped = crop_to_match(c2, u1) # Crop encoder feature map to match upsampled decoder map
    u1 = Concatenate()([u1, c2_cropped]) # Concatenate for skip connection

    u1 = Conv2D(filters*2, kernel_size, padding='same')(u1)
    u1 = BatchNormalization()(u1) # IMPROVEMENT: Added BatchNormalization
    u1 = Activation('relu')(u1)

    u1 = Conv2D(filters*2, kernel_size, padding='same')(u1)
    u1 = BatchNormalization()(u1) # IMPROVEMENT: Added BatchNormalization
    u1 = Activation('relu')(u1)

    # Second upsampling block (from u1 to c1 level)
    u2 = Conv2DTranspose(filters, (2, 2), strides=(2, 2), padding='same')(u1) # Upsample by 2x2
    c1_cropped = crop_to_match(c1, u2) # Crop encoder feature map to match upsampled decoder map
    u2 = Concatenate()([u2, c1_cropped]) # Concatenate for skip connection

    u2 = Conv2D(filters, kernel_size, padding='same')(u2)
    u2 = BatchNormalization()(u2) # IMPROVEMENT: Added BatchNormalization
    u2 = Activation('relu')(u2)

    u2 = Conv2D(filters, kernel_size, padding='same')(u2)
    u2 = BatchNormalization()(u2) # IMPROVEMENT: Added BatchNormalization
    u2 = Activation('relu')(u2)

    return u2

def build_double_unet(input_shape, num_classes, dropout=0.3):
    """
    Builds the Double UNet model for urban sound classification.
    This architecture stacks two UNet-like encoder-decoder paths.

    Args:
        input_shape (tuple): Shape of the input spectrogram (height, width, channels).
        num_classes (int): Number of output classes for classification.
        dropout (float): Dropout rate to apply in UNet blocks.

    Returns:
        tf.keras.Model: The compiled Keras model.
    """
    inp = Input(shape=input_shape) # Define the input layer

    # First UNet block: Processes the raw input spectrogram.
    block1 = unet_block(inp, filters=32, dropout=dropout)

    # Second UNet block: Processes the output of the first UNet block.
    # This "double UNet" structure aims to refine features or handle more complex
    # hierarchical patterns, potentially useful for "mask-aware" processing.
    block2 = unet_block(block1, filters=32, dropout=dropout)

    # --- IMPROVEMENT: Classification Head with Dense Layer ---
    # Instead of a 1x1 Conv2D followed by GlobalAveragePooling, a more typical
    # classification head involves GlobalAveragePooling followed by one or more
    # Dense layers. This allows the model to learn complex non-linear mappings
    # from the pooled features to the final class probabilities.

    # Apply Global Average Pooling to reduce spatial dimensions to 1x1,
    # resulting in a feature vector for each sample in the batch.
    pooled_features = GlobalAveragePooling2D()(block2)

    # Add a Dense layer for classification.
    # The number of units equals `num_classes`, and 'softmax' activation
    # is used for multi-class classification, outputting probabilities for each class.
    classification_output = Dense(num_classes, activation='softmax', name='output')(pooled_features)

    # Define the Keras Model with the specified input and output.
    model = Model(inputs=inp, outputs=classification_output)
    return model

# --- Train/Validation Splits ---
print("\n--- Preparing Data Splits ---")
metadata_path = "/kaggle/input/urbansound8k/UrbanSound8K.csv"
audio_dir = "/kaggle/input/urbansound8k"

df = pd.read_csv(metadata_path)

# Define the fold to be used for validation (e.g., fold 1).
# This ensures a consistent split for evaluation.
val_fold = 1
train_df = df[df['fold'] != val_fold].reset_index(drop=True) # Training data: all folds except validation fold
val_df = df[df['fold'] == val_fold].reset_index(drop=True)   # Validation data: samples from the validation fold

# Optional: Create a DataFrame for mask overlay samples.
# These samples will be randomly overlaid on training data for augmentation.
mask_overlay_df = train_df.sample(20, random_state=42)

batch_size = 16 # Define the batch size for training and validation generators.

# Instantiate the CustomDataGen for training and validation datasets.
train_gen = CustomDataGen(train_df, audio_dir=audio_dir, batch_size=batch_size, shuffle=True, mask_overlay_df=mask_overlay_df)
val_gen = CustomDataGen(val_df, audio_dir=audio_dir, batch_size=batch_size, shuffle=False, mask_overlay_df=None) # No augmentation for validation

# Inspect the input shape from one batch of the training generator.
# This shape is used to define the input layer of the Keras model.
X_sample, y_sample = train_gen[0]
input_shape = X_sample.shape[1:]  # (n_mels, time_steps, 1) - excludes the batch dimension
num_classes = 10  # UrbanSound8K dataset has 10 distinct classes

print(f"Model Input Shape: {input_shape}")
print(f"Number of Classes: {num_classes}")

# Build the Double UNet model with the determined input shape and number of classes.
model = build_double_unet(input_shape=input_shape, num_classes=num_classes, dropout=0.3)
model.summary() # Print a summary of the model architecture, including layer types and output shapes.

# --- Model Compilation ---
# Configure the model for training.
model.compile(
    optimizer=Adam(learning_rate=0.001), # IMPROVEMENT: Explicitly set initial learning rate for Adam optimizer.
                                         # Adam is a good default, but fine-tuning LR can help.
    loss='sparse_categorical_crossentropy', # Appropriate loss function for multi-class classification
                                            # with integer labels (not one-hot encoded).
    metrics=['accuracy'] # Metric to monitor during training.
)

# --- Callbacks ---
# Define Keras callbacks to control training behavior.
# EarlyStopping: Stops training if validation loss doesn't improve for 'patience' epochs.
#                `restore_best_weights=True` loads the weights from the best epoch.
es = callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True, verbose=1)
# ModelCheckpoint: Saves the model (or weights) when validation loss is at its minimum.
mc = callbacks.ModelCheckpoint("/kaggle/working/models/double_unet_best.keras",
                                    save_best_only=True, monitor='val_loss', verbose=1)
# IMPROVEMENT: ReduceLROnPlateau: Reduces the learning rate when a metric (val_loss) has stopped improving.
#              This helps the model converge better in later stages of training.
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3, min_lr=1e-6, verbose=1)


# --- Model Training ---
print("\n--- Starting Model Training ---")
# Train the model using the data generators.
# `steps_per_epoch` and `validation_steps` are explicitly provided for clarity
# and robustness with custom Sequence generators.
history = model.fit(
    train_gen,
    validation_data=val_gen,
    epochs=20, # Maximum number of epochs to train
    steps_per_epoch=len(train_gen), # Number of batches per training epoch
    validation_steps=len(val_gen), # Number of batches per validation epoch
    callbacks=[es, mc, reduce_lr], # List of callbacks to apply during training
    verbose=2 # Verbosity mode (1 = progress bar, 2 = one line per epoch)
)

# --- Plotting Training History ---
print("\n--- Plotting Training History ---")
plt.figure(figsize=(12, 5))

# Plot training and validation accuracy over epochs.
plt.subplot(1, 2, 1) # 1 row, 2 columns, first plot
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Val Accuracy')
plt.title('Double UNet Model Accuracy per Epoch')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

# Plot training and validation loss over epochs.
plt.subplot(1, 2, 2) # 1 row, 2 columns, second plot
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Val Loss')
plt.title('Double UNet Model Loss per Epoch')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.tight_layout() # Adjust layout to prevent overlapping titles/labels
plt.show()


# --- Confusion Matrix for Validation Results (Improved) ---
print("\n--- Generating Confusion Matrix ---")

# IMPROVEMENT: Initialize val_preds and val_truth ONCE before the loop.
# This ensures that predictions and true labels are collected from all validation batches.
val_preds, val_truth = [], []

# Iterate through the validation generator to collect all predictions and true labels.
for batch_idx, (Xb, yb) in enumerate(val_gen):
    # Optional: Check for NaN/Inf values in batch before prediction.
    # Such values can cause model prediction errors.
    if np.any(np.isnan(Xb)) or np.any(np.isinf(Xb)):
        print(f"Warning: Batch {batch_idx} contains NaN or Inf values. Skipping prediction for this batch.")
        continue # Skip this batch if it has problematic values

    # Optional: Check if a batch is entirely zeros. This might indicate an issue
    # with data loading or preprocessing, as silent audio might not be informative.
    if np.all(Xb == 0):
        print(f"Warning: Batch {batch_idx} is all zeros. This might indicate an issue with data loading/preprocessing.")

    try:
        # Predict on the current batch. `verbose=0` prevents progress bar for each batch prediction.
        preds = model.predict(Xb, verbose=0)
        # Convert predicted probabilities to class labels (index of the highest probability).
        pred_labels = np.argmax(preds, axis=1)
        # Extend the lists with predictions and true labels for the current batch.
        val_preds.extend(pred_labels.tolist())
        val_truth.extend(yb.tolist())
    except Exception as e:
        # Catch and report any errors during batch prediction.
        print(f"Error predicting on batch {batch_idx}: {e}")
        print(f"Problematic batch shape: {Xb.shape}")

# Defensive check to avoid errors if no predictions were collected (e.g., due to all batches having NaNs).
if val_truth and val_preds:
    # Compute the confusion matrix.
    cm = confusion_matrix(val_truth, val_preds, labels=list(range(num_classes)))
    # Map class IDs back to their original string names for better readability in the confusion matrix.
    label_names = [df[df['classID'] == i]['class'].values[0] for i in range(num_classes)]

    # Plot the confusion matrix.
    plt.figure(figsize=(10,8))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=label_names)
    disp.plot(xticks_rotation=45, cmap='magma', colorbar=True) # Rotate x-axis labels for readability
    plt.title("Validation Confusion Matrix (Double UNet)")
    plt.show()
else:
    print("Warning: No predictions available for confusion matrix. Check data loading and prediction loop.")