In [4]:
import os
import numpy as np
import csv
import sys

def read_local_pts(txt_path):
    """
    读取每个输出子文件夹下的 smplx_vs_openpose_joints.txt，返回 numpy 数组 (N, 3)
    并交换第1<->2，3<->4个点
    """
    local_pts = []
    with open(txt_path, 'r', encoding='utf-8') as f:
        reader = csv.reader(f)
        next(reader)  # 跳过表头
        for row in reader:
            x, y, z = float(row[3]), float(row[4]), float(row[5])
            local_pts.append([x, y, z])

    local_pts = np.array(local_pts, dtype=np.float32)

    # ✅ 交换第1-2和第3-4个点
    # if local_pts.shape[0] >= 4:
    #     local_pts[[0, 1]] = local_pts[[1, 0]]
    #     local_pts[[2, 3]] = local_pts[[3, 2]]

    return local_pts

def build_image_model_dict(input_dir, output_dir):
    """
    遍历输入图像，构建包含 image 路径、对应 model.obj 路径和 local_pts 的字典列表
    """
    image_model_list = []

    for filename in os.listdir(input_dir):
        if filename.lower().endswith((".jpg", ".jpeg", ".png", ".bmp")):
            image_path = os.path.join(input_dir, filename)

            # ✅ 不去掉扩展名
            sub_output_dir = os.path.join(output_dir, filename)  # 如 output/00027.png

            model_path = os.path.join(sub_output_dir, "model_transformed.obj")
            joint_txt_path = os.path.join(sub_output_dir, "smplx_vs_openpose_joints.txt")

            if not (os.path.isfile(model_path) and os.path.isfile(joint_txt_path)):
                print(f"⚠️ 缺失 model 或 joint txt 文件：{sub_output_dir}")
                continue

            local_pts = read_local_pts(joint_txt_path)

            image_model_list.append({
                "image": image_path,
                "model": model_path,
                "local_pts": local_pts
            })

    return image_model_list

def read_selected_joints2d(txt_path):
    """
    从 selected_joints2d.txt 中读取第 3, 6, 10, 12 行的 (x, y)，返回 numpy 数组 (4, 2)
    """
    selected_indices = [2, 5, 9, 12]
    selected_pts = []

    with open(txt_path, 'r', encoding='utf-8') as f:
        lines = f.readlines()
        for idx in selected_indices:
            if idx < len(lines):
                line = lines[idx].strip().replace(',', ' ')
                parts = line.split()
                if len(parts) >= 2:
                    x, y = float(parts[0]), float(parts[1])
                    selected_pts.append([x, y])
                else:
                    print(f"⚠️ 第 {idx+1} 行格式错误：{lines[idx]}")
                    selected_pts.append([0.0, 0.0])
            else:
                print(f"⚠️ 文件行数不足，缺少第 {idx+1} 行")
                selected_pts.append([0.0, 0.0])

    return np.array(selected_pts, dtype=np.float32)

def build_image_model_dict_screen(input_dir, output_dir):

    """
    遍历输入图像，构建包含 image 路径、对应 model.obj 路径、
    local_pts 和 selected_pts 的字典列表
    """
    image_model_list = []

    for filename in os.listdir(input_dir):
        if filename.lower().endswith((".jpg", ".jpeg", ".png", ".bmp")):
            image_path = os.path.join(input_dir, filename)

            # ✅ 不去掉扩展名
            sub_output_dir = os.path.join(output_dir, filename)  # 如 output/00027.png

            model_path = os.path.join(sub_output_dir, "model_transformed.obj")
            joint_txt_path = os.path.join(sub_output_dir, "smplx_vs_openpose_joints.txt")
            selected_joints_path = os.path.join(sub_output_dir, "selected_joints2d.txt")

            # ✅ 检查所有必要文件是否存在
            if not (os.path.isfile(model_path) and os.path.isfile(joint_txt_path) and os.path.isfile(selected_joints_path)):
                print(f"⚠️ 缺失必要文件：{sub_output_dir}")
                continue

            local_pts = read_local_pts(joint_txt_path)
            selected_pts = read_selected_joints2d(selected_joints_path)

            image_model_list.append({
                "image": image_path,
                "model": model_path,
                "local_pts": local_pts,
                "selected_pts": selected_pts
            })

    return image_model_list

