In [None]:
import cv2
import csv
import numpy as np
import os
import pickle
from ultralytics import YOLO  
from sort import *   

In [None]:
# ---------- Helper Functions ----------

def load_histogram(file_path):
    """Load a histogram from a pickle file."""
    with open(file_path, 'rb') as file:
        histogram = pickle.load(file)
    return histogram
## Below Function is for Simple Histogram
# def calculate_histogram(image, bbox):
#     """Calculate and normalize the histogram for a region in the image."""
#     height, width, _ = image.shape
#     x_center, y_center, box_width, box_height = bbox

#     x_min = int((x_center - box_width / 2) * width)
#     x_max = int((x_center + box_width / 2) * width)
#     y_min = int((y_center - box_height / 2) * height)
#     y_max = int((y_center + box_height / 2) * height)

#     roi = image[y_min:y_max, x_min:x_max]
#     histogram = cv2.calcHist([roi], [0, 1, 2], None, [8, 8, 8], [0, 256, 0, 256, 0, 256])
#     histogram = cv2.normalize(histogram, histogram).flatten()
#     return histogram
## Below Function is for Gray Scale Histogram
# def calculate_histogram(image, bbox):
#     """Calculate and normalize the grayscale histogram for a region in the image."""
#     height, width, _ = image.shape
#     x_center, y_center, box_width, box_height = bbox

#     # Convert normalized coordinates to pixel coordinates
#     x_min = int((x_center - box_width / 2) * width)
#     x_max = int((x_center + box_width / 2) * width)
#     y_min = int((y_center - box_height / 2) * height)
#     y_max = int((y_center + box_height / 2) * height)

#     # Crop the region of interest
#     roi = image[y_min:y_max, x_min:x_max]

#     # Convert to grayscale
#     gray_roi = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY)

#     # Compute histogram
#     histogram = cv2.calcHist([gray_roi], [0], None, [256], [0, 256])
#     histogram = cv2.normalize(histogram, histogram).flatten()
#     return histogram

def compare_histograms(frame_histogram, reference_histogram):
    """
    Compare the frame histogram with the reference histogram using correlation.
    Returns a similarity score.
    """

    score = cv2.compareHist(frame_histogram, reference_histogram, cv2.HISTCMP_CORREL)
    return score

def draw_arrow(frame, x1, y1, x2, y2):
    """
    Draw a large, filled downward arrow above the detected object.
    """
    center_x = (x1 + x2) // 2
    top_y = y1 - 150  # position above the bounding box

    arrow_width = 50
    arrow_height = 100
    shaft_width = 20

    arrow_tip = (center_x, y1)
    left_corner = (center_x - arrow_width, top_y + arrow_height)
    right_corner = (center_x + arrow_width, top_y + arrow_height)

    shaft_top_left = (center_x - shaft_width, top_y)
    shaft_top_right = (center_x + shaft_width, top_y)
    shaft_bottom_left = (center_x - shaft_width, top_y + arrow_height)
    shaft_bottom_right = (center_x + shaft_width, top_y + arrow_height)

    arrow_head = np.array([arrow_tip, left_corner, right_corner], np.int32)
    arrow_shaft = np.array([shaft_top_left, shaft_bottom_left, shaft_bottom_right, shaft_top_right], np.int32)

    color = (0, 0, 255)  # red arrow
    cv2.fillPoly(frame, [arrow_head], color)
    cv2.fillPoly(frame, [arrow_shaft], color)

def is_within_path(car_center_x, car_center_y, path_bbox, frame_shape):
    """
    Check if the car's center lies within the specified path region.
    """
    height, width, _ = frame_shape
    x_center, y_center, box_width, box_height = path_bbox

    x_min = int((x_center - box_width / 2) * width)
    x_max = int((x_center + box_width / 2) * width)
    y_min = int((y_center - box_height / 2) * height)
    y_max = int((y_center + box_height / 2) * height)

    return x_min <= car_center_x <= x_max and y_min <= car_center_y <= y_max

def point_side(point, line_start, line_end):
    """
    Compute the cross product of vectors (line_start -> point) and (line_start -> line_end).
    A positive value indicates that 'point' is on one side of the line,
    and a negative value indicates the opposite side.
    """
    return (point[0] - line_start[0]) * (line_end[1] - line_start[1]) - \
           (point[1] - line_start[1]) * (line_end[0] - line_start[0])

# ---------- Integrated Processing Function Using SORT ----------
def save_tracking_info_to_csv(processed_frame, track_id, x1, y1, x2, y2, csv_writer):
    """
    Save the tracking information to a CSV file when a violation is detected.
    """
    csv_writer.writerow([processed_frame, track_id, int(x1), int(y1), int(x2), int(y2)])

