In [3]:
pip install torch torchvision matplotlib opencv-python shap seaborn grad-cam


Note: you may need to restart the kernel to use updated packages.


In [4]:
pip install git+https://github.com/jacobgil/pytorch-grad-cam.git


Collecting git+https://github.com/jacobgil/pytorch-grad-cam.git
  Cloning https://github.com/jacobgil/pytorch-grad-cam.git to /tmp/pip-req-build-9y_bmhoj
  Running command git clone --filter=blob:none --quiet https://github.com/jacobgil/pytorch-grad-cam.git /tmp/pip-req-build-9y_bmhoj
  Resolved https://github.com/jacobgil/pytorch-grad-cam.git to commit 781dbc0d16ffa95b6d18b96b7b829840a82d93d1
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Note: you may need to restart the kernel to use updated packages.


In [5]:
# ==================== STEP 1: SETUP AND CONFIGURATION ====================
import os
import random
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.ops import MultiScaleRoIAlign
from torchvision.transforms import functional as F
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
from tqdm import tqdm
import cv2
import shap
from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
import warnings
from collections import defaultdict
from torchvision.ops import FeaturePyramidNetwork  # For FPN support
from torchvision.models.detection.rpn import RPNHead  # Add this import at the top
from copy import deepcopy  # Add this import at the top of your file
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"


# Configuration
warnings.filterwarnings('ignore')
IMG_SIZE = 512
BATCH_SIZE = 4
EPOCHS = 20
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
CHECKPOINT_DIR = './checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# Class definitions
CLASS_NAMES = ['Apple', 'Banana', 'Guava', 'Lime', 'Orange', 'Pomegranate' , 'Not_Fruit']
QUALITY_TYPES = ['Good', 'Bad']

