# üß† CIFAR-10 CNN Experiments

This notebook trains a convolutional neural network (CNN) on the CIFAR-10 dataset
using the utility functions from `utils.py`. The goal is to achieve strong predictive
performance with a clean, modular setup that runs well on a MacBook Air M3.


## 1Ô∏è‚É£ Setup and imports

In this section we import the utilities and configure a few global settings
such as the random seed, batch size and number of epochs.


In [1]:
from src.utils import (
    set_global_seed,
    load_cifar10,
    create_data_augmentation,
    build_cifar10_cnn,
    compile_model,
    train_model,
    evaluate_model,
    predict_classes,
    classification_report_str,
    confusion_matrix_array,
    CLASS_NAMES,
    CLASS_NAMES_EMOJI,
    NUM_CLASSES,
    save_fig,
    save_model_with_history,
    load_model,
    load_history
)

import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Ensure reproducibility
set_global_seed(42)

# High-level training configuration
LEARNING_RATE: float = 1e-3
BATCH_SIZE: int = 64
EPOCHS: int = 30  # adjust down if you want faster experiments


## 2Ô∏è‚É£ Load and inspect CIFAR-10 üì•

Here we load the CIFAR-10 dataset using the helper from `utils.py` and
briefly inspect shapes and class distribution.


In [3]:
# Load CIFAR-10 data (raw uint8 images, normalization is handled in the model)
data = load_cifar10(normalize=False)

print("Training images:", data.x_train.shape, data.x_train.dtype)
print("Test images:    ", data.x_test.shape, data.x_test.dtype)

# Basic class distribution in the training set
class_counts = np.bincount(data.y_train, minlength=len(CLASS_NAMES))
for idx, (name, count) in enumerate(zip(CLASS_NAMES, class_counts)):
    print(f"Class {idx:2d} ({name:10s}): {count}")


Training images: (50000, 32, 32, 3) float32
Test images:     (10000, 32, 32, 3) float32
Class  0 (airplane  ): 5000
Class  1 (automobile): 5000
Class  2 (bird      ): 5000
Class  3 (cat       ): 5000
Class  4 (deer      ): 5000
Class  5 (dog       ): 5000
Class  6 (frog      ): 5000
Class  7 (horse     ): 5000
Class  8 (ship      ): 5000
Class  9 (truck     ): 5000


### üìä Class distribution plot

The bar chart below shows how many samples we have per class in the
training split. CIFAR-10 is perfectly balanced, which is helpful for
both accuracy and recall.


In [6]:
fig_class_dist = px.bar(
    x=CLASS_NAMES_EMOJI,
    y=class_counts,
    title="CIFAR-10 training set class distribution",
    labels={"x": "Class", "y": "Count"},
)
fig_class_dist.update_layout(xaxis_tickangle=0)
fig_class_dist.update_xaxes(tickfont=dict(size=28))
fig_class_dist.show()


save_fig(fig_class_dist, "class_distribution")

Saved HTML to ../docs/class_distribution.html
Saved PNG to ../plots/class_distribution.png


### üñºÔ∏è Example images per class

In this section, we visualize multiple example images for each CIFAR-10 class using Plotly.
This helps to build an intuitive understanding of what the model will see during training
and how the different classes look in practice.

In [15]:
# Number of example images to display per class
EXAMPLES_PER_CLASS: int = 10  # you can increase this if you want

rows = NUM_CLASSES
cols = EXAMPLES_PER_CLASS

fig = make_subplots(
    rows=rows,
    cols=cols,
    horizontal_spacing=0.01,
    vertical_spacing=0.01,
)

for class_idx, class_name in enumerate(CLASS_NAMES_EMOJI):
    # Find indices of all images belonging to this class
    class_indices = np.where(data.y_train == class_idx)[0]

    if len(class_indices) == 0:
        # This can happen if TRAIN_LIMIT is very small and some classes are missing
        continue

    # Randomly select up to EXAMPLES_PER_CLASS images
    n_examples = min(EXAMPLES_PER_CLASS, len(class_indices))
    selected_indices = np.random.choice(
        class_indices,
        size=n_examples,
        replace=False,
    )

    for col_idx, img_idx in enumerate(selected_indices):
        row = class_idx + 1
        col = col_idx + 1

        fig.add_trace(
            go.Image(z=data.x_train[img_idx]),
            row=row,
            col=col,
        )

        # Hide axis ticks for a cleaner look
        fig.update_xaxes(showticklabels=False, row=row, col=col)
        fig.update_yaxes(showticklabels=False, row=row, col=col)

    # Add the class name as a y-axis title for the first column of the row
    fig.update_yaxes(
        title=dict(
            text=class_name,
            font=dict(size=30),  # üëà gr√∂√üerer Titel
            # standoff=10,       # optional: etwas Abstand zur Achse
        ),
        row=class_idx + 1,
        col=1,
    )

fig.update_layout(
    title="Example CIFAR-10 images per class",
    height=150 * rows,
    width=150 * cols,
    showlegend=False,
)
fig.show()

save_fig(fig, "examples_per_class")

Saved HTML to ../docs/examples_per_class.html
Saved PNG to ../plots/examples_per_class.png


## 3Ô∏è‚É£ Build and compile the CNN üß±

We now create a reasonably strong CNN architecture with:

- Data augmentation (random flips, rotations, zoom)
- Convolutional blocks with Batch Normalization and ReLU
- Max pooling and Dropout for downsampling and regularization
- A dense classification head with Dropout before the softmax layer

Compilation uses the Adam optimizer with a sensible learning rate for CIFAR-10.


In [5]:
# Create data augmentation pipeline
data_augmentation = create_data_augmentation()

# Build model
model = build_cifar10_cnn(
    input_shape=data.x_train.shape[1:],
    num_classes=len(CLASS_NAMES),
    data_augmentation=data_augmentation,
)

# Compile model
compile_model(model, learning_rate=LEARNING_RATE)

model.summary()


## 4Ô∏è‚É£ Train the model üöÇ

We train the model on the full CIFAR-10 training set with a small
validation split to monitor generalization. Default callbacks from
`train_model` use learning rate scheduling and early stopping
to reach strong performance without excessive overfitting.


In [7]:
history = train_model(
    model,
    data.x_train,
    data.y_train,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    validation_split=0.1,
)

Epoch 1/30
[1m704/704[0m [32m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m[37m[0m [1m87s[0m 123ms/step - accuracy: 0.4918 - loss: 1.3944 - val_accuracy: 0.3986 - val_loss: 2.1784 - learning_rate: 0.0010
Epoch 2/30
[1m704/704[0m [32m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m[37m[0m [1m127s[0m 180ms/step - accuracy: 0.5549 - loss: 1.2374 - val_accuracy: 0.4606 - val_loss: 1.6827 - learning_rate: 0.0010
Epoch 3/30
[1m704/704[0m [32m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m[37m[0m [1m103s[0m 146ms/step - accuracy: 0.5950 - loss: 1.1360 - val_accuracy: 0.5238 - val_loss: 1.4734 - learning_rate: 0.0010
Epoch 4/30
[1m704/704[0m [32m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m[37m[0m [1m108s[0m 154ms/step - accuracy: 0.6244 - loss: 1.0705 - val_accuracy: 0.5104 - val_loss: 1.6147 - learning_rate: 0.0010
Epoch 5/30
[1m704/704[0m [32m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚

In [6]:
# save model & history
save_model_with_history(model, history, "cifar10_main")

NameError: name 'model' is not defined