In [1]:
import torch
from torchvision.models import resnet18, ResNet18_Weights
from torchvision import transforms
from PIL import Image
import cv2
import numpy as np
import time

In [2]:
# Load the pre-trained ResNet18 model
#I have to import this paticular weights that is recommended cause the usual one threw loads and loads of warnings
weights = ResNet18_Weights.IMAGENET1K_V1
model = resnet18(weights=weights)

# Replacing final layer with just two classes
num_classes = 2 
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)


In [3]:
# directly loading the model parameters if you have already trained the model
model.load_state_dict(torch.load(r"C:\Users\User\Desktop\DeepLearning\Major Project\resnet_model_parameters.pth"))
model.eval()  # Set the model to evaluation mode (important for inference)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [4]:
# Preprocessing the frame
def preprocess_frame(frame):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # ResNet standard input size
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Standard normalization
    ])
    image = Image.fromarray(frame)  # Convert NumPy array (OpenCV) to PIL Image
    return transform(image).unsqueeze(0)  # Add batch dimension

# Classify the banana
def classify_banana(model, frame):
    input_tensor = preprocess_frame(frame)
    output = model(input_tensor)
    prediction = torch.argmax(output, dim=1).item()
    return prediction
# Real-time classification
def real_time_classification(model):
    cap = cv2.VideoCapture(0)  # Use the default camera
    print("Starting real-time banana classification...")
    last_capture_time = 0  # Track the time of the last capture
    while True:
        ret, frame = cap.read()
        if not ret:
            print("Failed to grab frame. Exiting...")
            break

        # Display the frame for visualization
        cv2.imshow("Banana Classifier", frame)
        current_time = time.time()
        if current_time - last_capture_time >= 4:
            # Update the last capture time
            last_capture_time = current_time
            # Classify the banana every few seconds
            start_time = time.time()
            prediction = classify_banana(model, frame)
            end_time = time.time()

            # Output the result
            if prediction == 1:
                print(f"Rotten Banana Detected (Time: {end_time - start_time:.2f}s)")
            # I will probably return 1 later for hardware control
            else:
                print(f"Fresh Banana Detected (Time: {end_time - start_time:.2f}s)")
            # I will probably return 1 later for hardware control


        # Wait for a keypress to exit
        if cv2.waitKey(1) & 0xFF == ord('q'):  # Press 'q' to quit
            break

    cap.release()
    cv2.destroyAllWindows()

In [5]:
real_time_classification(model)

Starting real-time banana classification...
Rotten Banana Detected (Time: 0.22s)
Rotten Banana Detected (Time: 0.09s)
Rotten Banana Detected (Time: 0.06s)
Rotten Banana Detected (Time: 0.09s)
Rotten Banana Detected (Time: 0.08s)
Rotten Banana Detected (Time: 0.09s)
Rotten Banana Detected (Time: 0.08s)
Rotten Banana Detected (Time: 0.10s)
Rotten Banana Detected (Time: 0.08s)
Rotten Banana Detected (Time: 0.08s)
Rotten Banana Detected (Time: 0.08s)
Rotten Banana Detected (Time: 0.09s)
Rotten Banana Detected (Time: 0.08s)
Rotten Banana Detected (Time: 0.07s)