In [6]:
# ==================== STEP 2: DATA LOADING AND PREPROCESSING ====================
class FruitQualityDataset(Dataset):
    """Improved dataset class with better non-fruit handling"""
    def __init__(self, fruit_root, non_fruit_root, transform=None, max_samples=None):
        self.transform = transform
        self.image_paths = []
        self.targets = []
        self._load_data(fruit_root, non_fruit_root, max_samples)
        
    def _load_data(self, fruit_root, non_fruit_root, max_samples):
        # Load fruit images
        fruit_count = 0
        for quality in QUALITY_TYPES:
            quality_dir = os.path.join(fruit_root, f'{quality} Quality_Fruits')
            if not os.path.exists(quality_dir):
                continue
                
            for fruit in os.listdir(quality_dir):
                fruit_dir = os.path.join(quality_dir, fruit)
                if not os.path.isdir(fruit_dir):
                    continue
                    
                fruit_type = fruit.split('_')[0]
                if fruit_type not in CLASS_NAMES[:-1]:  # Exclude Not_Fruit
                    continue
                    
                self._process_fruit_folder(fruit_dir, fruit_type, quality, max_samples)
                fruit_count += 1
        
        # Load non-fruit images - ensure balanced dataset
        non_fruit_max = min(max_samples // 2 if max_samples else len(self.image_paths) // 2, 200)
        self._process_non_fruit_folder(non_fruit_root, non_fruit_max)
        
    def _process_fruit_folder(self, folder_path, fruit_type, quality, max_samples):
        images = [f for f in os.listdir(folder_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
        class_id = CLASS_NAMES.index(fruit_type)
        quality_id = QUALITY_TYPES.index(quality)
        
        for img in images[:max_samples] if max_samples else images:
            img_path = os.path.join(folder_path, img)
            target = {
                'boxes': torch.tensor([[0, 0, IMG_SIZE-1, IMG_SIZE-1]], dtype=torch.float32),
                'labels': torch.tensor([class_id], dtype=torch.int64),
                'quality': torch.tensor([quality_id], dtype=torch.int64),
                'is_fruit': torch.tensor([1], dtype=torch.int64)  # Mark as fruit
            }
            self.image_paths.append(img_path)
            self.targets.append(target)
            
    def _process_non_fruit_folder(self, folder_path, max_samples):
        images = [f for f in os.listdir(folder_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
        for img in images[:max_samples] if max_samples else images:
            img_path = os.path.join(folder_path, img)
            target = {
                'boxes': torch.tensor([[0, 0, IMG_SIZE-1, IMG_SIZE-1]], dtype=torch.float32),
                'labels': torch.tensor([CLASS_NAMES.index('Not_Fruit')], dtype=torch.int64),
                'quality': torch.tensor([-1], dtype=torch.int64),  # -1 for non-fruit
                'is_fruit': torch.tensor([0], dtype=torch.int64)   # Mark as non-fruit
            }
            self.image_paths.append(img_path)
            self.targets.append(target)
            
    def __len__(self):
        return len(self.image_paths)
        
    def __getitem__(self, idx):
        try:
            img = Image.open(self.image_paths[idx]).convert('RGB')
            target = self.targets[idx]
            if self.transform:
                img = self.transform(img)
            return img, target
        except:
            return self[random.randint(0, len(self)-1)]

# ==================== ENHANCED DATA AUGMENTATION ====================
def get_enhanced_transforms(train=True):
    transforms_list = [
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]
    
    if train:
        transforms_list.insert(1, transforms.RandomApply([
            transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1)
        ], p=0.8))
        transforms_list.insert(1, transforms.RandomHorizontalFlip(0.5))
        transforms_list.insert(1, transforms.RandomVerticalFlip(0.3))
        transforms_list.insert(1, transforms.RandomRotation(15))
        transforms_list.insert(1, transforms.RandomAffine(
            degrees=0, translate=(0.15, 0.15), scale=(0.8, 1.2)))
        transforms_list.insert(1, transforms.RandomApply([
            transforms.GaussianBlur(kernel_size=(3, 7)),
            transforms.RandomAdjustSharpness(sharpness_factor=2)
        ], p=0.3))
        
    return transforms.Compose(transforms_list)

In [8]:
# ==================== ENHANCED MODEL ARCHITECTURE ====================
class EnhancedFruitQualityModel(nn.Module):
    def __init__(self, num_classes, num_qualities):
        super().__init__()
        
        # Initialize backbone with frozen early layers
        backbone = torchvision.models.resnet50(pretrained=True)
        
        # Freeze initial layers (keep BatchNorm trainable)
        for name, param in backbone.named_parameters():
            if 'layer1' in name or 'layer2' in name:
                param.requires_grad = False
            if 'bn' in name:
                param.requires_grad = True
                
        self.features = nn.Sequential(*list(backbone.children())[:-2])
        self.out_channels = 2048
        self.gradcam_layer = self.features[-1][-1].conv3

        # Enhanced detection head
        anchor_generator = AnchorGenerator(
            sizes=((16, 32, 64, 128, 256, 512),),  # More granular anchor sizes
            aspect_ratios=((0.25, 0.5, 1.0, 2.0, 4.0),)  # Wider range of aspect ratios
        )
        
        roi_pooler = MultiScaleRoIAlign(
            featmap_names=['0'],
            output_size=14,  # Larger ROI pooling size
            sampling_ratio=2
        )
        
        # Create backbone output channels dict for FasterRCNN
        backbone_with_channels = nn.Sequential(
            self.features,
            nn.Conv2d(2048, 256, kernel_size=1)  # Reduce channel dimension
        )
        backbone_with_channels.out_channels = 256
        
        self.detector = FasterRCNN(
            backbone=backbone_with_channels,
            num_classes=num_classes,
            rpn_anchor_generator=anchor_generator,
            box_roi_pool=roi_pooler,
            box_score_thresh=0.7
        )
        
        # Enhanced quality head with attention mechanism
        self.quality_attention = nn.Sequential(
            nn.Linear(self.out_channels * 14 * 14, 512),
            nn.ReLU(),
            nn.Linear(512, 1),
            nn.Sigmoid()
        )
        
        self.quality_head = nn.Sequential(
            nn.Linear(self.out_channels * 14 * 14, 1024),
            nn.ReLU(),
            nn.BatchNorm1d(1024),
            nn.Dropout(0.5),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(0.3),
            nn.Linear(512, num_qualities),
            nn.Sigmoid()  # Added for better probability calibration
        )
        
        # Improved fruit/non-fruit classifier
        self.fruit_classifier = nn.Sequential(
            nn.Linear(self.out_channels * 14 * 14, 1024),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 2)
        )

    def forward(self, images, targets=None):
        if self.training:
            return self.detector(images, targets)

        
        detections = self.detector(images)
        
        if not self.training:
            for i, det in enumerate(detections):
                if len(det['boxes']) > 0:
                    features = self._extract_roi_features(images[i].unsqueeze(0), det['boxes'])
                    flattened_features = features.flatten(1)
                    
                    # Fruit/non-fruit prediction
                    fruit_logits = self.fruit_classifier(flattened_features)
                    is_fruit = torch.argmax(fruit_logits, dim=1)
                    
                    # Quality prediction with attention
                    quality_scores = torch.zeros(len(det['boxes']), len(QUALITY_TYPES)).to(DEVICE)
                    fruit_mask = is_fruit == 1
                    
                    if fruit_mask.any():
                        # Apply attention to focus on quality-relevant features
                        attention_weights = self.quality_attention(flattened_features[fruit_mask])
                        attended_features = flattened_features[fruit_mask] * attention_weights
                        quality_scores[fruit_mask] = torch.sigmoid(self.quality_head(attended_features))
                    
                    det['quality_scores'] = quality_scores
                    det['is_fruit'] = is_fruit
                    
        return detections
        
    def _extract_roi_features(self, image, boxes):
        features = {'0': self.features(image)}
        roi_pooler = MultiScaleRoIAlign(
            featmap_names=['0'],
            output_size=14,
            sampling_ratio=2
        )
        return roi_pooler(features, [boxes], [image.shape[-2:]])

In [None]:
# # ==================== STEP 4: TRAINING ====================
# def train_model():
#     # Initialize datasets with balanced classes
#     train_dataset = FruitQualityDataset(
#         fruit_root='/kaggle/input/fruitnet/b6fftwbr2v-3/FruitNet_Processed Images/Processed Images_Fruits',
#         non_fruit_root='/kaggle/input/d/spectrewolf8/random-images-dataset/random_images_dataset/training/all_images',
#         transform=get_enhanced_transforms(train=True),
#         max_samples=1000
#     )
    
#     val_dataset = FruitQualityDataset(
#         fruit_root='/kaggle/input/fruitnet/b6fftwbr2v-3/FruitNet_Processed Images/Processed Images_Fruits',
#         non_fruit_root='/kaggle/input/d/spectrewolf8/random-images-dataset/random_images_dataset/training/all_images',
#         transform=get_enhanced_transforms(train=False),
#         max_samples=200
#     )

#     def worker_init_fn(worker_id):
#         np.random.seed(torch.initial_seed() % 2**32)

#     # DataLoaders with improved balancing
#     train_loader = DataLoader(
#         train_dataset,
#         batch_size=8,
#         worker_init_fn=worker_init_fn,
#         shuffle=True,
#         num_workers=4,
#         pin_memory=True,
#         collate_fn=lambda x: tuple(zip(*x)),
#         persistent_workers=True
#     )
    
#     val_loader = DataLoader(
#         val_dataset,
#         batch_size=8,
#         shuffle=False,
#         num_workers=2,
#         collate_fn=lambda x: tuple(zip(*x)),
#         persistent_workers=True
#     )

#     # Initialize model with enhanced configuration
#     model = EnhancedFruitQualityModel(len(CLASS_NAMES), len(QUALITY_TYPES)).to(DEVICE)
    
#     # Enhanced optimizer initialization
#     def get_enhanced_optimizer(model):
#         # Separate parameters with different learning rates
#         backbone_params = []
#         detector_params = []
#         quality_params = []
#         fruit_params = []
        
#         for name, param in model.named_parameters():
#             if 'features' in name:
#                 backbone_params.append(param)
#             elif 'detector' in name:
#                 detector_params.append(param)
#             elif 'quality_' in name:  # Changed to match new architecture
#                 quality_params.append(param)
#             elif 'fruit_classifier' in name:
#                 fruit_params.append(param)
        
#         # Differential learning rates
#         optimizer = optim.AdamW([
#             {'params': backbone_params, 'lr': 1e-5},
#             {'params': detector_params, 'lr': 3e-4},
#             {'params': quality_params, 'lr': 1e-3},
#             {'params': fruit_params, 'lr': 1e-4}
#         ], weight_decay=1e-4)
        
#         return optimizer
    
#     optimizer = get_enhanced_optimizer(model)
#     scheduler = optim.lr_scheduler.ReduceLROnPlateau(
#         optimizer, 'max', patience=2, factor=0.5, verbose=True)

#     # Checkpoint paths - use /kaggle/working for saving
#     os.makedirs('/kaggle/working/checkpoints', exist_ok=True)
#     checkpoint_path = '/kaggle/working/checkpoints/best_model.pth'
#     loaded_checkpoint_path = '/kaggle/working/checkpoints/epoch_12.pth'  # For loading only
    
#     start_epoch = 0
#     best_val_acc = 0.0
#     history = {'train_loss': [], 'val_acc': []}
    
#     # Load checkpoint if available
#     if os.path.exists(loaded_checkpoint_path):
#         print(f"\n=== Loading checkpoint from {loaded_checkpoint_path} ===")
#         checkpoint = torch.load(loaded_checkpoint_path, map_location=DEVICE)
#         model.load_state_dict(checkpoint['model_state_dict'])
#         optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
#         scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
#         best_val_acc = checkpoint['best_val_acc']
#         history = checkpoint['history']
#         start_epoch = checkpoint['epoch'] + 1
#         print(f"Resuming from epoch {start_epoch} with best val acc {best_val_acc:.4f}")
#     else:
#         print(f"Checkpoint not found at {loaded_checkpoint_path}")


#     # Training loop with enhanced monitoring
#     for epoch in range(start_epoch, EPOCHS):
#         try:
#             model.train()
#             epoch_loss = 0.0
#             progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
            
#             for images, targets in progress_bar:
#                 images = [img.to(DEVICE, non_blocking=True) for img in images]
#                 targets = [{k: v.to(DEVICE, non_blocking=True) for k, v in t.items()} for t in targets]
                
#                 with torch.cuda.amp.autocast():
#                     loss_dict = model(images, targets)
#                     losses = sum(loss for loss in loss_dict.values())
                
#                 optimizer.zero_grad()
#                 losses.backward()
#                 torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
#                 optimizer.step()
                
#                 epoch_loss += losses.item()
#                 progress_bar.set_postfix({'loss': losses.item()})
            
#             # Enhanced validation
#             val_acc = enhanced_evaluate_model(model, val_loader, verbose=True)
#             scheduler.step(val_acc)
            
#             # Update history
#             history['train_loss'].append(epoch_loss/len(train_loader))
#             history['val_acc'].append(val_acc)
            
#             # Save best model - now to writable directory
#             if val_acc > best_val_acc:
#                 best_val_acc = val_acc
#                 torch.save({
#                     'epoch': epoch,
#                     'model_state_dict': model.state_dict(),
#                     'optimizer_state_dict': optimizer.state_dict(),
#                     'scheduler_state_dict': scheduler.state_dict(),
#                     'best_val_acc': best_val_acc,
#                     'history': history,
#                     'val_acc': val_acc
#                 }, checkpoint_path)
#                 print(f"\n🔥 New best model! Val accuracy: {val_acc:.4f}")
            
#             # Save periodic checkpoint
#             if (epoch + 1) % 2 == 0:
#                 epoch_checkpoint = f'/kaggle/working/checkpoints/epoch_{epoch+1}.pth'
#                 torch.save({
#                     'epoch': epoch,
#                     'model_state_dict': model.state_dict(),
#                     'optimizer_state_dict': optimizer.state_dict(),
#                     'scheduler_state_dict': scheduler.state_dict(),
#                     'best_val_acc': best_val_acc,
#                     'history': history,
#                     'val_acc': val_acc
#                 }, epoch_checkpoint)
            
#             print(f"\nEpoch {epoch+1} Summary:")
#             print(f"Train Loss: {epoch_loss/len(train_loader):.4f}")
#             print(f"Val Accuracy: {val_acc:.4f} (Best: {best_val_acc:.4f})")
#             print("-" * 50)

#         except KeyboardInterrupt:
#             print("\nTraining interrupted. Saving current state...")
#             interrupt_path = f'/kaggle/working/checkpoints/interrupt_epoch_{epoch+1}.pth'
#             torch.save({
#                 'epoch': epoch,
#                 'model_state_dict': model.state_dict(),
#                 'optimizer_state_dict': optimizer.state_dict(),
#                 'scheduler_state_dict': scheduler.state_dict(),
#                 'best_val_acc': best_val_acc,
#                 'history': history,
#                 'val_acc': val_acc if 'val_acc' in locals() else 0.0
#             }, interrupt_path)
#             print(f"Saved interrupted state to {interrupt_path}")
#             break
            
#         torch.cuda.empty_cache()

In [None]:
# ==================== IMPROVED EVALUATION METRICS ====================
def enhanced_evaluate_model(model, dataloader, confidence_threshold=0.7, verbose=True):
    model.eval()
    
    # Initialize metrics storage
    classification_correct = 0
    classification_total = 0
    quality_correct = 0
    quality_total = 0
    
    # For confusion matrices
    all_class_preds = []
    all_class_targets = []
    all_quality_preds = []
    all_quality_targets = []
    
    with torch.no_grad():
        for images, targets in tqdm(dataloader, desc="Evaluating", disable=not verbose):
            images = [img.to(DEVICE) for img in images]
            outputs = model(images)
            
            for output, target in zip(outputs, targets):
                if len(output['labels']) > 0 and output['scores'][0] > confidence_threshold:
                    # Classification evaluation
                    pred_label = output['labels'][0].cpu().item()
                    true_label = target['labels'][0].cpu().item()
                    
                    all_class_preds.append(pred_label)
                    all_class_targets.append(true_label)
                    
                    if pred_label == true_label:
                        classification_correct += 1
                    classification_total += 1
                    
                    # Quality evaluation (only for fruits)
                    if true_label != CLASS_NAMES.index('Not_Fruit') and 'quality_scores' in output:
                        pred_quality = torch.argmax(output['quality_scores'][0]).item()
                        true_quality = target['quality'][0].item()
                        
                        all_quality_preds.append(pred_quality)
                        all_quality_targets.append(true_quality)
                        
                        if pred_quality == true_quality:
                            quality_correct += 1
                        quality_total += 1
    
    # Calculate metrics
    classification_acc = classification_correct / (classification_total + 1e-6)
    quality_acc = quality_correct / (quality_total + 1e-6) if quality_total > 0 else 0
    
    # Initialize confusion matrices
    cm_class = None
    cm_quality = None
    
    if verbose:
        # Classification Confusion Matrix
        cm_class = confusion_matrix(all_class_targets, all_class_preds, labels=range(len(CLASS_NAMES)))
        plt.figure(figsize=(12, 10))
        sns.heatmap(cm_class, annot=True, fmt='d', cmap='Blues',
                    xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES)
        plt.title('Fruit Classification Confusion Matrix')
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.xticks(rotation=45)
        plt.yticks(rotation=0)
        plt.show()
        
        # Classification metrics
        print(f"\nClassification Accuracy: {classification_acc:.4f}")
        print("Classification Metrics:")
        for i, class_name in enumerate(CLASS_NAMES):
            tp = cm_class[i,i]
            fp = cm_class[:,i].sum() - tp
            fn = cm_class[i,:].sum() - tp
            precision = tp / (tp + fp + 1e-6)
            recall = tp / (tp + fn + 1e-6)
            f1 = 2 * (precision * recall) / (precision + recall + 1e-6)
            print(f"{class_name:15} - Precision: {precision:.3f} | Recall: {recall:.3f} | F1: {f1:.3f}")
        
        # Quality Confusion Matrix (if applicable)
        if quality_total > 0:
            cm_quality = confusion_matrix(all_quality_targets, all_quality_preds, labels=range(len(QUALITY_TYPES)))
            plt.figure(figsize=(8, 6))
            sns.heatmap(cm_quality, annot=True, fmt='d', cmap='Greens',
                        xticklabels=QUALITY_TYPES, yticklabels=QUALITY_TYPES)
            plt.title('Quality Prediction Confusion Matrix')
            plt.xlabel('Predicted')
            plt.ylabel('True')
            plt.show()
            
            # Quality metrics
            print(f"\nQuality Prediction Accuracy: {quality_acc:.4f}")
            print("Quality Metrics:")
            for i, quality in enumerate(QUALITY_TYPES):
                tp = cm_quality[i,i]
                fp = cm_quality[:,i].sum() - tp
                fn = cm_quality[i,:].sum() - tp
                precision = tp / (tp + fp + 1e-6)
                recall = tp / (tp + fn + 1e-6)
                f1 = 2 * (precision * recall) / (precision + recall + 1e-6)
                print(f"{quality:15} - Precision: {precision:.3f} | Recall: {recall:.3f} | F1: {f1:.3f}")
    
    return {
        'classification_accuracy': classification_acc,
        'quality_accuracy': quality_acc,
        'classification_cm': cm_class,
        'quality_cm': cm_quality if quality_total > 0 else None
    }

In [None]:
import os
import glob

# Define the path
checkpoint_dir = "/kaggle/working/checkpoints"

# Delete all files in the directory
files = glob.glob(os.path.join(checkpoint_dir, '*'))
for f in files:
    try:
        os.remove(f)
        print(f"Deleted: {f}")
    except Exception as e:
        print(f"Error deleting {f}: {e}")


In [None]:
# # ==================== STEP 7: MAIN EXECUTION ====================
# if __name__ == "__main__":
#     # Step 1: Training
#     print("Starting training...")
#     train_model()

In [None]:
import os
import glob

# Define the path
checkpoint_dir = "/kaggle/working/checkpoints"

# List all files in the directory
files = glob.glob(os.path.join(checkpoint_dir, '*'))
for f in files:
    print(f"Found: {f}")


In [None]:
# Updated evaluation code
import os
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
from tqdm import tqdm

# Initialize model
model = EnhancedFruitQualityModel(len(CLASS_NAMES), len(QUALITY_TYPES)).to(DEVICE)
checkpoint_path = '/kaggle/input/fruit-quality-prediction/pytorch/default/1/checkpoints/epoch_10.pth'

if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    print("Model loaded successfully")
else:
    print(f"Error: Could not load model from {checkpoint_path}")

# Prepare validation dataset
val_dataset = FruitQualityDataset(
    fruit_root='/kaggle/input/fruitnet/b6fftwbr2v-3/FruitNet_Processed Images/Processed Images_Fruits',
    non_fruit_root='/kaggle/input/random-images-dataset/random_images_dataset/training/all_images',
    transform=get_enhanced_transforms(train=False)
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    collate_fn=lambda x: tuple(zip(*x))
)

# Run evaluation
print("Evaluating model with enhanced metrics...")
metrics = enhanced_evaluate_model(model, val_loader)

print("\nFinal Evaluation Results:")
print(f"Classification Accuracy: {metrics['classification_accuracy']:.4f}")
if metrics['quality_accuracy'] > 0:
    print(f"Quality Prediction Accuracy: {metrics['quality_accuracy']:.4f}")

In [9]:
pip install flask flask-cors pyngrok torch torchvision pillow


Note: you may need to restart the kernel to use updated packages.


In [10]:
from flask import Flask, request, jsonify, send_file
from flask_cors import CORS
from pyngrok import ngrok
from PIL import Image
import io
import torch
import numpy as np
import random
import matplotlib.pyplot as plt
import cv2
import os
import shap
import torchvision.transforms as transforms
from torchvision.ops import FeaturePyramidNetwork
from pymongo import MongoClient
import certifi
from datetime import datetime
from bson.objectid import ObjectId

# Configuration
IMG_SIZE = 512
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ngrok.set_auth_token("2wBSxGs4EHVS7K6QVnjkQFKao0n_wXmk8U1xNz5MxN8g61Ce")

# MongoDB Atlas Configuration
MONGODB_URI = "mongodb+srv://syedmuhammadmoizzaidi:Ronaldo7@fruitquality.hpbyilq.mongodb.net/"

client = MongoClient(MONGODB_URI, tlsCAFile=certifi.where())
db = client.get_database('FruitQualityDB')  # Using database named FruitQualityDB
predictions_collection = db.predictions

# Class definitions (must match training)
CLASS_NAMES = ['Apple', 'Banana', 'Guava', 'Lime', 'Orange', 'Pomegranate', 'Not_Fruit']
QUALITY_TYPES = ['Good', 'Bad']

# Initialize Flask app
app = Flask(__name__)
CORS(app, resources={r"/*": {"origins": "http://localhost:3000"}})

<flask_cors.extension.CORS at 0x7a05c7b3d990>

In [None]:
model = EnhancedFruitQualityModel(len(CLASS_NAMES), len(QUALITY_TYPES)).to(DEVICE)
checkpoint_path = '/kaggle/input/fruit-quality-prediction/pytorch/default/1/checkpoints/epoch_10.pth'

if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    print("Model loaded successfully")
else:
    print(f"Error: Could not load model from {checkpoint_path}")

In [None]:
# Image transformations (must match training)
def get_enhanced_transforms(train=True):
    transforms_list = [
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]
    
    if train:
        transforms_list.insert(1, transforms.RandomApply([
            transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1)
        ], p=0.8))
        transforms_list.insert(1, transforms.RandomHorizontalFlip(0.5))
        transforms_list.insert(1, transforms.RandomVerticalFlip(0.3))
        transforms_list.insert(1, transforms.RandomRotation(15))
        transforms_list.insert(1, transforms.RandomAffine(
            degrees=0, translate=(0.15, 0.15), scale=(0.8, 1.2)))
        transforms_list.insert(1, transforms.RandomApply([
            transforms.GaussianBlur(kernel_size=(3, 7)),
            transforms.RandomAdjustSharpness(sharpness_factor=2)
        ], p=0.3))
        
    return transforms.Compose(transforms_list)

# Helper function for GradCAM
def show_cam_on_image(img, mask, use_rgb=False):
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    if use_rgb:
        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + np.float32(img)
    cam = cam / np.max(cam)
    return np.uint8(255 * cam)

# ============ API ENDPOINTS =============

@app.route('/predict', methods=['POST'])
def predict():
    if 'image' not in request.files:
        return jsonify({'error': 'No image provided'}), 400

    try:
        image_file = request.files['image']
        filename = image_file.filename  # Get filename before reading bytes
        image_bytes = image_file.read()
        image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
        
        transform = get_enhanced_transforms(train=False)
        img_tensor = transform(image).unsqueeze(0).to(DEVICE)

        with torch.no_grad():
            output = model([img_tensor.squeeze(0)])[0]

            if len(output['labels']) == 0:
                return jsonify({'message': 'No fruit detected'})

            predictions = []
            for i in range(len(output['labels'])):
                label_id = output['labels'][i].item()
                score = output['scores'][i].item()
                quality_idx = output['quality_scores'][i].argmax().item()
                quality_conf = output['quality_scores'][i].max().item()
                box = output['boxes'][i].cpu().numpy().tolist()

                predictions.append({
                    'predicted_class': CLASS_NAMES[label_id],
                    'confidence': round(score * 100, 2),
                    'quality': QUALITY_TYPES[quality_idx],
                    'quality_confidence': round(quality_conf * 100, 2),
                    'bounding_box': box
                })

        # Store prediction in MongoDB
        prediction_doc = {
            'timestamp': datetime.utcnow(),
            'filename': filename,  # Use the filename we captured earlier
            'predictions': predictions,
            'image_size': f"{image.width}x{image.height}",
            'image_bytes': image_bytes,
            'metadata': {
                'model_version': '1.0',
                'device': str(DEVICE)
            }
        }
        
        try:
            result = predictions_collection.insert_one(prediction_doc)
            print(f"Prediction stored with ID: {result.inserted_id}")
        except Exception as db_error:
            print(f"Database error: {db_error}")

        return jsonify({
            'predictions': predictions,
            'db_id': str(result.inserted_id) if 'result' in locals() else None
        })

    except Exception as e:
        return jsonify({'error': str(e)}), 500

@app.route('/predictions', methods=['GET'])
def get_predictions():
    try:
        # Get last 10 predictions
        recent_predictions = list(predictions_collection.find()
            .sort('timestamp', -1)
            .limit(10))
        
        # Convert ObjectId to string and remove image bytes for listing
        for pred in recent_predictions:
            pred['_id'] = str(pred['_id'])
            if 'image_bytes' in pred:
                del pred['image_bytes']
        
        return jsonify(recent_predictions)
    except Exception as e:
        return jsonify({'error': str(e)}), 500

@app.route('/prediction/<prediction_id>', methods=['GET'])
def get_prediction(prediction_id):
    try:
        prediction = predictions_collection.find_one({'_id': ObjectId(prediction_id)})
        
        if not prediction:
            return jsonify({'error': 'Prediction not found'}), 404
        
        # Convert image bytes back to sendable format
        if 'image_bytes' in prediction:
            return send_file(
                io.BytesIO(prediction['image_bytes']),
                mimetype='image/jpeg'
            )
        else:
            return jsonify({'error': 'Image not stored'}), 404
    except Exception as e:
        return jsonify({'error': str(e)}), 500

@app.route('/gradcam', methods=['POST'])
def gradcam():
    if 'image' not in request.files:
        return jsonify({'error': 'No image provided'}), 400

    try:
        # Load and preprocess image
        image = request.files['image']
        image = Image.open(image).convert('RGB')
        transform = get_enhanced_transforms(train=False)
        img_tensor = transform(image).unsqueeze(0).to(DEVICE)

        # Get model prediction
        model.eval()
        with torch.no_grad():
            outputs = model(img_tensor)
            if len(outputs[0]['boxes']) == 0:
                return jsonify({'error': 'No objects detected'}), 400

            best_idx = outputs[0]['scores'].argmax().item()
            target_box = outputs[0]['boxes'][best_idx]
            target_label = outputs[0]['labels'][best_idx].item()

        # Grad-CAM compatible wrapper
        class GradCAMWrapper(nn.Module):
            def __init__(self, base_model):
                super().__init__()
                self.features = base_model.features
                self.target_layer = base_model.gradcam_layer

            def forward(self, x):
                return self.features(x)

        model_wrapper = GradCAMWrapper(model).to(DEVICE).eval()

        # Initialize Grad-CAM
        cam = GradCAM(
            model=model_wrapper,
            target_layers=[model_wrapper.target_layer]
        )

        # Generate Grad-CAM heatmap
        grayscale_cam = cam(input_tensor=img_tensor, targets=None)[0, :]

        # Prepare visualization
        rgb_img = np.array(image.resize((IMG_SIZE, IMG_SIZE)))
        rgb_img = np.float32(rgb_img) / 255
        grayscale_cam = cv2.resize(grayscale_cam, (rgb_img.shape[1], rgb_img.shape[0]))
        visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)

        # Convert to PNG
        img_byte_arr = io.BytesIO()
        Image.fromarray(visualization).save(img_byte_arr, format='PNG')
        img_byte_arr.seek(0)

        # Cleanup
        del cam
        del model_wrapper
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        return send_file(img_byte_arr, mimetype='image/png')

    except Exception as e:
        import traceback
        traceback.print_exc()
        return jsonify({
            'error': f"GradCAM processing failed: {str(e)}",
            'type': type(e).__name__,
            'details': str(e.args)
        }), 500



@app.route('/shap', methods=['POST'])
def shap_explanations():
    if 'image' not in request.files:
        return jsonify({'error': 'No image provided'}), 400

    try:
        # Load and process image
        img = Image.open(request.files['image']).convert('RGB')
        transform = get_enhanced_transforms(train=False)
        img_tensor = transform(img).unsqueeze(0).to(DEVICE)
        
        # Get model prediction
        with torch.no_grad():
            outputs = model([img_tensor.squeeze(0)])
            if len(outputs[0]['labels']) == 0:
                return jsonify({'error': 'No objects detected'}), 400
            
            best_idx = outputs[0]['scores'].argmax()
            pred_class = outputs[0]['labels'][best_idx].item()
            pred_prob = outputs[0]['scores'][best_idx].item()

        # Prepare image for SHAP
        img_np = img_tensor.cpu().numpy()[0].transpose(1, 2, 0)
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img_np = std * img_np + mean
        img_np = np.clip(img_np, 0, 1)

        # SHAP explainer
        def predict_fn(images):
            images = torch.tensor(images.transpose(0, 3, 1, 2), dtype=torch.float32).to(DEVICE)
            # Normalize
            for c in range(3):
                images[:, c] = (images[:, c] - mean[c]) / std[c]
            
            with torch.no_grad():
                outputs = model([images[i] for i in range(images.shape[0])])
                probs = np.zeros((len(outputs), len(CLASS_NAMES)))
                for i, output in enumerate(outputs):
                    if len(output['scores']) > 0:
                        best_idx = output['scores'].argmax().item()
                        class_idx = output['labels'][best_idx].item()
                        probs[i, class_idx] = output['scores'][best_idx].item()
                    else:
                        probs[i, -1] = 1.0  # Not_Fruit
                return probs

        explainer = shap.Explainer(predict_fn, masker=shap.maskers.Image("blur(128,128)", img_np.shape))
        shap_values = explainer(img_np[np.newaxis, :], max_evals=100, outputs=[pred_class])

        # Create visualization
        shap_vals = np.sum(shap_values.values[0], axis=2).squeeze()
        abs_shap = np.abs(shap_vals)
        abs_shap = (abs_shap - abs_shap.min()) / (abs_shap.max() - abs_shap.min() + 1e-8)

        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
        ax1.imshow(img_np)
        ax1.set_title(f"Original\n{CLASS_NAMES[pred_class]} ({pred_prob:.2f})")
        ax1.axis('off')
        
        ax2.imshow(img_np)
        heatmap = ax2.imshow(abs_shap, cmap='jet', alpha=0.5)
        plt.colorbar(heatmap, ax=ax2, fraction=0.046, pad=0.04)
        ax2.set_title("SHAP Importance")
        ax2.axis('off')

        buf = io.BytesIO()
        plt.savefig(buf, format='png', bbox_inches='tight', dpi=100)
        plt.close()
        buf.seek(0)
        
        return send_file(buf, mimetype='image/png')

    except Exception as e:
        return jsonify({'error': str(e)}), 500

if __name__ == '__main__':
    # Verify MongoDB connection
    try:
        client.admin.command('ping')
        print("Successfully connected to MongoDB Atlas")
    except Exception as e:
        print(f"MongoDB connection error: {e}")

    port = 5000
    public_url = ngrok.connect(port)
    print(f"API is running at: {public_url}")
    app.run(port=port)

In [None]:
from flask import Flask, request, jsonify, send_file
from flask_cors import CORS
from pyngrok import ngrok
from PIL import Image
import io
import torch
import numpy as np
import random
import matplotlib.pyplot as plt
import cv2
import os
import shap
import torchvision.transforms as transforms
from torchvision.ops import FeaturePyramidNetwork
from pymongo import MongoClient
import certifi
from datetime import datetime
from bson.objectid import ObjectId
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Configuration
IMG_SIZE = 512
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ngrok.set_auth_token("2wBSxGs4EHVS7K6QVnjkQFKao0n_wXmk8U1xNz5MxN8g61Ce")

# MongoDB Atlas Configuration
MONGODB_URI = "mongodb+srv://syedmuhammadmoizzaidi:Ronaldo7@fruitquality.hpbyilq.mongodb.net/"

client = MongoClient(MONGODB_URI, tlsCAFile=certifi.where())
db = client.get_database('FruitQualityDB')  # Using database named FruitQualityDB
predictions_collection = db.predictions

# Class definitions (must match training)
CLASS_NAMES = ['Apple', 'Banana', 'Guava', 'Lime', 'Orange', 'Pomegranate', 'Not_Fruit']
QUALITY_TYPES = ['Good', 'Bad']

# Initialize Flask app
app = Flask(__name__)
CORS(app, resources={r"/*": {"origins": "http://localhost:3000"}})

model = EnhancedFruitQualityModel(len(CLASS_NAMES), len(QUALITY_TYPES)).to(DEVICE)
checkpoint_path = '/kaggle/input/fruit-quality-prediction/pytorch/default/1/checkpoints/epoch_10.pth'

if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    print("Model loaded successfully")
else:
    print(f"Error: Could not load model from {checkpoint_path}")

# Image transformations (must match training)
def get_enhanced_transforms(train=True):
    transforms_list = [
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]
    
    if train:
        transforms_list.insert(1, transforms.RandomApply([
            transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1)
        ], p=0.8))
        transforms_list.insert(1, transforms.RandomHorizontalFlip(0.5))
        transforms_list.insert(1, transforms.RandomVerticalFlip(0.3))
        transforms_list.insert(1, transforms.RandomRotation(15))
        transforms_list.insert(1, transforms.RandomAffine(
            degrees=0, translate=(0.15, 0.15), scale=(0.8, 1.2)))
        transforms_list.insert(1, transforms.RandomApply([
            transforms.GaussianBlur(kernel_size=(3, 7)),
            transforms.RandomAdjustSharpness(sharpness_factor=2)
        ], p=0.3))
        
    return transforms.Compose(transforms_list)

