# Custom CNN Architecture for Image Classification

This notebook implements and trains a custom CNN architecture from scratch.

## Objectives:
- Design custom CNN architecture
- Implement data preprocessing and augmentation
- Train the model with proper validation
- Evaluate performance and visualize results
- Save the best model

## 1. Setup and Data Loading

In [1]:
# Configure environment for Apple Silicon optimization
import warnings
warnings.filterwarnings('ignore')
import os
import shutil
import random
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # Suppress TensorFlow warnings

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D, MaxPooling2D, Dropout, GlobalAveragePooling2D
# Configure TensorFlow for Apple Silicon
try:
    # Enable memory growth for GPU (if available)
    gpus = tf.config.list_physical_devices('GPU')
    if gpus:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print(f"✅ GPU configured: {len(gpus)} device(s) found")
    else:
        print("ℹ️  No GPU found, using CPU")
except Exception as e:
    print(f"⚠️  GPU configuration warning: {e}")

print(f"TensorFlow version: {tf.__version__}")
print(f"Available devices: {[device.name for device in tf.config.list_physical_devices()]}")

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)
random.seed(42)
print("✅ Random seeds set for reproducibility")

✅ GPU configured: 1 device(s) found
TensorFlow version: 2.15.0
Available devices: ['/physical_device:CPU:0', '/physical_device:GPU:0']
✅ Random seeds set for reproducibility


## 2. Data Loading and Preprocessing

In [2]:
path = '/Users/smithn5/.cache/kagglehub/datasets/alessiocorrado99/animals10/versions/2'
data_dir = os.path.join(path, 'raw-img')

In [3]:
# split data into training, validation, and test sets physically in the data directory
# previous_weights_path = '../models/custom_costum_animals10_best.h5'
base_dir = '../data/'
train_ratio = 0.9
val_ratio = 0.1
test_ratio = 0.0
batch_size = 32

img_height = 128 # for basic cnn
img_width = 128 # for basic cnn

# Create base directories for splits
for split in ['train', 'val', 'test']:
    split_dir = os.path.join(base_dir, split)
    os.makedirs(split_dir, exist_ok=True)
    for class_name in os.listdir(data_dir):
        class_path = os.path.join(data_dir, class_name)
        if not os.path.isdir(class_path):
            continue
        os.makedirs(os.path.join(split_dir, class_name), exist_ok=True)


In [4]:
# now split and copy images
for class_name in os.listdir(data_dir):
    class_path = os.path.join(data_dir, class_name)
    if not os.path.isdir(class_path):
        continue
    
    images = os.listdir(class_path)
    random.shuffle(images)
    
    n_total = len(images)
    n_train = int(n_total * train_ratio)
    n_val = int(n_total * val_ratio)
    n_test = n_total - n_train - n_val  # Ensure all images are used

    train_images = images[:n_train]
    val_images = images[n_train:n_train+n_val]
    test_images = images[n_train+n_val:]

    for img in train_images:
        shutil.copy(os.path.join(class_path, img), os.path.join(base_dir, 'train', class_name, img))
    for img in val_images:
        shutil.copy(os.path.join(class_path, img), os.path.join(base_dir, 'val', class_name, img))
    for img in test_images:
        shutil.copy(os.path.join(class_path, img), os.path.join(base_dir, 'test', class_name, img))


KeyboardInterrupt: 

In [None]:
# Load the dataset using tf.keras.preprocessing.image.ImageDataGenerator

base_dir = '../data/'


# ISSUE DIAGNOSIS: Your augmentation might be too aggressive!
# Let's create a less aggressive augmentation setup

print("🔧 Creating improved data generators...")

# Create LESS aggressive augmentation for training
train_datagen = ImageDataGenerator(
    rescale=1./255,  # Normalize to [0,1]
    rotation_range=20,      # Reduced from 20
    width_shift_range=0.1,  # Reduced from 0.2
    height_shift_range=0.1, # Reduced from 0.2
    shear_range=0.1,     # Reduced from 0.2
    zoom_range=0.1,         # Reduced from 0.2
    horizontal_flip=True,
    fill_mode='nearest'
)

