In [None]:
#----------------------------------------------------#
#   Integrate single image prediction, camera detection, and FPS testing
#   into a single py file, with mode specification for functionality selection.
#----------------------------------------------------#
import time
import os
import cv2
import numpy as np
from PIL import Image

from unet import Unet

if __name__ == "__main__":
    #-------------------------------------------------------------------------#
    #   To modify colors for corresponding classes, modify self.colors in the __init__ function
    #-------------------------------------------------------------------------#
    unet = Unet()
    #----------------------------------------------------------------------------------------------------------#
    #   mode specifies the testing mode:
    #   'predict'           Single image prediction; for modifications to the prediction process (e.g., saving images, cropping objects), see detailed comments below
    #   'video'             Video detection; camera or video file detection available, see comments below for details
    #----------------------------------------------------------------------------------------------------------#
    mode = "predict"   ### For image detection
    # mode = "video"     ## For video detection
    # mode = "video"
    #-------------------------------------------------------------------------#
    #   count               Specifies whether to perform pixel counting (i.e., area) and ratio calculation for targets
    #   name_classes        Class categories, same as in json_to_dataset, used for printing class names and counts
    #
    #   count and name_classes are only effective when mode='predict'
    #-------------------------------------------------------------------------#
    count           = False
    # name_classes    = ["background","aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]
    name_classes = ["_background_", "cuttlefish"]
    # name_classes    = ["background","cat","dog"]
    #----------------------------------------------------------------------------------------------------------#
    #   video_path          Specifies the video path; video_path=0 indicates camera detection
    #                       To detect a video, set video_path = "xxx.mp4" to read xxx.mp4 from the root directory
    #   video_save_path     Specifies the video save path; video_save_path="" indicates no saving
    #                       To save a video, set video_save_path = "yyy.mp4" to save as yyy.mp4 in the root directory
    #   video_fps           FPS for the saved video
    #
    #   video_path, video_save_path, and video_fps are only effective when mode='video'
    #   To save video completely, press ctrl+c to exit or run until the last frame
    #----------------------------------------------------------------------------------------------------------#
    video_path      = 0
    video_save_path = "yyy.mp4"
    video_fps       = 30.00

    if mode == "predict":
        '''
        Key points for predict.py:
        1. This code does not support batch prediction directly. For batch prediction, use os.listdir() to traverse folders and Image.open to open image files for prediction.
        Refer to get_miou_prediction.py for the specific process, which implements traversal.
        2. To save results, use r_image.save("img.jpg").
        3. To avoid blending original and segmented images, set the blend parameter to False.
        4. To extract regions based on mask, refer to the plotting section in detect_image function, determine the class of each pixel, then extract corresponding regions.
        seg_img = np.zeros((np.shape(pr)[0],np.shape(pr)[1],3))
        for c in range(self.num_classes):
            seg_img[:, :, 0] += ((pr == c)*( self.colors[c][0] )).astype('uint8')
            seg_img[:, :, 1] += ((pr == c)*( self.colors[c][1] )).astype('uint8')
            seg_img[:, :, 2] += ((pr == c)*( self.colors[c][2] )).astype('uint8')
        '''
        while True:
            path_list = os.listdir("C:/Users/12152/Desktop/data/sepia4/vision_white_after/origin_test")
            path_list.sort(key=lambda x: int(x[:-4]))
            my_counter = 0
            for img in path_list:
                image = Image.open("C:/Users/12152/Desktop/data/sepia4/vision_white_after/origin_test/" + img)
                r_image = unet.detect_image(image, count=count, name_classes=name_classes)
                r_image.save("C:/Users/12152/Desktop/data/sepia4/vision_white_after/mask_test/" + str(my_counter) + ".jpg")
                my_counter += 1


    ## Video detection from local disk files
    elif mode == "video":

        while True:
            video = input("Input video filename:")
            try:
                capture = cv2.VideoCapture(video)
            except:
                print("Open Error! Try again!")
                continue
            else:
        # capture = cv2.VideoCapture(video_path)
                if video_save_path != "":
                    fourcc = cv2.VideoWriter_fourcc(*'XVID')
                    size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
                    out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size)

                ref, frame = capture.read()
                if not ref:
                    raise ValueError("Failed to read camera (video) correctly. Please verify camera installation (or video path).")

                fps = 0.0
                while (True):
                    t1 = time.time()
                    # Read a frame
                    ref, frame = capture.read()
                    if not ref:
                        break
                    # Format conversion, BGR to RGB
                    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                    # Convert to Image
                    frame = Image.fromarray(np.uint8(frame))
                    # Perform detection
                    frame = np.array(unet.detect_image(frame))
                    # RGB to BGR for opencv display format
                    frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)

                    fps = (fps + (1. / (time.time() - t1))) / 2
                    print("fps= %.2f" % (fps))
                    frame = cv2.putText(frame, "fps= %.2f" % (fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)

                    #cv2.imshow("video", frame)
                    c = cv2.waitKey(1) & 0xff
                    if video_save_path != "":
                        out.write(frame)

                    if c == 27:
                        capture.release()
                        break
                print("Video Detection Done!")
                capture.release()
                if video_save_path != "":
                    print("Save processed video to the path :" + video_save_path)
                    out.release()
                cv2.destroyAllWindows()

    else:
        raise AssertionError("Please specify the correct mode: 'predict', 'video', 'fps' or 'dir_predict'.")