In [211]:
import cv2
import os
import json
import csv
from ultralytics import YOLO
from math import floor
from functools import reduce
import time
import torch

In [212]:
class Config:
    """
    Конфигурация для работы с фреймами
        - CONF_THRESHOLD: Порог вероятности фиксации нарушения
        - MIN_FRAME_AGE: Количество прошедших кадров для обнаружения нарушения
        - FPS_SPLIT: Количество кадров на 1 секунду видео
        - MIN_INTERVAL_DISTANCE - Минимальное расстояние в секундах между 
        интервалами, чтобы считать их разными
        - LIFE_TIME: Допустимый интервал утечки ошибки из кадра
    """

    def __init__(self, conf_threshold=0.5, min_frame_age=3, fps_split=1.5, life_time=2):
        self.CONF_THRESHOLD = conf_threshold
        self.MIN_FRAME_AGE = min_frame_age
        self.FPS_SPLIT = fps_split
        self.MIN_INTERVAL_DISTANCE = 3
        self.LIFE_TIME = life_time


def glue_last_intervals(intervals: list):
    intervals[-2][1] = intervals[-1][0]
    intervals.remove(intervals[-1])


def write_previews(video_name, cap, fps, intervals):
    os.makedirs('previews', exist_ok=True)
    for i, interval in enumerate(intervals):
        cap.set(cv2.CAP_PROP_POS_FRAMES, interval[0] * fps)
        _, frame = cap.read()
        out_path = f"previews/{video_name}_{i}.png"
        cv2.imwrite(out_path, frame)


def write_json(intervals: list, video_name, path_source):
    violations = list()

    for i in range(len(intervals)):
        violation = dict()
        violation["preview"] = f"previews/{video_name}_{i}.png"
        violation["start"] = intervals[i][0]
        violation["end"] = intervals[i][1]
        violations.append(violation)

    json_data = {"name": video_name, "path_source": path_source, "violations": violations}

    with open('data.json', 'w') as f:
        json.dump(json_data, f)

    
def draw_results(frame, results, config, model):
    class_names = {0: 'near_wagon', 1: 'under_wagon', 2: 'bashmak'}
    for result in results:
        boxes = result.boxes.xyxy
        confs = result.boxes.conf
        class_ids = result.boxes.cls

        for box, conf, class_id in zip(boxes, confs, class_ids):
            if conf > config.CONF_THRESHOLD:
                x1, y1, x2, y2 = map(int, box)
                label = f"{class_names[int(class_id)]}: {conf:.2f}"
                cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
                cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)

    cv2.imshow('Result', frame)


def format_seconds(seconds):
    return f"{floor(seconds) // 60:02}:{floor(seconds) % 60:02}"


def create_and_write_submission(video_name, intervals):
    with open('submission.csv', 'a', newline='') as csvfile:
        writer = csv.writer(csvfile, delimiter=';')
        starts_str = ','.join([str(interval[0]) for interval in intervals])
        writer.writerow([video_name, f'[{starts_str}]'])

In [213]:
def process_file(video_name, path_source, model, config: Config):
    cap = cv2.VideoCapture(path_source)

    frame_age = 0
    frame_pos = 0
    prev_detected = 0

    frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT)
    fps = cap.get(cv2.CAP_PROP_FPS)
    frames_increment = fps / config.FPS_SPLIT
    print(f'{video_name} TOTAL frames: {frame_count}')

    intervals = []

    while frame_pos < floor(frame_count):
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_pos)
        _, frame = cap.read()
        results = model.predict(frame, verbose=False)

        confs = reduce(lambda prev, cur: prev + cur.boxes.conf.detach().cpu().tolist(), results, [])
        if len(list(filter(lambda conf: conf > config.CONF_THRESHOLD, confs))) >= 1:
            frame_age += 1
            prev_detected = frame_pos
        elif frame_age >= config.MIN_FRAME_AGE and (frame_pos - prev_detected) / fps > config.LIFE_TIME:
            interval_end = round((frame_pos - frames_increment) / fps, 2)
            intervals[-1].append(interval_end)
            prev_detected = False
            frame_age = 0

            if len(intervals) > 1 and intervals[-1][0] - intervals[-2][1] <= config.MIN_INTERVAL_DISTANCE:
                glue_last_intervals(intervals)

        if frame_age >= config.MIN_FRAME_AGE and (len(intervals) == 0 or len(intervals[-1]) != 1):
            interval_start = round((frame_pos - (frame_age - 1) * frames_increment) / fps, 2)
            intervals.append([interval_start])

        frame_pos += frames_increment
        # draw_results(frame, results, config, model)

        # key = cv2.waitKey(100)
        # if key == ord('q'):
        #     break
        # elif key == ord('p'):
        #     after_pause_key = cv2.waitKey()
        #     if after_pause_key == 'p':
        #         continue

    if len(intervals) > 1 and len(intervals[-1]) == 1:
        intervals[-1].append(frame_count / fps)

    for i in range(len(intervals)):
        intervals[i] = list(map(format_seconds, interval[i]))

    #write_previews(video_name, cap, fps, intervals)
    cap.release()
    cv2.destroyAllWindows()
    create_and_write_submission(video_name, intervals)

    return intervals

In [214]:
def main(video_name, path_source, model_path, config_args=None):
    if config_args is None:
        config = Config()
    else:
        for key in config_args.keys():
            config_args[key] = float(config_args[key])
        config = Config(*config)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print('Using device:', device)
    model = YOLO(model_path).to(device)
    start_time = time.time()
    intervals = process_file(video_name, path_source, model, config)
    write_json(intervals, video_name, path_source)
    print(f'Spend time: {time.time() - start_time}', end='\n\n')

In [216]:
with open('submission.csv', 'w', newline='') as csvfile:
        writer = csv.writer(csvfile, delimiter=';')
        writer.writerow(['video_name', 'violation_timestamp']) 

test_dir = './test_dir/'
file_names = os.listdir(test_dir)
print(file_names)

for file_name in file_names:
        main(file_name, test_dir + file_name, './best_yolo_n.pt')

['v_001.MP4', 'v_002.MP4', 'v_003.mp4', 'v_004.MP4', 'v_005.mp4', 'v_006.mp4', 'v_007.mp4', 'v_008.MP4', 'v_009.MP4', 'v_010.MP4', 'v_011.MP4', 'v_012.MP4', 'v_013.MP4', 'v_014.mp4']
Using device: cuda
v_001.MP4 TOTAL frames: 26976.0
Spend time: 176.85067653656006

Using device: cuda
v_002.MP4 TOTAL frames: 1847.0
Spend time: 12.450232982635498

Using device: cuda
v_003.mp4 TOTAL frames: 2841.0
Spend time: 9.36141586303711

Using device: cuda
v_004.MP4 TOTAL frames: 7456.0
Spend time: 24.17507290840149

Using device: cuda
v_005.mp4 TOTAL frames: 1818.0
Spend time: 5.894580364227295

Using device: cuda
v_006.mp4 TOTAL frames: 1802.0
Spend time: 6.74736762046814

Using device: cuda
v_007.mp4 TOTAL frames: 901.0
Spend time: 3.423980474472046

Using device: cuda
v_008.MP4 TOTAL frames: 26976.0


KeyboardInterrupt: 