In [2]:
# Import necessary libraries
from mysql.connector import Error
import matplotlib.pyplot as plt
from collections import deque
import mysql.connector
from PIL import Image
import numpy as np
import pytesseract
import pytesseract
import easyocr
import cv2
import os
import re

In [3]:
# Database Connection Constants
DB_HOST = 'localhost'       # Hostname or IP address of the database server
DB_USER = 'root'            # Database username
DB_PASSWORD = 'adminroot'   # Database password
DB_NAME = 'VisionComputador' # Name of the database to connect to

In [4]:
def detect_traffic_light_color(image, rect):
    """
    Detect the color of a traffic light in a specified region of interest (ROI) in an image.

    Parameters:
    - image: The input image (BGR format).
    - rect: A tuple (x, y, w, h) representing the rectangle for the ROI.

    Returns:
    - image: The modified image with overlaid text indicating the traffic light status.
    - color: The detected color of the traffic light ('red', 'yellow', or 'green').
    """

    # Extract rectangle dimensions
    x, y, w, h = rect
    # Extract the region of interest (ROI) from the image using the rectangle coordinates
    roi = image[y:y+h, x:x+w]

    # Convert the ROI from BGR color space to HSV color space
    hsv = cv2.cvtColor(roi, cv2.COLOR_BGR2HSV)

    # Define the HSV range for detecting the color red
    red_lower = np.array([0, 120, 70])  # Lower boundary for red
    red_upper = np.array([10, 255, 255]) # Upper boundary for red

    # Define the HSV range for detecting the color yellow
    yellow_lower = np.array([20, 100, 100])  # Lower boundary for yellow
    yellow_upper = np.array([30, 255, 255]) # Upper boundary for yellow

    # Create binary masks to isolate red and yellow colors in the ROI
    red_mask = cv2.inRange(hsv, red_lower, red_upper)
    yellow_mask = cv2.inRange(hsv, yellow_lower, yellow_upper)

    # Font properties for displaying text on the image
    font = cv2.FONT_HERSHEY_TRIPLEX   # Font type
    font_scale = 1                   # Font size
    font_thickness = 2               # Thickness of the font

    # Determine which color is present by checking non-zero pixels in the masks
    if cv2.countNonZero(red_mask) > 0:
        # If red is detected, set text properties and message
        text_color = (0, 0, 255)  # Red color for text
        message = "Detected Signal Status: Stop"  # Message for red signal
        color = 'red'
    elif cv2.countNonZero(yellow_mask) > 0:
        # If yellow is detected, set text properties and message
        text_color = (0, 255, 255)  # Yellow color for text
        message = "Detected Signal Status: Caution"  # Message for yellow signal
        color = 'yellow'
    else:
        # If neither red nor yellow is detected, assume green signal
        text_color = (0, 255, 0)  # Green color for text
        message = "Detected Signal Status: Go"  # Message for green signal
        color = 'green'

    # Overlay the detected traffic light status on the main image
    cv2.putText(image, message, (15, 70), font, font_scale + 0.5, text_color, font_thickness + 1, cv2.LINE_AA)
    # Overlay a separator line below the text
    cv2.putText(image, 34 * '-', (10, 115), font, font_scale, (255, 255, 255), font_thickness, cv2.LINE_AA)

    # Return the modified image with text overlay and the detected traffic light color
    return image, color

