In [5]:
import sys
import subprocess
import importlib
import os
import struct
import socket
import time
import queue
import threading
import json 
import tkinter as tk
from tkinter import simpledialog, filedialog, messagebox
#os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

# ---------------------------------------------------------
# 1. 패키지 자동 설치
# ---------------------------------------------------------
def install_package(module_name, package_name=None):
    if package_name is None:
        package_name = module_name
    try:
        importlib.import_module(module_name)
    except ImportError:
        print(f"Installing {package_name} ...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", package_name])

print("Checking required packages...")
install_package("numpy")
install_package("cv2", "opencv-python")
install_package("tensorflow")

import cv2
import numpy as np
import tensorflow as tf

from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout, Input
from tensorflow.keras.models import Model
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.applications.efficientnet import preprocess_input

# ---------------------------------------------------------
# 2. 설정 및 연결
# ---------------------------------------------------------
HOST_CAM = '192.168.0.60' 
PORT_CAM = 80
PORT_MOT = 81

client_cam = None
client_mot = None

root = tk.Tk()
root.withdraw() 

while True:
    try:
        print(f"Connecting to {HOST_CAM}...")
        client_cam = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        client_mot = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        
        client_mot.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
        
        client_cam.settimeout(3)
        client_mot.settimeout(3)
        
        client_cam.connect((HOST_CAM, PORT_CAM))
        client_mot.connect((HOST_CAM, PORT_MOT))
        
        client_cam.settimeout(None)
        client_mot.settimeout(None)
        print("Connected successfully!")
        break 
    except Exception as e:
        print(f"Connection Failed: {e}")
        new_ip = simpledialog.askstring("Connection Failed", 
                                      f"Failed to connect to {HOST_CAM}.\n\nEnter ESP32 IP Address:",
                                      parent=root,
                                      initialvalue=HOST_CAM)
        if new_ip: HOST_CAM = new_ip
        else: sys.exit()

# ---------------------------------------------------------
# 3. 모델 구조 정의 및 가중치 로드 
# ---------------------------------------------------------
def create_model():
    base_model = EfficientNetB0(weights=None, include_top=False, input_shape=(120, 160, 3))
    fine_tune_at = -60
    for layer in base_model.layers[:fine_tune_at]:
        layer.trainable = False
    for layer in base_model.layers[fine_tune_at:]:
        if isinstance(layer, tf.keras.layers.BatchNormalization):
            layer.trainable = False
            
    inputs = Input(shape=(120, 160, 3))
    x = base_model(inputs)
    x = GlobalAveragePooling2D()(x)
    x = Dense(128, activation='relu')(x)
    x = Dropout(0.5)(x)
    x = Dense(64, activation='relu')(x)
    x = Dropout(0.3)(x)
    outputs = Dense(4, activation='softmax')(x)
    return Model(inputs, outputs)

model_path = 'weights.h5'

if not os.path.exists(model_path):
    root.lift(); root.attributes('-topmost', True)
    model_path = filedialog.askopenfilename(title="Select Weights File", filetypes=[("H5 Files", "*.h5")])
    root.attributes('-topmost', False)
    
    if not model_path:
        print("No file selected."); sys.exit()

try:
    print("Building model architecture...")
    model = create_model()
    
    print(f"Loading weights: {model_path} ...")
    model.load_weights(model_path)
    print("Model loaded.")
except Exception as e:
    messagebox.showerror("Error", f"Failed to load weights:\n{e}")
    sys.exit()

root.destroy()

names = ['_0_forward', '_1_right', '_2_left', '_3_stop']
cmd_chars = [b'F', b'R', b'L', b'S'] 

# ---------------------------------------------------------
# 4. 멀티 스레드 설정
# ---------------------------------------------------------
HOW_MANY_MESSAGES = 10
mq = [queue.Queue(HOW_MANY_MESSAGES) for _ in range(4)]
flag_exit = False

def cnn_main(args):
    print(f"Thread {args} started")
    while not flag_exit:
        try:
            frame = mq[args].get(timeout=1) 
        except queue.Empty:
            continue

        # Resize
        model_input = cv2.resize(frame, (160, 120))
        
        x = np.expand_dims(model_input, axis=0)
        
        image_tensor = preprocess_input(x)

        # 추론
        y_predict = model.predict(image_tensor, verbose=0)
        idx = np.argmax(y_predict, axis=1)[0]
        
        try:
            client_mot.sendall(cmd_chars[idx])
        except:
            pass