# Create validation generator without augmentation (IMPORTANT!)
test_val_datagen = ImageDataGenerator(
    rescale=1./255,  # Only rescaling for validation
    rotation_range=20,      # Reduced from 20
    width_shift_range=0.1,  # Reduced from 0.2
    height_shift_range=0.1, # Reduced from 0.2
    shear_range=0.1,     # Reduced from 0.2
    zoom_range=0.1,         # Reduced from 0.2
    horizontal_flip=True,
    fill_mode='nearest'
    
)

# Create training dataset with augmentation
train_ds = train_datagen.flow_from_directory(
    os.path.join(base_dir, 'train'),
    shuffle=True,
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode='categorical'
)

# Create validation dataset WITHOUT augmentation
val_ds = test_val_datagen.flow_from_directory(
    os.path.join(base_dir, 'val'),
    shuffle=False,  # Don't shuffle validation
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode='categorical'
)

#create test dataset WITHOUT augmentation
test_ds = test_val_datagen.flow_from_directory(
    os.path.join(base_dir, 'test'),
    shuffle=False,  # Don't shuffle test
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode='categorical'
)

# Get class names from the dataset
class_names = list(train_ds.class_indices.keys())
print(f"Found {len(class_names)} classes: {class_names}")

# Calculate dataset sizes
print(f"Training samples: {train_ds.samples}")
print(f"Validation samples: {val_ds.samples}")
print(f"Test samples: {test_ds.samples}")
print(f"Training batches per epoch: {len(train_ds)}")
print(f"Validation batches per epoch: {len(val_ds)}")
print(f"Test batches per epoch: {len(test_ds)}")
print(f"Batch size: {batch_size}")