In [5]:
class LineDetector:
    def __init__(self, num_frames_avg=10):
        """
        Initializes the LineDetector object.

        Parameters:
        - num_frames_avg: Number of frames to average y-coordinate values for smoother line detection.
        """
        # Deque queues to store y-coordinate values across frames
        self.y_start_queue = deque(maxlen=num_frames_avg)  # Queue for start y-coordinates
        self.y_end_queue = deque(maxlen=num_frames_avg)    # Queue for end y-coordinates

    def detect_white_line(self, frame, color, 
                          slope1=0.03, intercept1=920, slope2=0.03, intercept2=770, slope3=-0.8, intercept3=2420):
        """
        Detects a white line in the given frame and highlights it in the specified color.

        Parameters:
        - frame: Input image/frame.
        - color: Color for highlighting the detected line ('red', 'green', or 'yellow').
        - slope1, intercept1: Parameters for the first line (defining the upper boundary of ROI).
        - slope2, intercept2: Parameters for the second line (defining the lower boundary of ROI).
        - slope3, intercept3: Parameters for the third line (defining the left boundary of ROI).

        Returns:
        - frame: The frame with the detected line highlighted.
        - mask_line: The frame with the area above the detected line set to black.
        """

        # Maps color names to BGR codes
        def get_color_code(color_name):
            color_codes = {
                'red': (0, 0, 255),
                'green': (0, 255, 0),
                'yellow': (0, 255, 255)
            }
            return color_codes.get(color_name.lower())

        frame_org = frame.copy()  # Backup the original frame

        # Line equations defining the region of interest (ROI)
        def line1(x): return slope1 * x + intercept1
        def line2(x): return slope2 * x + intercept2
        def line3(x): return slope3 * x + intercept3

        height, width, _ = frame.shape

        # Mask1: Black out pixels below the first line
        mask1 = frame.copy()
        for x in range(width):
            y_line = line1(x)
            mask1[int(y_line):, x] = 0

        # Mask2: Black out pixels above the second line
        mask2 = mask1.copy()
        for x in range(width):
            y_line = line2(x)
            mask2[:int(y_line), x] = 0

        # Mask3: Black out pixels to the left of the third line
        mask3 = mask2.copy()
        for y in range(height):
            x_line = line3(y)
            mask3[y, :int(x_line)] = 0

        # Convert the masked frame to grayscale
        gray = cv2.cvtColor(mask3, cv2.COLOR_BGR2GRAY)

        # Apply Gaussian blur for noise reduction
        blurred_gray = cv2.GaussianBlur(gray, (7, 7), 0)

        # Apply CLAHE (Contrast Limited Adaptive Histogram Equalization) for better contrast
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        gray_clahe = clahe.apply(blurred_gray)

        # Edge detection using the Canny method
        edges = cv2.Canny(gray, 30, 100)

        # Dilate and erode edges to close gaps
        dilated_edges = cv2.dilate(edges, None, iterations=1)
        edges = cv2.erode(dilated_edges, None, iterations=1)

        # Hough Line Transform to detect lines in the edges
        lines = cv2.HoughLinesP(edges, 1, np.pi/180, 100, minLineLength=160, maxLineGap=5)

        x_start = 0  # Start x-coordinate
        x_end = width - 1  # End x-coordinate

        # If lines are detected, process each line
        if lines is not None:
            for line in lines:
                x1, y1, x2, y2 = line[0]
                # Calculate slope and intercept of the line
                slope = (y2 - y1) / (x2 - x1 + np.finfo(float).eps)  # Avoid division by zero
                intercept = y1 - slope * x1
                # Calculate y-coordinates corresponding to x_start and x_end
                y_start = int(slope * x_start + intercept)
                y_end = int(slope * x_end + intercept)
                # Add y-coordinates to the queues
                self.y_start_queue.append(y_start)
                self.y_end_queue.append(y_end)

        # Compute average y-coordinates for smoothing
        avg_y_start = int(sum(self.y_start_queue) / len(self.y_start_queue)) if self.y_start_queue else 0
        avg_y_end = int(sum(self.y_end_queue) / len(self.y_end_queue)) if self.y_end_queue else 0

        # Adjust start coordinates to draw a shorter line
        line_start_ratio = 0.32
        x_start_adj = x_start + int(line_start_ratio * (x_end - x_start))
        avg_y_start_adj = avg_y_start + int(line_start_ratio * (avg_y_end - avg_y_start))

        # Create a mask for drawing the detected line
        mask = np.zeros_like(frame)
        cv2.line(mask, (x_start_adj, avg_y_start_adj), (x_end, avg_y_end), (255, 255, 255), 4)

        # Determine color channels to highlight the line based on the specified color
        color_code = get_color_code(color)
        if color_code == (0, 255, 0):  # Green
            channel_indices = [1]
        elif color_code == (0, 0, 255):  # Red
            channel_indices = [2]
        elif color_code == (0, 255, 255):  # Yellow
            channel_indices = [1, 2]
        else:
            raise ValueError('Unsupported color')

        # Update the specified color channels where the mask is white
        for channel_index in channel_indices:
            frame[mask[:, :, channel_index] == 255, channel_index] = 255

        # Calculate slope and intercept of the average detected line
        slope_avg = (avg_y_end - avg_y_start) / (x_end - x_start + np.finfo(float).eps)
        intercept_avg = avg_y_start - slope_avg * x_start

        # Create a mask where pixels above the detected line are blacked out
        mask_line = np.copy(frame_org)
        for x in range(width):
            y_line = slope_avg * x + intercept_avg - 35
            mask_line[:int(y_line), x] = 0

        return frame, mask_line