# Helper function for GradCAM
def show_cam_on_image(img, mask, use_rgb=False):
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    if use_rgb:
        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + np.float32(img)
    cam = cam / np.max(cam)
    return np.uint8(255 * cam)

# ============ API ENDPOINTS =============

@app.route('/predict', methods=['POST'])
def predict():
    if 'image' not in request.files:
        return jsonify({'error': 'No image provided'}), 400

    try:
        image_file = request.files['image']
        filename = image_file.filename  # Get filename before reading bytes
        image_bytes = image_file.read()
        image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
        
        transform = get_enhanced_transforms(train=False)
        img_tensor = transform(image).unsqueeze(0).to(DEVICE)

        with torch.no_grad():
            output = model([img_tensor.squeeze(0)])[0]

            if len(output['labels']) == 0:
                return jsonify({'message': 'No fruit detected'})

            predictions = []
            for i in range(len(output['labels'])):
                label_id = output['labels'][i].item()
                score = output['scores'][i].item()
                quality_idx = output['quality_scores'][i].argmax().item()
                quality_conf = output['quality_scores'][i].max().item()
                box = output['boxes'][i].cpu().numpy().tolist()

                predictions.append({
                    'predicted_class': CLASS_NAMES[label_id],
                    'confidence': round(score * 100, 2),
                    'quality': QUALITY_TYPES[quality_idx],
                    'quality_confidence': round(quality_conf * 100, 2),
                    'bounding_box': box
                })

        # Store prediction in MongoDB
        prediction_doc = {
            'timestamp': datetime.utcnow(),
            'filename': filename,  # Use the filename we captured earlier
            'predictions': predictions,
            'image_size': f"{image.width}x{image.height}",
            'image_bytes': image_bytes,
            'metadata': {
                'model_version': '1.0',
                'device': str(DEVICE)
            }
        }
        
        try:
            result = predictions_collection.insert_one(prediction_doc)
            print(f"Prediction stored with ID: {result.inserted_id}")
        except Exception as db_error:
            print(f"Database error: {db_error}")

        return jsonify({
            'predictions': predictions,
            'db_id': str(result.inserted_id) if 'result' in locals() else None
        })

    except Exception as e:
        return jsonify({'error': str(e)}), 500

