In [None]:
import pandas as pd
import cv2
import matplotlib.pyplot as plt
import json
import numpy as np
import torch
from ultralytics import YOLO
import matplotlib.path as mplPath
from tqdm import tqdm

In [None]:
world_raw_path = '/home/lnt/PycharmProjects/analyze_gaze/world_raw.mp4'
world_output_path = '/home/lnt/PycharmProjects/analyze_gaze/world_output.mp4'

In [None]:
gaze_positions_df = pd.read_csv('/home/lnt/PycharmProjects/analyze_gaze/gaze_positions.csv')
gaze_positions_df

In [None]:
gaze_positions_df['norm_pos_x'].iloc[2] * 1088, (1-gaze_positions_df['norm_pos_y'].iloc[2]) * 1080

In [None]:
world_raw_cap = cv2.VideoCapture(world_raw_path)
world_output_cap = cv2.VideoCapture(world_output_path)

In [None]:
world_raw_cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
world_output_cap.set(cv2.CAP_PROP_POS_FRAMES, 22)

In [None]:
ret, world_raw_frame = world_raw_cap.read()
ret, world_output_frame = world_output_cap.read()

In [None]:
with open("scene_camera.json", "r") as f:
    data = json.load(f)

pupil_camera_matrix = np.array(data["camera_matrix"])
pupil_dist_coeffs = np.array(data["dist_coefs"])

In [None]:
model = YOLO('yolov8x-seg.pt')

In [None]:
num_rows = gaze_positions_df.shape[0]

world_img = None
gazed_data = []
mask_data = []
fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # 指定编码格式
out0 = cv2.VideoWriter(world_output_path, fourcc, 30.0, (1280, 720))

previous_world_index = -1  # 初始化为一个不可能的索引值

for row_n in tqdm(range(num_rows)):
    row = gaze_positions_df.loc[row_n]
    current_world_index = row['world_index']

    # 当 world_index 改变时，重新处理图像
    if current_world_index != previous_world_index:
        ret, world_img = world_raw_cap.read()
        
        # 检查是否读取成功
        if not ret or world_img is None:
            print(f"Failed to read frame {current_world_index}.")
            # 尝试重新初始化视频捕获对象
            world_raw_cap.release()
            world_raw_cap = cv2.VideoCapture(world_raw_path)
            world_raw_cap.set(cv2.CAP_PROP_POS_FRAMES, current_world_index + 1)
            continue  # 跳过这个帧并继续处理下一个帧

            
        # 如果图像有效，使用模型处理
        else:
            results = model.track(world_img, verbose=False, classes=[0, 1, 2, 3, 5, 7, 9, 11, 30], persist=True, conf=0.7)[0]
            result = results.plot()

    # 计算 gaze 点的实际坐标
    x = int(row['norm_pos_x'] * world_img.shape[1])
    y = int((1 - row['norm_pos_y']) * world_img.shape[0])

    # 判断 gaze 点是否落在某个物体的掩码区域内
    if results.masks is not None:
        for mid, xy in enumerate(results.masks.xy):
            poly_path = mplPath.Path(xy)
            if poly_path.contains_point((x, y)):
                cls = results.boxes.cls.cpu().tolist()[mid]

                if results.boxes.id is not None:
                    bid = results.boxes.id.cpu().tolist()[mid]
                else:
                    bid = None  # 或者您可以选择跳过或记录这个情况
                    print(f"No ID for object at index {mid} in frame {row['world_index']}")

                name = results.names[cls]
                gazed_data.append({
                    'index': row.name, 'timestamp': row['gaze_timestamp'], 'frame': row['world_index'], 
                    'name': name, 'x': x, 'y': y, 'cls': cls, 'bid': bid
                })

                
    # 在图像上绘制 gaze 点
    cv2.circle(world_img, (x, y), 30, (0, 255, 0), -1)
    cv2.circle(world_img, (x, y), 10, (255, 255, 0), -1)

    # 仅在 world_index 改变时将处理后的图像写入视频
    if current_world_index != previous_world_index:
        out0.write(world_img)
    
    # 更新 previous_world_index
    previous_world_index = current_world_index

out0.release()

In [None]:
gazed_df = pd.DataFrame(gazed_data)
gazed_df

In [None]:
gazed_df.to_csv('gazed.csv')