In [1]:
# imports
import cv2
import time
import subprocess
import os
import mmcv
import numpy as np

from mmdet.apis import init_detector, inference_detector


  from torch.distributed.optim import \


In [2]:
# functions
def detect_flashes(video_path, roi_x, roi_y, roi_width, roi_height, brightness_jump_threshold):
    """
    Uses pixel intensity thresholding to estimate what frame the dive starts on.
    Assumes the flash does not occur on the first frame of the video

    Args:
        video_path (str): Path to the video file.
        roi_x (int): X-coordinate of the top-left corner of the ROI.
        roi_y (int): Y-coordinate of the top-left corner of the ROI.
        roi_width (int): Width of the ROI.
        roi_height (int): Height of the ROI.
        brightness_jump_threshold (int): Minimum increase in average pixel intensity
            from the previous frame to trigger a flash detection.

    Returns:
        int: The frame number at which the first flash was detected, or -1 if no flash was found.
    """

    cap = cv2.VideoCapture(video_path)

    if not cap.isOpened():
        print(f"Error: Could not open video file: {video_path}")
        return

    frame_count = 0

    # extract first frame
    ret, frame = cap.read()

    # convert to grayscale
    gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)

    # define roi
    roi = gray_frame[roi_y : roi_y + roi_height, roi_x : roi_x + roi_width]

    # establish inital brightness thresholds so someone doesn't just walk infront of the camera and set it off Jon
    base_threshold = np.mean(roi)
    previous_brightness = np.mean(roi)
    frame_count +=1

    while True:
        ret, frame = cap.read()
        if not ret:
            # end of video 
            break  

        frame_count += 1

        # convert frame to grayscale for brightness calculation
        gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        # create roi frame
        roi = gray_frame[roi_y : roi_y + roi_height, roi_x : roi_x + roi_width]

        current_brightness = np.mean(roi)
            
        # calculate the brightness change from the previous frame
        brightness_change = current_brightness - previous_brightness

        # light flash detection logic
        if current_brightness > base_threshold and brightness_change > brightness_jump_threshold:
        
            cap.release()
            return frame_count-1

        # store current brightness for the next frame's comparison
        previous_brightness = current_brightness
    
    cap.release()
    print("No flash found")


def extract_frame(config, checkpoint, video_path, correction = 0, save_path=False):
    """
    Finds the frame at which the x coordinate of the highest score bounding box passes 1024 pixels

    Args:
        config (str): Path to model config file. 
        checkpoint (str): Path to model checkpoint file.
        video_path (str): Path to the video file.
        save_path (str): Path to folder to save frame, defaults to False.

    Returns:
        int: The frame number at which the x coordinate of the highest score bounding box passes 1024 pixels.
        if save_path: Jpeg file of the frame saved to outputs.
    
    """
    model = init_detector(config, checkpoint, device="cpu")
    cap = mmcv.VideoReader(video_path)
    frame_count = 0
    for frame in cap:
        frame_count += 1

        #run model inference on frame
        result = inference_detector(model, frame)

        # if the results are not empty and the confidence score is above 0.8 extract the scores
        if result.pred_instances and result.pred_instances.scores is not None and len(result.pred_instances.scores) > 0 and max(result.pred_instances.scores) > 0.8:
            # index of the highest score
            max_score_idx = np.argmax(result.pred_instances.scores)

            # bounding box corresponding to the highest score
            highest_score_bbox = result.pred_instances.bboxes[max_score_idx]

            # extract leftmost x coordinate
            x = highest_score_bbox[0]
            # extract rightmost x coordinate
            x2 = highest_score_bbox[2]

            y = highest_score_bbox[1]
            y2 = highest_score_bbox[3]

            threshold_x = 1024

            # reduced the detection window to reduce the number of false positives
            #if (threshold_x - 50) > x <= threshold_x:
            if (x <= threshold_x) and x2 > threshold_x:
                if save_path:
                    frame_bbox = cv2.rectangle(frame, (int(x), int(y)), (int(x2), int(y2)), (255, 0, 0), 2)
                    cv2.imwrite(f"{save_path}/{correction + frame_count-1}.jpg", frame_bbox)
                   
                return correction + frame_count - 1

In [3]:
"""
input_folder is a directory of subfolders, each containg a dive.

output_folder is a directory name for output screen captures of the swimmer
    at 5, 10 and 15m

"""
input_folder = "test_data"
output_folder = "test_results"
correction_5 = 100
correction_10 = 300
correction_15 = 600

In [4]:
# create output folder

if not os.path.exists(output_folder):
    os.makedirs(output_folder)

# find all subdirectories in parent folder
folders = []
for root, dirs, files in os.walk(input_folder):
    for dir in dirs:
        folders.append(os.path.join(root, dir))