@app.route('/predictions', methods=['GET'])
def get_predictions():
    try:
        # Get last 10 predictions
        recent_predictions = list(predictions_collection.find()
            .sort('timestamp', -1)
            .limit(10))
        
        # Convert ObjectId to string and remove image bytes for listing
        for pred in recent_predictions:
            pred['_id'] = str(pred['_id'])
            if 'image_bytes' in pred:
                del pred['image_bytes']
        
        return jsonify(recent_predictions)
    except Exception as e:
        return jsonify({'error': str(e)}), 500

@app.route('/prediction/<prediction_id>', methods=['GET'])
def get_prediction(prediction_id):
    try:
        prediction = predictions_collection.find_one({'_id': ObjectId(prediction_id)})
        
        if not prediction:
            return jsonify({'error': 'Prediction not found'}), 404
        
        # Convert image bytes back to sendable format
        if 'image_bytes' in prediction:
            return send_file(
                io.BytesIO(prediction['image_bytes']),
                mimetype='image/jpeg'
            )
        else:
            return jsonify({'error': 'Image not stored'}), 404
    except Exception as e:
        return jsonify({'error': str(e)}), 500

@app.route('/gradcam', methods=['POST'])
def gradcam():
    if 'image' not in request.files:
        return jsonify({'error': 'No image provided'}), 400

    try:
        # Load and preprocess image
        image = request.files['image']
        image = Image.open(image).convert('RGB')
        transform = get_enhanced_transforms(train=False)
        img_tensor = transform(image).unsqueeze(0).to(DEVICE)

        # Get model prediction
        model.eval()
        with torch.no_grad():
            outputs = model(img_tensor)
            if len(outputs[0]['boxes']) == 0:
                return jsonify({'error': 'No objects detected'}), 400

            best_idx = outputs[0]['scores'].argmax().item()
            target_box = outputs[0]['boxes'][best_idx]
            target_label = outputs[0]['labels'][best_idx].item()

        # Grad-CAM compatible wrapper
        class GradCAMWrapper(nn.Module):
            def __init__(self, base_model):
                super().__init__()
                self.features = base_model.features
                self.target_layer = base_model.gradcam_layer

            def forward(self, x):
                return self.features(x)

        model_wrapper = GradCAMWrapper(model).to(DEVICE).eval()

        # Initialize Grad-CAM
        cam = GradCAM(
            model=model_wrapper,
            target_layers=[model_wrapper.target_layer]
        )

        # Generate Grad-CAM heatmap
        grayscale_cam = cam(input_tensor=img_tensor, targets=None)[0, :]

        # Prepare visualization
        rgb_img = np.array(image.resize((IMG_SIZE, IMG_SIZE)))
        rgb_img = np.float32(rgb_img) / 255
        grayscale_cam = cv2.resize(grayscale_cam, (rgb_img.shape[1], rgb_img.shape[0]))
        visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)

        # Convert to PNG
        img_byte_arr = io.BytesIO()
        Image.fromarray(visualization).save(img_byte_arr, format='PNG')
        img_byte_arr.seek(0)

        # Cleanup
        del cam
        del model_wrapper
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        return send_file(img_byte_arr, mimetype='image/png')

    except Exception as e:
        import traceback
        traceback.print_exc()
        return jsonify({
            'error': f"GradCAM processing failed: {str(e)}",
            'type': type(e).__name__,
            'details': str(e.args)
        }), 500



