In [1]:
import numpy as np
import pickle
import cv2
import os
import torch

In [2]:
annotation_file = open("./annotations_public.pkl", 'rb')
annotations = pickle.load(annotation_file)#Load annotations
annotation_file.close()

In [None]:
# MiDas model for depth calculation in frames
model_type = "DPT_Large"
midas = torch.hub.load("intel-isl/MiDaS", model_type)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
midas.to(device)
midas.eval()

midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")

if model_type == "DPT_Large" or model_type == "DPT_Hybrid":
    transform = midas_transforms.dpt_transform
else:
    transform = midas_transforms.small_transform

# params for corner detection 
feature_params = dict(maxCorners=100, 
                      qualityLevel=0.3, 
                      minDistance=7, 
                      blockSize=7) 

# Parameters for lucas kanade optical flow 
lk_params = dict(winSize=(15, 15), 
                 maxLevel=2, 
                 criteria=(cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 
                           10, 0.03)) 

hazard_summary = {}

# Create some random colors 
color = np.random.randint(0, 255, (100, 3)) 
video_root = './COOOL_Benchmark/processed_videos/'
output_dir = "./COOOL_Benchmark/processed_videos_midas/"
os.makedirs(output_dir, exist_ok=True)
video_num = 0

def is_object_far(x1, y1, x2, y2, frame_image):
    input_batch = transform(frame_image).to(device)

    with torch.no_grad():
          prediction = midas(input_batch)
    
          prediction = torch.nn.functional.interpolate(
              prediction.unsqueeze(1),
              size=frame_image.shape[:2],
              mode="bicubic",
              align_corners=False,
          ).squeeze()

    output = prediction.cpu().numpy()
    object_color = output[int(y1+(abs(y2-y1)/2)), int(x1+(abs(x2-x1)/2))]
    return object_color


def filter_objects_by_midas(det_far):    
    obj_far = []
    for track_id, color in det_far.items():
        obj_far.append((track_id, color))
    obj_far.sort(key=lambda x: x[-1], reverse=True) #in descending order
    filtered_dist_color = det_far.copy()
    
    if len(obj_far) == 2: # if there are two objects, exclude with caution
        _, cent_0 = obj_far[0]
        _, cent_1 = obj_far[1]
        dist_color = abs(cent_0 - cent_1)
        if dist_color > 6:
            filtered_dist_color = {}
            track, color = obj_far[0]
            filtered_dist_color[track] = color
            
    elif len(obj_far) > 2:# if there are more than two objects, exclude objects more freely :)
        filter_average = {}
        color_list = [color_x for _, color_x, in obj_far]
        if color_list:
            avg_color_list = sum(color_list) / len(color_list)
        else:
            avg_color_list = 0  # No objects to compare
        for track_id_obj, color in obj_far:
            if color > avg_color_list:
                filter_average[track_id_obj] = color
        if len(filter_average) == 1 and len(obj_far)>1 and abs(obj_far[0][1] - obj_far[1][1]) < 6: #to address horse video
            filter_average[obj_far[1][0]] = obj_far[1][1] # to avoid filtering out all objects
                    
        filtered_dist_color = filter_average
    return filtered_dist_color 


def retain_first_and_get_unique_ids(def_far_all, num_id):
    # first_id_per_frame = {}
    unique_ids = set()
    unique_hazards = set()
    object_values = {}  # To track values of objects across all frames
    brightest_frame_per_object = {}  # Track the brightest frame for each object

    for frame, objects in def_far_all.items():
        objects = dict(sorted(objects.items(), key=lambda item: item[1], reverse=True))
        bright_objects = []
        # Check for objects with brightness value > 15
        if num_id <= 4:
            bright_objects.extend([obj_id for obj_id, value in objects.items() if value > 15])
            unique_hazards.update(bright_objects)

            # Update the brightest frame for bright objects
            for obj_id in bright_objects:
                if obj_id not in brightest_frame_per_object or objects[obj_id] > object_values.get(obj_id, 0):
                    brightest_frame_per_object[obj_id] = (frame, objects[obj_id])
                    object_values[obj_id] = objects[obj_id]


        if objects and len(unique_hazards) == 0:  #4
            first_object_id = list(objects.keys())[0]
            unique_ids.add(first_object_id)

           # Track the brightness value and brightest frame for fallback objects
            if first_object_id not in object_values or objects[first_object_id] > object_values.get(first_object_id, 0):
                brightest_frame_per_object[first_object_id] = (frame, objects[first_object_id])
                object_values[first_object_id] = objects[first_object_id]

        
    # If num_id > 3, remove the object with the least value in `unique_ids`
    if num_id>4 and len(unique_ids)>3:
        sorted_objects = sorted(unique_ids, key=lambda obj_id: object_values.get(obj_id, float('inf')))
    
        if len(sorted_objects) > 1:  # Ensure there are at least two objects to compare
            min_value_object = sorted_objects[0]
            second_min_value = object_values.get(sorted_objects[1], float('inf'))
            
            # Check the difference between the smallest and the second smallest value
            if second_min_value - object_values[min_value_object] > 1:
                unique_ids.remove(min_value_object)

    return unique_ids, unique_hazards, brightest_frame_per_object