In [5]:
# iterate through the folders
folder_count = 1
for folder in folders:
    # right now, ROI must be changed depening on the dive. 
    # This approach will only work with the light
    # kept in a similar spot from dive to dive.
    if folder_count == 1:
        roi_params = {
            "x": 1773,
            "y": 635,
            "width": 31,
            "height": 29,
            "brightness_jump_threshold": 30
        }
        folder_count +=1
    else: 
        roi_params = {
            "x": 1732,
            "y": 643,
            "width": 31,
            "height": 29,
            "brightness_jump_threshold": 30
        }
        folder_count +=1
    
    temp_out = os.path.join(output_folder, os.path.basename(os.path.normpath(folder)))
    if not os.path.exists(temp_out):
        os.makedirs(temp_out)
    # iterate through the files
    for filename in os.listdir(folder):
        # find start of dive video
        if (filename.lower().endswith(('-3.avi')) or filename.lower().endswith(('_3.avi'))):
            
            start_frame = detect_flashes(
            os.path.join(folder,filename),
            roi_params["x"],
            roi_params["y"],
            roi_params["width"],
            roi_params["height"],
            roi_params["brightness_jump_threshold"]
        )
    if start_frame == None:
        print("No start frame detected")
    else:
        for filename in os.listdir(folder):        
            if (filename.lower().endswith(('-4.avi')) or filename.lower().endswith(('_4.avi'))):
                output_name = "10m.avi"
                command = [
                "ffmpeg",
                "-i", os.path.join(folder,filename),
                "-vf", f"select=gte(n\,{start_frame + correction_10}),setpts=PTS-STARTPTS",
                "-c:v", "libx264", "-c:a",
                "aac", os.path.join(temp_out,output_name),
                ]
                subprocess.run(command, check=True, capture_output=True, text=True)

        for filename in os.listdir(folder):
            if (filename.lower().endswith(('-5.avi')) or filename.lower().endswith(('_5.avi'))):
                output_name = "5m.avi"
                command = [
                "ffmpeg",
                "-i", os.path.join(folder,filename),
                "-vf", f"select=gte(n\,{start_frame + correction_5}),setpts=PTS-STARTPTS",
                "-c:v", "libx264", "-c:a",
                "aac", os.path.join(temp_out,output_name),
                ]
                subprocess.run(command, check=True, capture_output=True, text=True)

        for filename in os.listdir(folder):
            if (filename.lower().endswith(('-7.avi')) or filename.lower().endswith(('_7.avi'))):
                output_name = "15m.avi"
                command = [
                "ffmpeg",
                "-i", os.path.join(folder,filename),
                "-vf", f"select=gte(n\,{start_frame + correction_15}),setpts=PTS-STARTPTS",
                "-c:v", "libx264", "-c:a",
                "aac", os.path.join(temp_out,output_name),
                ]
                subprocess.run(command, check=True, capture_output=True, text=True)

        # model parameters for extraction function
        checkpoint = r"work_dirs\faster-rcnn_r50-tnr-pre_fpn_1x_coco_full_custom_br_hue\best_coco_Swimmer_HBB_precision_epoch_10.pth"
        config = r"custom_configs\faster_rcnn\faster-rcnn_r50-tnr-pre_fpn_1x_coco_full_custom_br_hue.py"
        # for each file, infer time at threshold and add missing frames

        frame_5 = extract_frame(config, checkpoint, os.path.join(temp_out, "5m.avi"), correction_5, save_path=temp_out)

        frame_10 = extract_frame(config, checkpoint, os.path.join(temp_out, "10m.avi"), correction_10, save_path=temp_out) 

        frame_15 = extract_frame(config, checkpoint, os.path.join(temp_out, "15m.avi"), correction_15, save_path=temp_out)

        print(f"5m: {frame_5/100}s, 10m: {(frame_10)/100}s, 15m: {(frame_15)/100}s")
        os.remove(os.path.join(temp_out, "5m.avi"))
        os.remove(os.path.join(temp_out, "10m.avi"))
        os.remove(os.path.join(temp_out, "15m.avi"))
        

Loads checkpoint by local backend from path: work_dirs\faster-rcnn_r50-tnr-pre_fpn_1x_coco_full_custom_br_hue\best_coco_Swimmer_HBB_precision_epoch_10.pth


  checkpoint = torch.load(filename, map_location=map_location)


