In [None]:
import os
import json
from tqdm import tqdm
from dataclasses import dataclass
from ultralytics import YOLO 
import torch
import numpy as np
import warnings

# 忽略 tqdm 的未來警告
warnings.filterwarnings("ignore", "tqdm is not a Jupyter environment")

# 載入 VLM 邏輯檔案 (確保 LLM_PositionFilter.py 在相同目錄下)
try:
    from LLM_PositionFilter import select_target_position_filter, encode_full_frame_to_base64 
except ImportError:
    print("ERROR: LLM_PositionFilter.py not found. Please ensure the file is in the same directory.")
    raise

# -------------------- 您的路徑設定 (請替換為實際路徑) --------------------
HOME = os.getcwd()
KITTI_PATH = f"{HOME}/refer-kitti/KITTI/training/image_02/"
EXPRESSION_PATH = os.path.join(HOME, f"refer-kitti/expression/") 
EXP_PATH = "./tracking_position_filter_results" # 輸出結果目錄

os.makedirs(EXP_PATH, exist_ok=True)

# -------------------- Tracking 模型設定 --------------------
model = YOLO('yolov8n.pt') 
print("Tracking model loaded (YOLOv8n).")

# -------------------- 資料結構 (Crop_Img) --------------------
@dataclass
class Crop_Img:
    frame_id: int
    tracker_id: int
    bbox: dict 
    width: float
    height: float
    crop_path: str = ""
    cls: int = -1
    
    @property
    def x1(self):
        return self.bbox.get('x1', 0)
    @property
    def y1(self):
        return self.bbox.get('y1', 0)
    @property
    def x2(self):
        return self.bbox.get('x2', 0)
    @property
    def y2(self):
        return self.bbox.get('y2', 0)
print("Crop_Img dataclass defined.")

Tracking model loaded (YOLOv8n).
Crop_Img dataclass defined.