with open("results_md_blip_final.csv", 'w') as results_file:
    results_file.write("ID,Driver_State_Changed")
    for i in range(23):
        results_file.write(f",Hazard_Track_{i},Hazard_Name_{i}")
    results_file.write("\n")
        
    for video in sorted(list(annotations.keys())):
        video_num += 1 
        print("video_num:", video_num)
        if video_num > 0:
            video_stream = cv2.VideoCapture(os.path.join(video_root, video+'.mp4'))
            assert video_stream.isOpened()
            
            while video_stream.isOpened():
                ret, old_frame = video_stream.read() 
                old_gray = cv2.cvtColor(old_frame, cv2.COLOR_BGR2GRAY)
                if np.all(old_gray == 0):
                    continue
                else:
                    break
                
            p0 = cv2.goodFeaturesToTrack(old_gray, mask=None, **feature_params) 
            
            # Create a mask image for drawing purposes 
            mask = np.zeros_like(old_frame) 
            
            frame = 0
            slope_history = []  # Store slopes for every 5 frames
            track_id_lifecycle = {} 
            threshold = 4  # Define a threshold for fast slope change
            driver_state_flag = False
            
            video_stream = cv2.VideoCapture(os.path.join(video_root, video+'.mp4'))
            # Get video properties
            fps = int(video_stream.get(cv2.CAP_PROP_FPS))
            width = int(video_stream.get(cv2.CAP_PROP_FRAME_WIDTH))
            height = int(video_stream.get(cv2.CAP_PROP_FRAME_HEIGHT))
            fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # Codec for output video
            output_video_path = os.path.join(output_dir, f"{video}_midas.mp4")
            out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
            def_far_all = {}
            
            while video_stream.isOpened():
                video_frame = f'{video}_{frame}'
                ret, frame_image = video_stream.read()
                if ret == False: #False means end of video or error
                    assert frame == len(annotations[video].keys()) #End of the video must be final frame
                    break
        #########################################################################################             
                #Gather BBoxes from annotations
                ###########
                bboxes = {}
                centroids = []
                chips = {}
                track_ids = []
                det_far = {}
                for ann_type in ['challenge_object']:
                    for i in range(len(annotations[video][frame][ann_type])):
                        x1, y1, x2, y2 = annotations[video][frame][ann_type][i]['bbox']
                        track_id = annotations[video][frame][ann_type][i]['track_id']
        
                        # Update lifecycle of the track_id
                        if track_id not in track_id_lifecycle:
                            track_id_lifecycle[track_id] = {'first_frame': frame, 'last_frame': frame}
                        else:
                            track_id_lifecycle[track_id]['last_frame'] = frame
        
                        # if track_id not in bboxes:
                        bboxes[track_id] = {'frame': frame, 'bboxes': [x1, y1, x2, y2]}
                        chips[track_id]= {'frame': frame, 'chip': frame_image[int(y1):int(y2), int(x1):int(x2)]}
                        color = is_object_far(x1, y1, x2, y2, frame_image)
                        det_far[track_id] = color.item()
                ##################
            ##################    
                if frame not in def_far_all:
                    def_far_all[frame] = filter_objects_by_midas(det_far)
                else:
                    def_far_all[frame].update(filter_objects_by_midas(det_far))
             #######################  
                if frame==0:
                   frame +=1
                   continue #We can't make a prediction of state change w/o knowing the previous state
                    
          #########################################################################################                
                ###Driver state change detection
                frame_gray = cv2.cvtColor(frame_image, cv2.COLOR_BGR2GRAY) 
        
                # calculate optical flow 
                p1, st, err = cv2.calcOpticalFlowPyrLK(old_gray, frame_gray, p0, None, **lk_params) 
            
                # Select good points 
                good_new = p1[st == 1] 
                good_old = p0[st == 1] 
                
                # Calculate motion vectors and slopes
                motion_vectors = []
                # Calculate slopes for the current frame
                slopes = []
                for i, (new, old) in enumerate(zip(good_new, good_old)): 
                    a, b = new.ravel() 
                    c, d = old.ravel() 
                    
                    slope = (b - d) / ((a - c) + 1e-100)
                    slopes.append(slope)
            
                avg_slope_change = 0
                stop = ""

                # Update slope history every 5 frames                
                if frame % 5 == 0:
                    if slope_history:
                        previous_slopes = slope_history[-1]
                        slope_changes = [abs(s - ps) for s, ps in zip(slopes, previous_slopes)]
                        avg_slope_change = np.mean(slope_changes) 
                        if avg_slope_change < threshold and driver_state_flag != True:
                            driver_state_flag = True
                            print(f"Frame {frame} labeled as True")           
                    slope_history.append(slopes)  # Update the history
    
                cv2.putText(frame_image, str(driver_state_flag)+stop, (10, 30),
                            cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
                row_data = {"ID": video_frame, "Driver_State_Changed": driver_state_flag}
                for i in range(23):
                    row_data[f"Hazard_Track_{i}"] = ""
                    row_data[f"Hazard_Name_{i}"] = ""
                        
                out.write(frame_image)
                results_file.write(
                    f"{row_data['ID']},{row_data['Driver_State_Changed']}" +
                    "".join([f",{row_data[f'Hazard_Track_{i}']},{row_data[f'Hazard_Name_{i}']}" for i in range(23)]) +
                    "\n"
                )
                                           
                frame +=1
           
            video_stream.release()
            out.release()
    ###################################
        ##MiDas: to exclude very far objects
            num_id = len(track_id_lifecycle)
            # print("num_id:", num_id)
            frame_midas = 0        
            unique_ids, unique_hazards, brightest_frame_per_object = retain_first_and_get_unique_ids(def_far_all, num_id)
            # print("unique_ids:", unique_ids)
            # print("unique_hazards:", unique_hazards)
            if len(unique_hazards) > 0:
                unique_ids = unique_hazards
            # print("unique_ids:", unique_ids)
            # video_stream.set(cv2.CAP_PROP_POS_FRAMES, 0)  # Reset to the first frame
            video_stream_midas = cv2.VideoCapture(os.path.join(output_dir, f"{video}_midas.mp4"))
            out_midas = cv2.VideoWriter(
                os.path.join(output_dir, f"{video}_midas_hazard_v1.mp4"),
                cv2.VideoWriter_fourcc(*'mp4v'),
                int(video_stream_midas.get(cv2.CAP_PROP_FPS)),
                (int(video_stream_midas.get(cv2.CAP_PROP_FRAME_WIDTH)), int(video_stream_midas.get(cv2.CAP_PROP_FRAME_HEIGHT)))
            )
            hazard_midas = {}
            while video_stream_midas.isOpened():
                ret, frame_image_midas = video_stream_midas.read()
                if ret == False: #False means end of video or error
                    assert frame_midas == len(annotations[video].keys())-1 #End of the video must be final frame
                    break
                    
                for ann_type in ['challenge_object']:
                    for i in range(len(annotations[video][frame_midas][ann_type])):
                        x1, y1, x2, y2 = annotations[video][frame_midas][ann_type][i]['bbox']
                        track_id = annotations[video][frame_midas][ann_type][i]['track_id']
                        # print("track_id:", track_id)
                        if str(track_id) in unique_ids:  # Only process unique IDs
                            # if frame_midas not in hazard_midas:
                            #     hazard_midas[frame_midas] = [] 
                            # hazard_midas[frame_midas].append(track_id)
                            # print("hazard_midas:", hazard_midas)
                            cv2.rectangle(frame_image_midas, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
                            cv2.putText(frame_image_midas, f"ID: {track_id}", (int(x1), int(y1) - 10),
                                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
                            
                out_midas.write(frame_image_midas)
                frame_midas += 1
                
            unique_ids_with_brightest_frame = {obj_id: brightest_frame_per_object[obj_id][0] for obj_id in unique_ids}
            # Load the hazard_midas dictionary
            with open(f"./unique_ids/unique_ids_{video}.pkl", "wb") as f:
                pickle.dump(unique_ids_with_brightest_frame, f)
    
            video_stream_midas.release()
            out_midas.release()
        ############################
            
results_file.close()