# 06 – Error Analysis & Misclassification Exploration

Why the model still fails — visual, quantitative and qualitative inspection.


In [None]:
import numpy as np, matplotlib.pyplot as plt, seaborn as sns
from collections import Counter
from src.evaluation import evaluate


## 1  Get predictions & basic metrics


In [None]:
# choose model & validation data (swap for your variant)
model       = cnn_aug           # or resnet, cnn, etc.
X_val_imgs  = X_val             # images (H,W,3) float32 [0-1]
y_val_hot   = y_val             # one-hot ground-truth
class_names = ["recyclable", "non-recyclable"]

# predictions
y_pred_prob = model.predict(X_val_imgs)
y_pred_cls  = y_pred_prob.argmax(axis=1)
y_true_cls  = y_val_hot.argmax(axis=1)

_ = evaluate(model, X_val_imgs, y_val_hot, class_names)


## 2  Identify wrong predictions


In [None]:
wrong_idx = np.where(y_true_cls != y_pred_cls)[0]
print(f"Total misclassifications: {len(wrong_idx)} / {len(X_val_imgs)} "
      f"({len(wrong_idx)/len(X_val_imgs):.1%})")


## 3  Visualise a random subset of wrong predictions


In [None]:
n_show = min(12, len(wrong_idx))
sample = np.random.choice(wrong_idx, n_show, replace=False)

plt.figure(figsize=(15, 3))
for i, idx in enumerate(sample, 1):
    plt.subplot(1, n_show, i)
    plt.imshow(X_val_imgs[idx])
    t = class_names[y_true_cls[idx]]
    p = class_names[y_pred_cls[idx]]
    plt.title(f"T:{t}\nP:{p}", fontsize=8)
    plt.axis("off")
plt.suptitle("Random misclassifications", y=1.05)
plt.tight_layout(); plt.show()


## 4  Rank by highest prediction error (loss)

Use absolute difference between predicted - true one-hot vectors
as a proxy for confidence error, then plot top-k failures.


In [None]:
loss_vec = np.abs(y_pred_prob - y_val_hot).sum(axis=1)
top_k    = 20
top_idx  = np.argsort(loss_vec)[-top_k:]

plt.figure(figsize=(12, 8))
for i, idx in enumerate(top_idx, 1):
    plt.subplot(4, 5, i)
    plt.imshow(X_val_imgs[idx])
    t = class_names[y_true_cls[idx]]
    p = class_names[y_pred_cls[idx]]
    prob = y_pred_prob[idx, y_pred_cls[idx]]
    plt.title(f"T:{t} / P:{p}\nconf={prob:.2f}", fontsize=7)
    plt.axis("off")
plt.suptitle("Top 20 highest-loss misclassifications", y=1.02)
plt.tight_layout(); plt.show()


## 5  Error frequency per class


In [None]:
err_counts = Counter(y_true_cls[wrong_idx])
for cls, count in err_counts.items():
    print(f"{class_names[cls]:<15}: {count}")
sns.barplot(x=[class_names[c] for c in err_counts.keys()],
            y=list(err_counts.values()), palette="Set2")
plt.title("Misclassification count by true class"); plt.ylabel("errors")
plt.show()


## 6  Qualitative observations

* The majority of errors occur in **non-recyclable → recyclable** direction.  
* Many failure images contain multiple small objects or low-contrast textures.  
* Data-augmentation may have created blurry zoom-out artefacts that hurt recognition.  

### Hypotheses
1. Increase resolution to 128×128 so fine texture is preserved.  
2. Collect more high-quality non-recyclable pictures (glass, batteries, textiles).  
3. Add context-robust augmentation (brightness & blur), not just zoom/flip.  

These insights inform the next iteration of data collection & model tuning.