In [6]:
def extract_license_plate(frame, mask_line):    
    """
    Extracts license plates from the given frame based on the masked region.

    Parameters:
    - frame: Original frame.
    - mask_line: Frame with regions above the detected line blacked out.

    Returns:
    - frame: Frame with rectangles drawn around detected license plates.
    - license_plate_images: List of cropped license plate images.
    """

    # Convert the masked frame to grayscale
    gray = cv2.cvtColor(mask_line, cv2.COLOR_BGR2GRAY)

    # Apply CLAHE for better contrast
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    gray = clahe.apply(gray)

    # Erode the grayscale image to reduce noise
    kernel = np.ones((2, 2), np.uint8)
    gray = cv2.erode(gray, kernel, iterations=1)

    # Find bounding box around non-black pixels
    non_black_points = cv2.findNonZero(gray)
    x, y, w, h = cv2.boundingRect(non_black_points)

    # Adjust width to exclude 30% on the right side
    w = int(w * 0.7)

    # Crop the region containing the license plate
    cropped_gray = gray[y:y+h, x:x+w]

    # Detect license plates using a pre-trained Haar cascade
    license_plates = license_plate_cascade.detectMultiScale(cropped_gray, scaleFactor=1.07, minNeighbors=15, minSize=(20, 20))

    # Store cropped license plate images
    license_plate_images = []

    # Process detected license plates
    for (x_plate, y_plate, w_plate, h_plate) in license_plates:
        # Draw a rectangle around the license plate in the frame
        cv2.rectangle(frame, (x_plate + x, y_plate + y), (x_plate + x + w_plate, y_plate + y + h_plate), (0, 255, 0), 3)
        # Crop the license plate region
        license_plate_image = cropped_gray[y_plate:y_plate+h_plate, x_plate:x_plate+w_plate]
        license_plate_images.append(license_plate_image)

    return frame, license_plate_images

In [7]:
def apply_ocr_to_image(license_plate_image):
    """
    Applies Optical Character Recognition (OCR) to extract text from a license plate image.

    Parameters:
    - license_plate_image: Grayscale image of the license plate.

    Returns:
    - full_text: The extracted text from the license plate, with leading and trailing spaces removed.
    """

    # Threshold the image to convert it to a binary format (black and white)
    _, img = cv2.threshold(license_plate_image, 120, 255, cv2.THRESH_BINARY)

    # Convert the OpenCV image (numpy array) to PIL Image format, as pytesseract requires PIL images
    pil_img = Image.fromarray(img)

    # Use pytesseract to perform OCR and extract text from the image
    full_text = pytesseract.image_to_string(pil_img, config='--psm 6')

    # Return the extracted text, removing any extra whitespace from the ends
    return full_text.strip()

In [8]:
def draw_penalized_text(frame):
    """
    Draws the list of penalized license plate numbers on the frame.

    Parameters:
    - frame: The frame where the text will be drawn.

    Returns:
    - None. The function modifies the frame in place.
    """

    # Set text font, scale, thickness, and color
    font = cv2.FONT_HERSHEY_TRIPLEX       # Font style
    font_scale = 1                        # Font size
    font_thickness = 2                    # Thickness of the text
    color = (255, 255, 255)               # White color for the text
    
    # Initial Y-coordinate for the text position
    y_pos = 180

    # Add a title to the frame indicating fined license plates
    cv2.putText(frame, 'Fined license plates:', (25, y_pos), font, font_scale, color, font_thickness)

    # Update Y-coordinate for the next line of text
    y_pos += 80

    # Loop through the list of penalized license plates (global variable `penalized_texts`)
    for text in penalized_texts:
        # Add each license plate number to the frame with an arrow indicator
        cv2.putText(frame, '->  ' + text, (40, y_pos), font, font_scale, color, font_thickness)

        # Update Y-coordinate for the next license plate
        y_pos += 60

