In [251]:
import cv2 as cv
import os
import numpy as np
from matplotlib import pyplot as plt

In [252]:
# For test purposes
VIDEO_PATH = "sample_videos/cool_dog.mp4"

# Number of pixels cropped from each side
TOP_CROP = 100
BOTTOM_CROP = 10
LEFT_CROP = 0
RIGHT_CROP = 50

# Dimensions for resized frames
FRAME_WIDTH = 576
FRAME_HEIGHT = 324

# Threshold for pick_frames_sliding function
MSE_THRESHOLD = 20

# Interval for extract_frames function
INTERVAL = 300

DATASET_PATH = "dataset"
OUTPUT_PATH = 'training_dataset'

In [253]:
# Crops frame by n pixels in each direction
def crop_frame(frame, top_crop, bottom_crop, left_crop, right_crop):
    return frame[top_crop : frame.shape[0] - bottom_crop, 
                 left_crop : frame.shape[1] - right_crop]


# Extracts frames from given video and resizes them, taking a frame every n milliseconds
def extract_frames(video_path, interval, target_height, target_width):
    video = cv.VideoCapture(video_path)
    
    frames = []
    success, frame = video.read()
    total_frames = video.get(cv.CAP_PROP_FRAME_COUNT)
    i = 0
    while success and video.get(cv.CAP_PROP_POS_FRAMES) < total_frames:
        # Save current frame
        frames.append(cv.resize(frame, (target_width, target_height)))

        # Exit the loop if the video is unrealistically long
        assert i < 500, "Maximum number of frames exceeded"

        # Jump <interval> seconds forward
        video.set(cv.CAP_PROP_POS_MSEC, i * interval)
        success, frame = video.read()
        i += 1
    return frames

In [254]:
# Mean squared error between 2 frames
def mse(frame1, frame2):
    # Cast uint8 to int32 to avoid overflow
    frame1 = frame1.astype(np.int32)
    frame2 = frame2.astype(np.int32)
    
    assert frame1.shape == frame2.shape, "Shapes do not match"

    # Calculate MSE
    mse = np.mean((frame1 - frame2) ** 2)
    return mse

In [255]:
# Saves frame if its MSE with ref. frame exceeds threshold, then sets it as new ref. frame
# does not include first and last 3 frames
def pick_frames_sliding(frames, threshold):
    # Set first frame as reference frame
    reference_frame = frames[0]

    picked_frames = []
    for frame in frames:
        if mse(frame, reference_frame) > threshold:
            # Save the current frame and overwrite ref. frame
            reference_frame = frame
            picked_frames.append(frame)

    # Throw an error if no frames were picked
    assert len(picked_frames) > 0, "0 frames picked from the video"
    
    return picked_frames[3:-3]


In [256]:
# Picks frames from a video and saves them
def process_video(video_path, output_path):
    # Choose frames with moving object
    frames = extract_frames(video_path, INTERVAL, FRAME_HEIGHT, FRAME_WIDTH)
    for i in range(len(frames)):
        frames[i] = crop_frame(frames[i], TOP_CROP, BOTTOM_CROP, LEFT_CROP, RIGHT_CROP)
    picked_frames = pick_frames_sliding(frames, MSE_THRESHOLD)

    # Get video name from its path
    video_name = video_path.split('/')[-1]
    video_name = video_name[:-4] # Remove .mp4 from name

    # Save frames
    for i in range(len(picked_frames)):
        cv.imwrite(f'{output_path}/{video_name}_frame{i}.png', picked_frames[i])
    

In [257]:
# Picks frames from each video in a folder and saves them
def process_folder(folder_path, output_path):
    video_paths = [f"{folder_path}/{video}" for video in os.listdir(folder_path)]
    i = 0
    for path in video_paths:
        process_video(path, output_path)
        i += 1
        print(f"{i}/{len(video_paths)} ({i / len(video_paths) * 100}%)")

In [258]:
process_folder(DATASET_PATH, OUTPUT_PATH)

IndexError: list index out of range