# 스레드 시작
threads = []
for i in range(4):
    t = threading.Thread(target=cnn_main, args=(i,))
    t.start()
    threads.append(t)

# ---------------------------------------------------------
# 5. 마스킹 설정 로드 및 트랙바 초기화
# ---------------------------------------------------------
CONFIG_FILE = "mask_config.json"
crop_top, crop_bottom, crop_left, crop_right = 0, 0, 0, 0

if os.path.exists(CONFIG_FILE):
    try:
        with open(CONFIG_FILE, 'r') as f:
            config = json.load(f)
            crop_top = config.get("top", 0)
            crop_bottom = config.get("bottom", 0)
            crop_left = config.get("left", 0)
            crop_right = config.get("right", 0)
        print(f"Loaded mask config: {config}")
    except:
        print("Failed to load mask config.")

def nothing(x): pass

WINDOW_NAME = 'AI Driving Thread (Masking Control)'
cv2.namedWindow(WINDOW_NAME)

cv2.createTrackbar('Top', WINDOW_NAME, crop_top, 240, nothing)
cv2.createTrackbar('Bottom', WINDOW_NAME, crop_bottom, 240, nothing)
cv2.createTrackbar('Left', WINDOW_NAME, crop_left, 320, nothing)
cv2.createTrackbar('Right', WINDOW_NAME, crop_right, 320, nothing)

# ---------------------------------------------------------
# 6. 메인 루프
# ---------------------------------------------------------
fn = 0
t_prev = time.time()
cnt_frame = 0

DISPLAY_WIDTH = 640
DISPLAY_HEIGHT = 480

try:
    while True:
        try:
            client_cam.sendall(struct.pack('B', 12))
            data_len_bytes = client_cam.recv(4)
            if not data_len_bytes: break
            data_len = struct.unpack('I', data_len_bytes)[0]
            
            img_data = b''
            while len(img_data) < data_len:
                packet = client_cam.recv(data_len - len(img_data))
                if not packet: break
                img_data += packet
            if len(img_data) < data_len: break
            
            np_data = np.frombuffer(img_data, dtype='uint8')
            frame = cv2.imdecode(np_data, 1)
            if frame is None: continue

        except Exception:
            print("Connection Error")
            break

        # 리사이즈
        frame = cv2.rotate(frame, cv2.ROTATE_180)
        frame_resized = cv2.resize(frame, (DISPLAY_WIDTH, DISPLAY_HEIGHT))

        c_top = cv2.getTrackbarPos('Top', WINDOW_NAME)
        c_bottom = cv2.getTrackbarPos('Bottom', WINDOW_NAME)
        c_left = cv2.getTrackbarPos('Left', WINDOW_NAME)
        c_right = cv2.getTrackbarPos('Right', WINDOW_NAME)

        h, w, _ = frame_resized.shape
        if c_top > 0: frame_resized[:c_top, :] = 0
        if c_bottom > 0: frame_resized[h-c_bottom:, :] = 0
        if c_left > 0: frame_resized[:, :c_left] = 0
        if c_right > 0: frame_resized[:, w-c_right:] = 0

        cv2.imshow(WINDOW_NAME, frame_resized)

        if not mq[fn % 4].full():
            mq[fn % 4].put(frame_resized)
        fn += 1

        key = cv2.waitKey(1)
        if key == 27: break

        cnt_frame += 1
        if time.time() - t_prev > 1.0:
            print(f"FPS: {cnt_frame}")
            t_prev = time.time()
            cnt_frame = 0

except KeyboardInterrupt:
    print("Stopping...")

finally:
    flag_exit = True
    print("Waiting for threads...")
    for t in threads:
        t.join()
    
    if client_cam: client_cam.close()
    if client_mot: client_mot.close()
    cv2.destroyAllWindows()

Checking required packages...
Connecting to 192.168.0.60...
Connection Failed: timed out
Connecting to 192.168.137.24...
Connected successfully!
Building model architecture...
Loading weights: D:/python_workplace/ESP32_autonomous_driving/weights_EN_20251226113009.h5 ...
Model loaded.
Thread 0 started
Thread 1 started
Thread 2 started
Thread 3 started
Loaded mask config: {'top': 0, 'bottom': 0, 'left': 0, 'right': 0}
FPS: 24
FPS: 25
FPS: 27
FPS: 22
Waiting for threads...
