In [34]:
import os
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from pathlib import Path
# !pip install ultralytics
from ultralytics import YOLO
import numpy as np
import matplotlib.pyplot as plt


In [45]:
# Only characters present in Spanish license plates (no vowels)
CHAR_CLASSES = "0123456789BCDFGHJKLMNPQRSTVWXYZ"

class LicensePlateDetector:
    def __init__(self):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = None

    def create_dataset(self, dataset_path, yaml_path="dataset.yaml"):
        """Create a file dataset.yaml for training the model"""
        train_path = os.path.join(dataset_path, 'images', 'train')
        val_path = os.path.join(dataset_path, 'images', 'val')
        test_path = os.path.join(dataset_path, 'images', 'test')

        config_content = f"""
path: {dataset_path}
train: {train_path}
val: {val_path}
test: {test_path}
names:
  0: license_plate"""
        with open(yaml_path, 'w') as file:
            file.write(config_content)
        print(f"Dataset configuration file created at {yaml_path}")
        return yaml_path

    def train_model(self, dataset_path, epochs=10, batch=8, project="license_plate_training"):
        """Train a custom model for license plate detection."""
        
        yaml_path = self.create_dataset(dataset_path)
        
        print("Starting training...")

        # Load base model
        self.model = YOLO("yolov8n.pt")
        # Train the model on the custom dataset
        results = self.model.train(data=yaml_path, epochs=epochs, batch=batch, project=project)
        best_model_path = Path(results.save_dir) / "weights" / "best.pt"
        
        print("Training completed")
        print(f"Best model saved at: {str(best_model_path)}")
        return str(best_model_path)

    def detect_license_plate(self, best_model_path, test_dir, output_dir, conf=0.25):
        """Detect license plates in images using the trained model."""
        self.model = YOLO(best_model_path)
        print(f"Detecting license plates with model {best_model_path}")
        
        prediction_results = self.model.predict(source=test_dir, save=False, conf=conf)
        
        results_dir = os.path.join(output_dir, "results")
        os.makedirs(results_dir, exist_ok=True)
        
        for i, result in enumerate(prediction_results):
            annotated_image = result.plot()
            output_path = os.path.join(results_dir, f"detected_{i}.jpg")
            cv2.imwrite(output_path, annotated_image)
        
        print("Detection completed. Results saved.")
        return prediction_results

    # =========================================================================
    # Character Segmentation Functions 
    # =========================================================================
        
    def preprocess_plate_for_segmentation(self, plate_image):
        """
        Preprocesses a license plate image for character segmentation.
        - Transforms to grayscale and uses HSV to handle the blue section.
        - Applies Otsu's thresholding for binarization.
        """
        # Convert to HSV to better handle the blue band
        hsv = cv2.cvtColor(plate_image, cv2.COLOR_BGR2HSV)
        image_h, image_s, image_v = cv2.split(hsv)
        
        # Define range for blue color and create a mask
        lower_blue = np.array([100, 150, 0])
        upper_blue = np.array([140, 255, 255])
        mask_blue = cv2.inRange(hsv, lower_blue, upper_blue)
        
        # Convert to grayscale
        # gray = cv2.cvtColor(plate_image, cv2.COLOR_BGR2GRAY)
        
        # Paint the blue area white on the grayscale image
        image_v[mask_blue > 0] = 255

        # Apply Gaussian blur to reduce the noise
        image_v = cv2.GaussianBlur(image_v, (5, 5), 0)
        
        # Apply Otsu's thresholding to get a binary image
        _, otsu_thresh = cv2.threshold(image_v, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        
        # Invert the image so characters are white on a black background
        binary = cv2.bitwise_not(otsu_thresh)

        # Morphological operations to clean up
        kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
        # CLOSING: fill small holes in the characters
        binary = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel, iterations=3)
        #OPENING: remove small noise, and separate connected characters
        binary = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel, iterations=1)
            
        return binary

    def segment_characters(self, plate_image):
        """
        Finds and segments individual characters from a preprocessed license plate image.
        - Uses cv.findContours to identify character shapes.
        - Filters contours based on size and aspect ratio to isolate characters.
        """
        preprocessed_plate = self.preprocess_plate_for_segmentation(plate_image)        
        contours, _ = cv2.findContours(preprocessed_plate, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        
        characters = []
        # Get shape from preprocessed image which is grayscale (2 dimensions)
        plate_height, plate_width = preprocessed_plate.shape
        
        for contour in contours:
            x, y, w, h = cv2.boundingRect(contour)
            aspect_ratio = w / h
            area = cv2.contourArea(contour)

            # Filter the countours that touch the border of the plate
            # Characters should not be at the edge of the plate
            touches_border = (x==0 or y==0 or (x+w) >= plate_width or (y+h) >= plate_height)
            
            # Heuristic filters to identify a character based on area and aspect ratio.
            # We have removed the position filter to detect the first character.
            if area > 50 and 0.1 < aspect_ratio < 1.2 and h > 10 and w > 3 and not touches_border: #! Aspect ratio expanded to include narrow character as 1
                character_roi = plate_image[y:y+h, x:x+w]
                characters.append({'image': character_roi, 'bbox': (x, y, w, h)})
        
        characters.sort(key=lambda c: c['bbox'][0])

        # Maximum of 7 characters (spanish license plates)
        if len(characters) > 7:
            characters = sorted(characters, key=lambda c: c['bbox'][2]*c['bbox'][3], reverse=True)[:7]
            characters.sort(key=lambda c: c['bbox'][0])

        return characters, preprocessed_plate

    def process_and_show_results(self, prediction_results):
        """Processes each detected plate to segment and show the characters."""
        for i, result in enumerate(prediction_results):
            image = result.orig_img
            boxes = result.boxes
            
            if boxes:
                print(f"\n--- Processing Image {i+1} ---")
                
                for j, box in enumerate(boxes):
                    x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
                    
                    plate_image = image[y1:y2, x1:x2]
                    
                    if plate_image is not None and plate_image.size > 0:
                        print(f"Detected License Plate {j+1}: Bbox {x1, y1, x2, y2}")
                        
                        characters, preprocessed_plate = self.segment_characters(plate_image)
                        print(f"Found {len(characters)} character candidates.")

                        plate_text = self.ocr_characters(characters)
                        print(f"OCR Result, Detected Plate: {plate_text}")
            
                        # Create output image with bounding boxes
                        output_image = plate_image.copy()
                        for char_info in characters:
                            x, y, w, h = char_info['bbox']
                            cv2.rectangle(output_image, (x, y), (x + w, y + h), (0, 255, 0), 2)
                        
                        # Show images
                        annotated_full_image = result.plot()
                        
                        plt.figure(figsize=(20, 5))
                        
                        # 1. Full image with detected plate
                        plt.subplot(1, 4, 1)
                        plt.imshow(cv2.cvtColor(annotated_full_image, cv2.COLOR_BGR2RGB))
                        plt.title("Licensed Plate Detection")
                        plt.axis('off')
                        
                        # 2. Cropped license plate
                        plt.subplot(1, 4, 2)
                        plt.imshow(cv2.cvtColor(plate_image, cv2.COLOR_BGR2RGB))
                        plt.title("Cropped License Plate")
                        plt.axis('off')
                        
                        # 3. Image preprocessed (binarized)
                        plt.subplot(1, 4, 3)
                        plt.imshow(preprocessed_plate, cmap='gray')
                        plt.title("Image Preprocessed")
                        plt.axis('off')
                        
                        # 4. Characters segmented
                        plt.subplot(1, 4, 4)
                        plt.imshow(cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB))
                        plt.title(f"characters Segmented: {len(characters)}")
                        plt.axis('off')
                        
                        plt.suptitle(f"Image {i+1}")
                        plt.tight_layout()
                        plt.show()

                        
                    else:
                        print("Could not crop license plate. Skipping.")
            else:
                print(f"\n--- Processing Image {i+1} ---")
                print("No license plate detected in this image.")
    
    # =========================================================================
    # OCR Functions 
    # =========================================================================
        
    def load_char_recognizer(self, model_path="char_recognizer.pt"):
        self.char_model = CharRecognizer(len(CHAR_CLASSES)).to(self.device)
        self.char_model.load_state_dict(torch.load(model_path, map_location=self.device))
        self.char_model.eval()

    def ocr_characters(self, characters):
        """
        OCR using a simple CNN model for character recognition
        """
        plate_text = ""
        transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=1),
            transforms.Resize((28,28)),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])

        for char_info in characters:
            img_char = char_info['image']
            pil_img = cv2.cvtColor(img_char, cv2.COLOR_BGR2RGB)
            pil_img = transforms.ToPILImage()(pil_img)
            img_tensor = transform(pil_img).unsqueeze(0).to(self.device)

            with torch.no_grad():
                output = self.char_model(img_tensor)
                pred_idx = output.argmax(1).item()
                plate_text += CHAR_CLASSES[pred_idx]

        return plate_text

