In [15]:
import io
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
import gradio as gr

In [16]:
model_path = 'models/detection/faster_not_jit_rcnn_model.pt'
classes= {
    0: 'Corn Cercospora Leaf Spot',
    1: 'Corn Common Rust',
    2: 'Corn Healthy',
    3: 'Corn Streak',
    4: 'Corn Northern Leaf Blight',
    5: 'Pepper Leaf Curl',
    6: 'Pepper Cercospora',
    7: 'Pepper Leaf Blight',
    8: 'Pepper Bacterial Spot',
    9: 'Pepper Leaf Mosaic',
    10: 'Pepper Healthy',
    11: 'Pepper Fusarium',
    12: 'Pepper Septoria',
    13: 'Pepper Late Blight',
    14: 'Pepper Early Blight',
    15: 'Tomato Late Blight',
    16: 'Tomato Early Blight',
    17: 'Tomato Bacterial Spot',
    18: 'Tomato Septoria',
    19: 'Tomato Fusarium',
    20: 'Tomato Leaf Curl',
    21: 'Tomato Healthy',
    22: 'Tomato Mosaic'
}

In [17]:
# object detector
def obj_detector(model, img):
    # img = cv2.imread(img, cv2.IMREAD_COLOR)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)
    img /= 255.0
    img = torch.from_numpy(img)
    img = img.unsqueeze(0)
    img = img.permute(0,3,1,2)

    model.eval()
    
    img = list(im for im in img)
    output = model(img)

    for i , _ in enumerate(img):
        boxes = output[i]['boxes'].data.cpu().numpy()
        scores = output[i]['scores'].data.cpu().numpy()
        _ = output[i]['labels'].data.cpu().numpy()
    
    sample = img[0].permute(1,2,0).cpu().numpy()
    sample = np.array(sample)
    boxes = output[0]['boxes'].data.cpu().numpy()
    name = output[0]['labels'].data.cpu().numpy()
    scores = output[0]['scores'].data.cpu().numpy()
    boxes = boxes.astype(np.int32)
    names = name.tolist()
    
    return names, boxes, sample, scores

In [18]:
def plot_result(sample, boxes, names):
    plt.figure(figsize=(20,60))
    for i,box in enumerate(boxes):
        cv2.rectangle(
            sample,
            (box[0], box[1]),
            (box[2], box[3]),
            (0, 220, 0), 2
        )
        cv2.putText(
            sample, 
            classes[names[i]], 
            (box[0],box[1]-5),
            cv2.FONT_HERSHEY_COMPLEX ,
            0.7,
            (220,0,0),
            1,
            cv2.LINE_AA
        )

    plt.axis('off')
    plt.imshow(sample)
    # plt.savefig('save_image.png', bbox_inches='tight')

def save_results(image, boxes, names, save_path):
    img = image.copy()
    for box, name in zip(boxes, names):
        cv2.rectangle(
            img,
            (box[0], box[1]),
            (box[2], box[3]),
            (0, 220, 0), 2
        )
        cv2.putText(
            img, 
            classes[name], 
            (box[0],box[1]-5),
            cv2.FONT_HERSHEY_COMPLEX ,
            0.7,
            (220,0,0),
            1,
            cv2.LINE_AA
        )
    # Save the image with bounding boxes drawn
    # cv2.imwrite(save_path, cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    return img


In [19]:
def inference_detection_model(image):
    map_location = 'cpu'
    model = torch.load(model_path, map_location=map_location)
    names,boxes,_, scores = obj_detector(model, image)
    save_path = 'result_image.jpg'
    result_image = save_results(image, boxes, names, save_path)
    _names = []
    for name in names:
        _names.append(classes[name])
    return result_image, _names, boxes, scores

In [20]:
inputs = [
    gr.Image(label="Image")
]
outputs = [
    gr.Image(label="Result", type='pil'), 
    gr.Text(label="Names")
]

title = "Crop Disease Detector "
description = "This module detect disease that manifest symptoms on the leaves of crops. \nIt currently works with three crops namely Corn, Tomato, Pepper"

gr.Interface(
    fn=inference_detection_model,
    inputs=inputs,
    outputs=outputs,
    title=title,
    description=description
).launch()

Running on local URL:  http://127.0.0.1:7864

To create a public link, set `share=True` in `launch()`.


