# align
This script is for the central body axis alignment in deterministic methos

## Part 1

Change root

In [None]:
import numpy as np
import cv2
import csv
import os
from tqdm import tqdm

def compute_midline_transform(P_src, S_src, P_ref, S_ref, eps=1e-6):
    """
    根據 source / reference 的 pelvis_mid & shoulder_mid，計算一個
    similarity transform T (scale + rotation + translation)，讓：
        P_src -> P_ref
        S_src -> S_ref  (理想情況下)

    參數:
        P_src, S_src: (2,) np.array，來源 pelvis_mid / shoulder_mid
        P_ref, S_ref: (2,) np.array，目標 pelvis_mid / shoulder_mid
        eps: 避免除以 0 用的 threshold

    回傳:
        M: 2x3 仿射矩陣 (OpenCV warpAffine 用)
        A: 2x2 矩陣 (方便對 keypoints 做 x' = A x + t)
        t: 2, 向量
    """
    P_src = np.asarray(P_src, dtype=np.float32)
    S_src = np.asarray(S_src, dtype=np.float32)
    P_ref = np.asarray(P_ref, dtype=np.float32)
    S_ref = np.asarray(S_ref, dtype=np.float32)

    v_src = S_src - P_src
    v_ref = S_ref - P_ref

    L_src = np.linalg.norm(v_src)
    L_ref = np.linalg.norm(v_ref)

    if L_src < eps:
        # 中軸線長度幾乎 0，沒辦法定義方向，直接回傳 identity transform
        A = np.eye(2, dtype=np.float32)
        t = np.zeros(2, dtype=np.float32)
        M = np.hstack([A, t.reshape(2, 1)])
        return M, A, t

    # 縮放
    s = L_ref / (L_src + eps)

    # 旋轉角度
    theta_src = np.arctan2(v_src[1], v_src[0])
    theta_ref = np.arctan2(v_ref[1], v_ref[0])
    dtheta = theta_ref - theta_src

    cos_t = np.cos(dtheta)
    sin_t = np.sin(dtheta)

    # 2x2 旋轉 + 縮放矩陣
    R = np.array([[cos_t, -sin_t],
                  [sin_t,  cos_t]], dtype=np.float32)
    A = s * R  # A * x + t

    # 平移項：讓 T(P_src) = P_ref
    t = P_ref - A @ P_src

    # 組成 2x3 仿射矩陣
    M = np.hstack([A, t.reshape(2, 1)])  # (2, 3)
    return M, A, t


In [None]:
def align_midline_single(
    img,
    keypoints_src,
    P_src,
    S_src,
    P_ref,
    S_ref,
    output_size=None,
    border_value=0,
):
    """
    對「一張骨架圖 + 對應 keypoints」做中軸對齊：
    - 利用 pelvis_mid & shoulder_mid 算出 similarity transform
    - 對圖片做 warpAffine
    - 對所有 keypoints 做相同座標變換

    參數:
        img: 原圖 (H, W, C)，假設已經是 256x256
        keypoints_src: numpy array, shape = (N, 2)，來源 keypoints
        P_src, S_src: 來源 pelvis_mid / shoulder_mid (2,)
        P_ref, S_ref: 目標 pelvis_mid / shoulder_mid (2,)
        output_size: (W_out, H_out)，預設 None 時用原圖大小
        border_value: warpAffine 邊界填充值

    回傳:
        aligned_img: 對齊後的圖片
        aligned_kps: 對齊後的 keypoints, shape = (N, 2)
    """
    h, w = img.shape[:2]
    if output_size is None:
        W_out, H_out = w, h
    else:
        W_out, H_out = output_size

    keypoints_src = np.asarray(keypoints_src, dtype=np.float32)

    # 計算仿射矩陣
    M, A, t = compute_midline_transform(P_src, S_src, P_ref, S_ref)

    # 對圖片做 warp
    aligned_img = cv2.warpAffine(
        img,
        M,
        (W_out, H_out),
        flags=cv2.INTER_LINEAR,
        borderMode=cv2.BORDER_CONSTANT,
        borderValue=border_value,
    )

    # 對 keypoints 做 x' = A x + t
    # keypoints_src shape: (N, 2)
    aligned_kps = keypoints_src @ A.T + t  # (N, 2)

    return aligned_img, aligned_kps


