Using Tensorflow Object detection API to detect mushrooms in images

In [None]:
# For running inference on the TF-Hub module.
import tensorflow as tf

import tensorflow_hub as hub

# For downloading the image.
import matplotlib.pyplot as plt

# For drawing onto the image.
import numpy as np
from PIL import Image
from PIL import ImageColor
from PIL import ImageDraw
from PIL import ImageFont

# For measuring the inference time.
import time

# Print Tensorflow version
print(tf.__version__)

In [None]:
# Functions for displaying images and drawing bounding boxes on them
def display_image(image):
  plt.figure(figsize=(20, 15))
  plt.grid(False)
  plt.imshow(image)


def load_img_from_path(path):
  img = tf.io.read_file(path)
  img = tf.image.decode_jpeg(img, channels=3)
  return img


def draw_box_on_image(image, ymin, xmin, ymax, xmax, color, font, thickness=4, display_str_list=()):
  draw = ImageDraw.Draw(image)
  im_width, im_height = image.size

  (left, right, top, bottom) = (xmin * im_width, xmax * im_width, ymin * im_height, ymax * im_height)

  draw.line([(left, top), (left, bottom), (right, bottom), (right, top), (left, top)], width=thickness, fill=color)
  # If the total height of the display strings added to the top of the bounding box exceeds the top of the image, stack the strings below the bounding box instead of above.
  display_str_heights = [font.getbbox(ds)[3] for ds in display_str_list]
  # Calculate the total height to display text
  total_display_str_height = (1 + 2 * 0.05) * sum(display_str_heights)
  if top > total_display_str_height:
    text_bottom = top
  else:
    text_bottom = top + total_display_str_height
  # Reverse list and print from bottom to top
  for display_str in display_str_list[::-1]:
    bbox = font.getbbox(display_str)
    text_width, text_height = bbox[2], bbox[3]
    margin = np.ceil(0.05 * text_height)
    draw.rectangle([(left, text_bottom - text_height - 2 * margin), (left + text_width, text_bottom)], fill=color)
    draw.text((left + margin, text_bottom - text_height - margin), display_str, fill="black", font=font)
    text_bottom -= text_height - 2 * margin
  

In [None]:
# Function for drawing a box around the object
def draw_box_on_object(image, boxes, class_names, scores, max_boxes=5, min_score=0.1):
    colors = list(ImageColor.colormap.values())
    font = ImageFont.load_default()

    for i in range(min(boxes.shape[0], max_boxes)):
        if scores[i].any() >= min_score:
            ymin, xmin, ymax, xmax = tuple(boxes[i])
            display_str = "{}: {}%".format(class_names[i].decode("ascii"), int(100 * scores[i]))
            color = colors[hash(class_names[i]) % len(colors)]
            image_pil = Image.fromarray(np.uint8(image)).convert("RGB")
            draw_box_on_image(image_pil, ymin, xmin, ymax, xmax, color, font, display_str_list=[display_str])
            np.copyto(image, np.array(image_pil))
    return image

In [None]:
sample_image_path = 'mushrooms_dataset/images/Agaricus_bisporus/614156.jpg'
image = load_img_from_path(sample_image_path)
display_image(image)
# image = Image.open(sample_image_path + '.jpg').resize((640,640))

In [None]:
# Now we load the TF-Hub module
# module_handle = "https://tfhub.dev/google/faster_rcnn/openimages_v4/inception_resnet_v2/1"
module_handle = "https://tfhub.dev/google/openimages_v4/ssd/mobilenet_v2/1"

#TODO use YOLO model
detector = hub.load(module_handle).signatures['default']