# Verify class balance
print(f"\nClass distribution check:")
class_counts = {}
for class_name in class_names:
    class_dir = os.path.join(path, 'raw-img', class_name)
    if os.path.exists(class_dir):
        count = len([f for f in os.listdir(class_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
        class_counts[class_name] = count
        print(f"  {class_name}: {count} images")

# Check for extremely imbalanced classes
min_count = min(class_counts.values())
max_count = max(class_counts.values())
imbalance_ratio = max_count / min_count if min_count > 0 else float('inf')
print(f"\nClass balance analysis:")
print(f"  Min class size: {min_count}")
print(f"  Max class size: {max_count}")
print(f"  Imbalance ratio: {imbalance_ratio:.2f}")

if imbalance_ratio > 10:
    print("⚠️  WARNING: Severe class imbalance detected!")
    print("   This could explain poor learning performance")
elif imbalance_ratio > 3:
    print("⚡ Moderate class imbalance detected")
else:
    print("✅ Classes are reasonably balanced")

print(f"\n✅ Improved data generators created with:")
print(f"  - Reduced augmentation intensity")
print(f"  - Separate validation generator (no augmentation)")
print(f"  - Proper class balance verification")

🔧 Creating improved data generators...
Found 23556 images belonging to 10 classes.
Found 2614 images belonging to 10 classes.
Found 9 images belonging to 10 classes.
Found 10 classes: ['cane', 'cavallo', 'elefante', 'farfalla', 'gallina', 'gatto', 'mucca', 'pecora', 'ragno', 'scoiattolo']
Training samples: 23556
Validation samples: 2614
Test samples: 9
Training batches per epoch: 737
Validation batches per epoch: 82
Test batches per epoch: 1
Batch size: 32

Class distribution check:
  cane: 4863 images
  cavallo: 2623 images
  elefante: 1446 images
  farfalla: 2112 images
  gallina: 3098 images
  gatto: 1668 images
  mucca: 1866 images
  pecora: 1820 images
  ragno: 4821 images
  scoiattolo: 1862 images

Class balance analysis:
  Min class size: 1446
  Max class size: 4863
  Imbalance ratio: 3.36
⚡ Moderate class imbalance detected

✅ Improved data generators created with:
  - Reduced augmentation intensity
  - Separate validation generator (no augmentation)
  - Proper class balance ve

## 3. Custom CNN Architecture Design

In [None]:
def create_custom_made_cnn(input_shape, num_classes=10):
    """
    Create a simple CNN model.
    """
    model = Sequential([
    # block 1: 32 filtros
    Conv2D(32, (3, 3), activation='relu', padding='same', input_shape=input_shape),
    Conv2D(32, (3, 3), activation='relu', padding='same'),
    MaxPooling2D((2, 2)),
    Dropout(0.1),

    # block 2: 64 filtros
    Conv2D(64, (3, 3), activation='relu', padding='same'),
    Conv2D(64, (3, 3), activation='relu', padding='same'),
    MaxPooling2D((2, 2)),
    Dropout(0.2),

    # block 3: 128 filtros
    Conv2D(128, (3, 3), activation='relu', padding='same'),
    Conv2D(128, (3, 3), activation='relu', padding='same'),
    MaxPooling2D((2, 2)),
    Dropout(0.3),

    # block 4: 256 filtros
    Conv2D(256, (3, 3), activation='relu', padding='same'),
    Conv2D(256, (3, 3), activation='relu', padding='same'),
    MaxPooling2D((2, 2)),
    Dropout(0.4),

    # Capas densas
    GlobalAveragePooling2D(),
    Dense(1024, activation='relu'),
    Dense(256, activation='relu'),
    Dropout(0.5),
    Dense(num_classes, activation='softmax')
    ])

     # the one before was
    return model

# Create the model
input_shape = (img_height, img_width, 3)
num_classes = len(class_names)
custom_model = create_custom_made_cnn(input_shape, num_classes)

# Display model architecture
custom_model.summary()

print(f"\nModel Details:")
print(f"Input shape: {input_shape}")
print(f"Number of classes: {num_classes}")
print(f"Total parameters: {custom_model.count_params():,}")
print(f"Trainable parameters: {sum([tf.keras.backend.count_params(w) for w in custom_model.trainable_weights]):,}")

# Visualize model architecture
tf.keras.utils.plot_model(
    custom_model,
    to_file='../models/simple_cnn_architecture.png',
    show_shapes=True,
    show_layer_names=True,
    rankdir='TB'
)
print("✅ Model architecture diagram saved to '../models/simple_cnn_architecture.png'")

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d (Conv2D)             (None, 224, 224, 32)      896       
                                                                 
 conv2d_1 (Conv2D)           (None, 224, 224, 32)      9248      
                                                                 
 max_pooling2d (MaxPooling2  (None, 112, 112, 32)      0         
 D)                                                              
                                                                 
 dropout (Dropout)           (None, 112, 112, 32)      0         
                                                                 
 conv2d_2 (Conv2D)           (None, 112, 112, 64)      18496     
                                                                 
 conv2d_3 (Conv2D)           (None, 112, 112, 64)      36928     
                                                        

## 4. Model Compilation and Training Setup

In [None]:
# Let's first debug the data and model setup
print("🔍 DEBUGGING MODEL AND DATA...")
print("="*50)

print("\n" + "="*50)
print("🚀 STARTING IMPROVED TRAINING...")
print("="*50)

# STEP 1: Compile the model FIRST
custom_model.compile(
    optimizer=tf.keras.optimizers.legacy.Adam(learning_rate=0.0002),  
    loss='categorical_crossentropy',
    metrics=['accuracy']
)
print("✅ Model compiled with reduced learning rate: 0.0002")




# STEP 2: Setup callbacks (ModelCheckpoint will save best weights)
callbacks = [
    EarlyStopping(
        monitor='val_loss',
        patience=3,  # Reduced patience
        restore_best_weights=True,
        # verbose=1
    ),
    ModelCheckpoint(
        '../models/best_model_128.keras',
        monitor='val_loss',
        save_best_only=True,
        save_weights_only=False,  # Save full model
        # verbose=1
    ),
    ReduceLROnPlateau(
        monitor='val_accuracy',
        factor=0.5,  # More aggressive reduction
        patience=2,  # slower response
        min_lr=1e-6,
        # verbose=1
    )
]

print("✅ Training callbacks configured")
print("   - EarlyStopping: patience=10")
print("   - ModelCheckpoint: saves best model")
print("   - ReduceLROnPlateau: factor=0.1, patience=5")

🔍 DEBUGGING MODEL AND DATA...

🚀 STARTING IMPROVED TRAINING...
✅ Model compiled with reduced learning rate: 0.0002
✅ Training callbacks configured
   - EarlyStopping: patience=10
   - ModelCheckpoint: saves best model
   - ReduceLROnPlateau: factor=0.1, patience=5


## 5. Model Training

In [None]:

# Training parameters
EPOCHS = 35  # Increased epochs since we lowered LR

print(f"Improved training setup:")
print(f"- Learning rate: 0.0001 (reduced from 0.001)")
print(f"- Max epochs: {EPOCHS}")
print(f"- Early stopping patience: 8")
print(f"- LR reduction factor: 0.2 (more aggressive)")
print(f"- Batch size: {batch_size}")

# Train with improved settings
history = custom_model.fit(
    train_ds,
    #steps_per_epoch=len(train_ds),
    epochs=EPOCHS,
    validation_data=val_ds,
    #validation_steps=len(val_ds),
    callbacks=callbacks,
    verbose=1
)

print("\n✅ Training completed!")
if len(history.history['accuracy']) > 0:
    print(f"Total epochs trained: {len(history.history['accuracy'])}")
    print(f"Best validation accuracy: {max(history.history['val_accuracy']):.4f}")
    print(f"Final training accuracy: {history.history['accuracy'][-1]:.4f}")
    print(f"Final validation accuracy: {history.history['val_accuracy'][-1]:.4f}")


# Reset generators for future use
train_ds.reset()
val_ds.reset()

Improved training setup:
- Learning rate: 0.0001 (reduced from 0.001)
- Max epochs: 35
- Early stopping patience: 8
- LR reduction factor: 0.2 (more aggressive)
- Batch size: 32
Epoch 1/35
Epoch 2/35
Epoch 3/35
Epoch 4/35
137/737 [====>.........................] - ETA: 2:57 - loss: 2.9897 - accuracy: 0.1535

KeyboardInterrupt: 

## 6. Training History Visualization

In [None]:
def plot_training_history(history):
    """
    Plot training and validation metrics with enhanced visualizations.
    """
    # Create a larger figure with more subplots
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))
    
    # Plot accuracy
    epochs = range(1, len(history.history['accuracy']) + 1)
    ax1.plot(epochs, history.history['accuracy'], 'b-', label='Training Accuracy', linewidth=2, marker='o', markersize=4)
    ax1.plot(epochs, history.history['val_accuracy'], 'r-', label='Validation Accuracy', linewidth=2, marker='s', markersize=4)
    ax1.set_title('Model Accuracy Over Time', fontsize=14, fontweight='bold')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Accuracy')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    ax1.set_ylim([0, 1])
    
    # Add best accuracy annotation
    best_val_acc = max(history.history['val_accuracy'])
    best_val_acc_epoch = history.history['val_accuracy'].index(best_val_acc) + 1
    ax1.annotate(f'Best: {best_val_acc:.3f}', 
                xy=(best_val_acc_epoch, best_val_acc), 
                xytext=(best_val_acc_epoch + 2, best_val_acc - 0.05),
                arrowprops=dict(arrowstyle='->', color='red', alpha=0.7))
    
    # Plot loss
    ax2.plot(epochs, history.history['loss'], 'b-', label='Training Loss', linewidth=2, marker='o', markersize=4)
    ax2.plot(epochs, history.history['val_loss'], 'r-', label='Validation Loss', linewidth=2, marker='s', markersize=4)
    ax2.set_title('Model Loss Over Time', fontsize=14, fontweight='bold')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Loss')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    # Plot learning rate (if available)
    if 'lr' in history.history:
        ax3.plot(epochs, history.history['lr'], 'g-', label='Learning Rate', linewidth=2)
        ax3.set_title('Learning Rate Schedule', fontsize=14, fontweight='bold')
        ax3.set_xlabel('Epoch')
        ax3.set_ylabel('Learning Rate')
        ax3.set_yscale('log')
        ax3.legend()
        ax3.grid(True, alpha=0.3)
    else:
        # Plot accuracy difference (overfitting indicator)
        acc_diff = np.array(history.history['accuracy']) - np.array(history.history['val_accuracy'])
        ax3.plot(epochs, acc_diff, 'purple', label='Training - Validation Accuracy', linewidth=2)
        ax3.axhline(y=0, color='black', linestyle='--', alpha=0.5)
        ax3.set_title('Overfitting Indicator', fontsize=14, fontweight='bold')
        ax3.set_xlabel('Epoch')
        ax3.set_ylabel('Accuracy Difference')
        ax3.legend()
        ax3.grid(True, alpha=0.3)
    
    # Plot smoothed validation accuracy trend
    if len(history.history['val_accuracy']) > 5:
        # Simple moving average
        window_size = min(5, len(history.history['val_accuracy']) // 3)
        val_acc_smooth = np.convolve(history.history['val_accuracy'], 
                                   np.ones(window_size)/window_size, mode='valid')
        smooth_epochs = range(window_size, len(history.history['val_accuracy']) + 1)
        
        ax4.plot(epochs, history.history['val_accuracy'], 'lightcoral', alpha=0.5, label='Raw Validation Accuracy')
        ax4.plot(smooth_epochs, val_acc_smooth, 'darkred', linewidth=3, label=f'Smoothed (window={window_size})')
        ax4.set_title('Validation Accuracy Trend', fontsize=14, fontweight='bold')
        ax4.set_xlabel('Epoch')
        ax4.set_ylabel('Validation Accuracy')
        ax4.legend()
        ax4.grid(True, alpha=0.3)
        ax4.set_ylim([0, 1])
    else:
        # If not enough epochs, show final metrics summary
        ax4.text(0.1, 0.8, 'Training Summary', fontsize=16, fontweight='bold', transform=ax4.transAxes)
        ax4.text(0.1, 0.6, f'Epochs: {len(epochs)}', fontsize=12, transform=ax4.transAxes)
        ax4.text(0.1, 0.5, f'Best Val Acc: {best_val_acc:.4f}', fontsize=12, transform=ax4.transAxes)
        ax4.text(0.1, 0.4, f'Final Train Acc: {history.history["accuracy"][-1]:.4f}', fontsize=12, transform=ax4.transAxes)
        ax4.text(0.1, 0.3, f'Final Val Acc: {history.history["val_accuracy"][-1]:.4f}', fontsize=12, transform=ax4.transAxes)
        ax4.set_xlim(0, 1)
        ax4.set_ylim(0, 1)
        ax4.axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Print comprehensive metrics
    print("\n" + "="*60)
    print("TRAINING HISTORY ANALYSIS")
    print("="*60)
    print(f"Total epochs trained: {len(epochs)}")
    print(f"Best validation accuracy: {best_val_acc:.4f} at epoch {best_val_acc_epoch}")
    print(f"Final training accuracy: {history.history['accuracy'][-1]:.4f}")
    print(f"Final validation accuracy: {history.history['val_accuracy'][-1]:.4f}")
    print(f"Best validation loss: {min(history.history['val_loss']):.4f}")
    print(f"Final validation loss: {history.history['val_loss'][-1]:.4f}")
    
    # Calculate overfitting metrics
    final_gap = history.history['accuracy'][-1] - history.history['val_accuracy'][-1]
    print(f"\nOverfitting Analysis:")
    print(f"Final accuracy gap: {final_gap:.4f}")
    if final_gap > 0.1:
        print("⚠️  Potential overfitting detected (gap > 0.1)")
    elif final_gap > 0.05:
        print("⚡ Mild overfitting (gap > 0.05)")
    else:
        print("✅ Good generalization (gap ≤ 0.05)")
    
    # Training stability
    last_5_val_acc = history.history['val_accuracy'][-5:] if len(history.history['val_accuracy']) >= 5 else history.history['val_accuracy']
    val_acc_std = np.std(last_5_val_acc)
    print(f"Validation accuracy stability (last 5 epochs std): {val_acc_std:.4f}")
    
    if val_acc_std < 0.01:
        print("✅ Training converged well")
    elif val_acc_std < 0.02:
        print("⚡ Training mostly stable")
    else:
        print("⚠️  Training still fluctuating")
    
    print("="*60)

# Plot the training history
plot_training_history(history)

## 7. Model Evaluation on Test Set

In [None]:
# Evaluate on validation set (since we don't have a separate test set)
val_ds.reset()
val_loss, val_accuracy = custom_model.evaluate(val_ds, verbose=0)
print(f"Validation Accuracy: {val_accuracy:.4f}")
print(f"Validation Loss: {val_loss:.4f}")

# Generate predictions on validation set
print("Generating predictions...")
val_ds.reset()
y_pred = custom_model.predict(val_ds, verbose=1)
y_pred_classes = np.argmax(y_pred, axis=1)

# Get true labels from validation generator
val_ds.reset()
y_true_classes = []
for i in range(len(val_ds)):
    batch_images, batch_labels = next(val_ds)
    batch_true_classes = np.argmax(batch_labels, axis=1)
    y_true_classes.extend(batch_true_classes)

y_true_classes = np.array(y_true_classes)

# Ensure we have the same number of predictions and true labels
min_length = min(len(y_pred_classes), len(y_true_classes))
y_pred_classes = y_pred_classes[:min_length]
y_true_classes = y_true_classes[:min_length]

print(f"Number of samples evaluated: {min_length}")

# Classification report
print("\nClassification Report:")
print(classification_report(y_true_classes, y_pred_classes, target_names=class_names))

# Reset generator
val_ds.reset()

## 8. Confusion Matrix Visualization

In [None]:
# Generate and plot confusion matrix
cm = confusion_matrix(y_true_classes, y_pred_classes)

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
           xticklabels=class_names, yticklabels=class_names)
plt.title('Confusion Matrix - Custom CNN')
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.xticks(rotation=45)
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()

# Calculate per-class accuracy
class_accuracy = cm.diagonal() / cm.sum(axis=1)
print("\nPer-class Accuracy:")
for i, class_name in enumerate(class_names):
    print(f"{class_name}: {class_accuracy[i]:.4f}")

## 9. Model Saving and Results Summary

In [None]:
# model accuracy saving and documentation
val_accuracy_str = f"{val_accuracy:.2f}".replace('.', '_')
# Save the final model
model_filename = f'../models/custom_costum_animals10_acc_{val_accuracy_str}.h5'
custom_model.save(model_filename)
print(f"Model saved as: {model_filename}")

# Save training history
import pickle
with open(f'../models/custom_costum_animals10_acc_{val_accuracy_str}_history.pkl', 'wb') as f:
    pickle.dump(history.history, f)
print(f"Training history saved")

# Save model configuration
model_config = {
    'model_name': 'Custom CNN',
    'dataset': 'Animals10',
    'input_shape': input_shape,
    'num_classes': num_classes,
    'batch_size': batch_size,
    'epochs_trained': len(history.history['accuracy']),
    'best_val_accuracy': max(history.history['val_accuracy']),
    'final_val_accuracy': val_accuracy,
    'total_parameters': custom_model.count_params(),
    'architecture': 'VGG16-inspired with Global Average Pooling'
}

import json
with open(f'../models/custom_costum_animals10_acc_{val_accuracy_str}_config.json', 'w') as f:
    json.dump(model_config, f, indent=2)

# Results summary
print("\n" + "="*60)
print("CUSTOM CNN RESULTS SUMMARY")
print("="*60)
print(f"Dataset: Animals10")
print(f"Architecture: Custom CNN with {custom_model.count_params():,} parameters")
print(f"Input shape: {input_shape}")
print(f"Training epochs: {len(history.history['accuracy'])}")
print(f"Best validation accuracy: {max(history.history['val_accuracy']):.4f}")
print(f"Final validation accuracy: {val_accuracy:.4f}")
print(f"Model saved: {model_filename}")
print(f"Architecture follows CNN pattern:")
print(f"  - 5 convolutional blocks (64→128→256→512→512 filters)")
print(f"  - Global Average Pooling instead of FC layers")
print(f"  - Dense(512) + Dense({num_classes}) classifier")
print("="*60)