In [1]:
import torch
from transformers import AutoModelForImageClassification, AutoImageProcessor
from torchvision import transforms
from PIL import Image
import cv2

In [2]:
model = AutoModelForImageClassification.from_pretrained("jazzmacedo/fruits-and-vegetables-detector-36")

In [3]:
# Define the preprocessing transformation (for compatibility with AutoImageProcessor)
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

In [4]:
def classify_image(image_path: str) -> list:
    # Read and preprocess the image
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    pil_image = Image.fromarray(image)
    input_tensor = preprocess(pil_image).unsqueeze(0)

    # Get predictions without gradients
    with torch.no_grad():
        outputs = model(input_tensor)
    
    # Get top 5 predictions
    top_5_probs, top_5_indices = torch.topk(outputs.logits, k=5, dim=1)
    top_5_probs = top_5_probs.squeeze().tolist()
    top_5_indices = top_5_indices.squeeze().tolist()

    # Map indices to labels
    top_5_predictions = [(model.config.id2label[idx], prob) for idx, prob in zip(top_5_indices, top_5_probs)]
    
    return top_5_predictions

In [6]:
# Example usage
image_path = "../imgs/duck.jpg"  # Update with your image path
label = classify_image(image_path)
print("Detected label:", label)

Detected label: [('pear', -0.7279641032218933), ('onion', -1.9550637006759644), ('garlic', -1.9866228103637695), ('turnip', -3.0375967025756836), ('potato', -3.1575419902801514)]