In [44]:
#* ================= MAIN EXECUTION =================

# dataset_path = "/content/BD/BD_LicensePlate"
# test_dir = "/content/BD/BD_LicensePlate/images/test"
# output_dir = "/content/test_results"

dataset_path = r"C:\Users\adria\OneDrive - UAB\4 ENGINY\Processament Imatge i Video\Repte Matriculas\BD_LicensePlate"
test_dir = r"C:\Users\adria\OneDrive - UAB\4 ENGINY\Processament Imatge i Video\Repte Matriculas\BD_LicensePlate\images\test"
output_dir = r"C:\Users\adria\OneDrive - UAB\4 ENGINY\Processament Imatge i Video\Repte Matriculas\test_results"

detector = LicensePlateDetector()
# ! LOAD THE OCR MODEL
detector.load_char_recognizer(model_path="char_recognizer.pt")

# best_model_path = detector.train_model(dataset_path, epochs=10, batch=8)
best_model_path = r"C:\Users\adria\OneDrive - UAB\4 ENGINY\Processament Imatge i Video\Repte Matriculas\License-Plate-Detection\models\best_license_plate.pt"
# best_model_path = r"C:\Users\adria\OneDrive - UAB\4 ENGINY\Processament Imatge i Video\Repte Matriculas\License-Plate-Detection\models\best.pt"
prediction_results = detector.detect_license_plate(best_model_path, test_dir, output_dir, conf=0.25)

  self.char_model.load_state_dict(torch.load(model_path, map_location=self.device))


FileNotFoundError: [Errno 2] No such file or directory: 'char_recognizer.pt'

In [42]:
# Segmentation and display of characters

detector.process_and_show_results(prediction_results)


--- Processing Image 1 ---
Detected License Plate 1: Bbox (1510, 762, 1879, 1063)
Found 5 character candidates.


AttributeError: 'LicensePlateDetector' object has no attribute 'char_model'