# üëÅÔ∏è‚Äçüó®Ô∏è Computer Vision with CIFAR-10 Dataset

## ‚öôÔ∏è Setup and imports

In this section, we import all required libraries, enable Plotly's dark theme, and define some global constants such as the input shape and class names for CIFAR-10.

In [19]:
from plotly.graph_objs import Figure
from __future__ import annotations

from typing import Tuple
from pathlib import Path
from typing import Final

from plotly.graph_objects import Figure

import os
import numpy as np
from typing import Optional
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.io as pio

from tensorflow import keras
from tensorflow.keras import layers, models

from sklearn.metrics import confusion_matrix, classification_report

# Use Plotly dark theme globally
pio.templates.default = "plotly_dark"

# Make results a bit more reproducible
np.random.seed(42)

# CIFAR-10 meta info
NUM_CLASSES: int = 10
INPUT_SHAPE: Tuple[int, int, int] = (32, 32, 3)

CLASS_NAMES = [
    "airplane", "automobile", "bird", "cat", "deer",
    "dog", "frog", "horse", "ship", "truck",
]

In [24]:
# Base directories for exported figures
PLOTS_DIR: Final[Path] = Path("../plots")
DOCS_DIR: Final[Path] = Path("../docs")


def save_fig(fig: Figure, name: str, scale: int = 2) -> None:
    """
    Save a Plotly figure as both HTML and PNG.

    Parameters
    ----------
    fig : Figure
        The Plotly figure to be saved.
    name : str
        Base file name without extension.
    scale : int, default 2
        Scale factor for the PNG export (higher = higher resolution).
    """
    # Ensure output directories exist
    PLOTS_DIR.mkdir(parents=True, exist_ok=True)
    DOCS_DIR.mkdir(parents=True, exist_ok=True)

    html_path = DOCS_DIR / f"{name}.html"
    png_path = PLOTS_DIR / f"{name}.png"

    # Save interactive HTML file
    fig.write_html(str(html_path), include_plotlyjs="cdn")

    # Save static PNG image (requires kaleido)
    fig.write_image(str(png_path), scale=scale)

    print(f"Saved HTML to {html_path}")
    print(f"Saved PNG to {png_path}")

## üì• Loading the CIFAR-10 dataset

In this section, we load the CIFAR-10 dataset using Keras, optionally limit the number of training samples for faster experimentation,
and reshape the labels to a one-dimensional format.
We also print the shapes of the arrays and the mapping from class indices to human-readable class names.

In [25]:
# Optional: limit the number of training samples for faster experiments
TRAIN_LIMIT: Optional[int] = None  # e.g. 10_000 or None to use all samples

# Load CIFAR-10 using Keras
(train_images_raw, train_labels_raw), (test_images_raw, test_labels_raw) = keras.datasets.cifar10.load_data()

# Optionally reduce training set size
if TRAIN_LIMIT is not None:
    train_images_raw = train_images_raw[:TRAIN_LIMIT]
    train_labels_raw = train_labels_raw[:TRAIN_LIMIT]

# Flatten label arrays to shape (N,)
train_labels = train_labels_raw.reshape(-1)
test_labels = test_labels_raw.reshape(-1)

# For now we keep separate "raw" images (uint8, 0‚Äì255).
# Later we will create preprocessed versions (e.g. normalized float32).
train_images = train_images_raw
test_images = test_images_raw

print(f"Train images: {train_images.shape}, Train labels: {train_labels.shape}")
print(f"Test images:  {test_images.shape}, Test labels:  {test_labels.shape}")

print("\nClass index ‚Üí name mapping:")
for idx, name in enumerate(CLASS_NAMES):
    print(f"  {idx}: {name}")

Train images: (50000, 32, 32, 3), Train labels: (50000,)
Test images:  (10000, 32, 32, 3), Test labels:  (10000,)

Class index ‚Üí name mapping:
  0: airplane
  1: automobile
  2: bird
  3: cat
  4: deer
  5: dog
  6: frog
  7: horse
  8: ship
  9: truck


## üìä Class distribution in the training set

In this section, we visualize the class distribution of the CIFAR-10 training set using Plotly with a dark theme.
This helps us confirm that the dataset is balanced across all classes.

In [26]:
# Compute class counts for the training labels
unique_labels, label_counts = np.unique(train_labels, return_counts=True)

class_names_for_plot = [CLASS_NAMES[int(idx)] for idx in unique_labels]

fig_class_dist = px.bar(
    x=class_names_for_plot,
    y=label_counts,
    labels={"x": "Class", "y": "Count"},
    title="CIFAR-10 training set class distribution",
)
fig_class_dist.update_layout(xaxis_tickangle=-45)
fig_class_dist.show()


save_fig(fig_class_dist, "class_distribution")

Saved HTML to ../docs/class_distribution.html
Saved PNG to ../plots/class_distribution.png


## üñºÔ∏è Example images per class

In this section, we visualize multiple example images for each CIFAR-10 class using Plotly.
This helps to build an intuitive understanding of what the model will see during training
and how the different classes look in practice.

In [23]:
# Number of example images to display per class
EXAMPLES_PER_CLASS: int = 6  # you can increase this if you want

rows = NUM_CLASSES
cols = EXAMPLES_PER_CLASS

fig = make_subplots(
    rows=rows,
    cols=cols,
    horizontal_spacing=0.01,
    vertical_spacing=0.01,
)

for class_idx, class_name in enumerate(CLASS_NAMES):
    # Find indices of all images belonging to this class
    class_indices = np.where(train_labels == class_idx)[0]

    if len(class_indices) == 0:
        # This can happen if TRAIN_LIMIT is very small and some classes are missing
        continue

    # Randomly select up to EXAMPLES_PER_CLASS images
    n_examples = min(EXAMPLES_PER_CLASS, len(class_indices))
    selected_indices = np.random.choice(
        class_indices,
        size=n_examples,
        replace=False,
    )

    for col_idx, img_idx in enumerate(selected_indices):
        row = class_idx + 1
        col = col_idx + 1

        fig.add_trace(
            go.Image(z=train_images[img_idx]),
            row=row,
            col=col,
        )

        # Hide axis ticks for a cleaner look
        fig.update_xaxes(showticklabels=False, row=row, col=col)
        fig.update_yaxes(showticklabels=False, row=row, col=col)

    # Add the class name as a y-axis title for the first column of the row
    fig.update_yaxes(title_text=class_name, row=class_idx + 1, col=1)

fig.update_layout(
    title="Example CIFAR-10 images per class",
    height=200 * rows,
    width=200 * cols,
    showlegend=False,
)
fig.show()

save_fig(fig, "examples_per_class")

Saved HTML to docs/examples_per_class.html
Saved PNG to plots/examples_per_class.png
