## üîß Imports & setup

This section prepares all required libraries, plotting configuration and helper utilities
used throughout the CIFAR-10 evaluation notebook.# Imports & Setup

In [1]:
from __future__ import annotations

from typing import Any

import numpy as np
import pandas as pd
from sklearn.metrics import recall_score
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.io as pio

from src.utils import (
    load_cifar10,
    CLASS_NAMES,
    CLASS_NAMES_EMOJI,
    predict_classes,
    evaluate_model,
    confusion_matrix_array,
    load_model,
    load_history,
    classification_report_str,
    save_fig
)

# Use a dark theme for all Plotly figures
pio.templates.default = "plotly_dark"

## üì¶ Load model, history & CIFAR-10 test data

Here we load the trained CNN model, its saved training history and the CIFAR-10
test split that we will analyse in detail.# Model, History & Test Data

In [143]:
# Load trained model
model = load_model("cifar10_main")

history = load_history("cifar10_main")

data = load_cifar10(normalize=False)

print("Test images:", data.x_test.shape, data.x_test.dtype)
print("Test labels:", data.y_test.shape, data.y_test.dtype)

Test images: (10000, 32, 32, 3) float32
Test labels: (10000,) uint8


## üìâ Visualise training curves

The following plots show training and validation accuracy and loss over the epochs.
They help to understand convergence speed and to detect potential overfitting or underfitting.

In [146]:
test_metrics = evaluate_model(model, data.x_test, data.y_test)
print(f"Test loss:     {test_metrics['loss']:.4f}")
print(f"Test accuracy: {test_metrics['accuracy']:.4f}")

history_dict = history.history

epochs_range = list(range(1, len(history_dict["loss"]) + 1))

fig_acc = go.Figure()
fig_acc.add_trace(
    go.Scatter(
        x=epochs_range,
        y=history_dict["accuracy"],
        mode="lines+markers",
        name="Train accuracy",
    )
)
fig_acc.add_trace(
    go.Scatter(
        x=epochs_range,
        y=history_dict["val_accuracy"],
        mode="lines+markers",
        name="Val accuracy",
    )
)
fig_acc.update_layout(
    title="Training and validation accuracy",
    xaxis_title="Epoch",
    yaxis_title="Accuracy",
)
fig_acc.show()

fig_loss = go.Figure()
fig_loss.add_trace(
    go.Scatter(
        x=epochs_range,
        y=history_dict["loss"],
        mode="lines+markers",
        name="Train loss",
    )
)
fig_loss.add_trace(
    go.Scatter(
        x=epochs_range,
        y=history_dict["val_loss"],
        mode="lines+markers",
        name="Val loss",
    )
)
fig_loss.update_layout(
    title="Training and validation loss",
    xaxis_title="Epoch",
    yaxis_title="Loss",
)
fig_loss.show()

save_fig(fig_acc, "cifar10_acc")
save_fig(fig_loss, "cifar10_loss")

Test loss:     0.5797
Test accuracy: 0.8093


Saved HTML to ../docs/cifar10_acc.html
Saved PNG to ../plots/cifar10_acc.png
Saved HTML to ../docs/cifar10_loss.html
Saved PNG to ../plots/cifar10_loss.png


## ‚úÖ Evaluate on the test set

We evaluate the trained model on the held-out CIFAR-10 test set and inspect the
overall loss and accuracy to see how well the model generalises.

In [145]:
# Evaluate on test data
metrics = evaluate_model(model, data.x_test, data.y_test)
print(f"Test loss:     {metrics['loss']:.4f}")
print(f"Test accuracy: {metrics['accuracy']:.4f}")

# Predict test labels
y_test_pred = predict_classes(model, data.x_test)

# Classification report (precision / recall / F1 per class)
report_str = classification_report_str(data.y_test, y_test_pred, target_names=CLASS_NAMES_EMOJI)
print(report_str)

Test loss:     0.5797
Test accuracy: 0.8093
              precision    recall  f1-score   support

          ‚úàÔ∏è     0.8536    0.7870    0.8189      1000
           üöó     0.8901    0.9310    0.9101      1000
           ü¶ú     0.8093    0.7300    0.7676      1000
           üê±     0.8171    0.5270    0.6407      1000
           ü´é     0.8231    0.7400    0.7794      1000
           üê∂     0.8368    0.6820    0.7515      1000
           üê∏     0.6948    0.9400    0.7990      1000
           üê¥     0.7811    0.9030    0.8377      1000
          üõ≥Ô∏è     0.8658    0.9100    0.8874      1000
           üöõ     0.7787    0.9430    0.8530      1000

    accuracy                         0.8093     10000
   macro avg     0.8150    0.8093    0.8045     10000