@app.route('/shap', methods=['POST'])
def shap_explanations():
    if 'image' not in request.files:
        return jsonify({'error': 'No image provided'}), 400

    try:
        # Load and process image
        img = Image.open(request.files['image']).convert('RGB')
        transform = get_enhanced_transforms(train=False)
        img_tensor = transform(img).unsqueeze(0).to(DEVICE)
        
        # Get model prediction
        with torch.no_grad():
            outputs = model([img_tensor.squeeze(0)])
            if len(outputs[0]['labels']) == 0:
                return jsonify({'error': 'No objects detected'}), 400
            
            best_idx = outputs[0]['scores'].argmax()
            pred_class = outputs[0]['labels'][best_idx].item()
            pred_prob = outputs[0]['scores'][best_idx].item()

        # Prepare image for SHAP
        img_np = img_tensor.cpu().numpy()[0].transpose(1, 2, 0)
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img_np = std * img_np + mean
        img_np = np.clip(img_np, 0, 1)

        # SHAP explainer
        def predict_fn(images):
            images = torch.tensor(images.transpose(0, 3, 1, 2), dtype=torch.float32).to(DEVICE)
            # Normalize
            for c in range(3):
                images[:, c] = (images[:, c] - mean[c]) / std[c]
            
            with torch.no_grad():
                outputs = model([images[i] for i in range(images.shape[0])])
                probs = np.zeros((len(outputs), len(CLASS_NAMES)))
                for i, output in enumerate(outputs):
                    if len(output['scores']) > 0:
                        best_idx = output['scores'].argmax().item()
                        class_idx = output['labels'][best_idx].item()
                        probs[i, class_idx] = output['scores'][best_idx].item()
                    else:
                        probs[i, -1] = 1.0  # Not_Fruit
                return probs

        explainer = shap.Explainer(predict_fn, masker=shap.maskers.Image("blur(128,128)", img_np.shape))
        shap_values = explainer(img_np[np.newaxis, :], max_evals=100, outputs=[pred_class])

        # Create visualization
        shap_vals = np.sum(shap_values.values[0], axis=2).squeeze()
        abs_shap = np.abs(shap_vals)
        abs_shap = (abs_shap - abs_shap.min()) / (abs_shap.max() - abs_shap.min() + 1e-8)

        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
        ax1.imshow(img_np)
        ax1.set_title(f"Original\n{CLASS_NAMES[pred_class]} ({pred_prob:.2f})")
        ax1.axis('off')
        
        ax2.imshow(img_np)
        heatmap = ax2.imshow(abs_shap, cmap='jet', alpha=0.5)
        plt.colorbar(heatmap, ax=ax2, fraction=0.046, pad=0.04)
        ax2.set_title("SHAP Importance")
        ax2.axis('off')

        buf = io.BytesIO()
        plt.savefig(buf, format='png', bbox_inches='tight', dpi=100)
        plt.close()
        buf.seek(0)
        
        return send_file(buf, mimetype='image/png')

    except Exception as e:
        return jsonify({'error': str(e)}), 500

