In [None]:
import cv2
import numpy as np
import onnxruntime as ort
import time
from collections import Counter, defaultdict
import config as cfg
from detection import preprocess_image, process_output
from classification import predict_from_crop
from tracker_module import ObjectTracker

# --- setup sessions, video, tracker, etc. ---
detection_session = ort.InferenceSession(cfg.detection_model_path, providers=cfg.providers)
input_name = detection_session.get_inputs()[0].name
output_names = [o.name for o in detection_session.get_outputs()]
iw, ih = detection_session.get_inputs()[0].shape[2:]

classification_session = ort.InferenceSession(cfg.classification_model_path, providers=cfg.providers)

cap = cv2.VideoCapture(cfg.video_path)
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

# Define codec and create VideoWriter object
fourcc = cv2.VideoWriter_fourcc(*'mp4v')  
fps = int(cap.get(cv2.CAP_PROP_FPS))
out = cv2.VideoWriter('customer classification IN OUT.mp4', fourcc, fps, (w, h))

diag = np.hypot(w, h)
tracker = ObjectTracker(iou_threshold=0.3, max_dist=0.2 * diag, expire_after=24, BASE_IOU_W=0.8, BASE_DIR_W=0.2)

# --- counting infrastructure ---
track_labels = defaultdict(list)    # tid -> [labels...]
track_prev_cx = {}                  # tid -> last cx
counted_ids = set()                 # tids already counted
IN_count = 0

start_line_x = int(w//8)
start_line_y1 = int(1*h//5)
start_line_y2 = int(4*h // 5)
IN_line    = int(w // 2)

frame_count = 0

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break

    frame_count += 1
    if frame_count % 2 != 0:
        continue

    # ── detection ─────────────────────────────────────────────────
    blob, ratio, orig_shape = preprocess_image(frame, (iw, ih))
    outs = detection_session.run(output_names, {input_name: blob})
    dets = process_output(outs, conf_threshold=0.4, nms_threshold=0.5, img_shape=orig_shape, ratio=ratio)

    # ── tracking ──────────────────────────────────────────────────
    active = tracker.update([d['box'] for d in dets], frame_count)

    for tid, box in active.items():
        x1, y1, x2, y2 = map(int, box)
        cx,cy = (x1 + x2) // 2, (y1 + y2) // 2

        # skip if still left of the line OR cy is outside [y1, y2]
        if cx < start_line_x or cy < start_line_y1 or cy > start_line_y2:
            track_prev_cx[tid] = cx
            continue


        # crop & classify
        crop = frame[y1:y2, x1:x2]
        if crop.size == 0:
            continue
        label = predict_from_crop(classification_session, crop)
        track_labels[tid].append(label)

        # draw box + label + ID
        cv2.rectangle(frame, (x1, y1), (x2, y2), (0,255,0), 2)
        cv2.putText(frame, f"{label}", (x1, y1-10),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,0), 2)
        cv2.putText(frame, f"ID {tid}", (x2-15, y1-8),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0,255,0), 2)

        # check crossing left→right over IN_line
        prev_cx = track_prev_cx.get(tid)
        if prev_cx is not None and prev_cx < IN_line <= cx and tid not in counted_ids:
            # get mode label
            mode_label = Counter(track_labels[tid]).most_common(1)[0][0]
            if mode_label == 'customer':
                IN_count += 1
            counted_ids.add(tid)

        # update prev position
        track_prev_cx[tid] = cx

    # Vertical red line (between y1 and y2)
    cv2.line(frame, (start_line_x, start_line_y1), (start_line_x, start_line_y2), (0, 0, 255), 2)

    # Horizontal top red line from start_line_x to right edge
    cv2.line(frame, (start_line_x, start_line_y1), (IN_line, start_line_y1), (0, 0, 255), 2)

    # Horizontal bottom red line from start_line_x to right edge
    cv2.line(frame, (start_line_x, start_line_y2), (IN_line, start_line_y2), (0, 0, 255), 2)

    # ── IN LINE (Green, full height) ─────────
    cv2.line(frame, (IN_line, start_line_y1), (IN_line, start_line_y2), (0, 255, 0), 3)


    cv2.putText(frame, f"IN: {IN_count}", (10, 30),
                cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0,255,0), 2)

    cv2.imshow('People Detection with Count', frame)
    out.write(frame)  # Write the processed frame to output file
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cap.release()
out.release()
cv2.destroyAllWindows()

[0;93m2025-05-16 11:19:49.671826 [W:onnxruntime:, coreml_execution_provider.cc:112 GetCapability] CoreMLExecutionProvider::GetCapability, number of partitions supported by CoreML: 5 number of nodes in the graph: 320 number of nodes supported by CoreML: 315[m
[0;93m2025-05-16 11:19:51.060373 [W:onnxruntime:, session_state.cc:1263 VerifyEachNodeIsAssignedToAnEp] Some nodes were not assigned to the preferred execution providers which may or may not have an negative impact on performance. e.g. ORT explicitly assigns shape related ops to CPU to improve perf.[m
[0;93m2025-05-16 11:19:51.060380 [W:onnxruntime:, session_state.cc:1265 VerifyEachNodeIsAssignedToAnEp] Rerunning with verbose output on a non-minimal build will show node assignments.[m
[0;93m2025-05-16 11:19:51.073744 [W:onnxruntime:, coreml_execution_provider.cc:112 GetCapability] CoreMLExecutionProvider::GetCapability, number of partitions supported by CoreML: 3 number of nodes in the graph: 394 number of nodes supported by