weighted avg     0.8150    0.8093    0.8045     10000



## üî¢ Confusion matrix

The normalised confusion matrix reveals which classes the model confuses with each other.
Bright values on the diagonal are correct predictions, off-diagonal values indicate systematic errors.# Confusion Matrix

In [121]:
y_test_pred = predict_classes(model, data.x_test)

cm_norm = confusion_matrix_array(data.y_test, y_test_pred, normalize=True)

fig_cm = px.imshow(
    cm_norm,
    x=CLASS_NAMES_EMOJI,
    y=CLASS_NAMES_EMOJI,
    color_continuous_scale="Viridis",
    labels={"x": "Predicted class", "y": "True class", "color": "Proportion"},
    title="Normalized confusion matrix (CIFAR-10)",
)
fig_cm.update_xaxes(tickfont=dict(size=18))
fig_cm.update_yaxes(tickfont=dict(size=18))
fig_cm.update_xaxes(side="top")
fig_cm.show()

save_fig(fig_cm,"cifar10_confusion_matrix")

Saved HTML to ../docs/cifar10_confusion_matrix.html
Saved PNG to ../plots/cifar10_confusion_matrix.png


## üß† Build prediction analysis DataFrame

In this step we collect model outputs (probabilities, predicted labels and correctness flags)
into a single `analysis_df` DataFrame that powers all subsequent analyses and plots.## Analyse

In [9]:
probs = model.predict(data.x_test, batch_size=128, verbose=1)

y_true = data.y_test
y_pred = probs.argmax(axis=1)
prob_pred = probs.max(axis=1)
prob_true = probs[np.arange(len(y_true)), y_true]

analysis_df = pd.DataFrame(
    {
        "idx": np.arange(len(y_true)),
        "y_true": y_true,
        "y_pred": y_pred,
        "prob_pred": prob_pred,
        "prob_true": prob_true,
    }
)
analysis_df["correct"] = analysis_df["y_true"] == analysis_df["y_pred"]
analysis_df["true_name"] = analysis_df["y_true"].map(lambda i: CLASS_NAMES_EMOJI[int(i)])
analysis_df["pred_name"] = analysis_df["y_pred"].map(lambda i: CLASS_NAMES_EMOJI[int(i)])

analysis_df.head()

