In [None]:
import sys
import subprocess
import importlib
import os
import struct
import socket
import time
import json
import tkinter as tk
from tkinter import simpledialog, filedialog, messagebox

# ---------------------------------------------------------
# 1. 패키지 자동 설치
# ---------------------------------------------------------
def install_package(module_name, package_name=None):
    if package_name is None:
        package_name = module_name
    try:
        importlib.import_module(module_name)
    except ImportError:
        print(f"Installing {package_name} ...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", package_name])

print("Checking required packages...")
install_package("numpy")
install_package("cv2", "opencv-python")
install_package("tensorflow")

import cv2
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model

# ---------------------------------------------------------
# 2. 설정 및 연결
# ---------------------------------------------------------
HOST_CAM = '192.168.0.60' # 초기 IP
PORT_CAM = 80
PORT_MOT = 81

client_cam = None
client_mot = None

root = tk.Tk()
root.withdraw() 

while True:
    try:
        print(f"Connecting to {HOST_CAM}...")
        client_cam = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        client_mot = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        
        client_mot.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
        
        client_cam.settimeout(3)
        client_mot.settimeout(3)
        
        client_cam.connect((HOST_CAM, PORT_CAM))
        client_mot.connect((HOST_CAM, PORT_MOT))
        
        client_cam.settimeout(None)
        client_mot.settimeout(None)
        print("Connected successfully!")
        break 
    except Exception as e:
        print(f"Connection Failed: {e}")
        new_ip = simpledialog.askstring("Connection Failed", 
                                      f"Failed to connect to {HOST_CAM}.\n\nEnter ESP32 IP Address:",
                                      parent=root,
                                      initialvalue=HOST_CAM)
        if new_ip: HOST_CAM = new_ip
        else: sys.exit()

# ---------------------------------------------------------
# 3. 모델 로드
# ---------------------------------------------------------
model_path = 'model_CNN.h5'

if not os.path.exists(model_path):
    root.lift(); root.attributes('-topmost', True)
    model_path = filedialog.askopenfilename(title="Select AI Model File", filetypes=[("H5 Files", "*.h5")])
    root.attributes('-topmost', False)
    
    if not model_path:
        print("No model selected."); sys.exit()

try:
    print(f"Loading model: {model_path} ...")
    model = load_model(model_path, compile=False)
    print("Model loaded.")
except Exception as e:
    messagebox.showerror("Error", f"Failed to load model:\n{e}")
    sys.exit()

root.destroy()

names = ['_0_forward', '_1_right', '_2_left', '_3_stop']
cmd_chars = [b'F\n', b'R\n', b'L\n', b'S\n']

# ---------------------------------------------------------
# 4. 마스킹 설정 로드 및 트랙바 초기화
# ---------------------------------------------------------
CONFIG_FILE = "mask_config.json"
crop_top, crop_bottom, crop_left, crop_right = 0, 0, 0, 0
load_path = CONFIG_FILE

if not os.path.exists(CONFIG_FILE):
    mask_root = tk.Tk()
    mask_root.withdraw()
    
    def bring_front():
        mask_root.deiconify()
        mask_root.attributes('-alpha', 0.0)
        mask_root.attributes('-topmost', True)
        mask_root.lift()
        mask_root.focus_force()

    bring_front()
    ans = messagebox.askyesno("Config Missing", 
                              "mask_config.json not found.\nDo you want to select a mask config file manually?",
                              parent=mask_root)
    
    if ans:
        bring_front()
        selected_file = filedialog.askopenfilename(
            title="Select Mask Config JSON",
            filetypes=[("JSON Files", "*.json"), ("All Files", "*.*")],
            parent=mask_root
        )
        if selected_file:
            load_path = selected_file
        else:
            print("Selection cancelled. Using Default (0,0,0,0).")
            load_path = None
    else:
        print("Using Default (0,0,0,0).")
        load_path = None
        
    mask_root.destroy()