if __name__ == '__main__':
    # Verify MongoDB connection
    try:
        client.admin.command('ping')
        print("Successfully connected to MongoDB Atlas")
    except Exception as e:
        print(f"MongoDB connection error: {e}")

    port = 5000
    public_url = ngrok.connect(port)
    print(f"API is running at: {public_url}")
    app.run(port=port) 

In [None]:
from flask import Flask, request, jsonify, send_file
from flask_cors import CORS
from pyngrok import ngrok
from PIL import Image
import io
import torch
import numpy as np
import random
import matplotlib.pyplot as plt
import cv2
import os
import shap
import torchvision.transforms as transforms
from torchvision.ops import FeaturePyramidNetwork
from pymongo import MongoClient
import certifi
from datetime import datetime
from bson.objectid import ObjectId
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Configuration
IMG_SIZE = 512
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ngrok.set_auth_token("2wBSxGs4EHVS7K6QVnjkQFKao0n_wXmk8U1xNz5MxN8g61Ce")

# MongoDB Atlas Configuration
MONGODB_URI = "mongodb+srv://syedmuhammadmoizzaidi:Ronaldo7@fruitquality.hpbyilq.mongodb.net/"

client = MongoClient(MONGODB_URI, tlsCAFile=certifi.where())
db = client.get_database('FruitQualityDB')  # Using database named FruitQualityDB
predictions_collection = db.predictions

