In [2]:
import cv2
import torch
from torch import nn
from torchvision import transforms, models
from PIL import Image
import numpy as np

In [3]:
model = torch.load("../models/model.pth", weights_only=False)
model.eval()

transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(0, 1)
])

In [4]:
class_names = ['apple', 'banana', 'beetroot', 'bell pepper', 'cabbage', 'capsicum', 'carrot', 'cauliflower', 'chilli pepper', 'corn', 'cucumber', 'eggplant', 'garlic', 'ginger', 'grapes', 'jalepeno', 'kiwi', 'lemon', 'lettuce', 'mango', 'onion', 'orange', 'paprika', 'pear', 'peas', 'pineapple', 'pomegranate', 'potato', 'raddish', 'soy beans', 'spinach', 'sweetcorn', 'sweetpotato', 'tomato', 'turnip', 'watermelon']

# Image

In [5]:
image = cv2.imread("banana.jpg")
image_rgb = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))

image_tensor = transform(image_rgb).unsqueeze(0)  # Add batch dimension

with torch.no_grad():
    output = model(image_tensor)
    _, predicted = torch.max(output, 1)

class_idx = predicted.item()

predicted_label = class_names[class_idx]
print(f"Predicted label: {predicted_label}")

Predicted label: banana


# Video

In [8]:
# 使用本地视频文件
video_path = "2025-02-18-20-51-03.mp4"  # 替换为你的视频文件路径
cap = cv2.VideoCapture(video_path)

if not cap.isOpened():
    print("Error: Could not open video file.")
    exit()

while True:
    ret, frame = cap.read()
    
    if not ret:
        print("End of video.")
        break

    # 将视频帧从 BGR 转换为 RGB
    image_rgb = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))

    # 应用图像转换
    image_tensor = transform(image_rgb).unsqueeze(0)  # Add batch dimension

    # 进行预测
    with torch.no_grad():
        output = model(image_tensor)
        _, predicted = torch.max(output, 1)

    # 获取预测的标签
    class_idx = predicted.item()
    predicted_label = class_names[class_idx]
    
    # 在视频帧上显示预测结果
    font = cv2.FONT_HERSHEY_SIMPLEX
    cv2.putText(frame, f'Predicted: {predicted_label}', (10, 30), font, 1, (0, 255, 0), 2)

    # 显示视频帧
    cv2.imshow("Video Recognition", frame)

    # 按 'q' 键退出
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

# 释放视频捕捉对象和关闭窗口
cap.release()
cv2.destroyAllWindows()

# Camera

In [9]:
cap = cv2.VideoCapture(0)

if not cap.isOpened():
    exit()

while True:
    ret, frame = cap.read()
    if not ret:
        print("Error: Failed to capture image")
        break
        
    # 将捕获到的帧转换为RGB并应用预处理
    image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
    image_tensor = transform(image).unsqueeze(0)  # Add batch dimension

    # 将图像传入模型进行预测
    with torch.no_grad():
        outputs = model(image_tensor)  # 获取模型输出
        _, predicted = torch.max(outputs, 1)  # 获取最大概率的预测类别
    
    # 获取预测的类别索引
    class_idx = predicted.item()
    
    # 显示预测标签
    predicted_class = class_names[class_idx]
    font = cv2.FONT_HERSHEY_SIMPLEX
    cv2.putText(frame, f'Predicted: {predicted_class}', (10, 30), font, 1, (0, 255, 0), 2)

    # 显示视频帧
    cv2.imshow("Fruit Recognition", frame)

    # 按 'q' 键退出
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

# 释放视频捕捉对象和关闭窗口
cap.release()
cv2.destroyAllWindows()