In [9]:
def create_database_and_table(host, user, password, database):
    """
    Creates a MySQL database and a table for storing license plate violations.

    Parameters:
    - host: The hostname or IP address of the MySQL server.
    - user: The username for the MySQL connection.
    - password: The password for the MySQL connection.
    - database: The name of the database to be created.

    Returns:
    - None. Prints messages indicating the success or failure of the operations.
    """
    try:
        # Establish a connection to the MySQL server
        connection = mysql.connector.connect(
            host=host,
            user=user,
            password=password
        )

        # Check if the connection was successful
        if connection.is_connected():
            # Create a cursor object to execute SQL queries
            cursor = connection.cursor()

            # Create the database if it does not already exist
            cursor.execute(f"CREATE DATABASE IF NOT EXISTS {database}")
            print(f"Database {database} created successfully!")

            # Switch to the newly created database
            cursor.execute(f"USE {database}")

            # Create a table to store license plate violations if it does not already exist
            cursor.execute("""
                CREATE TABLE IF NOT EXISTS license_plates (
                    id INT AUTO_INCREMENT PRIMARY KEY,       -- Unique identifier for each record
                    plate_number VARCHAR(255) NOT NULL UNIQUE,  -- License plate number, must be unique
                    violation_count INT DEFAULT 1            -- Number of violations, default is 1
                )
            """)
            print("Table created successfully!")

            # Close the cursor
            cursor.close()

    except Error as e:
        # Handle any errors that occur during the connection or execution
        print("Error while connecting to MySQL", e)

    finally:
        # Ensure the connection is closed to release resources
        if connection.is_connected():
            connection.close()

In [10]:
def update_database_with_violation(plate_number, host, user, password, database):
    """
    Updates the database with a license plate violation. 
    If the license plate exists, increments its violation count; 
    otherwise, adds it as a new record.

    Parameters:
    - plate_number: The license plate number to update or add.
    - host: The hostname or IP address of the MySQL server.
    - user: The username for the MySQL connection.
    - password: The password for the MySQL connection.
    - database: The name of the database to connect to.

    Returns:
    - None. The database is updated directly.
    """
    try:
        # Establish a connection to the MySQL server and specified database
        connection = mysql.connector.connect(
            host=host,
            user=user,
            password=password,
            database=database
        )

        # Check if the connection was successful
        if connection.is_connected():
            # Create a cursor object to execute SQL queries
            cursor = connection.cursor()

            # Check if the license plate already exists in the database
            cursor.execute(f"SELECT violation_count FROM license_plates WHERE plate_number='{plate_number}'")
            result = cursor.fetchone()  # Fetch the result of the query

            if result:
                # If the license plate exists, increment its violation count by 1
                cursor.execute(f"UPDATE license_plates SET violation_count=violation_count+1 WHERE plate_number='{plate_number}'")
            else:
                # If the license plate does not exist, insert it as a new record
                cursor.execute(f"INSERT INTO license_plates (plate_number) VALUES ('{plate_number}')")

            # Commit the changes to the database
            connection.commit()

            # Close the cursor
            cursor.close()

    except Error as e:
        # Handle any errors that occur during the connection or execution
        print("Error while connecting to MySQL", e)

    finally:
        # Ensure the connection is closed to release resources
        if connection.is_connected():
            connection.close()

In [11]:
def print_all_violations(host, user, password, database):
    """
    Fetches and prints all traffic violations stored in the database, ordered by violation count.

    Parameters:
    - host: The hostname or IP address of the MySQL server.
    - user: The username for the MySQL connection.
    - password: The password for the MySQL connection.
    - database: The name of the database to connect to.

    Returns:
    - None. Outputs the violations directly to the console.
    """
    try:
        # Establish a connection to the MySQL server and specified database
        connection = mysql.connector.connect(
            host=host,
            user=user,
            password=password,
            database=database
        )

        # Check if the connection was successful
        if connection.is_connected():
            # Create a cursor object to execute SQL queries
            cursor = connection.cursor()

            # Fetch all license plates and their violation counts, sorted in descending order by violations
            cursor.execute("SELECT plate_number, violation_count FROM license_plates ORDER BY violation_count DESC")
            result = cursor.fetchall()  # Fetch all results from the query

            # Print a formatted output of the violations
            print("\n")
            print("-" * 66)
            print("\nAll Registered Traffic Violations in the Database:\n")
            for record in result:
                print(f"Plate Number: {record[0]}, Violations: {record[1]}")

            # Close the cursor
            cursor.close()

    except Error as e:
        # Handle any errors that occur during the connection or execution
        print("Error while connecting to MySQL", e)

    finally:
        # Ensure the connection is closed to release resources
        if connection.is_connected():
            connection.close()