def flip_selected_pts_y(selected_pts, image_height):
    flipped_pts = selected_pts.copy()
    flipped_pts[:, 1] = image_height - flipped_pts[:, 1]
    return flipped_pts



import torch
import numpy as np
from PIL import Image
from transformers import AutoProcessor, VitPoseForPoseEstimation
from imgutils.detect import detect_person

# 提取关键点坐标 (脖子、右肩、左肩、骨盆)
def KeyPoint(image_pose_result, image_width, image_height):
    if image_pose_result is None or len(image_pose_result) == 0:
        return None
    keypoints = image_pose_result[0]['keypoints']
    return [
        # [(keypoints[0][0].item() / image_width), 1-(keypoints[0][1].item() / image_height)],  
        # [(keypoints[3][0].item() / image_width), 1-(keypoints[3][1].item() / image_height)],  
        # [(keypoints[4][0].item() / image_width), 1-(keypoints[4][1].item() / image_height)],  
        [(keypoints[5][0].item() / image_width), 1-(keypoints[5][1].item() / image_height)],  
        [(keypoints[6][0].item() / image_width), 1-(keypoints[6][1].item() / image_height)],  
 
        [(keypoints[11][0].item() / image_width), 1-(keypoints[11][1].item() / image_height)], 
        [(keypoints[12][0].item() / image_width), 1-(keypoints[12][1].item() / image_height)],

        # [(keypoints[7][0].item() / image_width), 1-(keypoints[7][1].item() / image_height)],  
        # [(keypoints[8][0].item() / image_width), 1-(keypoints[8][1].item() / image_height)] 
    ]