In [2]:
def inference_position_filter(EXP_PATH, EXPRESSION_PATH, seq, video_id):
    
    # 1. 初始化路徑和 Prompt
    SEQ_PATH = os.path.join(EXPRESSION_PATH, seq[0], seq[1])
    save_format = "{frame},{id},{x1},{y1},{w},{h},1,1,1\n"
    
    try:
        with open(SEQ_PATH, "r", encoding="utf-8") as f:
            prompt_data = json.load(f)
            sentence = prompt_data.get("sentence", "")
    except FileNotFoundError:
        print(f"[ERROR] Expression file not found: {SEQ_PATH}")
        return
        
    print(f"\nProcessing {seq[0]}/{seq[1]} with prompt: '{sentence}'")

    # 修正路徑邏輯：輸出 predict 檔案於序列資料夾下 (例如: EXP_PATH/0005/exp_a.txt)
    exp_dir = os.path.join(EXP_PATH, seq[0]) 
    os.makedirs(exp_dir, exist_ok=True)
    
    expression_file_name = seq[1].replace('.json', '.txt') 
    predict_path = os.path.join(exp_dir, expression_file_name) 

    ORI_IMG_DIR = os.path.join(KITTI_PATH, seq[0])

    # 2. 狀態追蹤集合
    confirmed_ids = set() 
    checked_ids = set() 
    
    # 3. 運行 ByteTrack 追蹤
    print(f"Running ByteTrack on {ORI_IMG_DIR}...")
    try:
        results = model.track(
            source=ORI_IMG_DIR,
            tracker="bytetrack.yaml", 
            persist=True, 
            conf=0.25, 
            iou=0.5, 
            save=False, 
            verbose=False
        )
    except Exception as e:
        print(f"[ERROR] ByteTrack failed for sequence {seq[0]}: {e}")
        return
        
    print("Tracking complete.")

    # 4. 寫入結果與 VLM 檢查循環
    with open(predict_path, "w") as fout:
        for frame_idx, result in enumerate(tqdm(results), start=1):
            
            # --- Frame-level Setup ---
            current_frame_detections = []
            img_name = f"{frame_idx-1:06d}.png"
            full_image_path = os.path.join(ORI_IMG_DIR, img_name)
            
            # 獲取圖像寬度 (用於計算相對位置 rel_x)
            if hasattr(result, 'orig_shape') and len(result.orig_shape) == 2:
                img_h, img_w = result.orig_shape
            else:
                img_h, img_w = 375, 1242 # 預設 KITTI 寬度
            
            if result.boxes is None or result.boxes.id is None:
                continue
                
            for box in result.boxes:
                if box.id is None: continue
                
                tracker_id = int(box.id[0])
                
                x1, y1, x2, y2 = list(map(float, box.xyxy[0]))
                
                crop_img_obj = Crop_Img(
                    frame_id=frame_idx, tracker_id=tracker_id,
                    bbox={'x1': x1, 'y1': y1, 'x2': x2, 'y2': y2},
                    width=x2 - x1, height=y2 - y1, cls=int(box.cls[0])
                )
                current_frame_detections.append(crop_img_obj)

            # 5. VLM 檢查 (只檢查新出現的 ID)
            new_appearing_detections = [
                d for d in current_frame_detections if d.tracker_id not in checked_ids
            ]
            
            if new_appearing_detections:
                
                # *** 核心優化：整個 Frame 只 Encode 一次大圖 ***
                try:
                    base64_img_str = encode_full_frame_to_base64(full_image_path)
                except Exception as e:
                    print(f"[ERROR] Skipping frame {frame_idx}: Failed to encode image: {e}")
                    # 如果編碼失敗，則本幀的新 ID 標記為已檢查，避免重複呼叫
                    checked_ids.update(d.tracker_id for d in new_appearing_detections) 
                    continue
                
                # 遍歷所有新的物件，序列式詢問 VLM
                for det in new_appearing_detections:
                    
                    # 重新計算您的原版位置描述 (POSITION LOGIC)
                    center_x = (det.x1 + det.x2) / 2
                    rel_x = center_x / img_w 

                    if rel_x < 0.35:
                        position_desc = "on the left side of the road"
                    elif rel_x > 0.65:
                        position_desc = "on the right side of the road"
                    else:
                        position_desc = "in the center or directly ahead"
                    
                    # 呼叫 VLM 檢查單一 Bbox 
                    is_match = select_target_position_filter(
                        prompt=sentence,
                        tracker_id=det.tracker_id,
                        x1=det.x1, y1=det.y1, x2=det.x2, y2=det.y2,
                        position_desc=position_desc, 
                        base64_image_string=base64_img_str, 
                        quiet=True
                    )
                    
                    if is_match:
                        confirmed_ids.add(det.tracker_id)
                    
                    checked_ids.add(det.tracker_id)

            # 6. 寫入結果 (只寫入 confirmed_ids 內的 ID)
            for det in current_frame_detections:
                if det.tracker_id in confirmed_ids:
                    line = save_format.format(
                        frame=frame_idx, id=det.tracker_id,
                        x1=det.x1, y1=det.y1, w=det.width, h=det.height
                    )
                    fout.write(line)
    
    print(f"Finished processing sequence {seq[0]}/{seq[1]}. Results saved to {predict_path}")

In [3]:
# 函式：取得所有 Sequence (只包含有 GT 的序列: 0005, 0011, 0013)
def get_sequences(expression_path):
    sequences = []
    target_sequences = ['0005', '0011', '0013'] 

    for root, _, files in os.walk(expression_path):
        # 檢查資料夾名稱是否在目標序列中
        relative_dir = os.path.relpath(root, expression_path)
        sequence_id = os.path.basename(relative_dir)
        
        if sequence_id in target_sequences:
            for f in files:
                if f.endswith('.json'):
                    # 輸出格式為 (sequence_id, filename)
                    sequences.append((relative_dir, f)) 
    return sequences

sequences_to_run = get_sequences(EXPRESSION_PATH)
print(f"Found {len(sequences_to_run)} expressions across target sequences to process.")

# 迭代所有序列並運行推論函式
for idx, seq in enumerate(sequences_to_run):
    inference_position_filter(EXP_PATH, EXPRESSION_PATH, seq, idx)

print("\n--- All sequence processing complete. ---")

Found 0 expressions across target sequences to process.

--- All sequence processing complete. ---