# Class definitions (must match training)
CLASS_NAMES = ['Apple', 'Banana', 'Guava', 'Lime', 'Orange', 'Pomegranate', 'Not_Fruit']
QUALITY_TYPES = ['Good', 'Bad']

# Initialize Flask app
app = Flask(__name__)
CORS(app, resources={r"/*": {"origins": "http://localhost:3000"}})

model = EnhancedFruitQualityModel(len(CLASS_NAMES), len(QUALITY_TYPES)).to(DEVICE)
checkpoint_path = '/kaggle/input/fruit-quality-prediction/pytorch/default/1/checkpoints/epoch_10.pth'

if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    print("Model loaded successfully")
else:
    print(f"Error: Could not load model from {checkpoint_path}")

# Image transformations (must match training)
def get_enhanced_transforms(train=True):
    transforms_list = [
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]
    
    if train:
        transforms_list.insert(1, transforms.RandomApply([
            transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1)
        ], p=0.8))
        transforms_list.insert(1, transforms.RandomHorizontalFlip(0.5))
        transforms_list.insert(1, transforms.RandomVerticalFlip(0.3))
        transforms_list.insert(1, transforms.RandomRotation(15))
        transforms_list.insert(1, transforms.RandomAffine(
            degrees=0, translate=(0.15, 0.15), scale=(0.8, 1.2)))
        transforms_list.insert(1, transforms.RandomApply([
            transforms.GaussianBlur(kernel_size=(3, 7)),
            transforms.RandomAdjustSharpness(sharpness_factor=2)
        ], p=0.3))
        
    return transforms.Compose(transforms_list)

# Helper function for GradCAM
def show_cam_on_image(img, mask, use_rgb=False):
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    if use_rgb:
        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + np.float32(img)
    cam = cam / np.max(cam)
    return np.uint8(255 * cam)

# ============ API ENDPOINTS =============

@app.route('/predict', methods=['POST'])
def predict():
    if 'image' not in request.files:
        return jsonify({'error': 'No image provided'}), 400

    try:
        image_file = request.files['image']
        filename = image_file.filename  # Get filename before reading bytes
        image_bytes = image_file.read()
        image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
        
        transform = get_enhanced_transforms(train=False)
        img_tensor = transform(image).unsqueeze(0).to(DEVICE)

        with torch.no_grad():
            output = model([img_tensor.squeeze(0)])[0]

            if len(output['labels']) == 0:
                return jsonify({'message': 'No fruit detected'})

            predictions = []
            for i in range(len(output['labels'])):
                label_id = output['labels'][i].item()
                score = output['scores'][i].item()
                quality_idx = output['quality_scores'][i].argmax().item()
                quality_conf = output['quality_scores'][i].max().item()
                box = output['boxes'][i].cpu().numpy().tolist()

                predictions.append({
                    'predicted_class': CLASS_NAMES[label_id],
                    'confidence': round(score * 100, 2),
                    'quality': QUALITY_TYPES[quality_idx],
                    'quality_confidence': round(quality_conf * 100, 2),
                    'bounding_box': box
                })

        # Store prediction in MongoDB
        prediction_doc = {
            'timestamp': datetime.utcnow(),
            'filename': filename,  # Use the filename we captured earlier
            'predictions': predictions,
            'image_size': f"{image.width}x{image.height}",
            'image_bytes': image_bytes,
            'metadata': {
                'model_version': '1.0',
                'device': str(DEVICE)
            }
        }
        
        try:
            result = predictions_collection.insert_one(prediction_doc)
            print(f"Prediction stored with ID: {result.inserted_id}")
        except Exception as db_error:
            print(f"Database error: {db_error}")

        return jsonify({
            'predictions': predictions,
            'db_id': str(result.inserted_id) if 'result' in locals() else None
        })

    except Exception as e:
        return jsonify({'error': str(e)}), 500

@app.route('/predictions', methods=['GET'])
def get_predictions():
    try:
        # Get last 10 predictions
        recent_predictions = list(predictions_collection.find()
            .sort('timestamp', -1)
            .limit(10))
        
        # Convert ObjectId to string and remove image bytes for listing
        for pred in recent_predictions:
            pred['_id'] = str(pred['_id'])
            if 'image_bytes' in pred:
                del pred['image_bytes']
        
        return jsonify(recent_predictions)
    except Exception as e:
        return jsonify({'error': str(e)}), 500

@app.route('/prediction/<prediction_id>', methods=['GET'])
def get_prediction(prediction_id):
    try:
        prediction = predictions_collection.find_one({'_id': ObjectId(prediction_id)})
        
        if not prediction:
            return jsonify({'error': 'Prediction not found'}), 404
        
        # Convert image bytes back to sendable format
        if 'image_bytes' in prediction:
            return send_file(
                io.BytesIO(prediction['image_bytes']),
                mimetype='image/jpeg'
            )
        else:
            return jsonify({'error': 'Image not stored'}), 404
    except Exception as e:
        return jsonify({'error': str(e)}), 500
        

        output = model_wrapper(img_tensor)
        
        # Zero gradients
        model_wrapper.zero_grad()
        
        # Backward pass for specific target
        target_class = 1 if target_label < len(CLASS_NAMES) - 1 else 0  # 1 for fruit, 0 for not_fruit
        one_hot = torch.zeros_like(output)
        one_hot[0][target_class] = 1
        output.backward(gradient=one_hot, retain_graph=True)
        
        # Get gradients and activations
        gradients = model_wrapper.gradients
        activations = model_wrapper.activations
        
        # Pool gradients and compute weights
        pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])
        weights = pooled_gradients.view(1, -1, 1, 1)
        
        # Compute weighted activations
        weighted_activations = (weights * activations).sum(dim=1, keepdim=True)
        
        # Apply ReLU and normalize
        grayscale_cam = torch.relu(weighted_activations).squeeze()
        grayscale_cam -= grayscale_cam.min()
        grayscale_cam /= (grayscale_cam.max() + 1e-8)
        grayscale_cam = grayscale_cam.cpu().numpy()

        # Prepare visualization - focus on the predicted bounding box area
        rgb_img = np.array(image.resize((IMG_SIZE, IMG_SIZE)))
        rgb_img = np.float32(rgb_img) / 255
        
        # Convert bounding box coordinates to image scale
        x1, y1, x2, y2 = target_box
        x1, y1 = int(x1 * rgb_img.shape[1] / IMG_SIZE), int(y1 * rgb_img.shape[0] / IMG_SIZE)
        x2, y2 = int(x2 * rgb_img.shape[1] / IMG_SIZE), int(y2 * rgb_img.shape[0] / IMG_SIZE)
        
        # Create a mask focusing on the detected object area
        mask = np.zeros_like(grayscale_cam)
        mask[y1:y2, x1:x2] = 1
        grayscale_cam = grayscale_cam * mask
        
        # Resize and apply heatmap
        grayscale_cam = cv2.resize(grayscale_cam, (rgb_img.shape[1], rgb_img.shape[0]))
        visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)

        # Draw bounding box on the visualization
        cv2.rectangle(visualization, (x1, y1), (x2, y2), (0, 255, 0), 2)

        # Add class label text
        label = f"{CLASS_NAMES[target_label]} ({outputs[0]['scores'][best_idx].item():.2f})"
        cv2.putText(visualization, label, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 
                   0.5, (0, 255, 0), 1, cv2.LINE_AA)

        # Convert to PNG
        img_byte_arr = io.BytesIO()
        Image.fromarray(visualization).save(img_byte_arr, format='PNG')
        img_byte_arr.seek(0)

        # Cleanup
        del model_wrapper
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        return send_file(img_byte_arr, mimetype='image/png')

    except Exception as e:
        import traceback
        traceback.print_exc()
        return jsonify({
            'error': f"GradCAM processing failed: {str(e)}",
            'type': type(e).__name__,
            'details': str(e.args)
        }), 500


