#### Step 1: Import Library

###### 차선 검출에 필요한 Library 들을 현재 파일에 호출한다.

In [13]:
# Step 1: Import Required Libraries
import cv2
import torch
import time
from torchvision import transforms
from PIL import Image as PILImage
import numpy as np
from model.lanenet.LaneNet import LaneNet
from sklearn.cluster import DBSCAN
from sklearn.preprocessing import StandardScaler
from collections import defaultdict
import threading


#### Step 2: GPU Setting

In [14]:
# Step 2: Setup Device for Computation
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(torch.cuda.is_available())


True


#### Step 3: Define the LaneDetector Class

In [15]:
# Step 3: Define the LaneDetector Class
class LaneDetector:
    def __init__(self):
        self.resize_height = 480
        self.resize_width = 640
        self.roi_height = self.resize_height // 3  # ROI height setting

        # Load LaneNet Model
        print("Loading LaneNet model...")
        self.model = LaneNet(arch='DeepLabv3+')
        self.model.load_state_dict(torch.load('C:\\Users\\yth12\\Dropbox\\4. 기타 자료\\해군 AI 특강\\AI_Lecture\\Day1\\Deep-Learning\\LaneNet\\log\\lanenet_DeepLabv3+_CrossEntrophy_epoch100_batchsize8.pth'))
        self.model.eval()
        self.model.to(DEVICE)
        print("Model loaded successfully!")

        self.data_transform = transforms.Compose([
            transforms.Resize((self.roi_height, self.resize_width)),  # Resize to ROI size
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

        self.vehicle_center_x = self.resize_width // 2

        self.lane_colors = {
            'left': (255, 0, 0),    # Red
            'center': (0, 255, 0),  # Green
            'right': (0, 0, 255)    # Blue
        }

        self.processing_lock = threading.Lock()
        self.current_frame = None
        self.processing_thread = threading.Thread(target=self.process_frames)
        self.processing_thread.daemon = True
        self.processing_thread.start()

    # Step 4: Define Utility Functions
    def load_test_data(self, img):
        img = PILImage.fromarray(img)
        img = self.data_transform(img)
        return img

    def _morphological_process(self, image, kernel_size=5):
        image = (image * 255).astype(np.uint8)
        kernel = cv2.getStructuringElement(shape=cv2.MORPH_ELLIPSE, ksize=(kernel_size, kernel_size))
        closing = cv2.morphologyEx(image, cv2.MORPH_CLOSE, kernel, iterations=1)
        return closing

    def _connect_components_analysis(self, image):
        return cv2.connectedComponentsWithStats(image, connectivity=8, ltype=cv2.CV_32S)

    # Step 5: Clustering Embeddings
    def cluster_embeddings(self, binary_seg, instance_seg):
        idxs = np.where(binary_seg > 0.5)
        if len(idxs[0]) == 0:
            print("No pixels found above threshold in binary segmentation")
            return {}

        embeddings = instance_seg[:, idxs[0], idxs[1]].transpose(1, 0)

        if embeddings.shape[0] < 2:
            print("Not enough samples for clustering")
            return {}

        try:
            scaler = StandardScaler()
            embeddings = scaler.fit_transform(embeddings)

            db = DBSCAN(eps=0.5, min_samples=min(100, embeddings.shape[0] // 2))
            db.fit(embeddings)

            lanes = defaultdict(list)
            for idx, label in enumerate(db.labels_):
                if label != -1:
                    lanes[label].append((idxs[0][idx], idxs[1][idx]))

            return lanes
        except Exception as e:
            print(f"Error in clustering: {e}")
            return {}

    # Step 6: Assign Lane Positions
    def assign_lane_positions(self, lanes):
        if not lanes:
            return {}

        lane_avg_x = {lane_id: np.mean([p[1] for p in points]) for lane_id, points in lanes.items()}
        sorted_lanes = sorted(lane_avg_x.items(), key=lambda x: x[1])
        
        lane_positions = {}
        if len(sorted_lanes) == 1:
            lane_positions[sorted_lanes[0][0]] = 'center'
        elif len(sorted_lanes) == 2:
            lane_positions[sorted_lanes[0][0]] = 'left'
            lane_positions[sorted_lanes[1][0]] = 'right'
        elif len(sorted_lanes) >= 3:
            lane_positions[sorted_lanes[0][0]] = 'left'
            lane_positions[sorted_lanes[1][0]] = 'center'
            lane_positions[sorted_lanes[2][0]] = 'right'
        
        return lane_positions
    
    # Step 7: Process Frame
    def process_frame(self, frame):
        roi = frame[-self.roi_height:, :]
        input_img = cv2.cvtColor(roi, cv2.COLOR_BGR2RGB)
        input_img = PILImage.fromarray(input_img)
        input_np = np.array(input_img)

        dummy_input = self.load_test_data(input_np).to(DEVICE)
        dummy_input = torch.unsqueeze(dummy_input, dim=0)

        with torch.no_grad():
            outputs = self.model(dummy_input)

        binary_seg = torch.squeeze(outputs['binary_seg_pred']).cpu().numpy()
        instance_seg = torch.squeeze(outputs['instance_seg_logits']).cpu().numpy()

        binary_seg = (binary_seg - binary_seg.min()) / (binary_seg.max() - binary_seg.min() + 1e-8)
        binary_seg = self._morphological_process(binary_seg)

        num_labels, labels, stats, centroids = self._connect_components_analysis(binary_seg)

        for index in range(1, num_labels):
            if stats[index][4] <= 100:
                binary_seg[labels == index] = 0

        binary_seg = binary_seg.astype(float) / 255.0

        lanes = self.cluster_embeddings(binary_seg, instance_seg)
        lane_positions = self.assign_lane_positions(lanes)

        result_img = cv2.cvtColor(input_np, cv2.COLOR_RGB2BGR)

        if lanes:
            for lane_id, points in lanes.items():
                position = lane_positions.get(lane_id, 'center')
                color = self.lane_colors[position]
                for point in points:
                    cv2.circle(result_img, (point[1], point[0]), 1, color, -1)

            print(f"Detected lanes: {len(lanes)} lanes")
        else:
            print("No lanes detected in this frame")

        cv2.line(result_img, (self.vehicle_center_x, 0), (self.vehicle_center_x, self.roi_height), (255, 255, 255), 2)
        
        # Binary, Instance 결과 화면 출력
        binary_seg_img = (binary_seg * 255).astype(np.uint8)
        binary_seg_img = cv2.resize(binary_seg_img, (self.resize_width, self.roi_height))  # 크기 조정
        cv2.imshow('Binary Output', binary_seg_img)

        instance_seg_img = np.sum(instance_seg, axis=0)  # 채널 축을 따라 합산하여 하나의 이미지로 변환
        instance_seg_img = (instance_seg_img - instance_seg_img.min()) / (instance_seg_img.max() - instance_seg_img.min() + 1e-8) * 255.0
        instance_seg_img = instance_seg_img.astype(np.uint8)
        instance_seg_img = cv2.resize(instance_seg_img, (self.resize_width, self.roi_height))  # 크기 조정
        cv2.imshow('Instance Output', instance_seg_img)

        return result_img
    
    # Step 8: Image Callback
    def image_callback(self, frame):
        with self.processing_lock:
            self.current_frame = frame
            
    # Step 9: Frame Processing Loop
    def process_frames(self):
        while True:
            with self.processing_lock:
                if self.current_frame is None:
                    continue
                frame = self.current_frame.copy()

            start_time = time.time()
            result_img = self.process_frame(frame)

            full_result = frame.copy()
            full_result[-self.roi_height:, :] = result_img
            
            # ROI box 그리기
            cv2.rectangle(full_result, (0, frame.shape[0] - self.roi_height), (frame.shape[1], frame.shape[0]), (0, 255, 255), 2)

            cv2.imshow('Lane Detection with Vehicle Center', full_result)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                print("Shutting down...")
                break

            end_time = time.time()
            fps = 1 / (end_time - start_time)
            print(f"FPS: {fps:.2f}") 

#### Step 10: Main Function

###### 메인 함수로 웹캠에서 영상을 캡처하고 이를 LaneDetector 클래스에 넘겨줌

In [16]:
# Step 10: Main Function
if __name__ == "__main__":
    lane_detector = LaneDetector()

    cap = cv2.VideoCapture(0)

    try:
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
                        
            # 딥러닝 모델로 처리된 프레임 확인
            lane_detector.image_callback(frame)
            
    except KeyboardInterrupt:
        print("Shutting down")
    finally:
        cap.release()
        cv2.destroyAllWindows()

Loading LaneNet model...
Use DeepLabv3+ as backbone


  self.model.load_state_dict(torch.load('C:\\Users\\yth12\\Dropbox\\4. 기타 자료\\해군 AI 특강\\AI_Lecture\\Day1\\Deep-Learning\\LaneNet\\log\\lanenet_DeepLabv3+_CrossEntrophy_epoch100_batchsize8.pth'))


Model loaded successfully!
No pixels found above threshold in binary segmentation
No lanes detected in this frame
FPS: 13.42
No pixels found above threshold in binary segmentation
No lanes detected in this frame
FPS: 40.01
No pixels found above threshold in binary segmentation
No lanes detected in this frame
FPS: 24.27
No pixels found above threshold in binary segmentation
No lanes detected in this frame
FPS: 32.26
No pixels found above threshold in binary segmentation
No lanes detected in this frame
FPS: 31.25
No pixels found above threshold in binary segmentation
No lanes detected in this frame
FPS: 31.55
No pixels found above threshold in binary segmentation
No lanes detected in this frame
FPS: 31.27
No pixels found above threshold in binary segmentation
No lanes detected in this frame
FPS: 31.30
No pixels found above threshold in binary segmentation
No lanes detected in this frame
FPS: 40.00
No pixels found above threshold in binary segmentation
No lanes detected in this frame
FPS: