In [1]:
import os
import torch
import json
import numpy as np
from PIL import Image
from torchvision import transforms
from transformers import AutoImageProcessor, AutoModelForImageClassification

In [2]:
# === SET DEVICE ===
print(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No CUDA device found")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

NVIDIA GeForce RTX 3050 Ti Laptop GPU


In [3]:
# === PATHS ===
test_image_dir = "New Plant Diseases Dataset//test" 
model_dir = "efficientnet-plant-disease"

In [4]:
# === LOAD LABEL MAP ===
with open(os.path.join(model_dir, "label_map.json"), "r") as f:
    label_map = json.load(f)
id2label = {int(k): v for k, v in label_map.items()}

In [5]:
# === LOAD PROCESSOR & MODEL ===
processor = AutoImageProcessor.from_pretrained(model_dir)
model = AutoModelForImageClassification.from_pretrained(model_dir).to(device)
model.eval()

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


EfficientNetForImageClassification(
  (efficientnet): EfficientNetModel(
    (embeddings): EfficientNetEmbeddings(
      (padding): ZeroPad2d((0, 1, 0, 1))
      (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=valid, bias=False)
      (batchnorm): BatchNorm2d(32, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)
      (activation): SiLU()
    )
    (encoder): EfficientNetEncoder(
      (blocks): ModuleList(
        (0): EfficientNetBlock(
          (depthwise_conv): EfficientNetDepthwiseLayer(
            (depthwise_conv_pad): ZeroPad2d((0, 1, 0, 1))
            (depthwise_conv): EfficientNetDepthwiseConv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=same, groups=32, bias=False)
            (depthwise_norm): BatchNorm2d(32, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)
            (depthwise_act): SiLU()
          )
          (squeeze_excite): EfficientNetSqueezeExciteLayer(
            (squeeze): AdaptiveAvgPool2d(output

In [6]:
# === TRANSFORM ===
transform = transforms.Compose([
    transforms.Resize((260, 260)),
    transforms.ToTensor(),
    transforms.Normalize(mean=processor.image_mean, std=processor.image_std)
])

In [7]:
# === INFERENCE FUNCTION ===
def predict_image(image_path):
    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0).to(device)  # shape: [1, 3, 260, 260]
    
    with torch.no_grad():
        outputs = model(image).logits
        predicted_class_idx = torch.argmax(outputs, dim=1).item()
        predicted_label = id2label[predicted_class_idx]
    
    return predicted_label

In [8]:
# === RUN PREDICTION ON ALL IMAGES ===
print("🔍 Running inference on test images...\n")

image_files = [f for f in os.listdir(test_image_dir) if f.lower().endswith((".jpg", ".jpeg", ".png"))]

for img_file in sorted(image_files):
    img_path = os.path.join(test_image_dir, img_file)
    pred_label = predict_image(img_path)
    print(f"{img_file} --> Predicted: {pred_label}")

print(f"\n✅ Completed predictions on {len(image_files)} images.")

🔍 Running inference on test images...

AppleCedarRust1.JPG --> Predicted: Apple___Cedar_apple_rust
AppleCedarRust2.JPG --> Predicted: Apple___Cedar_apple_rust
AppleCedarRust3.JPG --> Predicted: Apple___Cedar_apple_rust
AppleCedarRust4.JPG --> Predicted: Apple___Cedar_apple_rust
AppleScab1.JPG --> Predicted: Apple___Apple_scab
AppleScab2.JPG --> Predicted: Apple___Apple_scab
AppleScab3.JPG --> Predicted: Apple___Apple_scab
CornCommonRust1.JPG --> Predicted: Corn_(maize)___Common_rust_
CornCommonRust2.JPG --> Predicted: Corn_(maize)___Common_rust_
CornCommonRust3.JPG --> Predicted: Corn_(maize)___Common_rust_
PotatoEarlyBlight1.JPG --> Predicted: Potato___Early_blight
PotatoEarlyBlight2.JPG --> Predicted: Potato___Early_blight
PotatoEarlyBlight3.JPG --> Predicted: Potato___Early_blight
PotatoEarlyBlight4.JPG --> Predicted: Potato___Early_blight
PotatoEarlyBlight5.JPG --> Predicted: Potato___Early_blight
PotatoHealthy1.JPG --> Predicted: Potato___healthy
PotatoHealthy2.JPG --> Predicted: 