Loads checkpoint by local backend from path: work_dirs\faster-rcnn_r50-tnr-pre_fpn_1x_coco_full_custom_br_hue\best_coco_Swimmer_HBB_precision_epoch_10.pth
Loads checkpoint by local backend from path: work_dirs\faster-rcnn_r50-tnr-pre_fpn_1x_coco_full_custom_br_hue\best_coco_Swimmer_HBB_precision_epoch_10.pth
5m: 1.55s, 10m: 3.94s, 15m: 6.82s
Loads checkpoint by local backend from path: work_dirs\faster-rcnn_r50-tnr-pre_fpn_1x_coco_full_custom_br_hue\best_coco_Swimmer_HBB_precision_epoch_10.pth
Loads checkpoint by local backend from path: work_dirs\faster-rcnn_r50-tnr-pre_fpn_1x_coco_full_custom_br_hue\best_coco_Swimmer_HBB_precision_epoch_10.pth
Loads checkpoint by local backend from path: work_dirs\faster-rcnn_r50-tnr-pre_fpn_1x_coco_full_custom_br_hue\best_coco_Swimmer_HBB_precision_epoch_10.pth
5m: 1.51s, 10m: 3.87s, 15m: 6.83s
Loads checkpoint by local backend from path: work_dirs\faster-rcnn_r50-tnr-pre_fpn_1x_coco_full_custom_br_hue\best_coco_Swimmer_HBB_precision_epoch_10.pth
Lo

In [6]:
# # iterate through the files
# for filename in os.listdir(folder):
#     # find start of dive video
#     if (filename.lower().endswith(('-3.avi')) or filename.lower().endswith(('_3.avi'))):
#         # roi params #1
#         roi_params = {
#             "x": 1773,
#             "y": 635,
#             "width": 65,
#             "height": 49,
#             "brightness_jump_threshold": 60
#         }
#         # roi params #2 and #3
#         roi_params = {
#             "x": 1732,
#             "y": 643,
#             "width": 31,
#             "height": 29,
#             "brightness_jump_threshold": 60
#         }
        
#         start_frame = detect_flashes(
#         os.path.join(folder,filename),
#         roi_params["x"],
#         roi_params["y"],
#         roi_params["width"],
#         roi_params["height"],
#         roi_params["brightness_jump_threshold"]
#     )
# if start_frame == None:
#     print("No start frame detected")
# else:
#     for filename in os.listdir(folder):        
#         if (filename.lower().endswith(('-4.avi')) or filename.lower().endswith(('_4.avi'))):
#             output_name = "10m.avi"
#             command = [
#             "ffmpeg",
#             "-i", os.path.join(folder,filename),
#             "-vf", f"select=gte(n\,{start_frame + correction_10}),setpts=PTS-STARTPTS",
#             "-c:v", "libx264", "-c:a",
#             "aac", os.path.join(temp_out,output_name),
#             ]
#             subprocess.run(command, check=True, capture_output=True, text=True)

#     for filename in os.listdir(folder):
#         if (filename.lower().endswith(('-5.avi')) or filename.lower().endswith(('_5.avi'))):
#             output_name = "5m.avi"
#             command = [
#             "ffmpeg",
#             "-i", os.path.join(folder,filename),
#             "-vf", f"select=gte(n\,{start_frame}),setpts=PTS-STARTPTS",
#             "-c:v", "libx264", "-c:a",
#             "aac", os.path.join(temp_out,output_name),
#             ]
#             subprocess.run(command, check=True, capture_output=True, text=True)

#     for filename in os.listdir(folder):
#         if (filename.lower().endswith(('-7.avi')) or filename.lower().endswith(('_7.avi'))):
#             output_name = "15m.avi"
#             command = [
#             "ffmpeg",
#             "-i", os.path.join(folder,filename),
#             "-vf", f"select=gte(n\,{start_frame + correction_15}),setpts=PTS-STARTPTS",
#             "-c:v", "libx264", "-c:a",
#             "aac", os.path.join(temp_out,output_name),
#             ]
#             subprocess.run(command, check=True, capture_output=True, text=True)

#     # model parameters for extraction function
#     checkpoint = r"work_dirs\faster-rcnn_r50-tnr-pre_fpn_1x_coco_full_custom_br_hue\best_coco_Swimmer_HBB_precision_epoch_10.pth"
#     config = r"custom_configs\faster_rcnn\faster-rcnn_r50-tnr-pre_fpn_1x_coco_full_custom_br_hue.py"
#     # for each file, infer time at threshold and add missing frames

#     frame_5 = extract_frame(config, checkpoint, os.path.join(temp_out, "5m.avi"), save_path=temp_out)

#     frame_10 = extract_frame(config, checkpoint, os.path.join(temp_out, "10m.avi"), correction_10, save_path=temp_out) 

#     frame_15 = extract_frame(config, checkpoint, os.path.join(temp_out, "15m.avi"), correction_15, save_path=temp_out)

#     print(f"5m: {frame_5/100}s, 10m: {(frame_10)/100}s, 15m: {(frame_15)/100}s")