In [None]:
def run_detector(detector, image, display_img=True):
  img = load_img_from_path(image)

  converted_img  = tf.image.convert_image_dtype(img, tf.float32)[tf.newaxis, ...]
  start_time = time.time()
  result = detector(converted_img)
  end_time = time.time()


  result = {key:value.numpy() for key,value in result.items()}

  print("Found %d objects." % len(result["detection_scores"]))
  print("Inference time: ", end_time-start_time)
  
  image_with_boxes = draw_box_on_object(
      img.numpy(), result["detection_boxes"],
      result["detection_class_entities"], result["detection_scores"])
  
  if display_img:
    display_image(image_with_boxes)

  return result
  

run_detector(detector, sample_image_path, display_img=True)

In [None]:
run_detector(detector,'mushrooms_dataset/images/Agaricus_abruptibulbus/60416.jpg', display_img=True)

In [None]:
# Examlple of running detector without displaying the image
run_detector(detector,'mushrooms_dataset/images/Agaricus_bisporus/869.jpg',display_img=False)

Skrypt do wykrywania czy na zdjeciu jest grzyb i selekcji zdjecia do odpowiednich katalogów

In [None]:
import os
import shutil

input_dir = 'mushrooms_dataset/images'
subdirs = [os.path.join(input_dir, subdir).replace('\\', '/') for subdir in os.listdir(input_dir) if os.path.isdir(os.path.join(input_dir, subdir))]

# Setting up directories for correct and incorrect images
output_dir_correct = 'mushrooms_dataset/images_FasterRCNN/images_correct'
output_dir_incorrect = 'mushrooms_dataset/images_FasterRCNN/images_incorrect'

len(subdirs)


In [None]:
def process_images(detector, output_dir_correct, output_dir_incorrect):
    detected = False
    detected_threshold = 0.05
    for subdir in subdirs:
        for image_path in os.listdir(subdir):
            image_path = os.path.join(subdir, image_path)
            print(f'Processing image {image_path}')

            subdir_name = os.path.basename(subdir)
            image_name = os.path.basename(image_path)

            result = run_detector(detector, image_path, display_img=False)
            # We focus here on class entity Mushroom and if the score is above the threshold
            for score, label in zip(result['detection_scores'], result['detection_class_entities']):
                if label == b'Mushroom':
                    if score >= detected_threshold:
                        detected = True
                        break
                    else:
                        detected = False
                else:
                    detected = False

            if detected:
                correct_image_path =  os.path.join(output_dir_correct, subdir_name)
                os.makedirs(correct_image_path, exist_ok=True) # Create the directory if it does not exist
                output_dir = os.path.join(correct_image_path, image_name) # Create the full path to the image
            else:
                incorrect_image_path =  os.path.join(output_dir_incorrect, subdir_name)
                os.makedirs(incorrect_image_path, exist_ok=True)
                output_dir = os.path.join(incorrect_image_path, image_name)
            
            # Copy the image to the right directory based on the detection result
            shutil.copy(image_path, output_dir)
            print(f'Image {image_path} was copied to {output_dir}')

process_images(detector, output_dir_correct, output_dir_incorrect)

In [None]:
# Make a file with mushroom names
import os
def get_mushroom_names():
    mushroom_names = []
    for mushroom in os.listdir('mushrooms_dataset/imagesFasterRCNN/images_correct'):
        mushroom_names.append(mushroom)
    with open('mushroom_names.txt', 'w') as f:
        for item in mushroom_names:
            f.write("%s\n" % item)

# get_mushroom_names()

In [None]:
# lets check how many images we have in images_correct and images_incorrect
correct_images = 0
incorrect_images = 0

subdirs_correct = os.listdir(output_dir_correct)
subdirs_incorrect = os.listdir(output_dir_incorrect)

for subdir in subdirs_correct:
    correct_images += len(os.listdir(os.path.join(output_dir_correct, subdir)))

for subdir in subdirs_incorrect:
    incorrect_images += len(os.listdir(os.path.join(output_dir_incorrect, subdir)))

print(f'Correct images: {correct_images}') # 98639
print(f'Incorrect images: {incorrect_images}') # 77076