# 检测人物并估计姿态
def detect_and_estimate_pose(image_path):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    image = Image.open(image_path)

    result = detect_person(image_path)
    person_boxes = result[0][0]
    if not person_boxes:
        print("No person detected in the image.")
        return None

    person_boxes = np.array(person_boxes, dtype=np.float32)
    if person_boxes.ndim == 1:
        person_boxes = person_boxes[np.newaxis, :]
    person_boxes[:, 2] -= person_boxes[:, 0]
    person_boxes[:, 3] -= person_boxes[:, 1]

    processor = AutoProcessor.from_pretrained("usyd-community/vitpose-base-simple")
    model = VitPoseForPoseEstimation.from_pretrained(
        "usyd-community/vitpose-base-simple", device_map=device
    )
    inputs = processor(image, boxes=[person_boxes], return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    pose_results = processor.post_process_pose_estimation(outputs, boxes=[person_boxes])
    return pose_results[0] if pose_results else None

def make_unity_view_matrix(pos, euler_deg, include_z_flip: bool = True) -> np.ndarray:
    pos = np.asarray(pos, dtype=np.float32)
    euler_deg = np.asarray(euler_deg, dtype=np.float32)
    pitch, yaw, roll = np.deg2rad(euler_deg)
    Rz = np.array([[ np.cos(roll), -np.sin(roll), 0], [ np.sin(roll),  np.cos(roll), 0], [ 0, 0, 1]], dtype=np.float32)
    Rx = np.array([[1, 0, 0], [0, np.cos(pitch), -np.sin(pitch)], [0, np.sin(pitch), np.cos(pitch)]], dtype=np.float32)
    Ry = np.array([[ np.cos(yaw), 0, np.sin(yaw)], [0, 1, 0], [-np.sin(yaw), 0, np.cos(yaw)]], dtype=np.float32)
    Rcw = Ry @ Rx @ Rz
    view_rot = Rcw.T
    view_trans = -view_rot @ pos
    view = np.eye(4, dtype=np.float32)
    view[:3, :3] = view_rot
    view[:3,  3] = view_trans
    if include_z_flip:
        flipZ = np.diag([1, 1, -1, 1]).astype(np.float32)
        view = flipZ @ view
    return view

def perspective_projection(fov_deg, aspect, near, far) -> np.ndarray:
    f = 1.0 / np.tan(np.deg2rad(fov_deg) * 0.5)
    M = np.zeros((4,4), dtype=np.float32)
    M[0,0] = f / aspect
    M[1,1] = f
    M[2,2] = -(far / (far - near))
    M[2,3] = near * far / (far - near)
    M[3,2] = -1.0
    return M

def project_point(point_3d, camera_params, image_width=1024, image_height=1024, fov_deg=60.0, fixed_roll=0.0):
    if len(camera_params) == 5:
        px, py, pz, pitch, yaw = camera_params
        roll = fixed_roll
    else:
        px, py, pz, pitch, yaw, roll = camera_params

    view = make_unity_view_matrix([px, py, pz], [pitch, yaw, roll], include_z_flip=True)
    aspect = image_width / image_height
    proj = perspective_projection(fov_deg, aspect, near=0.1, far=1000.0)
    # print(view)
    # print(proj)
    mvp = proj @ view
    pt_h = np.array([*point_3d, 1.0], dtype=np.float32)
    clip = mvp @ pt_h
    ndc = clip[:3] / (clip[3] if clip[3] != 0 else 1)
    sx = (ndc[0] + 1) * 0.5 #适配右手系先这样，因为
    sy = (ndc[1] + 1) * 0.5
    return np.array([sx, sy])

# 加入方向比例残差
def pairwise_vector_ratios(points):
    p = np.array(points)
    v01 = p[1] - p[0]
    v23 = p[3] - p[2]
    v02 = p[2] - p[0]
    v13 = p[3] - p[1]
    ratios = np.array([
        v01[0], v01[1],
        v23[0], v23[1],
        v02[0], v02[1],
        v13[0], v13[1],
    ], dtype=np.float32)
    return ratios

def compute_residuals(camera_params, local_points, screen_points, image_width=1024,image_height=1024,fixed_roll=0.0, ratio_weight=0):
    residuals = []
    projected_points = []
    for i, (local, screen) in enumerate(zip(local_points, screen_points)):
        proj = project_point(local, camera_params,image_width=image_width,image_height=image_height, fixed_roll=fixed_roll)
        projected_points.append(proj)
        # 对前三个点应用权重，其余保持原权重（1.0）
        if i < 3:
            weight = ratio_weight
        if i > 6:
            weight=0.5
        else:
            weight = 1.0
        # weight = ratio_weight if i < 3 else 1.0  #head weight
        residuals.extend((proj - screen) * weight)

    return np.array(residuals, dtype=np.float32)

def compute_jacobian(camera_params, local_points, screen_points,image_width=1024,image_height=1024, fixed_roll=0.0, ratio_weight=0):
    n_params = len(camera_params)
    eps = 1e-6 * (1 + np.abs(camera_params))
    base = compute_residuals(camera_params, local_points, screen_points,image_width=image_width,image_height=image_height, fixed_roll=fixed_roll, ratio_weight=ratio_weight)
    J = np.zeros((base.size, n_params), dtype=np.float32)
    for i in range(n_params):
        p2 = camera_params.copy()
        p2[i] += eps[i]
        r2 = compute_residuals(p2, local_points, screen_points,image_width=image_width,image_height=image_height, fixed_roll=fixed_roll, ratio_weight=ratio_weight)
        J[:, i] = (r2 - base) / eps[i]
    return J

def levenberg_marquardt_optimization(camera_params, local_points, screen_points,image_width=1024,image_height=1024, max_iter=200, tol=1e-6, fixed_roll=0.0, ratio_weight=0):
    lam = 1e-3
    cost = np.sum(compute_residuals(camera_params, local_points, screen_points,image_width=image_width,image_height=image_height, fixed_roll=fixed_roll, ratio_weight=ratio_weight)**2)
    for _ in range(max_iter):
        r = compute_residuals(camera_params, local_points, screen_points,image_width=image_width,image_height=image_height, fixed_roll=fixed_roll, ratio_weight=ratio_weight)
        J = compute_jacobian(camera_params, local_points, screen_points,image_width=image_width,image_height=image_height, fixed_roll=fixed_roll, ratio_weight=ratio_weight)
        A = J.T @ J
        g = J.T @ r
        diag = np.diag(A)
        L = lam * np.diag(diag + 1e-6)
        try:
            dp = np.linalg.solve(A + L, -g)
        except np.linalg.LinAlgError:
            break
        new_p = camera_params + dp
        new_cost = np.sum(compute_residuals(new_p, local_points, screen_points,image_width=image_width,image_height=image_height, fixed_roll=fixed_roll, ratio_weight=ratio_weight)**2)
        if new_cost < cost:
            camera_params = new_p
            cost = new_cost
            lam *= 0.1
            if np.linalg.norm(dp) < tol:
                break
        else:
            lam *= 10
    return camera_params

def left_to_right(local_pts):
    return local_pts * np.array([-1, 1, 1], dtype=np.float32)

def estimate_camera_pose_from_image(image_path, local_pts, fixed_roll=0.0, ratio_weight=0.0):
    """
    根据图像路径和3D关键点，估计并优化相机位姿（固定 roll），返回优化后的参数。
    
    参数:
        image_path (str): 图像文件路径
        local_pts (np.ndarray): 3D 本地坐标点数组，形状为 (N, 3)
        fixed_roll (float): 固定的相机 roll 角，默认 0.0
        ratio_weight (float): 比例项残差权重，默认 0.0
    
    返回:
        np.ndarray: 优化后的相机参数（长度为6的数组：[x, y, z, pitch, yaw]）
    """
    # 获取图像尺寸
    with Image.open(image_path) as img:
        image_width, image_height = img.size

    # 估计2D关键点
    pose_res = detect_and_estimate_pose(image_path)
    screen_pts = KeyPoint(pose_res, image_width, image_height)

    # 左右坐标转换
    local_pts = left_to_right(local_pts)

    # 初始化相机参数
    init_pos = np.array([0.0, 0.0, 2.0], dtype=np.float32)
    init_rot = np.array([0.0, 180.0], dtype=np.float32)  # pitch, yaw
    camera_params = np.concatenate([init_pos, init_rot])  # 不包括 roll

    # 优化
    opt_params = levenberg_marquardt_optimization(
        camera_params, local_pts, np.array(screen_pts),
        image_width=image_width,
        image_height=image_height,
        fixed_roll=fixed_roll,
        ratio_weight=ratio_weight
    )

    return opt_params

def maybe_reverse_result(render_result, result):
    """
    判断 render_result 的 keypoint_0 和 keypoint_1 顺序是否与 result 相反。
    
    返回:
    - True: 需要反转
    - False: 保持不变
        """
    r0, r1 = render_result['keypoint_0'], render_result['keypoint_1']
    o0, o1 = result['keypoint_0'], result['keypoint_1']


    # 判断是否方向相反
    opposite = (r0 > r1 and o0 < o1) or (r0 < r1 and o0 > o1)
    return opposite
def extract_given_points_depths(image_path, keypoints_norm):
    """
    输入图像路径和归一化二维点，输出这些点的平均深度值。
    
    参数:
    - image_path: 图像路径
    - keypoints_norm: 归一化二维点列表，如 [[x1, y1], [x2, y2], ...]，范围在 0~1
    
    返回:
    - 字典，键为 'keypoint_0' 等，值为平均深度
    """
    # 加载模型
    model = load_depth_model()

    # 获取图像尺寸
    with Image.open(image_path) as img:
        image_width, image_height = img.size
    raw_img = cv2.imread(image_path)

    # 深度推理
    depth = model.infer_image(raw_img)

    # 计算关键点深度
    mean_depths = {}
    for idx in [0, 1, 2, 3]:
        mean_depths[f'keypoint_{idx}'] = get_mean_depth(depth, keypoints_norm[idx], image_width, image_height, kernel=1,if_screen_pos=False)

    return mean_depths

SYS_PATH = r"C:\Users\31878\Desktop\pose_model\PoseCtrl-main\Depth-Anything-V2"
sys.path.insert(0, SYS_PATH)

from depth_anything_v2.dpt import DepthAnythingV2

def load_depth_model():
    """加载深度预测模型"""
    model = DepthAnythingV2(encoder='vitl', features=256, out_channels=[256, 512, 1024, 1024])
    model.load_state_dict(torch.load(
        r"C:\Users\31878\Desktop\pose_model\PoseCtrl-main\Depth-Anything-V2\checkpoints\depth_anything_v2_vitl.pth",
        map_location='cpu'
    ))
    model.eval()
    return model

def get_mean_depth(depth_image, keypoint, image_width, image_height, kernel=1, if_screen_pos=False):
    """
    计算 keypoint 周围区域的平均深度。

    参数:
    - depth_image: 2D numpy array，深度图
    - keypoint: 坐标点
    - image_width, image_height: 图像宽高
    - kernel: 邻域半径
    - if_screen_pos: bool，True 表示 keypoint 是像素坐标（不用再乘图像大小），
                     False 表示 keypoint 是归一化坐标（0~1），需要乘图像大小转换成像素坐标
    """
    if if_screen_pos:
        cx = int(round(keypoint[0] * image_width))
        cy = int(round((1 - keypoint[1]) * image_height))

    else:
        cx = keypoint[0]
        cy = keypoint[1]


    h, w = depth_image.shape

    vals = []
    for dy in range(-kernel, kernel + 1):
        for dx in range(-kernel, kernel + 1):
            x = cx + dx
            y = cy + dy
            if 0 <= x < w and 0 <= y < h:
                vals.append(float(depth_image[y, x]))
    if not vals:
        return None
    return sum(vals) / len(vals)


def extract_keypoints_depths(image_path):
    """
    输入图像路径，输出关键点的平均深度值（索引 5~8）。
    """
    # 加载模型
    model = load_depth_model()

    # 读取图像
    with Image.open(image_path) as img:
        image_width, image_height = img.size
    raw_img = cv2.imread(image_path)

    # 替换为你自己的关键点检测函数
    pose_res = detect_and_estimate_pose(image_path)
    screen_pts = KeyPoint(pose_res, image_width, image_height)
    if screen_pts is None:
        print("未检测到人体关键点。")
        return None

    # 深度推理
    depth = model.infer_image(raw_img)

    # 计算关键点深度
    mean_depths = {}
    for idx in [0, 1, 2, 3]:
        mean_depths[f'keypoint_{idx}'] = get_mean_depth(depth, screen_pts[idx], image_width, image_height, kernel=1,if_screen_pos=True)

    return mean_depths


import math

def mirror_point_about_yz_plane(point):
    x, y, z = point
    return (-x, y, z)

def mirror_vector_about_yz_plane(vec):
    return mirror_point_about_yz_plane(vec)

def euler_to_forward(pitch, yaw):
    pitch_rad = math.radians(pitch)
    yaw_rad = math.radians(yaw)
    x = math.sin(yaw_rad) * math.cos(pitch_rad)
    y = math.sin(pitch_rad)
    z = math.cos(yaw_rad) * math.cos(pitch_rad)
    return (x, y, z)

def forward_to_euler(forward):
    x, y, z = forward
    yaw = math.degrees(math.atan2(x, z))
    hyp = math.sqrt(x*x + z*z)
    pitch = math.degrees(math.atan2(y, hyp))
    return (pitch, yaw)

def mirror_camera(camera_pos, camera_euler):
    # 镜像位置
    new_pos = mirror_point_about_yz_plane(camera_pos)

    pitch, yaw = camera_euler
    roll=0
    forward_vec = euler_to_forward(pitch, yaw)

    # 镜像观察向量
    mirrored_forward = mirror_vector_about_yz_plane(forward_vec)

    # 转换回欧拉角
    new_pitch, new_yaw = forward_to_euler(mirrored_forward)

    # roll取负
    new_roll = -roll

    return new_pos + (new_pitch, new_yaw % 360, new_roll % 360)

# from unity_controller import UnityController
# from functools import lru_cache
# from concurrent.futures import ThreadPoolExecutor
# controller = UnityController()

# image = controller.render_obj(
#     obj_path=r"C:\Users\31878\Desktop\output\3.png_002\3.obj",
#     camera_position=(0, 0, 2),
#     camera_rotation=(0, 180),
#     resolution=(1024, 1024)
# )




xFormers not available
xFormers not available


In [5]:
from unity_controller import UnityController
from functools import lru_cache
from concurrent.futures import ThreadPoolExecutor
controller = UnityController()
input_dir = r"C:\Users\31878\Desktop\test\input"
output_dir = r"C:\Users\31878\Desktop\test\output"

result = build_image_model_dict(input_dir, output_dir)

# 查看结果
for item in result:
    local_pts = left_to_right(item["local_pts"])
    opt_params = estimate_camera_pose_from_image(item["image"], local_pts)
    print(local_pts)
    # print("图像:", item["image"])
    # print("模型:", item["model"])
    # print("local_pts:\n", item["local_pts"])
    # print("opt_params:\n", opt_params)

    with Image.open(item["image"]) as img:
        image_width, image_height = img.size
        
    opt_params = [float(x) for x in opt_params]

    image = controller.render_obj(
    obj_path=item["model"],
    camera_position=(opt_params[0], opt_params[1], opt_params[2]),
    camera_rotation=(opt_params[3], opt_params[4]),
    resolution=(image_width, image_height)    )

    filename = os.path.basename(item["image"])              # 获取文件名 例如 00027.png
    save_path = os.path.join(r"C:\Users\31878\Desktop\test\unity_image", filename)
    image.save(save_path)

    
    # ori_image_point = [None] * len(local_pts)  # 初始化列表，长度和local_pts一样
    # # ori_pixel_points = [None] * len(item["local_pts"])
    # ori_pixel_points = []
    # for i, pt in enumerate(local_pts):
    #     ori_image_point = project_point(pt, opt_params, image_width=image_width, image_height=image_height, fixed_roll=0.0)
    #     pixel = (
    #         int(round(ori_image_point[0] * image_width)),
    #         int(round(ori_image_point[1] * image_height))
    #     )
    #     # print(pixel)
    #     ori_pixel_points.append(pixel)


    # ##渲染图
    # print(save_path)
    # print(ori_pixel_points)
    
    # render_result=extract_given_points_depths(save_path, ori_pixel_points)
    # ##原图
    # result = extract_keypoints_depths(item["image"])

    # print(type(render_result), render_result)
    # print(type(result), result)
    
    # if maybe_reverse_result(render_result, result):
    #     print(filename+" inverse")
    #     new_opt_params = mirror_camera((opt_params[0], opt_params[1], opt_params[2]), (opt_params[3], opt_params[4]))
    #     image = controller.render_obj(
    #     obj_path=item["model"],
    #     camera_position=(new_opt_params[0], new_opt_params[1], new_opt_params[2]),
    #     camera_rotation=(new_opt_params[3], new_opt_params[4]),
    #     resolution=(image_width, image_height)    )

    #     filename = os.path.basename(item["image"])              # 获取文件名 例如 00027.png
    #     save_path = os.path.join(r"C:\Users\31878\Desktop\test\unity_image", filename)
    #     image.save(save_path)


    print("=" * 50)


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
  L = lam * np.diag(diag + 1e-6)


[[ 0.183033  0.019016  0.057625]
 [-0.151864  0.056067  0.071633]
 [ 0.055617 -0.490661 -0.009767]
 [-0.056351 -0.478296 -0.015449]]


ConnectionRefusedError: [WinError 10061] 由于目标计算机积极拒绝，无法连接。

In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

# 可微版本 perspective projection
def perspective_projection_torch(fov_deg, aspect, near, far):
    f = 1.0 / torch.tan(torch.deg2rad(torch.tensor(fov_deg * 0.5)))
    M = torch.zeros(4, 4, dtype=torch.float32)
    M[0, 0] = f / aspect
    M[1, 1] = f
    M[2, 2] = -(far / (far - near))
    M[2, 3] = near * far / (far - near)
    M[3, 2] = -1.0
    return M

def make_unity_view_matrix_torch(pos, euler_deg, include_z_flip=True):
    pitch, yaw, roll = euler_deg  # pitch, yaw, roll 必须是 torch.Tensor

    cp, sp = torch.cos(torch.deg2rad(pitch)), torch.sin(torch.deg2rad(pitch))
    cy, sy = torch.cos(torch.deg2rad(yaw)), torch.sin(torch.deg2rad(yaw))
    cr, sr = torch.cos(torch.deg2rad(roll)), torch.sin(torch.deg2rad(roll))

    device = pitch.device  # 获取当前设备

    # ✅ 所有常数变成 tensor
    zero = torch.tensor(0.0, device=device)
    one = torch.tensor(1.0, device=device)

    Rz = torch.stack([
        torch.stack([ cr, -sr, zero]),
        torch.stack([ sr,  cr, zero]),
        torch.stack([zero, zero, one])
    ])

    Rx = torch.stack([
        torch.stack([one, zero, zero]),
        torch.stack([zero,  cp, -sp]),
        torch.stack([zero,  sp,  cp])
    ])

    Ry = torch.stack([
        torch.stack([ cy, zero,  sy]),
        torch.stack([zero,  one, zero]),
        torch.stack([-sy, zero,  cy])
    ])

    R = Ry @ Rx @ Rz
    view_rot = R.T
    view_trans = -view_rot @ pos

    view = torch.eye(4, device=device)
    view[:3, :3] = view_rot
    view[:3, 3] = view_trans

    if include_z_flip:
        flipZ = torch.diag(torch.tensor([1, 1, -1, 1], dtype=torch.float32, device=device))
        view = flipZ @ view

    return view


def project_point_torch(point_3d, camera_params, image_width, image_height, fov_deg=60.0):
    px, py, pz, pitch, yaw = camera_params
    roll = torch.tensor(0.0)
    pos = torch.stack([px, py, pz])
    angles = torch.stack([pitch, yaw, roll])

    view = make_unity_view_matrix_torch(pos, angles)
    aspect = image_width / image_height
    proj = perspective_projection_torch(fov_deg, aspect, near=0.1, far=1000.0)
    mvp = proj @ view

    pt_h = torch.cat([point_3d, torch.tensor([1.0])])
    clip = mvp @ pt_h
    ndc = clip[:3] / clip[3]
    sx = (ndc[0] + 1) * 0.5
    sy = (ndc[1] + 1) * 0.5
    return torch.stack([sx, sy])

# ==== 优化入口 ====

def optimize_camera_torch(
    local_pts, selected_pts, image_width, image_height,
    fov_deg=60.0, lr=1e-2, iterations=300
):
    device = torch.device("cpu")

    local_pts = torch.tensor(local_pts, dtype=torch.float32, device=device)
    selected_pts = torch.tensor(selected_pts, dtype=torch.float32, device=device)
    selected_pts = selected_pts / torch.tensor([image_width, image_height], dtype=torch.float32)

    # 初始相机参数 px, py, pz, pitch, yaw
    centroid = local_pts.mean(dim=0)
    init_params = torch.tensor([
        centroid[0],
        centroid[1],
        centroid[2] + 2.0,
        0,     # 初始 pitch，略有扰动
        180.0     # 初始 yaw
    ], dtype=torch.float32)
    camera_params = nn.Parameter(init_params.clone(), requires_grad=True)

    optimizer = optim.Adam([camera_params], lr=lr)

    for i in range(iterations):
        optimizer.zero_grad()

        # 拆解 camera 参数（保持计算图）
        px, py, pz, pitch, yaw = camera_params
        pos = torch.stack([px, py, pz])

        # roll 固定值，但放在同图中，避免断图
        roll = torch.zeros(1, device=camera_params.device)[0]  # 不用 tensor(0.0)，会断计算图！
        angles = torch.stack([pitch, yaw, roll])  # 不打断计算图

        # 得到 view 和 projection 矩阵
        view = make_unity_view_matrix_torch(pos, angles)
        aspect = image_width / image_height
        proj = perspective_projection_torch(fov_deg, aspect, near=0.1, far=1000.0)
        mvp = proj @ view

        # 批量投影
        ones = torch.ones(local_pts.shape[0], 1, device=local_pts.device)
        pts_homo = torch.cat([local_pts, ones], dim=1)
        clip = (mvp @ pts_homo.T).T
        ndc = clip[:, :3] / clip[:, 3:4]
        screen = (ndc[:, :2] + 1.0) * 0.5

        # Loss: 平均像素误差
        loss = ((screen - selected_pts) ** 2).mean()
        loss.backward()
        optimizer.step()

        if i % 10 == 0 or i == iterations - 1:
            print(f"[{i}] Loss: {loss.item():.6f}, Params: {camera_params.data.tolist()}")

    return camera_params.data.detach().cpu().numpy()



In [14]:
result = build_image_model_dict_screen(input_dir, output_dir)

for item in result:
    local_pts = item["local_pts"]
    with Image.open(item["image"]) as img:
        image_width, image_height = img.size
    selected_pts = flip_selected_pts_y(item["selected_pts"], image_height)

    # 使用 PyTorch 反向传播求相机参数
    cam_params = optimize_camera_torch(
        local_pts=local_pts,
        selected_pts=selected_pts,
        image_width=image_width,
        image_height=image_height,
        fov_deg=60.0,
        iterations=500
    )
    print("求解相机参数:", cam_params)

    # 投影并对比
    projected_pts = np.array([
        project_point(pt3d, cam_params, image_width, image_height)
        for pt3d in local_pts
    ])
    proj_pts_pixel = projected_pts * np.array([image_width, image_height])
    selected_pts_pixel = np.array(selected_pts)

    errors = np.linalg.norm(proj_pts_pixel - selected_pts_pixel, axis=1)
    print(f"投影误差 (像素): {errors}")
    print(f"平均误差: {errors.mean():.2f} px")

    print("\n投影点:")
    print(proj_pts_pixel)
    print("\n原始 selected_pts:")
    print(selected_pts_pixel)


[0] Loss: 0.015454, Params: [0.002391248010098934, -0.2134685218334198, 2.0160105228424072, -0.009999928064644337, 179.99000549316406]
[10] Loss: 0.007659, Params: [0.09723693132400513, -0.1166161596775055, 1.9143435955047607, -0.10679274797439575, 179.89547729492188]
[20] Loss: 0.006189, Params: [0.1547548472881317, -0.041153088212013245, 1.8214061260223389, -0.18320006132125854, 179.83848571777344]
[30] Loss: 0.006282, Params: [0.14743460714817047, -0.01416033785790205, 1.7959777116775513, -0.21623750030994415, 179.84378051757812]
[40] Loss: 0.006044, Params: [0.12037810683250427, -0.028240254148840904, 1.8091837167739868, -0.21354253590106964, 179.868408203125]
[50] Loss: 0.006010, Params: [0.11718661338090897, -0.047277260571718216, 1.789986252784729, -0.20739829540252686, 179.8715057373047]
[60] Loss: 0.005966, Params: [0.12553799152374268, -0.05332431569695473, 1.7477352619171143, -0.21373961865901947, 179.8642120361328]
[70] Loss: 0.005954, Params: [0.12282471358776093, -0.05214