if load_path and os.path.exists(load_path):
    try:
        with open(load_path, 'r') as f:
            config = json.load(f)
            crop_top = config.get("top", 0)
            crop_bottom = config.get("bottom", 0)
            crop_left = config.get("left", 0)
            crop_right = config.get("right", 0)
        print(f"Loaded mask config from: {os.path.basename(load_path)}")
    except Exception as e:
        print(f"Failed to load mask config: {e}")
        print("Using default (0).")

def nothing(x): pass

WINDOW_NAME = 'AI Driving CNN'
cv2.namedWindow(WINDOW_NAME)

cv2.createTrackbar('Top', WINDOW_NAME, crop_top, 240, nothing)
cv2.createTrackbar('Bottom', WINDOW_NAME, crop_bottom, 240, nothing)
cv2.createTrackbar('Left', WINDOW_NAME, crop_left, 320, nothing)
cv2.createTrackbar('Right', WINDOW_NAME, crop_right, 320, nothing)

# ---------------------------------------------------------
# 5. 주행 루프
# ---------------------------------------------------------
t_prev = time.time()
cnt_frame = 0

DISPLAY_WIDTH = 640
DISPLAY_HEIGHT = 480

try:
    while True:
        # 1. 영상 수신
        try:
            client_cam.sendall(struct.pack('B', 12))
            data_len_bytes = client_cam.recv(4)
            if not data_len_bytes: break
            data_len = struct.unpack('I', data_len_bytes)[0]
            
            img_data = b''
            while len(img_data) < data_len:
                packet = client_cam.recv(data_len - len(img_data))
                if not packet: break
                img_data += packet
            if len(img_data) < data_len: break
            
            np_data = np.frombuffer(img_data, dtype='uint8')
            frame = cv2.imdecode(np_data, 1)
            if frame is None: continue

        except Exception:
            print("Connection Error")
            break

        # 2. 리사이즈 (640x480) - 마스킹 기준 해상도
        frame = cv2.rotate(frame, cv2.ROTATE_180)
        frame_resized = cv2.resize(frame, (DISPLAY_WIDTH, DISPLAY_HEIGHT))

        c_top = cv2.getTrackbarPos('Top', WINDOW_NAME)
        c_bottom = cv2.getTrackbarPos('Bottom', WINDOW_NAME)
        c_left = cv2.getTrackbarPos('Left', WINDOW_NAME)
        c_right = cv2.getTrackbarPos('Right', WINDOW_NAME)

        h, w, _ = frame_resized.shape
        if c_top > 0: frame_resized[:c_top, :] = 0
        if c_bottom > 0: frame_resized[h-c_bottom:, :] = 0
        if c_left > 0: frame_resized[:, :c_left] = 0
        if c_right > 0: frame_resized[:, w-c_right:] = 0

        cv2.imshow(WINDOW_NAME, frame_resized)

        # 3. AI 모델 입력용 전처리
        model_input = cv2.resize(frame_resized, (160, 120))
        image_norm = model_input.astype(np.float32) / 255.0
        image_tensor = np.expand_dims(image_norm, axis=0)

        # 4. 추론
        y_predict = model.predict(image_tensor, verbose=0)
        idx = np.argmax(y_predict, axis=1)[0]
        
        print(f"Pred: {names[idx]} ({y_predict[0][idx]:.2f})")

        # 5. 모터 제어
        try:
            client_mot.sendall(cmd_chars[idx])
        except:
            pass

        key = cv2.waitKey(1)
        if key == 27: break

        cnt_frame += 1
        if time.time() - t_prev > 1.0:
            print(f"FPS: {cnt_frame}")
            t_prev = time.time()
            cnt_frame = 0

except KeyboardInterrupt:
    print("Stopping...")
finally:
    if client_cam: client_cam.close()
    if client_mot: client_mot.close()
    cv2.destroyAllWindows()