In [4]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import DenseNet201
from tensorflow.keras.models import Model
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout
import os

In [26]:
data_dir = "Train_Spectrogram_Images"

# === Image Preprocessing ===
img_size = (224, 224)
batch_size = 32

# Data augmentation and rescaling
datagen = ImageDataGenerator(
    rescale=1./255,
    validation_split=0.2,
    rotation_range=20,          # Randomly rotate images by 0 to 20 degrees
    width_shift_range=0.2,      # Randomly shift images horizontally (20% of width)
    height_shift_range=0.2,     # Randomly shift images vertically (20% of height)
    shear_range=0.15,           # Shear angle in counter-clockwise direction
    zoom_range=0.2,             # Randomly zoom into images
)

train_generator = datagen.flow_from_directory(
    data_dir,
    target_size=img_size,
    batch_size=batch_size,
    class_mode='categorical',
    subset='training'
)

val_generator = datagen.flow_from_directory(
    data_dir,
    target_size=img_size,
    batch_size=batch_size,
    class_mode='categorical',
    subset='validation'
)

Found 5328 images belonging to 7 classes.
Found 1328 images belonging to 7 classes.


In [29]:
# === Load DenseNet201 Pre-trained Model ===
base_model = DenseNet201(weights='imagenet', include_top=False, input_shape=(img_size[0], img_size[1], 3))

# === Freeze base model layers ===
base_model.trainable = False

# === Custom Classification Head ===
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(256, activation='relu')(x)
x = Dropout(0.5)(x)
output = Dense(train_generator.num_classes, activation='softmax')(x)

# === Final Model ===
model = Model(inputs=base_model.input, outputs=output)

In [30]:
model.summary()

In [31]:
import keras
# === Compile the Model ===
model.compile(optimizer=keras.optimizers.Adam(learning_rate=0.0001),
              loss='categorical_crossentropy',
              metrics=['accuracy'])