In [12]:
def clear_license_plates(host, user, password, database):
    """
    Deletes all records of license plates from the database.

    Parameters:
    - host: The hostname or IP address of the MySQL server.
    - user: The username for the MySQL connection.
    - password: The password for the MySQL connection.
    - database: The name of the database to connect to.

    Returns:
    - None. The database table is cleared of all records.
    """
    try:
        # Establish a connection to the MySQL server and specified database
        connection = mysql.connector.connect(
            host=host,
            user=user,
            password=password,
            database=database
        )

        # Check if the connection was successful
        if connection.is_connected():
            # Create a cursor object to execute SQL queries
            cursor = connection.cursor()

            # Delete all records from the license_plates table
            cursor.execute("DELETE FROM license_plates")

            # Commit the changes to the database
            connection.commit()

            # Close the cursor
            cursor.close()

    except Error as e:
        # Handle any errors that occur during the connection or execution
        print("Error while connecting to MySQL", e)

    finally:
        # Ensure the connection is closed to release resources
        if connection.is_connected():
            connection.close()

In [13]:
def main():
    """
    The main function coordinates all steps of the traffic violation detection system,
    from setting up the database to processing video frames for violations.

    Steps:
    - Sets up the database and clears old records.
    - Processes a video to detect traffic violations.
    - Uses OCR to identify license plates of vehicles violating traffic rules.
    - Updates the database with violations and displays results.
    """
    # Ensure the database and table exist
    create_database_and_table(DB_HOST, DB_USER, DB_PASSWORD, DB_NAME)

    # Clear previous run data from the database (optional, can be commented out)
    clear_license_plates(DB_HOST, DB_USER, DB_PASSWORD, DB_NAME)

    # Load the video file
    vid = cv2.VideoCapture('./videos/traffic_video.mp4')
    # vid = cv2.VideoCapture('./videos/video1.mp4')  # Alternate video option (uncomment to use)

    # Initialize the line detector object
    detector = LineDetector()

    # Process each frame of the video
    while True:
        # Read a single frame from the video
        ret, frame = vid.read()

        # Break the loop if no frame is returned (end of video)
        if not ret:
            break

        # Define the rectangle region for detecting the traffic light
        rect = (1700, 40, 100, 250)
        # rect = (48, 277, 1205, 40)  # Alternate region for a different video (uncomment to use)

        # Detect the traffic light color within the defined region
        frame, color = detect_traffic_light_color(frame, rect)

        # Detect the presence of a white line (vehicle stop line) based on the traffic light color
        frame, mask_line = detector.detect_white_line(frame, color)

        # Process vehicles if the traffic light is red
        if color == 'red':
            # Extract license plate images from the detected stop line area
            frame, license_plate_images = extract_license_plate(frame, mask_line)

            # Process each detected license plate
            for license_plate_image in license_plate_images:
                # Apply OCR to recognize the text on the license plate
                text = apply_ocr_to_image(license_plate_image)

                # Check if the detected text matches the license plate pattern and is not already fined
                if text and re.match("^[A-Z]{2}\s[0-9]{3,4}$", text) and text not in penalized_texts:
                    penalized_texts.append(text)  # Add the license plate to the fined list
                    print(f"\nFined license plate: {text}")

                    # Display the detected license plate image
                    plt.figure()
                    plt.imshow(license_plate_image, cmap='gray')
                    plt.axis('off')
                    plt.show()

                    # Update the database with the detected violation
                    update_database_with_violation(text, DB_HOST, DB_USER, DB_PASSWORD, DB_NAME)

        # Draw the list of penalized license plates on the video frame
        if penalized_texts:
            draw_penalized_text(frame)

        # Display the current video frame
        cv2.imshow('frame', frame)

        # Break the loop if the ESC key is pressed
        if cv2.waitKey(1) == 27:
            break

    # Release video resources after processing
    vid.release()

    # Close all OpenCV windows
    cv2.destroyAllWindows()

    # Display all traffic violations from the database
    print_all_violations(DB_HOST, DB_USER, DB_PASSWORD, DB_NAME)

In [14]:
# Load the trained Haar Cascade for license plate detection
license_plate_cascade = cv2.CascadeClassifier('./car_plate_detector.xml')

# List to store unique license plates that have been penalized
penalized_texts = []

In [None]:
# Execute the main function if the script is run directly
if __name__ == "__main__":
    main()