def process_video_with_sort(input_video_path, reference_histogram_file, bbox, path_bbox, output_video_path,
                            upper_line_start, upper_line_end, lower_line_start, lower_line_end,
                            threshold, save_frames=False, temp_frame_folder=None, csv_output_path=None):
    """
    Process the video with SORT tracking.
      - Compute the grayscale histogram in a specified region.
      - Compare with a reference histogram.
      - If the similarity score exceeds the threshold (signal is Red),
        run YOLO detection, update the SORT tracker, and process detections.
      - Only write (crop) the frame to the output video when the signal is Red.
      - Save tracking info (Frame, ID, x1, y1, x2, y2) to CSV according to the cut video.
    """
    reference_histogram = load_histogram(reference_histogram_file)
    model = YOLO("yolov8n.pt")
    tracker = Sort()
    video_cap = cv2.VideoCapture(input_video_path)

    fps = int(video_cap.get(cv2.CAP_PROP_FPS))
    frame_width = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    video_writer = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height))

    csv_file = None
    csv_writer = None
    if csv_output_path:
        csv_file = open(csv_output_path, mode='w', newline='')
        csv_writer = csv.writer(csv_file)
        csv_writer.writerow(['Frame', 'ID', 'x1', 'y1', 'x2', 'y2'])  # CSV header

    track_states = {}
    frame_count = 0  # Frame index in the original video
    processed_frame_count = 0  # Frame index in the cut video

    while video_cap.isOpened():
        ret, frame = video_cap.read()
        if not ret:
            break

        # Draw custom lines
        #cv2.line(frame, upper_line_start, upper_line_end, (0, 255, 255), 2)
        #cv2.line(frame, lower_line_start, lower_line_end, (255, 0, 0), 2)

        # Compute the grayscale histogram
        frame_histogram = calculate_histogram(frame, bbox)
        score = compare_histograms(frame_histogram, reference_histogram)
        print(score)

        h, w, _ = frame.shape
        x_center, y_center, box_width, box_height = bbox
        x_min = int((x_center - box_width / 2) * w)
        x_max = int((x_center + box_width / 2) * w)
        y_min = int((y_center - box_height / 2) * h)
        y_max = int((y_center + box_height / 2) * h)
        cv2.rectangle(frame, (x_min, y_min), (x_max, y_max), (255, 255, 255), 2)
        cv2.putText(frame, f"Score: {score:.2f}", (x_min, y_max + 20),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)

        if score > threshold:
            results = model(frame)
            detections_list = []

            for box in results[0].boxes.data.cpu().numpy():
                x1, y1, x2, y2, conf, cls = box[:6]
                if int(cls) != 2:
                    continue
                detections_list.append([x1, y1, x2, y2, conf])

            detections = np.array(detections_list) if len(detections_list) > 0 else np.empty((0, 5))
            tracked_objects = tracker.update(detections)

            for d in tracked_objects:
                x1, y1, x2, y2, track_id = d
                track_id = int(track_id)
                center_x = int((x1 + x2) / 2)
                center_y = int((y1 + y2) / 2)

                if not is_within_path(center_x, center_y, path_bbox, frame.shape):
                    continue

                if track_id not in track_states:
                    if point_side((center_x, center_y), lower_line_start, lower_line_end) > 0:
                        track_states[track_id] = {"entered": True, "violated": False}
                else:
                    if track_states[track_id]["entered"] and point_side((center_x, center_y), upper_line_start, upper_line_end) > 0:
                        track_states[track_id]["violated"] = True

                if track_states.get(track_id, {}).get("violated", True):
                    cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255), 2)
                    draw_arrow(frame, int(x1), int(y1), int(x2), int(y2))

                    if csv_writer:
                        save_tracking_info_to_csv(processed_frame_count, track_id, x1, y1, x2, y2, csv_writer)

            video_writer.write(frame)
            processed_frame_count += 1  # Increment only when a frame is written to the output

        frame_count += 1  # Always increment for original video frame count

    video_cap.release()
    video_writer.release()
    if csv_file:
        csv_file.close()

#---------------------------------Input Canal----------------------------------------
input_video_path = '/content/drive/MyDrive/Input_Data/vid1_Canal.mp4'
reference_histogram_file = '/content/drive/MyDrive/Histogram/Red_colored_hist_canal.pkl'
bbox =(0.246313, 0.375587, 0.041298, 0.154430)  # Normalized coordinates
path_bbox = (0.720047, 0.827044, 0.559906, 0.345912)  # New specified path coordinate
lower_line_start = (696,1069)
lower_line_end   = (1916,1064)
upper_line_start = (657,953)
upper_line_end   = (1920,959)
output_video_path = '/content/drive/MyDrive/Resulted_violation/output_video_canal.mp4'  # Output video path
# Process the video:
process_video_with_sort(input_video_path,
                        reference_histogram_file,
                        bbox,path_bbox,
                        output_video_path,
                        upper_line_start,upper_line_end,
                        lower_line_start,lower_line_end,
                        threshold=0.6,
                        csv_output_path="/content/drive/MyDrive/Resulted_violation/tracking_info_canal_1.csv")
#-------------------------------- Input Faisal Town -----------------------------------
# input_video_path = '/content/drive/MyDrive/Input_Data/vid5_27_7_FaisalTown.mp4'  # Input video path
# reference_histogram_file = '/content/drive/MyDrive/Histogram/GreyScale_histogram.pkl'
# bbox = (0.485335, 0.578670, 0.009373, 0.043541)  # Normalized coordinates for histogram calculation
# path_bbox = (0.901061, 0.730515, 0.197877, 0.199434)  # Normalized coordinates for car detection path
# lower_line_start = (1912, 842)
# lower_line_end   = (1332, 1061)
# upper_line_start = (1389, 699)
# upper_line_end   = (1859, 666)
#output_video_path = '/content/drive/MyDrive/Resulted_violation/output_video_canal.mp4'  # Output video path
#--------------------------------Input Canal---------------------------------------------
# input_video_path = '/content/drive/MyDrive/Input_Data/vid1_Canal.mp4'
# reference_histogram_file = '/content/drive/MyDrive/Histogram/Red_colored_hist_canal.pkl'
# bbox =(0.246313, 0.375587, 0.041298, 0.154430)  # Normalized coordinates
# path_bbox = (0.720047, 0.827044, 0.559906, 0.345912)  # New specified path coordinate
# lower_line_start = (696,1069)
# lower_line_end   = (1916,1064)
# upper_line_start = (657,953)
# upper_line_end   = (1920,959)
#output_video_path = '/content/drive/MyDrive/Resulted_violation/output_video_canal.mp4'  # Output video path