In [1]:
# imports
import cv2
import time
import subprocess
import os
import mmcv
import numpy as np
import MOT_sort as sort

from mmdet.apis import init_detector, inference_detector


  from torch.distributed.optim import \


In [2]:
# functions
def detect_flashes(video_path, roi_x, roi_y, roi_width, roi_height, brightness_jump_threshold):
    """
    Uses pixel intensity thresholding to estimate what frame the dive starts on.
    Assumes the flash does not occur on the first frame of the video

    Args:
        video_path (str): Path to the video file.
        roi_x (int): X-coordinate of the top-left corner of the ROI.
        roi_y (int): Y-coordinate of the top-left corner of the ROI.
        roi_width (int): Width of the ROI.
        roi_height (int): Height of the ROI.
        brightness_jump_threshold (int): Minimum increase in average pixel intensity
            from the previous frame to trigger a flash detection.

    Returns:
        int: The frame number at which the first flash was detected, or -1 if no flash was found.
    """

    cap = cv2.VideoCapture(video_path)

    if not cap.isOpened():
        print(f"Error: Could not open video file: {video_path}")
        return

    frame_count = 0

    # extract first frame
    ret, frame = cap.read()

    # convert to grayscale
    gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)

    # define roi
    roi = gray_frame[roi_y : roi_y + roi_height, roi_x : roi_x + roi_width]

    # establish inital brightness thresholds so someone doesn't just walk in front of the camera and set it off Jon
    base_threshold = np.mean(roi)
    previous_brightness = np.mean(roi)
    frame_count +=1

    while True:
        ret, frame = cap.read()
        if not ret:
            # end of video 
            break  

        frame_count += 1

        # convert frame to grayscale for brightness calculation
        gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        # create roi frame
        roi = gray_frame[roi_y : roi_y + roi_height, roi_x : roi_x + roi_width]

        current_brightness = np.mean(roi)
            
        # calculate the brightness change from the previous frame
        brightness_change = current_brightness - previous_brightness

        # light flash detection logic
        if current_brightness > base_threshold and brightness_change > brightness_jump_threshold:
        
            cap.release()
            return frame_count-1

        # store current brightness for the next frame's comparison
        previous_brightness = current_brightness
    
    cap.release()
    print("No flash found")


def extract_frame(config, checkpoint, video_path, correction = 0, save_path=False, sort_max_age=15, sort_min_hits=2, sort_iou_thresh=0.2):
    """
    Finds the frame at which the x coordinate of the highest score bounding box passes 1024 pixels

    Args:
        config (str): Path to model config file. 
        checkpoint (str): Path to model checkpoint file.
        video_path (str): Path to the video file.
        correction (int): Number of frames skipped in video
        save_path (str): Path to folder to save frame, defaults to False.

    Returns:
        int: The frame number at which the x coordinate of the highest score bounding box passes 1024 pixels.
        if save_path is True: Jpeg file of the frame saved to outputs.
    
    """
    model = init_detector(config, checkpoint, device="cpu")
    cap = mmcv.VideoReader(video_path)
    sort_tracker = sort.Sort(sort_max_age, sort_min_hits, sort_iou_thresh)
    frame_count = 0
    detection_count = 0
    track_ids = set()
    max_tracks = 2
    pred_thresh = 0.95
    for frame in cap:
        frame_count += 1

        #run model inference on frame
        result = inference_detector(model, frame)

        if result.pred_instances.bboxes.numel() > 0:
            # Convert bboxes and scores to numpy arrays
            bboxes_np = result.pred_instances.bboxes.numpy()
            scores_np = result.pred_instances.scores.numpy().reshape(-1, 1) # Reshape scores to a column vector

            # Concatenate bboxes and scores to get the desired format
            detections = np.concatenate((bboxes_np, scores_np), axis=1)
        else:
            # Handle the case where there are no detections
            detections = np.empty((0, 5))
        track_bb_ids = sort_tracker.update(detections)

            
        if track_bb_ids.shape[0] > 0:  # Check if there are any tracks in the list
            sorted_indices = np.argsort(track_bb_ids[:, 5])[::-1]

            # Get the indices of the top 3 highest scoring tracks.
            # If fewer than 3 tracks exist, it will take all available tracks.
            top_indices = sorted_indices[:max_tracks]

            threshold_x = 1024

            # Iterate through the top 3 tracks to find if any cross the threshold
            for idx in top_indices:
                current_bbox = track_bb_ids[idx]
                id = int(current_bbox[4])

                if id not in track_ids:
                    # Extract coordinates from the current bounding box
                    x = current_bbox[0]
                    x2 = current_bbox[2]
                    y = current_bbox[1]
                    y2 = current_bbox[3]
                    score = current_bbox[5]
                    # Calculate the area of the bounding box. (y2-y) ensures positive area.
                    area = (x2 - x) * (y2 - y)

                    # Check if the bounding box crosses the threshold and meets the area criteria
                    if (x <= threshold_x) and (x2 > threshold_x) and (area > 1024):
                            track_ids.add(id)
                            # The filename includes the frame number.
                            frame_with_bbox = cv2.rectangle(frame.copy(), (int(x), int(y)), (int(x2), int(y2)), (255, 0, 0), 2)
                            cv2.imwrite(f"{save_path}/{correction + frame_count - 1}.jpg", frame_with_bbox)
                            detection_count +=1
                            print(f"Frame {correction + frame_count - 1}: Bounding box for track ID {int(current_bbox[4])} (score: {current_bbox[5]:.2f}) crossed threshold.")
                            if (len(track_ids) == max_tracks) or score >= pred_thresh:
                                return
                            # return correction + frame_count - 1
    print(f"Detection count: {detection_count}")

