# Final Code

In [None]:


import cv2
import numpy as np
from ultralytics import YOLO
import time
import math
from tracker import Tracker

# Load YOLO model
model = YOLO('yolo12n.pt')  # Update with correct model path

# Video file paths
video_paths = [
    'traffic1.mp4',  # North
    'traffic2.mp4',  # East
    'traffic3.mp4',  # South
    'traffic4.mp4'   # West
]

# Load videos and check if they opened successfully
videos = []
for i, path in enumerate(video_paths):
    cap = cv2.VideoCapture(path)
    if not cap.isOpened():
        print(f"Error: Could not open video {path}")
        exit(1)
    videos.append(cap)

# Class list for vehicle types (subset of COCO classes)
class_list = ['bicycle', 'car', 'motorcycle', 'bus', 'truck']
valid_class_ids = [1, 2, 3, 6, 8]  # COCO indices: bicycle, car, motorcycle, bus, truck

# Tracker instance
tracker = Tracker()

# Signal states: 0=Red, 1=Green, 2=Yellow
signal_states = [1, 0, 0, 0]  # Start with frame 0 (North) green
timers = [60, 0, 0, 0]  # Initial timer for green (60s), others to be calculated
current_frame = 0  # Start with frame 0
last_switch_time = time.time()

def draw_signals():
    signal_frame = np.zeros((400, 800, 3), dtype=np.uint8)
    colors = [(0, 0, 255), (0, 255, 0), (0, 255, 255)]  # Red, Green, Yellow
    for i in range(4):
        x, y = (i % 2) * 400, (i // 2) * 200
        state = signal_states[i]
        cv2.circle(signal_frame, (x + 50, y + 50), 40, colors[state], -1)
        # Display timer as integer
        cv2.putText(signal_frame, f"{int(timers[i])}s", (x + 20, y + 150), 
                    cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
    return signal_frame

def calculate_max_distance_and_draw(frame, scan=True):
    if not scan:
        return frame, 0
    
    results = model.predict(frame, verbose=False)
    max_dist = 0
    objects_rect = []

    for r in results:
        boxes = r.boxes.xyxy.cpu().numpy()
        classes = r.boxes.cls.cpu().numpy()
        for box, cls in zip(boxes, classes):
            cls = int(cls)
            if cls not in valid_class_ids:
                continue
            
            class_idx = valid_class_ids.index(cls)
            class_name = class_list[class_idx]
            
            x1, y1, x2, y2 = map(int, box)
            w, h = x2 - x1, y2 - y1
            objects_rect.append([x1, y1, w, h])
            
            dist = frame.shape[0] - y1
            max_dist = max(max_dist, dist)

    tracked_objects = tracker.update(objects_rect)
    
    for obj in tracked_objects:
        x, y, w, h, obj_id = obj
        for box in boxes:
            bx1, by1, bx2, by2 = map(int, box)
            if bx1 == x and by1 == y:
                for cls in classes:
                    if int(cls) in valid_class_ids:
                        class_idx = valid_class_ids.index(int(cls))
                        class_name = class_list[class_idx]
                        break
                break
        label = f"{class_name} ID:{obj_id}"
        cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 0), 2)
        cv2.putText(frame, label, (x, y - 10), 
                    cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
    
    return frame, max_dist

def update_timers(max_distances):
    global timers, signal_states, current_frame, last_switch_time
    current_time = time.time()
    elapsed = current_time - last_switch_time

    # Update timers for green and yellow
    for i in range(4):
        if signal_states[i] in [1, 2]:  # Green or Yellow
            timers[i] = max(0, timers[i] - elapsed)

    # Handle state transitions
    if signal_states[current_frame] == 1 and timers[current_frame] <= 0:
        signal_states[current_frame] = 2  # Green to Yellow
        timers[current_frame] = 5  # Yellow for 5 seconds
    elif signal_states[current_frame] == 2 and timers[current_frame] <= 0:
        signal_states[current_frame] = 0  # Yellow to Red
        current_frame = (current_frame + 1) % 4  # Move to next frame
        signal_states[current_frame] = 1  # Next frame to Green
        max_dist = max_distances[current_frame]
        green_duration = 60 if max_dist > 250 else max(6, (max_dist // 25) * 6)
        timers[current_frame] = green_duration

    # Calculate red light timers (time until green)
    total_cycle_time = 0
    for i in range(4):
        if signal_states[i] == 1:  # Green
            total_cycle_time += timers[i] + 5  # Remaining green + yellow
        elif signal_states[i] == 2:  # konserwYellow
            total_cycle_time += timers[i]  # Remaining yellow

    for i in range(1, 4):  # Calculate for red lights ahead in sequence
        idx = (current_frame + i) % 4
        if signal_states[idx] == 0:  # Red
            wait_time = total_cycle_time
            for j in range(i):
                prev_idx = (current_frame + j) % 4
                if prev_idx != current_frame:  # Skip current green/yellow
                    max_dist = max_distances[prev_idx]
                    green_duration = 60 if max_dist > 250 else max(6, (max_dist // 25) * 6)
                    wait_time += green_duration + 5  # Add green + yellow for each prior frame
            timers[idx] = wait_time

    # Debug output with integer timers
    print(f"Frame: {current_frame}, States: {signal_states}, Timers: {[int(t) for t in timers]}")
    last_switch_time = current_time

while True:
    frames = []
    max_distances = []
    for i, cap in enumerate(videos):
        ret, frame = cap.read()
        if not ret:
            print(f"Warning: Failed to read frame from video {video_paths[i]}. Attempting to loop.")
            cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
            ret, frame = cap.read()
            if not ret:
                print(f"Error: Could not loop video {video_paths[i]}. Exiting.")
                exit(1)
        
        if frame is None or frame.size == 0:
            print(f"Error: Invalid frame from video {video_paths[i]}. Exiting.")
            exit(1)
        
        frame = cv2.resize(frame, (400, 300))
        
        scan = signal_states[i] != 1
        frame, max_dist = calculate_max_distance_and_draw(frame, scan)
        max_distances.append(max_dist)
        
        cv2.putText(frame, f"Max Dist: {max_dist:.0f}px", (10, 30), 
                    cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
        frames.append(frame)

    update_timers(max_distances)

    top_row = np.hstack((frames[0], frames[1]))
    bottom_row = np.hstack((frames[2], frames[3]))
    video_grid = np.vstack((top_row, bottom_row))

    signal_grid = draw_signals()

    cv2.imshow('Traffic Videos', video_grid)
    cv2.imshow('Traffic Signals', signal_grid)

    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

# Cleanup
for cap in videos:
    cap.release()
cv2.destroyAllWindows()