In [None]:
import torch
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import random

# Load the best saved model
model.load_state_dict(torch.load('best_food_model.pth'))
model.eval()

# Full evaluation on test set
all_preds = []
all_labels = []
all_probs = []

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        probs = torch.nn.functional.softmax(outputs, dim=1)
        _, preds = torch.max(outputs, 1)
        
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        all_probs.extend(probs.cpu().numpy())

# Calculate metrics
test_accuracy = accuracy_score(all_labels, all_preds)
print(f"Test accuracy: {test_accuracy:.4f}")

# Generate classification report
class_names = test_dataset.classes
print("\nClassification Report:")
print(classification_report(all_labels, all_preds, target_names=class_names))

# Generate confusion matrix
conf_matrix = confusion_matrix(all_labels, all_preds)

# Plot confusion matrix
plt.figure(figsize=(12, 10))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues',
            xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Confusion Matrix')
plt.xticks(rotation=90)
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()

# Analysis of model confidence
confidence_scores = np.max(all_probs, axis=1)
plt.figure(figsize=(10, 6))
plt.hist(confidence_scores, bins=20, alpha=0.7)
plt.xlabel('Confidence Score')
plt.ylabel('Count')
plt.title('Model Confidence Distribution')
plt.grid(True, alpha=0.3)
plt.show()

# Visualize some predictions
def visualize_predictions(model, test_loader, device, class_names, num_images=5):
    model.eval()
    fig, axes = plt.subplots(nrows=num_images, ncols=2, figsize=(12, 15))
    
    # Get a batch of test images
    dataiter = iter(test_loader)
    images, labels = next(dataiter)
    
    with torch.no_grad():
        outputs = model(images.to(device))
        _, preds = torch.max(outputs, 1)
    
    # Denormalize images
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    
    for i in range(num_images):
        # Original image
        img = images[i].cpu().numpy().transpose((1, 2, 0))
        img = std * img + mean
        img = np.clip(img, 0, 1)
        
        # Plot image with prediction
        axes[i, 0].imshow(img)
        axes[i, 0].set_title(f"True: {class_names[labels[i]]}")
        axes[i, 0].axis('off')
        
        # Plot confidence scores
        probs = torch.nn.functional.softmax(outputs[i], dim=0).cpu().numpy()
        top5_probs, top5_indices = torch.topk(torch.from_numpy(probs), 5)
        top5_classes = [class_names[idx] for idx in top5_indices]
        
        y_pos = np.arange(5)
        axes[i, 1].barh(y_pos, top5_probs)
        axes[i, 1].set_yticks(y_pos)
        axes[i, 1].set_yticklabels(top5_classes)
        axes[i, 1].set_title('Top 5 Predictions')
        axes[i, 1].set_xlim(0, 1)
    
    plt.tight_layout()
    plt.show()

visualize_predictions(model, test_loader, device, class_names)

# Critical analysis
print("\n--- Critical Analysis ---")
print(f"1. Validation Accuracy is perfect (100%) through all epochs, which is unusual.")
print(f"2. Training converged extremely quickly, reaching 100% accuracy by epoch 7.")
print(f"3. Possible explanations:")
print("   a. The dataset might be too small or not challenging enough")
print("   b. There might be data leakage between train and test sets")
print("   c. The classes could be too easily distinguishable")
print("   d. The model might be memorizing rather than generalizing")
print("\n4. Recommendations:")
print("   a. Test on completely new, unseen images")
print("   b. Use more challenging data augmentation")
print("   c. Implement cross-validation")
print("   d. Consider a more diverse dataset")