In [None]:
def align_midline_dataset(
    img_dir,
    csv_in,
    ref_csv_path,
    img_out_dir,
    csv_out,
    target_size=256,
    border_value=0,
):
    """
    對一個資料夾內的所有「來源骨架圖 + keypoints」做中軸對齊，
    對齊到同一個 reference keypoints 的中軸線。

    假設：
        - img_dir 裡檔名像 "000123.png"（骨架圖）
        - csv_in: 對應 keypoints (256版)
        - ref_csv_path: standard / original keypoints (256版)
        - src / ref 的 CSV 都有欄位:
            filename, ..., shoulder_mid_x, shoulder_mid_y,
                              pelvis_mid_x, pelvis_mid_y

        - src / ref 透過「不含副檔名的 stem」對應：
            e.g. src 圖 000123.png → stem="000123"
                 ref CSV row 中 filename="000123.jpg" 或 "000123.png"
                 也會有 stem="000123"

    產出：
        - img_out_dir: 對齊後的骨架圖
        - csv_out: 對齊後的 keypoints CSV
    """
    os.makedirs(img_out_dir, exist_ok=True)

    # --------------------------
    # 1) 讀 reference CSV
    # --------------------------
    with open(ref_csv_path, "r", newline="") as f_ref:
        ref_reader = csv.DictReader(f_ref)
        ref_header = ref_reader.fieldnames
        if ref_header is None or "filename" not in ref_header:
            raise ValueError("ref CSV 必須包含 'filename' 欄位")

        # 找出 keypoints 欄位順序 + shoulder/pelvis index
        ref_cols = ref_header[1:]  # 除去 filename
        if len(ref_cols) % 2 != 0:
            raise ValueError("除了 filename 之外，ref CSV 欄位數應為偶數 (x,y 成對)")

        kp_pairs = []
        idx_shoulder = None
        idx_pelvis = None
        for i in range(0, len(ref_cols), 2):
            x_col = ref_cols[i]
            y_col = ref_cols[i + 1]
            kp_pairs.append((x_col, y_col))
            if x_col == "shoulder_mid_x":
                idx_shoulder = len(kp_pairs) - 1
            if x_col == "pelvis_mid_x":
                idx_pelvis = len(kp_pairs) - 1

        if idx_shoulder is None or idx_pelvis is None:
            raise ValueError("在 ref CSV 裡找不到 shoulder_mid_x 或 pelvis_mid_x 欄位")

        # 以 stem 建立 reference row 辭典
        ref_rows_by_stem = {}
        for row in ref_reader:
            fname = row["filename"]
            stem = os.path.splitext(fname)[0]
            ref_rows_by_stem[stem] = row

    # --------------------------
    # 2) 讀 source CSV
    # --------------------------
    with open(csv_in, "r", newline="") as f_src:
        src_reader = csv.DictReader(f_src)
        src_header = src_reader.fieldnames
        if src_header is None or "filename" not in src_header:
            raise ValueError("src CSV 必須包含 'filename' 欄位")

        # 假設 keypoints 欄位順序跟 ref 相同（同一個模板）
        # 如果不一樣，可以再做一次 kp_pairs，但多數情況是一致的
        src_rows_by_stem = {}
        for row in src_reader:
            fname = row["filename"]
            stem = os.path.splitext(fname)[0]
            src_rows_by_stem[stem] = row

    # --------------------------
    # 3) 準備輸出 CSV
    # --------------------------
    with open(csv_out, "w", newline="") as f_out:
        writer = csv.DictWriter(f_out, fieldnames=src_header)
        writer.writeheader()

        # --------------------------
        # 4) 處理每張 source 圖片
        # --------------------------
        exts = [".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".tif"]
        img_files = [
            name for name in os.listdir(img_dir)
            if any(name.lower().endswith(ext) for ext in exts)
        ]

        print(f"中軸對齊: 從 {img_dir} 讀取 {len(img_files)} 張圖片")

        for img_name in tqdm(img_files, desc="align midline"):
            img_path = os.path.join(img_dir, img_name)
            img = cv2.imread(img_path)
            if img is None:
                print(f"警告：無法讀取圖片 {img_name}")
                continue

            h, w = img.shape[:2]
            if (h != target_size) or (w != target_size):
                print(f"警告：圖片 {img_name} 不是 {target_size}x{target_size}，仍然照原尺寸對齊")

            stem, ext = os.path.splitext(img_name)

            src_row = src_rows_by_stem.get(stem)
            ref_row = ref_rows_by_stem.get(stem)
            if src_row is None or ref_row is None:
                print(f"警告：找不到 stem={stem} 的 src/ref keypoints，略過 {img_name}")
                continue

            # 將 src_row 轉成 keypoints array
            src_kps_list = []
            for x_col, y_col in kp_pairs:
                x_val = float(src_row[x_col])
                y_val = float(src_row[y_col])
                src_kps_list.append([x_val, y_val])
            src_kps = np.array(src_kps_list, dtype=np.float32)

            # 同樣把 ref_row 轉成 keypoints array
            ref_kps_list = []
            for x_col, y_col in kp_pairs:
                x_val = float(ref_row[x_col])
                y_val = float(ref_row[y_col])
                ref_kps_list.append([x_val, y_val])
            ref_kps = np.array(ref_kps_list, dtype=np.float32)

            # 取得中軸端點
            P_src = src_kps[idx_pelvis]    # pelvis_mid
            S_src = src_kps[idx_shoulder] # shoulder_mid
            P_ref = ref_kps[idx_pelvis]
            S_ref = ref_kps[idx_shoulder]

            # 做中軸對齊
            aligned_img, aligned_kps = align_midline_single(
                img,
                src_kps,
                P_src,
                S_src,
                [128.0, 110.0], # for 2: [128.0, 110.0], for 4: [128.0, 120.0], for 6: [128.0, 135.0]
                [128.0, 70.0], # for 2 & 4: [128.0, 70.0], for 6: [128.0, 65.0]
                output_size=(w, h),
                border_value=border_value,
            )


            # 存圖片
            out_img_path = os.path.join(img_out_dir, img_name)
            cv2.imwrite(out_img_path, aligned_img)

            # 存對齊後 keypoints 行
            new_row = dict(src_row)
            # filename 欄位要不要改？看你需求
            # 這裡保留原本 filename，不動
            # new_row["filename"] = src_row["filename"]

            for idx, (x_col, y_col) in enumerate(kp_pairs):
                x_new, y_new = aligned_kps[idx]
                new_row[x_col] = float(x_new)
                new_row[y_col] = float(y_new)

            writer.writerow(new_row)

    print(f"完成中軸對齊：圖片輸出到 {img_out_dir}, CSV 輸出到 {csv_out}")


In [None]:
root = "/Users/kosu/dataset_noSTN/2_stand_open_arm_156"

align_midline_dataset(
    img_dir=root+"/resized/original",
    csv_in=root+"/resized/resized_original.csv",
    ref_csv_path=root+"/resized/resized_original_70.csv",
    img_out_dir=root+"/aligned/original",
    csv_out=root+"/aligned/aligned_original.csv",
    target_size=256,
)

## for verification ##
# align_midline_dataset(
#     img_dir=root+"/resized/original_70",
#     csv_in=root+"/resized/resized_original_70.csv",
#     ref_csv_path=root+"/resized/resized_original_70.csv",
#     img_out_dir=root+"/aligned/original_70",
#     csv_out=root+"/aligned/aligned_original_70.csv",
#     target_size=256,
# )

align_midline_dataset(
    img_dir=root+"/resized/swirl_180",
    csv_in=root+"/resized/resized_swirl_180.csv",
    ref_csv_path=root+"/resized/resized_original_70.csv",
    img_out_dir=root+"/aligned/swirl_180",
    csv_out=root+"/aligned/aligned_swirl_180.csv",
    target_size=256,
)

align_midline_dataset(
    img_dir=root+"/resized/resize_left_0.7",
    csv_in=root+"/resized/resized_resize_left_0.7.csv",
    ref_csv_path=root+"/resized/resized_original_70.csv",
    img_out_dir=root+"/aligned/resize_left_0.7",
    csv_out=root+"/aligned/aligned_resize_left_0.7.csv",
    target_size=256,
)

align_midline_dataset(
    img_dir=root+"/resized/resize_left_1.5",
    csv_in=root+"/resized/resized_resize_left_1.5.csv",
    ref_csv_path=root+"/resized/resized_original_70.csv",
    img_out_dir=root+"/aligned/resize_left_1.5",
    csv_out=root+"/aligned/aligned_resize_left_1.5.csv",
    target_size=256,
)

align_midline_dataset(
    img_dir=root+"/resized/resize_right_0.5",
    csv_in=root+"/resized/resized_resize_right_0.5.csv",
    ref_csv_path=root+"/resized/resized_original_70.csv",
    img_out_dir=root+"/aligned/resize_right_0.5",
    csv_out=root+"/aligned/aligned_resize_right_0.5.csv",
    target_size=256,
)

align_midline_dataset(
    img_dir=root+"/resized/resize_right_1.3",
    csv_in=root+"/resized/resized_resize_right_1.3.csv",
    ref_csv_path=root+"/resized/resized_original_70.csv",
    img_out_dir=root+"/aligned/resize_right_1.3",
    csv_out=root+"/aligned/aligned_resize_right_1.3.csv",
    target_size=256,
)

align_midline_dataset(
    img_dir=root+"/resized/resize_upper_0.7",
    csv_in=root+"/resized/resized_resize_upper_0.7.csv",
    ref_csv_path=root+"/resized/resized_original_70.csv",
    img_out_dir=root+"/aligned/resize_upper_0.7",
    csv_out=root+"/aligned/aligned_resize_upper_0.7.csv",
    target_size=256,
)

align_midline_dataset(
    img_dir=root+"/resized/resize_upper_1.5",
    csv_in=root+"/resized/resized_resize_upper_1.5.csv",
    ref_csv_path=root+"/resized/resized_original_70.csv",
    img_out_dir=root+"/aligned/resize_upper_1.5",
    csv_out=root+"/aligned/aligned_resize_upper_1.5.csv",
    target_size=256,
)

align_midline_dataset(
    img_dir=root+"/resized/resize_lower_0.5",
    csv_in=root+"/resized/resized_resize_lower_0.5.csv",
    ref_csv_path=root+"/resized/resized_original_70.csv",
    img_out_dir=root+"/aligned/resize_lower_0.5",
    csv_out=root+"/aligned/aligned_resize_lower_0.5.csv",
    target_size=256,
)

align_midline_dataset(
    img_dir=root+"/resized/resize_lower_1.3",
    csv_in=root+"/resized/resized_resize_lower_1.3.csv",
    ref_csv_path=root+"/resized/resized_original_70.csv",
    img_out_dir=root+"/aligned/resize_lower_1.3",
    csv_out=root+"/aligned/aligned_lower_1.3.csv",
    target_size=256,
)


## Part 2
Align central body axis of skeletons after swirl transformation.

Change root

In [None]:
def align_midline_dataset(
    img_dir,
    csv_in,
    ref_csv_path,
    img_out_dir,
    csv_out,
    target_size=256,
    border_value=0,
):
    """
    對一個資料夾內的所有「來源骨架圖 + keypoints」做中軸對齊，
    對齊到同一個 reference keypoints 的中軸線。

    假設：
        - img_dir 裡檔名像 "000123.png"（骨架圖）
        - csv_in: 對應 keypoints (256版)
        - ref_csv_path: standard / original keypoints (256版)
        - src / ref 的 CSV 都有欄位:
            filename, ..., shoulder_mid_x, shoulder_mid_y,
                              pelvis_mid_x, pelvis_mid_y

        - src / ref 透過「不含副檔名的 stem」對應：
            e.g. src 圖 000123.png → stem="000123"
                 ref CSV row 中 filename="000123.jpg" 或 "000123.png"
                 也會有 stem="000123"

    產出：
        - img_out_dir: 對齊後的骨架圖
        - csv_out: 對齊後的 keypoints CSV
    """
    os.makedirs(img_out_dir, exist_ok=True)

    # --------------------------
    # 1) 讀 reference CSV
    # --------------------------
    with open(ref_csv_path, "r", newline="") as f_ref:
        ref_reader = csv.DictReader(f_ref)
        ref_header = ref_reader.fieldnames
        if ref_header is None or "filename" not in ref_header:
            raise ValueError("ref CSV 必須包含 'filename' 欄位")

        # 找出 keypoints 欄位順序 + shoulder/pelvis index
        ref_cols = ref_header[1:]  # 除去 filename
        if len(ref_cols) % 2 != 0:
            raise ValueError("除了 filename 之外，ref CSV 欄位數應為偶數 (x,y 成對)")

        kp_pairs = []
        idx_shoulder = None
        idx_pelvis = None
        for i in range(0, len(ref_cols), 2):
            x_col = ref_cols[i]
            y_col = ref_cols[i + 1]
            kp_pairs.append((x_col, y_col))
            if x_col == "shoulder_mid_x":
                idx_shoulder = len(kp_pairs) - 1
            if x_col == "pelvis_mid_x":
                idx_pelvis = len(kp_pairs) - 1

        if idx_shoulder is None or idx_pelvis is None:
            raise ValueError("在 ref CSV 裡找不到 shoulder_mid_x 或 pelvis_mid_x 欄位")

        # 以 stem 建立 reference row 辭典
        ref_rows_by_stem = {}
        for row in ref_reader:
            fname = row["filename"]
            stem = os.path.splitext(fname)[0]
            ref_rows_by_stem[stem] = row

    # --------------------------
    # 2) 讀 source CSV
    # --------------------------
    with open(csv_in, "r", newline="") as f_src:
        src_reader = csv.DictReader(f_src)
        src_header = src_reader.fieldnames
        if src_header is None or "filename" not in src_header:
            raise ValueError("src CSV 必須包含 'filename' 欄位")

        # 假設 keypoints 欄位順序跟 ref 相同（同一個模板）
        # 如果不一樣，可以再做一次 kp_pairs，但多數情況是一致的
        src_rows_by_stem = {}
        for row in src_reader:
            fname = row["filename"]
            stem = os.path.splitext(fname)[0]
            src_rows_by_stem[stem] = row

    # --------------------------
    # 3) 準備輸出 CSV
    # --------------------------
    with open(csv_out, "w", newline="") as f_out:
        writer = csv.DictWriter(f_out, fieldnames=src_header)
        writer.writeheader()

        # --------------------------
        # 4) 處理每張 source 圖片
        # --------------------------
        exts = [".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".tif"]
        img_files = [
            name for name in os.listdir(img_dir)
            if any(name.lower().endswith(ext) for ext in exts)
        ]

        print(f"中軸對齊: 從 {img_dir} 讀取 {len(img_files)} 張圖片")

        for img_name in tqdm(img_files, desc="align midline"):
            img_path = os.path.join(img_dir, img_name)
            img = cv2.imread(img_path)
            if img is None:
                print(f"警告：無法讀取圖片 {img_name}")
                continue

            h, w = img.shape[:2]
            if (h != target_size) or (w != target_size):
                print(f"警告：圖片 {img_name} 不是 {target_size}x{target_size}，仍然照原尺寸對齊")

            stem, ext = os.path.splitext(img_name)

            src_row = src_rows_by_stem.get(stem)
            ref_row = ref_rows_by_stem.get(stem)
            if src_row is None or ref_row is None:
                print(f"警告：找不到 stem={stem} 的 src/ref keypoints，略過 {img_name}")
                continue

            # 將 src_row 轉成 keypoints array
            src_kps_list = []
            for x_col, y_col in kp_pairs:
                x_val = float(src_row[x_col])
                y_val = float(src_row[y_col])
                src_kps_list.append([x_val, y_val])
            src_kps = np.array(src_kps_list, dtype=np.float32)

            # 同樣把 ref_row 轉成 keypoints array
            ref_kps_list = []
            for x_col, y_col in kp_pairs:
                x_val = float(ref_row[x_col])
                y_val = float(ref_row[y_col])
                ref_kps_list.append([x_val, y_val])
            ref_kps = np.array(ref_kps_list, dtype=np.float32)

            # 取得中軸端點
            P_src = src_kps[idx_pelvis]    # pelvis_mid
            S_src = src_kps[idx_shoulder] # shoulder_mid
            P_ref = ref_kps[idx_pelvis]
            S_ref = ref_kps[idx_shoulder]

            # 做中軸對齊
            aligned_img, aligned_kps = align_midline_single(
                img,
                src_kps,
                P_src,
                S_src,
                [128.0, 135.0], ## for 6: [128.0, 135.0], for 4: [128.0, 130.0], for 2: [128.0, 110.0]
                [128.0, 70.0], ## for 6: [128.0, 70.0], for 4: [128.0, 80.0], for 2: [128.0, 70.0]
                output_size=(w, h),
                border_value=border_value,
            )


            # 存圖片
            out_img_path = os.path.join(img_out_dir, img_name)
            cv2.imwrite(out_img_path, aligned_img)

            # 存對齊後 keypoints 行
            new_row = dict(src_row)
            # filename 欄位要不要改？看你需求
            # 這裡保留原本 filename，不動
            # new_row["filename"] = src_row["filename"]

            for idx, (x_col, y_col) in enumerate(kp_pairs):
                x_new, y_new = aligned_kps[idx]
                new_row[x_col] = float(x_new)
                new_row[y_col] = float(y_new)

            writer.writerow(new_row)

    print(f"完成中軸對齊：圖片輸出到 {img_out_dir}, CSV 輸出到 {csv_out}")


In [None]:
root = ""

align_midline_dataset(
    img_dir=root+"/resized/swirl_15",
    csv_in=root+"/resized/resized_swirl_15.csv",
    ref_csv_path=root+"/resized/resized_original_70.csv",
    img_out_dir=root+"/aligned/swirl_15",
    csv_out=root+"/aligned/aligned_swirl_15.csv",
    target_size=256,
)

align_midline_dataset(
    img_dir=root+"/resized/swirl_30",
    csv_in=root+"/resized/resized_swirl_30.csv",
    ref_csv_path=root+"/resized/resized_original_70.csv",
    img_out_dir=root+"/aligned/swirl_30",
    csv_out=root+"/aligned/aligned_swirl_30.csv",
    target_size=256,
)

align_midline_dataset(
    img_dir=root+"/resized/swirl_45",
    csv_in=root+"/resized/resized_swirl_45.csv",
    ref_csv_path=root+"/resized/resized_original_70.csv",
    img_out_dir=root+"/aligned/swirl_45",
    csv_out=root+"/aligned/aligned_swirl_45.csv",
    target_size=256,
)

align_midline_dataset(
    img_dir=root+"/resized/swirl_60",
    csv_in=root+"/resized/resized_swirl_60.csv",
    ref_csv_path=root+"/resized/resized_original_70.csv",
    img_out_dir=root+"/aligned/swirl_60",
    csv_out=root+"/aligned/aligned_swirl_60.csv",
    target_size=256,
)