In [3]:
# params
input_folder = "test_data"
output_folder = "test_results_MOT"


# number of frames to cut from each video at 5, 10 and 15m, reduces processing time and risk of FP
correction_5 = 125
correction_10 = 350
correction_15 = 620

# model parameters for extraction function
checkpoint = r"work_dirs\faster-rcnn_r50-tnr-pre_fpn_1x_coco_full_custom_br_hue\best_coco_Swimmer_HBB_precision_epoch_10.pth"
config = r"custom_configs\faster_rcnn\faster-rcnn_r50-tnr-pre_fpn_1x_coco_full_custom_br_hue.py"

In [4]:
# create output folder

if not os.path.exists(output_folder):
    os.makedirs(output_folder)

# find all subdirectories in parent folder
folders = []
for root, dirs, files in os.walk(input_folder):
    for dir in dirs:
        folders.append(os.path.join(root, dir))

In [5]:
# iterate through the folders
folder_count = 1
for folder in folders:
    # right now ROI must be changed depening on the dive. 
    # This approach will only work with the light
    # kept in a similar spot from dive to dive.
    if folder_count == 1:
        roi_params = {
            "x": 1773,
            "y": 635,
            "width": 31,
            "height": 29,
            "brightness_jump_threshold": 30
        }
        # makes sure the correct param is set for correct video
        folder_count +=1
    else:
        # if the dives are dives number 2-6 uses these params.
        roi_params = {
            "x": 1732,
            "y": 643,
            "width": 31,
            "height": 29,
            "brightness_jump_threshold": 30
        }
        folder_count +=1

    # create temp folder with subfolders for each dive with the same name as the original subfolders
    temp_out = os.path.join(output_folder, os.path.basename(os.path.normpath(folder)))
    if not os.path.exists(temp_out):
        os.makedirs(temp_out)

    # iterate through the files
    for filename in os.listdir(folder):
        # find start of dive video (-3)
        if (filename.lower().endswith(('-3.avi')) or filename.lower().endswith(('_3.avi'))):
            # calls detect_flash function 
            start_frame = detect_flashes(
            os.path.join(folder,filename),
            roi_params["x"],
            roi_params["y"],
            roi_params["width"],
            roi_params["height"],
            roi_params["brightness_jump_threshold"]
        )
    # error handling if no flash is found
    if start_frame == None:
        print("No start frame detected")

    else:
        # 10m camera
        for filename in os.listdir(folder):        
            if (filename.lower().endswith(('-4.avi')) or filename.lower().endswith(('_4.avi'))):
                output_name = "10m.avi"
                
                # write out ffmpeg command line arguments as a list
                command = [
                "ffmpeg",
                # define input folder
                "-i", os.path.join(folder,filename),
                # filtergraph string, select= frames greater than or equal to start frame + correction, 
                # setpts= resets timestamps so first frame starts at 0
                "-vf", f"select=gte(n\,{start_frame + correction_10}),setpts=PTS-STARTPTS",
                # define video codec
                "-c:v", "libx264",
                #define output folder
                os.path.join(temp_out,output_name),
                ]

                # run ffmpeg command from command line, trim video of extraneous frames and save intermediate video to output folder
                subprocess.run(command, check=True, capture_output=True, text=True)

        for filename in os.listdir(folder):
            if (filename.lower().endswith(('-5.avi')) or filename.lower().endswith(('_5.avi'))):
                output_name = "5m.avi"
                command = [
                "ffmpeg",
                "-i", os.path.join(folder,filename),
                "-vf", f"select=gte(n\,{start_frame + correction_5}),setpts=PTS-STARTPTS",
                "-c:v", "libx264", os.path.join(temp_out,output_name),
                ]
                subprocess.run(command, check=True, capture_output=True, text=True)

        for filename in os.listdir(folder):
            if (filename.lower().endswith(('-7.avi')) or filename.lower().endswith(('_7.avi'))):
                output_name = "15m.avi"
                command = [
                "ffmpeg",
                "-i", os.path.join(folder,filename),
                "-vf", f"select=gte(n\,{start_frame + correction_15}),setpts=PTS-STARTPTS",
                "-c:v", "libx264", os.path.join(temp_out,output_name),
                ]
                subprocess.run(command, check=True, capture_output=True, text=True)
        
        # for each file, infer time at threshold, add missing frames and output frame with bounding box graphic
        frame_5 = extract_frame(config, checkpoint, os.path.join(temp_out, "5m.avi"), correction_5, save_path=temp_out)

        frame_10 = extract_frame(config, checkpoint, os.path.join(temp_out, "10m.avi"), correction_10, save_path=temp_out) 

        frame_15 = extract_frame(config, checkpoint, os.path.join(temp_out, "15m.avi"), correction_15, save_path=temp_out)
        
        # remove intermediate videos
        os.remove(os.path.join(temp_out, "5m.avi"))
        os.remove(os.path.join(temp_out, "10m.avi"))
        os.remove(os.path.join(temp_out, "15m.avi"))
        

