# CNN + Attention: Images Only (Optimized)

Simple and optimized model using only spectrogram images.

In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

import numpy as np
from PIL import Image
import tensorflow as tf
from tensorflow.keras import layers, Model
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

np.random.seed(42)
tf.random.set_seed(42)

print(f"TensorFlow: {tf.__version__}")

In [None]:
# ==================== CONFIG ====================
IMAGE_PATH = '/Users/narac0503/GIT/GTZAN Dataset Classification/GTZAN-Dataset-Classification/gtzan-classification/data/gtzan/images_original'
TARGET_SIZE = (128, 128)
GENRES = ['blues', 'classical', 'country', 'disco', 'hiphop', 'jazz', 'metal', 'pop', 'reggae', 'rock']
NUM_CLASSES = 10

print(f"Path exists: {os.path.exists(IMAGE_PATH)}")

In [None]:
# ==================== LOAD IMAGES ====================
X, y = [], []

print("Loading images...\n")
for genre in GENRES:
    genre_path = os.path.join(IMAGE_PATH, genre)
    if not os.path.exists(genre_path):
        print(f"{genre}: NOT FOUND")
        continue
    
    files = [f for f in os.listdir(genre_path) if f.endswith('.png')]
    print(f"{genre}: {len(files)} images")
    
    for f in tqdm(files, desc=genre):
        try:
            img = Image.open(os.path.join(genre_path, f))
            img = img.convert('RGB').resize(TARGET_SIZE)
            X.append(np.array(img) / 255.0)
            y.append(genre)
        except:
            pass

X = np.array(X)
y = np.array(y)
print(f"\nLoaded: {X.shape}")

In [None]:
# ==================== PREPROCESS ====================
le = LabelEncoder()
y_enc = le.fit_transform(y)
y_cat = to_categorical(y_enc, NUM_CLASSES)

# Normalize images
X = (X - X.mean()) / (X.std() + 1e-8)
print(f"Normalized - Mean: {X.mean():.4f}, Std: {X.std():.4f}")

In [None]:
# ==================== SPLIT ====================
X_temp, X_test, y_temp, y_test = train_test_split(
    X, y_cat, test_size=0.1, stratify=y_enc, random_state=42
)
X_train, X_val, y_train, y_val = train_test_split(
    X_temp, y_temp, test_size=0.111, stratify=np.argmax(y_temp, 1), random_state=42
)

print(f"Train: {len(X_train)}")
print(f"Val: {len(X_val)}")
print(f"Test: {len(X_test)}")

In [None]:
# ==================== BUILD MODEL ====================
inputs = layers.Input(shape=(128, 128, 3))
x = inputs

# CNN blocks
for filters in [32, 64, 128, 256]:
    x = layers.Conv2D(filters, 3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.MaxPooling2D(2)(x)
    x = layers.Dropout(0.25)(x)

# Reshape for attention
x = layers.Reshape((-1, 256))(x)

# Multi-head attention
x = layers.MultiHeadAttention(num_heads=4, key_dim=32)(x, x)
x = layers.GlobalAveragePooling1D()(x)

# Classifier
x = layers.Dense(256, activation='relu')(x)
x = layers.Dropout(0.4)(x)
x = layers.Dense(128, activation='relu')(x)
x = layers.Dropout(0.3)(x)
outputs = layers.Dense(NUM_CLASSES, activation='softmax')(x)

model = Model(inputs, outputs)

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
    loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1),
    metrics=['accuracy']
)

model.summary()

In [None]:
# ==================== TRAIN ====================
callbacks = [
    EarlyStopping(monitor='val_accuracy', patience=20, restore_best_weights=True, verbose=1),
    ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=7, min_lr=1e-7, verbose=1),
    ModelCheckpoint('best_cnn_attention.keras', monitor='val_accuracy', save_best_only=True, verbose=1)
]

history = model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    batch_size=16,
    epochs=100,
    callbacks=callbacks
)

In [None]:
# ==================== PLOT ====================
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.plot(history.history['accuracy'], label='Train')
ax1.plot(history.history['val_accuracy'], label='Val')
ax1.set_title('Accuracy')
ax1.legend()

ax2.plot(history.history['loss'], label='Train')
ax2.plot(history.history['val_loss'], label='Val')
ax2.set_title('Loss')
ax2.legend()

plt.tight_layout()
plt.show()

In [None]:
# ==================== EVALUATE ====================
model.load_weights('best_cnn_attention.keras')
loss, acc = model.evaluate(X_test, y_test, verbose=0)

print(f"\n{'='*50}")
print(f"TEST ACCURACY: {acc*100:.2f}%")
print(f"TEST LOSS: {loss:.4f}")
print(f"{'='*50}")

In [None]:
# ==================== RESULTS ====================
y_pred = np.argmax(model.predict(X_test, verbose=0), axis=1)
y_true = np.argmax(y_test, axis=1)

print("\nClassification Report:")
print(classification_report(y_true, y_pred, target_names=GENRES))

# Confusion matrix
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=GENRES, yticklabels=GENRES)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title(f'Confusion Matrix (Acc: {acc:.2%})')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()

In [None]:
# ==================== SAVE ====================
model.save('cnn_attention_images_final.keras')
print("Model saved!")