In [2]:
from SPEN.pose import get_pos_encoder, get_pos_decoder
from SPEN.pose import get_ori_encoder, get_ori_decoder
import numpy as np
import torch
from torch import nn
import json
from pathlib import Path

In [5]:
label_path = Path("../datasets/speed/train.json")
with open(label_path, "r") as f:
    labels = json.load(f)
labels["img000001.jpg"]

{'ori': [-0.419541, -0.484436, -0.214179, 0.73718],
 'pos': [-0.21081, -0.094466, 6.705986],
 'bbox': [539, 222, 1036, 700]}

In [40]:
# 定义检测函数
def check_encoder_decoder(label, key, encoder, decoder):
    pos = np.array(label[key])
    pos_encoded_dict = encoder.encode(pos)
    pos_encoded_dict_tensot = {k: torch.tensor(v).unsqueeze(0) for k, v in pos_encoded_dict.items()}
    pos_decoded_tensor = decoder.decode_batch(pos_encoded_dict_tensot)
    pos_decoded = pos_decoded_tensor.squeeze(0).numpy()
    if pos[0] * pos_decoded[0] < 0:
        pos_decoded = -pos_decoded
    return np.allclose(pos, pos_decoded, atol=1e-6)

In [41]:
# 定义循环检测函数
def check_loop(labels, key, encoder, decoder):
    for image_name, label in labels.items():
        right = check_encoder_decoder(label, key, encoder, decoder)
        if not right:
            print(f"Error in {image_name}")

# 位置编码

## 笛卡尔坐标系位置编码

In [33]:
cart_encoder = get_pos_encoder("Cart")
cart_decoder = get_pos_decoder("Cart")
print(check_encoder_decoder(labels["img000001.jpg"], "pos", cart_encoder, cart_decoder))
check_loop(labels, "pos", cart_encoder, cart_decoder)

True


## 球坐标系位置编码

In [34]:
spher_encoder = get_pos_encoder("Spher", r_max=50)
spher_decoder = get_pos_decoder("Spher", r_max=50)
print(check_encoder_decoder(labels["img000001.jpg"], "pos", spher_encoder, spher_decoder))
check_loop(labels, "pos", spher_encoder, spher_decoder)

True


## 离散球坐标系位置编码

In [36]:
discrete_spher_encoder = get_pos_encoder("DiscreteSpher", angle_stride=1, r_stride=1, r_max=50, alpha=0.1, neighbor=5)
discrete_spher_decoder = get_pos_decoder("DiscreteSpher", angle_stride=1, r_stride=1, r_max=50, alpha=0.1, neighbor=5)
print(check_encoder_decoder(labels["img000150.jpg"], "pos", discrete_spher_encoder, discrete_spher_decoder))
check_loop(labels, "pos", discrete_spher_encoder, discrete_spher_decoder)

True


# 角度编码

## 四元数角度编码

In [37]:
quat_encoder = get_ori_encoder("Quat")
quat_decoder = get_ori_decoder("Quat")
print(check_encoder_decoder(labels["img000001.jpg"], "ori", quat_encoder, quat_decoder))
check_loop(labels, "ori", quat_encoder, quat_decoder)

True


## 欧拉角角度编码

In [42]:
euler_encoder = get_ori_encoder("Euler")
euler_decoder = get_ori_decoder("Euler")
print(check_encoder_decoder(labels["img000001.jpg"], "ori", euler_encoder, euler_decoder))
check_loop(labels, "ori", euler_encoder, euler_decoder)

True


## 离散欧拉角角度编码

In [43]:
discrete_euler_encoder = get_ori_encoder("DiscreteEuler", stride=1, alpha=0.1, neighbor=5)
discrete_euler_decoder = get_ori_decoder("DiscreteEuler", stride=1, alpha=0.1, neighbor=5)
print(check_encoder_decoder(labels["img000001.jpg"], "ori", discrete_euler_encoder, discrete_euler_decoder))
check_loop(labels, "ori", discrete_euler_encoder, discrete_euler_decoder)

True