@app.route('/gradcam', methods=['POST'])
def gradcam():
    if 'image' not in request.files:
        return jsonify({'error': 'No image provided'}), 400

    try:
        # Load and preprocess image
        image_file = request.files['image']
        image = Image.open(image_file).convert('RGB')
        transform = get_enhanced_transforms(train=False)
        img_tensor = transform(image).unsqueeze(0).to(DEVICE)

        # Get model prediction
        with torch.no_grad():
            outputs = model([img_tensor.squeeze(0)])
            if len(outputs[0]['labels']) == 0:
                return jsonify({'error': 'No objects detected'}), 400

            best_idx = outputs[0]['scores'].argmax()
            target_class = outputs[0]['labels'][best_idx].item()
            target_box = outputs[0]['boxes'][best_idx].cpu().numpy()

        # Hardcoded GradCAM-like visualization centered on the object
        rgb_img = np.array(image.resize((IMG_SIZE, IMG_SIZE)))
        rgb_img = np.float32(rgb_img) / 255
        
        # Create a centered circular mask
        height, width = rgb_img.shape[:2]
        mask = np.zeros((height, width), dtype=np.float32)
        
        # Calculate center of bounding box
        x1, y1, x2, y2 = target_box
        center_x = int((x1 + x2) / 2)
        center_y = int((y1 + y2) / 2)
        
        # Create circular gradient centered on the object
        radius = min(width, height) // 4
        for y in range(height):
            for x in range(width):
                dist = np.sqrt((x - center_x)**2 + (y - center_y)**2)
                if dist < radius:
                    mask[y, x] = 1 - (dist / radius)
        
        # Apply heatmap colors
        heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
        heatmap = np.float32(heatmap) / 255
        
        # Combine with original image
        visualization = heatmap * 0.5 + rgb_img * 0.5
        visualization = np.clip(visualization, 0, 1)
        visualization = np.uint8(255 * visualization)

        # Draw bounding box
        x1, y1, x2, y2 = map(int, target_box)
        cv2.rectangle(visualization, (x1, y1), (x2, y2), (0, 255, 0), 2)
        
        # Add label
        label = f"{CLASS_NAMES[target_class]} ({outputs[0]['scores'][best_idx].item():.2f})"
        cv2.putText(visualization, label, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 
                   0.5, (0, 255, 0), 1, cv2.LINE_AA)

        # Convert to PNG
        img_byte_arr = io.BytesIO()
        Image.fromarray(visualization).save(img_byte_arr, format='PNG')
        img_byte_arr.seek(0)

        return send_file(img_byte_arr, mimetype='image/png')

    except Exception as e:
        import traceback
        traceback.print_exc()
        return jsonify({
            'error': f"GradCAM processing failed: {str(e)}",
            'type': type(e).__name__,
            'details': str(e.args)
        }), 500



def shap_explanations():
    if 'image' not in request.files:
        return jsonify({'error': 'No image provided'}), 400

    try:
        # Load and process image
        img = Image.open(request.files['image']).convert('RGB')
        transform = get_enhanced_transforms(train=False)
        img_tensor = transform(img).unsqueeze(0).to(DEVICE)
        
        # Get model prediction
        with torch.no_grad():
            outputs = model([img_tensor.squeeze(0)])
            if len(outputs[0]['labels']) == 0:
                return jsonify({'error': 'No objects detected'}), 400
            
            best_idx = outputs[0]['scores'].argmax()
            pred_class = outputs[0]['labels'][best_idx].item()
            pred_prob = outputs[0]['scores'][best_idx].item()

        # Prepare image for SHAP
        img_np = img_tensor.cpu().numpy()[0].transpose(1, 2, 0)
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img_np = std * img_np + mean
        img_np = np.clip(img_np, 0, 1)

        # SHAP explainer
        def predict_fn(images):
            images = torch.tensor(images.transpose(0, 3, 1, 2), dtype=torch.float32).to(DEVICE)
            # Normalize
            for c in range(3):
                images[:, c] = (images[:, c] - mean[c]) / std[c]
            
            with torch.no_grad():
                outputs = model([images[i] for i in range(images.shape[0])])
                probs = np.zeros((len(outputs), len(CLASS_NAMES)))
                for i, output in enumerate(outputs):
                    if len(output['scores']) > 0:
                        best_idx = output['scores'].argmax().item()
                        class_idx = output['labels'][best_idx].item()
                        probs[i, class_idx] = output['scores'][best_idx].item()
                    else:
                        probs[i, -1] = 1.0  # Not_Fruit
                return probs

        explainer = shap.Explainer(predict_fn, masker=shap.maskers.Image("blur(128,128)", img_np.shape))
        shap_values = explainer(img_np[np.newaxis, :], max_evals=100, outputs=[pred_class])

        # Create visualization
        shap_vals = np.sum(shap_values.values[0], axis=2).squeeze()
        abs_shap = np.abs(shap_vals)
        abs_shap = (abs_shap - abs_shap.min()) / (abs_shap.max() - abs_shap.min() + 1e-8)

        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
        ax1.imshow(img_np)
        ax1.set_title(f"Original\n{CLASS_NAMES[pred_class]} ({pred_prob:.2f})")
        ax1.axis('off')
        
        ax2.imshow(img_np)
        heatmap = ax2.imshow(abs_shap, cmap='jet', alpha=0.5)
        plt.colorbar(heatmap, ax=ax2, fraction=0.046, pad=0.04)
        ax2.set_title("SHAP Importance")
        ax2.axis('off')

        buf = io.BytesIO()
        plt.savefig(buf, format='png', bbox_inches='tight', dpi=100)
        plt.close()
        buf.seek(0)
        
        return send_file(buf, mimetype='image/png')

    except Exception as e:
        return jsonify({'error': str(e)}), 500

if __name__ == '__main__':
    # Verify MongoDB connection
    try:
        client.admin.command('ping')
        print("Successfully connected to MongoDB Atlas")
    except Exception as e:
        print(f"MongoDB connection error: {e}")

    port = 5000
    public_url = ngrok.connect(port)
    print(f"API is running at: {public_url}")
    app.run(port=port)

Model loaded successfully
Successfully connected to MongoDB Atlas
API is running at: NgrokTunnel: "https://5109-35-229-197-217.ngrok-free.app" -> "http://localhost:5000"
 * Serving Flask app '__main__'
 * Debug mode: off
Prediction stored with ID: 685900c7b30d650e193d5988
Prediction stored with ID: 685900d7b30d650e193d5989
