In [None]:

import kagglehub
import os

# Download dataset
dataset_path = kagglehub.dataset_download("vipoooool/new-plant-diseases-dataset")
print("Dataset path:", dataset_path)

# Define training path
base_path = "/kaggle/input/new-plant-diseases-dataset"
train_path = os.path.join(
    base_path,
    "New Plant Diseases Dataset(Augmented)",
    "New Plant Diseases Dataset(Augmented)",
    "train"
)

classes = sorted(os.listdir(train_path))
print("Number of classes:", len(classes))

Using Colab cache for faster access to the 'new-plant-diseases-dataset' dataset.
Dataset path: /kaggle/input/new-plant-diseases-dataset
Number of classes: 38


In [None]:
!pip install -q segment-anything opencv-python

In [None]:
import cv2
import torch
import numpy as np
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator

In [None]:
!wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth

In [None]:
sam_checkpoint = "/content/sam_vit_b_01ec64.pth"
model_type = "vit_b"

device = "cuda" if torch.cuda.is_available() else "cpu"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device)

mask_generator = SamAutomaticMaskGenerator(
    sam,
    points_per_side=32,
    pred_iou_thresh=0.88,
    stability_score_thresh=0.92,
    min_mask_region_area=500
)

print("SAM loaded on:", device)

SAM loaded on: cuda


In [None]:
def sam_disease_localization(image_path):
    image_bgr = cv2.imread(image_path)
    image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
    hsv = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2HSV)

    h, w, _ = image_rgb.shape
    img_area = h * w

    # Leaf mask
    lower_leaf = np.array([20, 30, 30])
    upper_leaf = np.array([90, 255, 255])
    leaf_mask = cv2.inRange(hsv, lower_leaf, upper_leaf)

    sam_masks = mask_generator.generate(image_rgb)

    boxes = []
    MAX_MASKS_PER_IMAGE = 20

    for m in sam_masks:
        area_ratio = m["area"] / img_area
        if area_ratio < 0.002 or area_ratio > 0.25:
            continue

        mask = m["segmentation"]

        # Must overlap leaf
        overlap = np.logical_and(mask, leaf_mask > 0).sum()
        if overlap / m["area"] < 0.7:
            continue

        # -------------------------------
        # DISEASE FILTER (KEY FIX)
        # -------------------------------

        masked_hsv = hsv[mask]

        mean_h, mean_s, mean_v = masked_hsv.mean(axis=0)

        # Healthy leaf green ≈ mid H, high S, high V
        healthy_green = (
            35 <= mean_h <= 85 and
            mean_s > 60 and
            mean_v > 60
        )

        if healthy_green:
            continue  # reject healthy regions

        # Texture check (variance)
        gray = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2GRAY)
        texture_var = gray[mask].var()

        if texture_var < 100:
            continue  # uniform = likely healthy

        # -------------------------------
        # Accept region → bounding box
        # -------------------------------

        x, y, w_box, h_box = m["bbox"]
        boxes.append((x, y, w_box, h_box))

        if len(boxes) >= MAX_MASKS_PER_IMAGE:
            break

    return image_rgb, boxes

In [None]:
def overlay_boxes(image, boxes):
    output = image.copy()

    # for mask in masks:
    #     color = np.random.randint(0, 255, (3,), dtype=np.uint8)
    #     output[mask] = (
    #         0.6 * output[mask] + 0.4 * color
    #     ).astype(np.uint8)

    for x, y, w, h in boxes:
        # Cast coordinates to integers
        x, y, w, h = int(x), int(y), int(w), int(h)
        cv2.rectangle(output, (x, y), (x + w, y + h), (255, 0, 0), 2)

    return output

In [None]:
import matplotlib.pyplot as plt
import math

def visualize_class_sam_results(
    class_name,
    num_images=20,
    rows=10,
    cols=10
):
    class_path = os.path.join(train_path, class_name)
    images = os.listdir(class_path)[:num_images]

    plt.figure(figsize=(cols * 3, rows * 3))

    for idx, img_file in enumerate(images):

      print(f"[{idx+1}/{num_images}] Processing {img_file}")

      img_path = os.path.join(class_path, img_file)

      image, boxes = sam_disease_localization(img_path)
      result = overlay_boxes(image, boxes)

      plt.subplot(rows, cols, idx + 1)
      plt.imshow(result)
      plt.axis("off")

    plt.suptitle(class_name, fontsize=16)
    plt.tight_layout()
    plt.show()

In [None]:
visualize_class_sam_results(classes[5])

Output hidden; open in https://colab.research.google.com to view.