In [1]:
import torch
from torchvision import models
from PIL import Image
from torchvision import transforms
import os
import gradio as gr

In [3]:
# Define the function to load the model
def load_model(model_name):
    model = None
    if model_name=="VGG19":
        model = models.vgg19(pretrained=True)
        model.load_state_dict(torch.load('vgg19_best_model.pth'))
    elif model_name=="GoogLeNet":        
        model = models.googlenet(pretrained=True)
        model.load_state_dict(torch.load('googlenet_best_model.pth'))
    elif model_name=="EfficientNet":
        model = models.efficientnet_b0(pretrained=True)
        model.load_state_dict(torch.load('effnet_best_model.pth'))
    elif model_name=="ResNet50":
        model = models.resnet50(pretrained=True)
        model.load_state_dict(torch.load('resnet50_best_model.pth'))
    return model

In [4]:
def predict(inp, model_name):
    class_names = ['Apple__black_rot', 'Apple__healthy', 'Apple__rust', 'Apple__scab', 'Cassava__bacterial_blight', 'Cassava__brown_streak_disease', 'Cassava__green_mottle', 'Cassava__healthy', 'Cassava__mosaic_disease', 'Cherry__healthy', 'Cherry__powdery_mildew', 'Chili__healthy', 'Chili__leaf curl', 'Chili__leaf spot', 'Chili__whitefly', 'Chili__yellowish', 'Coffee__cercospora_leaf_spot', 'Coffee__healthy', 'Coffee__red_spider_mite', 'Coffee__rust', 'Corn__common_rust', 'Corn__gray_leaf_spot', 'Corn__healthy', 'Corn__northern_leaf_blight', 'Cucumber__diseased', 'Cucumber__healthy', 'Gauva__diseased', 'Gauva__healthy', 'Grape__black_measles', 'Grape__black_rot', 'Grape__healthy', 'Grape__leaf_blight_(isariopsis_leaf_spot)', 'Jamun__diseased', 'Jamun__healthy', 'Lemon__diseased', 'Lemon__healthy', 'Mango__diseased', 'Mango__healthy', 'Peach__bacterial_spot', 'Peach__healthy', 'Pepper_bell__bacterial_spot', 'Pepper_bell__healthy', 'Pomegranate__diseased', 'Pomegranate__healthy', 'Potato__early_blight', 'Potato__healthy', 'Potato__late_blight', 'Rice__brown_spot', 'Rice__healthy', 'Rice__hispa', 'Rice__leaf_blast', 'Rice__neck_blast', 'Soybean__bacterial_blight', 'Soybean__caterpillar', 'Soybean__diabrotica_speciosa', 'Soybean__downy_mildew', 'Soybean__healthy', 'Soybean__mosaic_virus', 'Soybean__powdery_mildew', 'Soybean__rust', 'Soybean__southern_blight', 'Strawberry___leaf_scorch', 'Strawberry__healthy', 'Sugarcane__bacterial_blight', 'Sugarcane__healthy', 'Sugarcane__red_rot', 'Sugarcane__red_stripe', 'Sugarcane__rust', 'Tea__algal_leaf', 'Tea__anthracnose', 'Tea__bird_eye_spot', 'Tea__brown_blight', 'Tea__healthy', 'Tea__red_leaf_spot', 'Tomato__bacterial_spot', 'Tomato__early_blight', 'Tomato__healthy', 'Tomato__late_blight', 'Tomato__leaf_mold', 'Tomato__mosaic_virus', 'Tomato__septoria_leaf_spot', 'Tomato__spider_mites_(two_spotted_spider_mite)', 'Tomato__target_spot', 'Tomato__yellow_leaf_curl_virus', 'Wheat__brown_rust', 'Wheat__healthy', 'Wheat__septoria', 'Wheat__yellow_rust']
    model = load_model(model_name)
    model.eval()
    inp = transforms.ToTensor()(inp).unsqueeze(0)
    with torch.no_grad():
        prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
        confidences = {class_names[i]: float(prediction[i]) for i in range(88)}
    return confidences
    

In [None]:
inputs = [
    gr.Image(type="pil", label="Input Image"),
    gr.Dropdown(choices=["VGG19", "GoogLeNet", "EfficientNet", "ResNet50"], label="Model")
]

output = gr.Label(num_top_classes=5)

# Create the interface
gr.Interface(fn=predict, inputs=inputs, outputs=output, title="Plant Disease Classification", theme = "gradio/monochrome").launch(debug=True, share=True)

Running on local URL:  http://127.0.0.1:7860
IMPORTANT: You are using gradio version 4.27.0, however version 4.29.0 is available, please upgrade.
--------
Running on public URL: https://5bb552e9df6efdd533.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)
