In [1]:
"""
Real-Time Fruit Freshness Detection using a Trained VGG16 Classifier
Can predict from static image or webcam feed.
"""

import torch
import torch.nn as nn
import torchvision.transforms.v2 as T
import numpy as np
import cv2
from PIL import Image
from torchvision.models import vgg16, VGG16_Weights

In [2]:
# Class info
CLASS_LABELS = ["fresh apple", "fresh banana", "fresh orange", "rotten apple", "rotten banana", "rotten orange"]
FRUIT_TYPE = {0: "Apple", 1: "Banana", 2: "Orange", 3: "Apple", 4: "Banana", 5: "Orange"}
FRUIT_STATE = {0: "Fresh", 1: "Fresh", 2: "Fresh", 3: "Rotten", 4: "Rotten", 5: "Rotten"}

# Load model architecture and weights
def load_model(path="fruit_model.pth"):
    model = vgg16(weights=VGG16_Weights.DEFAULT)
    model.classifier[6] = nn.Linear(4096, 6)

    custom_model = nn.Sequential(
        model.features,
        model.avgpool,
        nn.Flatten(),
        model.classifier[0:3],
        nn.Dropout(0.3),
        nn.Linear(4096, 500),
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Linear(500, 6)
    )
    custom_model.load_state_dict(torch.load(path, map_location=device))
    return custom_model.to(device).eval()


In [3]:
# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Transformation for inference
transform = T.Compose([
    T.Resize((224, 224)),
    T.ToImage(),
    T.ToDtype(torch.float32, scale=True),
    T.Normalize(mean=(0.485, 0.456, 0.406),
                std=(0.229, 0.224, 0.225))
])

In [28]:
# Predict from PIL or OpenCV image

def predict_and_annotate(image, model, device):
    original_h, original_w = image.shape[:2]
    img_resized = cv2.resize(image, (224, 224))
    img_pil = Image.fromarray(cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB))
    img_tensor = transform(img_pil).unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(img_tensor)
        class_id = output.argmax(1).item()

    label = f"{FRUIT_TYPE[class_id]} - {FRUIT_STATE[class_id]}"

    # Draw square on center of original image resolution
    box_w, box_h = original_w // 3, original_h // 3
    x1 = (original_w - box_w) // 2
    y1 = (original_h - box_h) // 2
    x2 = x1 + box_w
    y2 = y1 + box_h
    cv2.rectangle(image, (x1, y1), (x2, y2), color=(0, 255, 0), thickness=3)
    cv2.putText(image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 2)

    return image

In [29]:
# Predict from image path
def predict_from_image(path, model, device):
    image = cv2.imread(path)
    annotated = predict_and_annotate(image, model, device)
    cv2.imshow("Prediction", annotated)
    cv2.waitKey(0)
    cv2.destroyAllWindows()


In [6]:
# Predict from webcam
def predict_from_camera(model, device):
    cap = cv2.VideoCapture(0)
    print("Press 'q' to quit")
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        annotated = predict_and_annotate(frame, model, device)
        cv2.imshow("Webcam Prediction", annotated)

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()
    cv2.destroyAllWindows()


In [35]:
# Load model
my_model = load_model("C:/Users/okeiy/Downloads/Nvdia Learning/fruit_model_1.pth")

  custom_model.load_state_dict(torch.load(path, map_location=device))


In [36]:
# Use either one:

# predict_from_image("C:/Users/okeiy/Downloads/Nvdia Learning/5.jpg", my_model, device)

#  Or 

predict_from_camera(my_model, device)

# Note: Ensure the model path is correct and the model is trained with the same architecture.
# The webcam feed will show real-time predictions with annotations.

Press 'q' to quit