Loads checkpoint by local backend from path: work_dirs\faster-rcnn_r50-tnr-pre_fpn_1x_coco_full_custom_br_hue\best_coco_Swimmer_HBB_precision_epoch_10.pth


  checkpoint = torch.load(filename, map_location=map_location)


Frame 153: Bounding box for track ID 0 (score: 0.99) crossed threshold.
Loads checkpoint by local backend from path: work_dirs\faster-rcnn_r50-tnr-pre_fpn_1x_coco_full_custom_br_hue\best_coco_Swimmer_HBB_precision_epoch_10.pth
Frame 394: Bounding box for track ID 18 (score: 1.00) crossed threshold.
Loads checkpoint by local backend from path: work_dirs\faster-rcnn_r50-tnr-pre_fpn_1x_coco_full_custom_br_hue\best_coco_Swimmer_HBB_precision_epoch_10.pth
Frame 682: Bounding box for track ID 51 (score: 0.77) crossed threshold.
Detection count: 1
Loads checkpoint by local backend from path: work_dirs\faster-rcnn_r50-tnr-pre_fpn_1x_coco_full_custom_br_hue\best_coco_Swimmer_HBB_precision_epoch_10.pth
Frame 151: Bounding box for track ID 290 (score: 0.98) crossed threshold.
Loads checkpoint by local backend from path: work_dirs\faster-rcnn_r50-tnr-pre_fpn_1x_coco_full_custom_br_hue\best_coco_Swimmer_HBB_precision_epoch_10.pth
Frame 387: Bounding box for track ID 325 (score: 0.07) crossed thresh

31m7s minutes