[1m79/79[0m [32m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m[37m[0m [1m3s[0m 36ms/step


Unnamed: 0,idx,y_true,y_pred,prob_pred,prob_true,correct,true_name,pred_name
0,0,3,3,0.840778,0.840778,True,üê±,üê±
1,1,8,8,0.760333,0.760333,True,üõ≥Ô∏è,üõ≥Ô∏è
2,2,8,8,0.91052,0.91052,True,üõ≥Ô∏è,üõ≥Ô∏è
3,3,0,8,0.652681,0.156621,False,‚úàÔ∏è,üõ≥Ô∏è
4,4,6,6,0.998789,0.998789,True,üê∏,üê∏


## üìä Per-class accuracy & confidence

Here we measure how well the model performs for each class individually and how confident
it is about the true class on average. This highlights particularly strong and weak classes.## Per-Class Accuracy & Confidence

In [10]:
# Per-class Accuracy (%)
per_class_acc = (
    analysis_df
    .groupby("true_name")["correct"]
    .mean()
    .mul(100.0)
    .reset_index(name="accuracy_percent")
    .sort_values("accuracy_percent", ascending=False)
)
per_class_acc

Unnamed: 0,true_name,accuracy_percent
6,üöõ,94.3
4,üê∏,94.0
5,üöó,93.1
7,üõ≥Ô∏è,91.0
2,üê¥,90.3
0,‚úàÔ∏è,78.7
9,ü´é,74.0
8,ü¶ú,73.0
3,üê∂,68.2
1,üê±,52.7


In [118]:
fig_acc_per_class = px.bar(
    per_class_acc,
    x="true_name",
    y="accuracy_percent",
    title="Per-class accuracy on CIFAR-10",
    labels={"true_name": "True class", "accuracy_percent": "Accuracy (%)"},
)
fig_acc_per_class.update_xaxes(tickfont=dict(size=28))
fig_acc_per_class.update_layout()
fig_acc_per_class.show()

save_fig(fig_acc_per_class,"cifar10_per_class_accuracy")

Saved HTML to ../docs/cifar10_per_class_accuracy.html
Saved PNG to ../plots/cifar10_per_class_accuracy.png


In [14]:
true_conf_stats = (
    analysis_df
    .groupby("true_name")["prob_true"]
    .agg(["mean", "median"])
    .reset_index()
    .sort_values("mean", ascending=False)
)
true_conf_stats

Unnamed: 0,true_name,mean,median
6,üöõ,0.914664,0.998825
5,üöó,0.914105,0.999112
4,üê∏,0.898902,0.994708
7,üõ≥Ô∏è,0.879917,0.996983
2,üê¥,0.85622,0.99136
0,‚úàÔ∏è,0.742276,0.944883
8,ü¶ú,0.670746,0.830354
9,ü´é,0.669897,0.807679
3,üê∂,0.601945,0.711592
1,üê±,0.442382,0.386253


In [128]:
fig_true_conf = px.bar(
    true_conf_stats,
    x="true_name",
    y="mean",
    title="Average confidence for the true class (all predictions)",
    labels={"true_name": "True class", "mean": "Avg probability for true class"},
)
fig_true_conf.update_xaxes(tickfont=dict(size=28))
fig_true_conf.update_layout()
fig_true_conf.show()

save_fig(fig_true_conf,"cifar10_avg_confidence_per_true_class")

Saved HTML to ../docs/cifar10_avg_confidence_per_true_class.html
Saved PNG to ../plots/cifar10_avg_confidence_per_true_class.png


## üèÜ Top 1% most confident correct predictions

We visualise the most confidently correct predictions to see which images the model
is absolutely sure about and how these "easy" examples look.# Top 1 % right

In [127]:
TOP_FRACTION = 0.01

correct_df = analysis_df[analysis_df["correct"]].copy()
n_top_correct = max(1, int(len(correct_df) * TOP_FRACTION))

top_confident_correct = (
    correct_df
    .sort_values("prob_pred", ascending=False)
    .head(n_top_correct)
    .reset_index(drop=True)
)

print(
    f"Top {TOP_FRACTION*100:.1f}% most confident CORRECT predictions: "
    f"{len(top_confident_correct)} samples"
)
top_confident_correct[["idx", "true_name", "pred_name", "prob_pred", "prob_true"]].head()

Top 1.0% most confident CORRECT predictions: 80 samples


Unnamed: 0,idx,true_name,pred_name,prob_pred,prob_true
0,4128,üöó,üöó,1.0,1.0
1,1378,üöó,üöó,1.0,1.0
2,3394,üöó,üöó,1.0,1.0
3,5596,üöõ,üöõ,1.0,1.0
4,7483,üöó,üöó,1.0,1.0


In [129]:
import math

N_SHOW = min(25, len(top_confident_correct))
rows = math.ceil(N_SHOW / 5)
cols = 5

fig_top = make_subplots(
    rows=rows,
    cols=cols,
    subplot_titles=[
        f"true={row.true_name}<br> pred={row.pred_name} {row.prob_pred*100:.1f}%"
        for _, row in top_confident_correct.iloc[:N_SHOW].iterrows()
    ],
)

for i, (_, row) in enumerate(top_confident_correct.iloc[:N_SHOW].iterrows()):
    r = i // cols + 1
    c = i % cols + 1
    img = data.x_test[row.idx]

    fig_top.add_trace(
        go.Image(z=img.astype("uint8")),
        row=r,
        col=c,
    )

for ann in fig_top.layout.annotations:
    ann.font.size = 18

fig_top.update_layout(
    height=300 * rows,
    width=300 * cols,
    title=dict(
        text="Top 1% most confident CORRECT predictions",
        font=dict(size=18),
        y=0.98,
        yanchor="top",
    ),
)

fig_top.for_each_xaxis(
    lambda ax: ax.update(
        showticklabels=False,
        showgrid=False,
        zeroline=False,
    )
)
fig_top.for_each_yaxis(
    lambda ax: ax.update(
        showticklabels=False,
        showgrid=False,
        zeroline=False,
    )
)

fig_top.show()

save_fig(fig_top,"cifar10_top_1_percent_correct_predictions")

Saved HTML to ../docs/cifar10_top_1_percent_correct_predictions.html
Saved PNG to ../plots/cifar10_top_1_percent_correct_predictions.png


## ‚ö†Ô∏è Top 1% most confident wrong predictions

Next we inspect the most confidently wrong predictions. These are especially interesting,
because the model is very sure but still incorrect ‚Äì useful for understanding failure modes.

In [26]:
wrong_df = analysis_df[~analysis_df["correct"]].copy()
n_top_wrong = max(1, int(len(wrong_df) * TOP_FRACTION))

top_confident_wrong = (
    wrong_df
    .sort_values("prob_pred", ascending=False)
    .head(n_top_wrong)
    .reset_index(drop=True)
)

print(
    f"Top {TOP_FRACTION*100:.1f}% most confident WRONG predictions: "
    f"{len(top_confident_wrong)} samples"
)
top_confident_wrong[["idx", "true_name", "pred_name", "prob_pred", "prob_true"]].head()

Top 1.0% most confident WRONG predictions: 19 samples


Unnamed: 0,idx,true_name,pred_name,prob_pred,prob_true
0,9227,üöó,üöõ,0.99995,4.965745e-05
1,3645,üöó,üöõ,0.999945,5.380484e-05
2,4784,‚úàÔ∏è,üõ≥Ô∏è,0.999906,6.566145e-05
3,2532,‚úàÔ∏è,üöó,0.999893,6.713568e-07
4,5511,üê±,üê∏,0.999859,9.72245e-05


In [130]:
N_SHOW_WRONG = min(25, len(top_confident_wrong))
rows = math.ceil(N_SHOW_WRONG / 5)
cols = 5

fig_wrong = make_subplots(
    rows=rows,
    cols=cols,
    subplot_titles=[
        f"true={row.true_name}<br>pred={row.pred_name} {row.prob_pred*100:.1f}%"
        for _, row in top_confident_wrong.iloc[:N_SHOW_WRONG].iterrows()
    ],
)

for i, (_, row) in enumerate(top_confident_wrong.iloc[:N_SHOW_WRONG].iterrows()):
    r = i // cols + 1
    c = i % cols + 1
    img = data.x_test[row.idx]

    fig_wrong.add_trace(
        go.Image(z=img.astype("uint8")),
        row=r,
        col=c,
    )

for ann in fig_wrong.layout.annotations:
    ann.font.size = 18

fig_wrong.update_layout(
    height=300 * rows,
    width=300 * cols,
    title=dict(
        text="Top 1% most confident WRONG predictions",
        font=dict(size=18),
        y=0.98,
        yanchor="top",
    ),
)
fig_wrong.show()

save_fig(fig_wrong,"cifar10_top_1_percent_wrong_predictions")

Saved HTML to ../docs/cifar10_top_1_percent_wrong_predictions.html
Saved PNG to ../plots/cifar10_top_1_percent_wrong_predictions.png


## üéØ Confidence distribution: correct vs wrong

This plot compares the model‚Äôs confidence for the predicted class between **correct**
and **wrong** predictions. Ideally, correct predictions should concentrate at high
probabilities, while wrong predictions should appear mostly at lower confidence levels.
Large overlap indicates overconfident errors.

In [131]:
fig_conf_hist = go.Figure()

fig_conf_hist.add_trace(
    go.Histogram(
        x=analysis_df.loc[analysis_df["correct"], "prob_pred"],
        nbinsx=20,
        name="Correct",
        opacity=0.6,
    )
)

fig_conf_hist.add_trace(
    go.Histogram(
        x=analysis_df.loc[~analysis_df["correct"], "prob_pred"],
        nbinsx=20,
        name="Wrong",
        opacity=0.6,
    )
)

fig_conf_hist.update_layout(
    barmode="overlay",
    title="Confidence distribution for predicted class (correct vs wrong)",
    xaxis_title="Probability of predicted class",
    yaxis_title="Count",
)
fig_conf_hist.show()

save_fig(fig_conf_hist,"cifar10_confidence_hist")

Saved HTML to ../docs/cifar10_confidence_hist.html
Saved PNG to ../plots/cifar10_confidence_hist.png


In [138]:
THRESHOLD = 0.3

hard_but_correct = (
    analysis_df
    .loc[analysis_df["correct"] & (analysis_df["prob_true"] < THRESHOLD)]
    .sort_values("prob_true", ascending=True)
    .reset_index(drop=True)
)

print(f"Found {len(hard_but_correct)} 'hard but correct' samples with prob_true < {THRESHOLD}")

hard_but_correct[["idx", "true_name", "pred_name", "prob_pred", "prob_true"]].head()

Found 28 'hard but correct' samples with prob_true < 0.3


Unnamed: 0,idx,true_name,pred_name,prob_pred,prob_true
0,637,üê±,üê±,0.205308,0.205308
1,5835,üê±,üê±,0.21952,0.21952
2,689,‚úàÔ∏è,‚úàÔ∏è,0.221308,0.221308
3,1368,üê±,üê±,0.226974,0.226974
4,8940,üê∏,üê∏,0.232509,0.232509


In [141]:
import math
from plotly.subplots import make_subplots
import plotly.graph_objects as go

N_SHOW_HARD = min(25, len(hard_but_correct))
if N_SHOW_HARD == 0:
    print("No hard-but-correct samples found for this threshold.")
else:
    rows = math.ceil(N_SHOW_HARD / 5)
    cols = 5

    fig_hard = make_subplots(
        rows=rows,
        cols=cols,
        subplot_titles=[
            f"{row.true_name}<br>p_true={row.prob_true*100:.1f}%"
            for _, row in hard_but_correct.iloc[:N_SHOW_HARD].iterrows()
        ],
    )

    for i, (_, row) in enumerate(hard_but_correct.iloc[:N_SHOW_HARD].iterrows()):
        r = i // cols + 1
        c = i % cols + 1
        img = data.x_test[row.idx]

        fig_hard.add_trace(
            go.Image(z=img.astype("uint8")),
            row=r,
            col=c,
        )


    for ann in fig_hard.layout.annotations:
        ann.font.size = 20

    fig_hard.update_layout(
        height=300 * rows,
        width=300 * cols,
        title=dict(
            text=f"Hard but correct (prob_true < {THRESHOLD})",
            font=dict(size=18),
            y=0.98,
            yanchor="top",
        ),
    )

    fig_hard.for_each_xaxis(
        lambda ax: ax.update(
            showticklabels=False,
            showgrid=False,
            zeroline=False,
        )
    )
    fig_hard.for_each_yaxis(
        lambda ax: ax.update(
            showticklabels=False,
            showgrid=False,
            zeroline=False,
        )
    )

    fig_hard.show()

    save_fig(fig_hard,"cifar10_hard_prediction_grid")

Saved HTML to ../docs/cifar10_hard_prediction_grid.html
Saved PNG to ../plots/cifar10_hard_prediction_grid.png


## üß™ Additional error analysis

In the final analysis we look at confidence distributions, "hard but correct" examples
and per-class misclassification patterns to better understand where the model struggles.

In [68]:
wrong_df = analysis_df[~analysis_df["correct"]].copy()

# Counts pro (true_name, pred_name)
confusions_all = (
    wrong_df
    .groupby(["true_name", "pred_name"])
    .size()
    .reset_index(name="count")
)

# pro true_name top_k h√§ufigste pred_name behalten
top_k = 5
confusions_all["rank"] = confusions_all.groupby("true_name")["count"].rank(
    method="first", ascending=False
)
top_confusions_all = confusions_all[confusions_all["rank"] <= top_k].copy()

top_confusions_all.head()

Unnamed: 0,true_name,pred_name,count,rank
1,‚úàÔ∏è,üê¥,20,5.0
3,‚úàÔ∏è,üöó,29,3.0
4,‚úàÔ∏è,üöõ,53,2.0
5,‚úàÔ∏è,üõ≥Ô∏è,63,1.0
6,‚úàÔ∏è,ü¶ú,26,4.0


In [142]:
fig_all_conf = px.bar(
    top_confusions_all,
    x="pred_name",
    y="count",
    color="pred_name",
    facet_col="true_name",
    facet_col_wrap=5,
    facet_row_spacing=0.15,
    facet_col_spacing=0.06,
    title=f"Top {top_k} misclassifications per true class",
    labels={
        "pred_name": "Predicted class",
        "count": "Number of errors",
        "true_name": "True class",
    },
)

# Facet-Titel aufr√§umen und gr√∂√üer machen
for ann in fig_all_conf.layout.annotations:
    if isinstance(ann.text, str) and "=" in ann.text:
        ann.text = ann.text.split("=", 1)[1].strip()
    ann.font.size = 28

fig_all_conf.update_layout(
    showlegend=False,
    height=650,
    title=dict(
        font=dict(size=18),
        y=0.98,
        yanchor="top",
    ),
    margin=dict(t=80, b=80, l=60, r=20),
)

fig_all_conf.for_each_xaxis(
    lambda axis: axis.update(
        tickfont=dict(size=14),
        title_font=dict(size=14),
    )
)

fig_all_conf.for_each_yaxis(
    lambda axis: axis.update(
        tickfont=dict(size=12),
        title_font=dict(size=14),
    )
)

fig_all_conf.show()

save_fig(fig_all_conf,"cifar10_misclassification_grid")

Saved HTML to ../docs/cifar10_misclassification_grid.html
Saved PNG to ../plots/cifar10_misclassification_grid.png
