# Libraries

In [1]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
import os
from tensorflow.keras.callbacks import ModelCheckpoint

# Load and process

In [None]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.utils.class_weight import compute_class_weight
import numpy as np

# Configuration
img_size = (224, 224)
batch_size = 32
data_dir = 'all_data/'

# 1. Get class weights FIRST (before any augmentation)
full_datagen = ImageDataGenerator(rescale=1./255)
full_data = full_datagen.flow_from_directory(
    data_dir,
    target_size=img_size,
    class_mode='categorical',
    shuffle=False  # Important for correct class mapping
)
class_weights = compute_class_weight(
    'balanced',
    classes=np.arange(len(full_data.class_indices)),
    y=full_data.classes
)
class_weight_dict = dict(enumerate(class_weights))

# 2. Audio-Specific Augmentation Pipeline
def spectrogram_augmentation(image):
    """Custom augmentation for spectrograms"""
    # Time warping
    if tf.random.uniform(()) > 0.7:
        time_warp = tf.random.uniform((), -5, 5, dtype=tf.int32)
        image = tf.roll(image, shift=time_warp, axis=1)
    
    # Frequency masking
    if tf.random.uniform(()) > 0.5:
        max_freq = min(10, img_size[0]//10)  # Max 10% of frequencies
        f = tf.random.uniform((), 1, max_freq, dtype=tf.int32)
        f0 = tf.random.uniform((), 0, img_size[0]-f, dtype=tf.int32)
        image = tf.concat([
            image[:f0, :, :],
            tf.zeros_like(image[f0:f0+f, :, :]),
            image[f0+f:, :, :]
        ], axis=0)
    
    # Time masking
    if tf.random.uniform(()) > 0.5:
        max_time = min(10, img_size[1]//10)  # Max 10% of time
        t = tf.random.uniform((), 1, max_time, dtype=tf.int32)
        t0 = tf.random.uniform((), 0, img_size[1]-t, dtype=tf.int32)
        image = tf.concat([
            image[:, :t0, :],
            tf.zeros_like(image[:, t0:t0+t, :]),
            image[:, t0+t:, :]
        ], axis=1)
    
    return image

# 3. Data Generators with Validation Split
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=8,        # Increased from 5
    width_shift_range=0.08,  # Increased from 0.05
    height_shift_range=0.08,
    zoom_range=0.1,          # Increased from 0.05
    preprocessing_function=spectrogram_augmentation,
    validation_split=0.2
)

val_datagen = ImageDataGenerator(
    rescale=1./255,
    validation_split=0.2
)

# 4. Data Loading with Balanced Batches
train_data = train_datagen.flow_from_directory(
    data_dir,
    target_size=img_size,
    batch_size=batch_size,
    class_mode='categorical',
    subset='training',
    shuffle=True,
    seed=42,
    color_mode='rgb'
)

validation_data = val_datagen.flow_from_directory(
    data_dir,
    target_size=img_size,
    batch_size=batch_size,
    class_mode='categorical',
    subset='validation',
    shuffle=True,  # Important for proper validation
    color_mode='rgb'
)

# 5. Verify Data Pipeline
def show_batch_distribution(generator, name):
    """Verify class distribution in batches"""
    print(f"\n{name} class indices:", generator.class_indices)
    generator.reset()  # Important to start from beginning
    for i in range(2):  # Check first 2 batches
        x, y = next(generator)  # Using next() on the iterator
        print(f"Batch {i} distribution:", np.sum(y, axis=0))

show_batch_distribution(train_data, "Training")
show_batch_distribution(validation_data, "Validation")

# Check if original dataset is balanced
full_datagen = ImageDataGenerator()
full_data = full_datagen.flow_from_directory(data_dir, shuffle=False)
print("Full dataset distribution:", np.bincount(full_data.classes))

# Check if validation split maintained balance
train_indices = train_data.classes
val_indices = validation_data.classes
print("Train distribution:", np.bincount(train_indices))
print("Val distribution:", np.bincount(val_indices))

# 6. Visualize Augmented Samples
import matplotlib.pyplot as plt
train_data.reset()  # Reset generator before sampling
x, y = next(train_data)
plt.figure(figsize=(10,5))
for i in range(3):
    plt.subplot(1,3,i+1)
    plt.imshow(x[i])
    plt.title(f"Class {np.argmax(y[i])}")
    plt.axis('off')
plt.tight_layout()
plt.show()

# Model

In [4]:
model = Sequential([
    Conv2D(32, (3,3), activation='relu', input_shape=(224, 224, 3)),
    MaxPooling2D(2,2),

    Conv2D(64, (3,3), activation='relu'),
    MaxPooling2D(2,2),

    Conv2D(128, (3,3), activation='relu'),
    MaxPooling2D(2,2),

    Conv2D(256, (3,3), activation='relu'),
    MaxPooling2D(2,2),

    Flatten(),
    Dense(256, activation='relu'),
    Dropout(0.3),
    Dense(train_data.num_classes, activation='softmax')
    ])
model.summary()
    

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


In [6]:
model.compile(
    optimizer='adam',
    loss='categorical_crossentropy',
    metrics = ['accuracy']
)