In [1]:
import torch
from torchvision import models

def load_trained_mobilenetv2(weights_path, num_classes):
    """
    Load a pre-trained MobileNetV2 model.

    Args:
    weights_path (str): Path to the model's weights file (e.g., 'mobilenetv2_epoch10.pth')
    num_classes (int): Number of classes in the classification task, should match the training setup

    Returns:
    model (torch.nn.Module): Model with loaded weights, ready for inference or further training
    """
    # Initialize the MobileNetV2 architecture without pretrained weights to load custom ones
    model = models.mobilenet_v2(pretrained=False)

    # Replace the classifier to match the training setup (usually redefined based on the dataset)
    model.classifier[1] = torch.nn.Linear(model.last_channel, num_classes)

    # Load the trained model parameters (state_dict)
    model.load_state_dict(torch.load(weights_path, map_location=torch.device("cpu")))

    # Set the model to evaluation mode, suitable for inference phase
    model.eval()

    return model


In [5]:
model = load_trained_mobilenetv2("mobilenetv2_epoch10.pth", num_classes=39)

In [7]:
from PIL import Image
from torchvision import transforms
import torch

def predict_image(model, image_path, class_names):
    """
    Perform classification prediction on a single image.

    Args:
    model (nn.Module): The loaded model
    image_path (str): Path to the image file
    class_names (list): List of class names, indices correspond to model outputs

    Returns:
    predicted_class (str): The predicted class name
    """
    # Define image preprocessing steps consistent with training
    preprocess = transforms.Compose([
        transforms.Resize((224, 224)),        # Resize the image
        transforms.ToTensor(),                # Convert image to tensor
    ])

    image = Image.open(image_path).convert('RGB')
    image_tensor = preprocess(image)                 # Apply preprocessing to the image
    image_tensor = image_tensor.unsqueeze(0)         # Add batch dimension -> [1, C, H, W]

    model.eval()
    with torch.no_grad():
        outputs = model(image_tensor)                # Forward pass
        _, predicted = torch.max(outputs, 1)         # Get the predicted class index

    predicted_index = predicted.item()
    predicted_class = class_names[predicted_index]   # Map index to class name
    return predicted_class



In [9]:
class_names = [
    'Apple___Apple_scab',  # 苹果黑星病 Apple scab
    'Apple___Black_rot',  # 苹果黑腐病 Apple Black rot
    'Apple___Cedar_apple_rust',  # 苹果雪松锈病 Apple Cedar apple rust
    'Apple___healthy',  # 苹果健康 Apple healthy
    'Blueberry___healthy',  # 蓝莓健康 Blueberry healthy
    'Cherry_(including_sour)___Powdery_mildew',  # 樱桃白粉病 Cherry Powdery mildew
    'Cherry_(including_sour)___healthy',  # 樱桃健康 Cherry healthy
    'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot',  # 玉米灰斑病 Corn Gray leaf spot
    'Corn_(maize)___Common_rust_',  # 玉米普通锈病 Corn Common rust
    'Corn_(maize)___Northern_Leaf_Blight',  # 玉米北方叶斑病 Corn Northern Leaf Blight
    'Corn_(maize)___healthy',  # 玉米健康 Corn healthy
    'Grape___Black_rot',  # 葡萄黑腐病 Grape Black rot
    'Grape___Esca_(Black_Measles)',  # 葡萄腐烂病 Grape Esca (Black Measles)
    'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)',  # 葡萄叶枯病 Grape Leaf blight (Isariopsis Leaf Spot)
    'Grape___healthy',  # 葡萄健康 Grape healthy
    'Orange___Haunglongbing_(Citrus_greening)',  # 橙黄龙病 Orange Huanglongbing (Citrus greening)
    'Peach___Bacterial_spot',  # 桃细菌性斑点 Peach Bacterial spot
    'Peach___healthy',  # 桃健康 Peach healthy
    'Pepper,_bell___Bacterial_spot',  # 灯笼椒细菌性斑点 Pepper, bell Bacterial spot
    'Pepper,_bell___healthy',  # 灯笼椒健康 Pepper, bell healthy
    'Potato___Early_blight',  # 马铃薯早疫病 Potato Early blight
    'Potato___Late_blight',  # 马铃薯晚疫病 Potato Late blight
    'Potato___healthy',  # 马铃薯健康 Potato healthy
    'Raspberry___healthy',  # 树莓健康 Raspberry healthy
    'Soybean___healthy',  # 大豆健康 Soybean healthy
    'Squash___Powdery_mildew',  # 南瓜白粉病 Squash Powdery mildew
    'Strawberry___Leaf_scorch',  # 草莓叶灼病 Strawberry Leaf scorch
    'Strawberry___healthy',  # 草莓健康 Strawberry healthy
    'Tomato___Bacterial_spot',  # 番茄细菌性斑点 Tomato Bacterial spot
    'Tomato___Early_blight',  # 番茄早疫病 Tomato Early blight
    'Tomato___Late_blight',  # 番茄晚疫病 Tomato Late blight
    'Tomato___Leaf_Mold',  # 番茄叶霉病 Tomato Leaf Mold
    'Tomato___Septoria_leaf_spot',  # 番茄叶斑病 Tomato Septoria leaf spot
    'Tomato___Spider_mites Two-spotted_spider_mite',  # 番茄二斑叶螨 Tomato Spider mites (Two-spotted spider mite)
    'Tomato___Target_Spot',  # 番茄靶斑病 Tomato Target Spot
    'Tomato___Tomato_Yellow_Leaf_Curl_Virus',  # 番茄黄化卷叶病毒病 Tomato Yellow Leaf Curl Virus
    'Tomato___Tomato_mosaic_virus',  # 番茄花叶病毒病 Tomato mosaic virus
    'Tomato___healthy',  # 番茄健康 Tomato healthy
    'background'  # 背景背景（非植物）Background (non-plant)
]


In [15]:

pred = predict_image(model, "apple_black_rot.png", class_names)
print("result：", pred)

result： Apple___Black_rot


In [19]:
pred = predict_image(model, "Apple_scab.png", class_names)
print("result：", pred)

result： Apple___Apple_scab


In [23]:
pred = predict_image(model, "Corn_Northern.png", class_names)
print("result：", pred)

result： Corn_(maize)___Northern_Leaf_Blight


In [25]:
pred = predict_image(model, "peach_bacterial.png", class_names)
print("result：", pred)

result： Peach___Bacterial_spot
