# Agricultural Produce Classification Experiments

This notebook contains experiments for the agricultural produce classification project.

In [None]:
# Import required libraries
import numpy as np
import matplotlib.pyplot as plt
import cv2
import tensorflow as tf
from tensorflow.keras.models import load_model
import sys
sys.path.append('..')
from src.model import ProduceClassifier
from src.classify import ImageClassifier
from utils.preprocessing import load_and_preprocess_images, extract_features

## 1. Data Exploration

In [None]:
# Load sample images
data_dir = '../data/train'
produce_type = 'mango'

# Load good and bad samples
good_images = load_and_preprocess_images(f'{data_dir}/good', produce_type)
bad_images = load_and_preprocess_images(f'{data_dir}/bad', produce_type)

print(f'Good samples: {len(good_images)}')
print(f'Bad samples: {len(bad_images)}')

In [None]:
# Visualize samples
fig, axes = plt.subplots(2, 5, figsize=(15, 6))

# Show good samples
for i in range(5):
    axes[0, i].imshow(good_images[i])
    axes[0, i].set_title('Good')
    axes[0, i].axis('off')

# Show bad samples
for i in range(5):
    axes[1, i].imshow(bad_images[i])
    axes[1, i].set_title('Bad')
    axes[1, i].axis('off')

plt.tight_layout()
plt.show()

## 2. Feature Analysis

In [None]:
# Extract features from images
good_features = [extract_features(img) for img in good_images[:50]]
bad_features = [extract_features(img) for img in bad_images[:50]]

# Plot color distributions
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

colors = ['r', 'g', 'b']
for idx, color in enumerate(colors):
    good_values = [f[f'mean_{color}'] for f in good_features]
    bad_values = [f[f'mean_{color}'] for f in bad_features]
    
    axes[idx].hist(good_values, alpha=0.5, label='Good', bins=20)
    axes[idx].hist(bad_values, alpha=0.5, label='Bad', bins=20)
    axes[idx].set_title(f'{color.upper()} Channel Distribution')
    axes[idx].set_xlabel('Mean Value')
    axes[idx].legend()

plt.tight_layout()
plt.show()

## 3. Model Evaluation

In [None]:
# Load trained model
model_path = '../models/mangoes.h5'
model = load_model(model_path)

# Evaluate on test set
test_dir = '../data/test'
test_good = load_and_preprocess_images(f'{test_dir}/good', produce_type)
test_bad = load_and_preprocess_images(f'{test_dir}/bad', produce_type)

# Combine test data
X_test = np.concatenate([test_good, test_bad])
y_test = np.concatenate([np.ones(len(test_good)), np.zeros(len(test_bad))])

# Evaluate
test_loss, test_acc = model.evaluate(X_test, y_test)
print(f'Test Loss: {test_loss:.4f}')
print(f'Test Accuracy: {test_acc:.4f}')

In [None]:
# Confusion matrix
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

# Get predictions
y_pred = (model.predict(X_test) > 0.5).astype(int).reshape(-1)

# Confusion matrix
cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title('Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.show()

# Classification report
print('\nClassification Report:')
print(classification_report(y_test, y_pred, target_names=['Bad', 'Good']))

## 4. Error Analysis

In [None]:
# Find misclassified samples
misclassified_indices = np.where(y_test != y_pred)[0]
print(f'Number of misclassified samples: {len(misclassified_indices)}')

# Visualize some misclassified samples
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
axes = axes.ravel()

for i in range(min(10, len(misclassified_indices))):
    idx = misclassified_indices[i]
    axes[i].imshow(X_test[idx])
    axes[i].set_title(f'True: {"Good" if y_test[idx] else "Bad"}\nPred: {"Good" if y_pred[idx] else "Bad"}')
    axes[i].axis('off')

plt.tight_layout()
plt.show()

## 5. Model Performance Visualization

In [None]:
# ROC Curve
from sklearn.metrics import roc_curve, auc

# Get prediction probabilities
y_pred_proba = model.predict(X_test).reshape(-1)

# Compute ROC curve
fpr, tpr, thresholds = roc_curve(y_test, y_pred_proba)
roc_auc = auc(fpr, tpr)

# Plot ROC curve
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc="lower right")
plt.show()