- DOING
    - 使用`./trial`目录下的txt文件模拟流程（可以按GNSS时间戳判断是否缺数据）

In [1]:
import sys
import requests
import numpy as np
import torch
import time
import copy
import importlib
from parse import parse

In [2]:
sys.path.append('./')
from spd_dnn import *
from data_struct import *
from ftr_generator import *
from coord_utils import *

# 算法流程

## 航向推算

### 积分

- 数据投影、IMU零漂计算参考`2023_test_data_inference.ipynb`脚本（赛事方说测试数据与比赛数据的手机姿态是一致的）
- 因为比赛前半段会给相当时长的带有GNSS的数据，所以后续可以实现实时计算投影矩阵与估算IMU零漂

In [3]:
# def infer_bearing_via_integration(sec_gyro_unit_list, crnt_brng, collect_hz=250):
#     gy_ys = [u.sen_y for u in sec_gyro_unit_list]
#     bearing_diff = np.sum(gy_ys) / len(gy_ys) * -1.0
#     return (crnt_brng + bearing_diff) % 360.0


# def infer_bearing_via_model(sec_acc_unit_list, sec_gyro_unit_list, crnt_spd, crnt_brng, 
#                             model=None, ftr_gen=None, collect_hz=250):
#     sec_acc_unit_list = calibrate_sample_freq(sec_acc_unit_list, collect_hz=collect_hz)
#     sec_gyro_unit_lis = calibrate_sample_freq(sec_gyro_unit_list, collect_hz=collect_hz)
#     acc_ftr_list = ftr_gen.calc_imu_sec_ftr(sec_acc_unit_list)
#     gy_ftr_list = ftr_gen.calc_imu_sec_ftr(sec_gyro_unit_lis)

#     acc = torch.tensor([[acc_ftr_list]], dtype=torch.float).permute(0, 2, 1)
#     gy = torch.tensor([[gy_ftr_list]], dtype=torch.float).permute(0, 2, 1)
#     init_spd = torch.tensor([[[crnt_spd]]], dtype=torch.float)
#     init_brng = torch.tensor([[[crnt_brng]]], dtype=torch.float)
#     pred = model(acc, gy, init_spd=init_spd, init_brng=init_brng)
#     brng_diff = pred.cpu().detach().squeeze().numpy().tolist()[1]
    
#     return (crnt_brng + brng_diff) % 360


# def infer_bearing(sec_acc_unit_list, sec_gyro_unit_list, crnt_spd, crnt_brng, 
#                   model=None, ftr_gen=None, collect_hz=250):
#     integration_brng = infer_bearing_via_integration(sec_gyro_unit_list, crnt_brng, collect_hz=collect_hz)
    
#     return integration_brng


# def infer_bearing(sec_gyro_unit_list, crnt_brng, collect_hz=250, drift=0):
#     drift = -0.012397165517337367
#     gy_ys = [(u.sen_y - drift) for u in sec_gyro_unit_list]
#     bearing_diff = np.sum(gy_ys) / len(gy_ys) * -1.0
#     return (crnt_brng + bearing_diff) % 360.0

In [4]:
def project_gyro(obj, X=None, Y=None, Z=None):
    prj_ax = np.dot(np.array([obj.sen_x, obj.sen_y, obj.sen_z]), X)
    prj_ay = np.dot(np.array([obj.sen_x, obj.sen_y, obj.sen_z]), Y)
    prj_az = np.dot(np.array([obj.sen_x, obj.sen_y, obj.sen_z]), Z)

    prj_obj = copy.deepcopy(obj)
    prj_obj.gyr_x = prj_ax
    prj_obj.gyr_y = prj_ay
    prj_obj.gyr_z = prj_az

    return prj_obj


def infer_bearing(sec_gyro_unit_list, crnt_brng, collect_hz=250, drift=0):
    X = [-9.99832515e-01, 3.88289583e-04, -1.82972810e-02] 
    Y = [-0.01773261, -0.2678686, 0.96329226] 
    Z = [-0.00462569, 0.96195208, 0.27317907]
#     drift = -0.00032083975943054683
    drift = -0.011834270481663509
    
    prj_gyro_unit_list = [project_gyro(u, X=X, Y=Y, Z=Z) for u in sec_gyro_unit_list]
    gy_zs = [(u.gyr_z - drift) for u in prj_gyro_unit_list]
    bearing_diff = np.sum(gy_zs) / len(gy_zs) * -1.0
    return (crnt_brng + bearing_diff) % 360.0

## 速度推算（LSTM）

In [5]:
ftr_hz = 25
collect_hz = 250

model = load_spd_dnn(trained_model='./spd_dnn_weight_ftrHz%d.pt' % ftr_hz,
                     ftr_hz=ftr_hz, collect_hz=collect_hz)
model.eval()

ftr_gen = ftrGenerator(ftr_hz=ftr_hz, collect_hz=collect_hz)

Load trained model: ./spd_dnn_weight_ftrHz25.pt


In [6]:
def calibrate_sample_freq(imu_unit_list, collect_hz=250):
    if len(imu_unit_list) < collect_hz:
        imu_unit_list.extend([copy.deepcopy(imu_unit_list[-1])] * (collect_hz - len(imu_unit_list)))
    elif len(imu_unit_list) > collect_hz:
        imu_unit_list = imu_unit_list[-collect_hz:]
    else:
        pass
    return imu_unit_list


def infer_speed(sec_acc_unit_list, sec_gyro_unit_list, crnt_spd, crnt_brng,
                model=None, ftr_gen=None, collect_hz=250):
    sec_acc_unit_list = calibrate_sample_freq(sec_acc_unit_list, collect_hz=collect_hz)
    sec_gyro_unit_lis = calibrate_sample_freq(sec_gyro_unit_list, collect_hz=collect_hz)
    acc_ftr_list = ftr_gen.calc_imu_sec_ftr(sec_acc_unit_list)
    gy_ftr_list = ftr_gen.calc_imu_sec_ftr(sec_gyro_unit_lis)

    acc = torch.tensor([[acc_ftr_list]], dtype=torch.float).permute(0, 2, 1)
    gy = torch.tensor([[gy_ftr_list]], dtype=torch.float).permute(0, 2, 1)
    init_spd = torch.tensor([[[crnt_spd]]], dtype=torch.float)
    init_brng = torch.tensor([[[crnt_brng]]], dtype=torch.float)
    pred = model(acc, gy, init_spd=init_spd, init_brng=init_brng)
    spd_diff = pred.cpu().detach().squeeze().numpy().tolist()[0]

    return max(0.0, min(15.0, crnt_spd + spd_diff))

## 速度推算（OBD）

- **Trick**
    - `*0.80`减小OBD车速变相约束误差发散
    - 不严谨，需更多验证

In [7]:
# def infer_speed_via_obd(obd_unit_list, crnt_spd=None):
#     obd_spds = [u.vehicle_speed for u in obd_unit_list]
#     spd = np.sum(obd_spds) / len(obd_spds) * 1.0
#     return max(-1, min(15, spd * 0.8))


def infer_speed_via_obd(obd_unit_list, crnt_spd=None):
    obd_spds = [u.vehicle_speed for u in obd_unit_list]
    spd = np.sum(obd_spds) / len(obd_spds) * 1.0
    return max(-1, min(15, spd * 0.8))

## 速度推算（积分）

In [8]:
def project_acce(obj, X=None, Y=None, Z=None):
    prj_ax = np.dot(np.array([obj.sen_x, obj.sen_y, obj.sen_z]), X)
    prj_ay = np.dot(np.array([obj.sen_x, obj.sen_y, obj.sen_z]), Y)
    prj_az = np.dot(np.array([obj.sen_x, obj.sen_y, obj.sen_z]), Z)

    prj_obj = copy.deepcopy(obj)
    prj_obj.acc_x = prj_ax
    prj_obj.acc_y = prj_ay
    prj_obj.acc_z = prj_az

    return prj_obj


def infer_speed_diff_via_integration(acce_unit_list, crnt_spd):
    X = [-9.99832515e-01, 3.88289583e-04, -1.82972810e-02] 
    Y = [-0.01773261, -0.2678686, 0.96329226] 
    Z = [-0.00462569, 0.96195208, 0.27317907]
#     drift = 0.06431939105769646
    drift = 0.054283966428103465
    
    prj_acce_unit_list = [project_acce(u, X=X, Y=Y, Z=Z) for u in acce_unit_list]
    prj_ys = [(u.acc_y - drift) for u in prj_acce_unit_list]
    spd_diff = np.sum(prj_ys) / len(prj_ys) * -1.0
    return spd_diff

## 速度推算（加权）

In [9]:
def infer_speed_via_multi(acce_unit_list, obd_unit_list, crnt_spd):
    acce_spd_diff = infer_speed_diff_via_integration(acce_unit_list, crnt_spd)
    obd_spd = infer_speed_via_obd(obd_unit_list)
    
    if obd_spd < 1e-4 and acce_spd_diff < 0:
        spd = acce_spd_diff
    else:
        spd = obd_spd
    return spd

# 服务端配置及请求

## ATTETION

- 比赛数据无法reload，且需要我们严格按照间隔**一秒**请求
- 可以在本地**yaml**配置文件中将**reloadable**字段修改为**False**，模拟比赛环境

## 配置

- 正式比赛配置

In [25]:
# server = 'http://121.196.218.26/evaalapi/'
# trialname = 'S623b14b068102'

# estfmt = "{pts:.3f},{c:.3f},{h:.3f},{s:.3f},{pos}"
# statefmt = "{trialts:.3f},{rem:.3f},{V:.3f},{S:.3f},{p:.3f},{h:.3f},{pts:.3f},{pos}"

- 使用官方服务器测试

In [11]:
# server = 'http://121.196.218.26/evaalapi/'
# trialname = 'T623412f0da401'

# estfmt = "{pts:.3f},{c:.3f},{h:.3f},{s:.3f},{pos}"
# statefmt = "{trialts:.3f},{rem:.3f},{V:.3f},{S:.3f},{p:.3f},{h:.3f},{pts:.3f},{pos}"

- 使用本地服务器测试

In [21]:
server = 'http://127.0.0.1:5000/evaalapi/'
trialname = 'demo2023'

estfmt = "{pts:.3f},{c:.3f},{h:.3f},{s:.3f},{pos}"
statefmt = "{trialts:.3f},{rem:.3f},{V:.3f},{S:.3f},{p:.3f},{h:.3f},{pts:.3f},{pos}"

In [12]:
def do_state():
    r = requests.get(server + trialname + '/state')
    print(r.text)
    s = parse(statefmt, r.text)
    return s.named
    
    
def do_reload():
    r = requests.get(server + trialname + '/reload')
    s = parse(statefmt, r.text)
    print(s.named)

In [13]:
def parse_response(text):
    sen_str_list = text.splitlines()
    acc_unit_list, gy_unit_list = [], []
    obd_unit_list = []
    gnss_unit = None
    for sen_str in sen_str_list:
        if sen_str.startswith('ACCE'):
            acc_unit_list.append(acceUnit(sen_str))
        elif sen_str.startswith('GYRO'):
            gy_unit_list.append(gyroUnit(sen_str))
        elif sen_str.startswith('GNSS'):
            gnss_unit = gnssUnit(sen_str)
        elif sen_str.startswith('OBD'):
            obd_unit_list.append(obdUnit(sen_str))
        else:
            continue
    return acc_unit_list, gy_unit_list, obd_unit_list, gnss_unit


def estimate_next_loc(acc_unit_list, gy_unit_list, obd_unit_list, gnss_unit, 
                      model=None, ftr_gen=None, ftr_hz=50, collect_hz=250):
    crnt_brng = gnss_unit.bearing
    crnt_spd = gnss_unit.loc_speed
    if crnt_spd < 0:
        crnt_spd = gnss_unit.speed
    
    next_brng = infer_bearing(gy_unit_list, crnt_brng, 
                              collect_hz=collect_hz)
    
    if len(obd_unit_list) > 10:
#         next_spd = infer_speed_via_obd(obd_unit_list)
        next_spd = infer_speed_via_multi(acc_unit_list, obd_unit_list, crnt_spd)
    else:
        next_spd = infer_speed(acc_unit_list, gy_unit_list, crnt_spd, crnt_brng, 
                               model=model, ftr_gen=ftr_gen, collect_hz=collect_hz)

    next_lng, next_lat = geo_util.position_from_angle_and_dist(gnss_unit.lng, gnss_unit.lat, 
                                                               next_brng, next_spd)
    return next_lng, next_lat, next_spd, next_brng


def estimate_next_loc_by_pred(acc_unit_list, gy_unit_list, obd_unit_list, pred_loc, 
                              model=None, ftr_gen=None, ftr_hz=25, collect_hz=250):
    crnt_lng, crnt_lat, crnt_spd, crnt_brng = pred_loc
    
    next_brng = infer_bearing(gy_unit_list, crnt_brng, 
                              collect_hz=collect_hz)
    
    if len(obd_unit_list) > 10:
#         next_spd = infer_speed_via_obd(obd_unit_list)
        next_spd = infer_speed_via_multi(acc_unit_list, obd_unit_list, crnt_spd)
    else:
        next_spd = infer_speed(acc_unit_list, gy_unit_list, crnt_spd, crnt_brng, 
                               model=model, ftr_gen=ftr_gen, collect_hz=collect_hz)
    
    next_lng, next_lat = geo_util.position_from_angle_and_dist(crnt_lng, crnt_lat, 
                                                               next_brng, next_spd)
    return next_lng, next_lat, next_spd, next_brng
    
    
def estimate_next_loc_by_dft(crnt_lng, crnt_lat, crnt_spd, crnt_brng):
    next_brng, next_spd = crnt_brng, crnt_spd
    next_lng, next_lat = geo_util.position_from_angle_and_dist(crnt_lng, crnt_lat, 
                                                               next_brng, next_spd)
    return next_lng, next_lat, next_spd, next_brng
    

In [14]:
# do_reload()

In [15]:
# s = do_state()
# print(s)

In [16]:
# r = requests.get(server + trialname + '/estimates')
# print(r.text)

In [17]:
# r = requests.get(server + trialname + '/nextdata?horizon=1.0&position=0,0,0')
# print(r.text)

# r = requests.get(server + trialname + '/log')
# print(r.text)

In [18]:
# 注意time.sleep

In [None]:
init_pos = None  # 车辆初始位置
altitude = 0.0  # 高程
pre_gnss = None  # 上一时刻GNSS
pred_loc = None
next_position = None
gt_gnss_unit_list = []

sec_cnt = 0
request_arr = []

# do_reload()  # 正式比赛需要注释
while True:
    s = do_state()
    if s['trialts'] == 0.0:
        print('testing start...')
        try:
            lat, lng, altitude = [float(x) for x in s['pos'].split(',')]
        except ValueError as e:
            lat, lng, altitude = [float(x) for x in s['pos'].split(';')]
        init_pos = [lng, lat]
    elif s['trialts'] < 0:
        print('testing timeout...')
        break
    else:
        pass
    
    if next_position is None:
        r = requests.get(server + trialname + '/nextdata?horizon=1.0')
    else:
        r = requests.get(server + trialname + \
                         '/nextdata?horizon=1.0&position=%s,%.1f' % (next_position, altitude))
    
    acc_unit_list, gy_unit_list, obd_unit_list, gnss_unit = parse_response(r.text)
    
#     request_arr.append(r.text.splitlines())
    if gnss_unit is None:
        print(len(acc_unit_list))
        print(len(gy_unit_list))
        
    if gnss_unit is not None:
        gt_gnss_unit_list.append(gnss_unit)
        
    if len(acc_unit_list) < 10 and gnss_unit is None:
        if pre_gnss is not None:
            lng, lat, spd, brng = pre_gnss.lng, pre_gnss.lat, pre_gnss.loc_speed, pre_gnss.bearing
            if spd < 0:
                spd = pre_gnss.speed
            next_lng, next_lat, \
            next_spd, next_brng = estimate_next_loc_by_dft(lng, lat, spd, brng)
            next_position = '%.6f,%.6f' % (next_lng, next_lat)
            pred_loc = [next_lng, next_lat, next_spd, next_brng]
        elif pred_loc is not None:
            next_lng, next_lat, \
            next_spd, next_brng = estimate_next_loc_by_dft(pred_loc[0], pred_loc[1], 
                                                           pred_loc[2], pred_loc[3])
            next_position = '%.6f,%.6f' % (next_lng, next_lat)
            pred_loc = [next_lng, next_lat, next_spd, next_brng]
        else:
            pass
        time.sleep(1.001)
        sec_cnt += 1
        continue
    
    if gnss_unit is not None:
        gnss_unit.pre_gnss = pre_gnss
        pre_gnss = gnss_unit
        next_position = '%.6f,%.6f' % (gnss_unit.lng, gnss_unit.lat)
    elif pre_gnss is not None:
        next_lng, next_lat, \
        next_spd, next_brng = estimate_next_loc(acc_unit_list, gy_unit_list, obd_unit_list, pre_gnss, 
                                                model=model, ftr_gen=ftr_gen, 
                                                ftr_hz=ftr_hz, collect_hz=collect_hz)
        next_position = '%.6f,%.6f' % (next_lng, next_lat)
        pred_loc = [next_lng, next_lat, next_spd, next_brng]
        pre_gnss = None
    elif pred_loc is not None:
        next_lng, next_lat, \
        next_spd, next_brng = estimate_next_loc_by_pred(acc_unit_list, gy_unit_list, obd_unit_list, pred_loc, 
                                                        model=model, ftr_gen=ftr_gen, 
                                                        ftr_hz=ftr_hz, collect_hz=collect_hz)
        next_position = '%.6f,%.6f' % (next_lng, next_lat)
        pred_loc = [next_lng, next_lat, next_spd, next_brng]
    else:
        pass
        
    time.sleep(1.001)  # 正式比赛需要
    sec_cnt += 1

In [27]:
est_fd = open('./EvaluationRun02.csv', 'w')

cnt = 0
init_gps_tow = gt_gnss_unit_list[0].gps_tow

r = requests.get(server + trialname + '/estimates')
for i, line in enumerate(r.text.splitlines()):
    if i in {0, 1}:
        continue
    s = parse(estfmt, line)
    ts = float(s.named['pts'])
    lng, lat, altitude = [float(x) for x in s.named['pos'].split(',')]
    
    if cnt == 0:
        ts = init_gps_tow
    else:
        ts = init_gps_tow + cnt
    est_fd.write('%.3f,%.6f,%.6f,%.1f\n' % (ts, lat, lng, altitude))
    cnt += 1
    
est_fd.close()

In [28]:
r = requests.get(server + trialname + '/estimates')
print(len(r.text.splitlines()))

5055


In [29]:
print(parse(estfmt, r.text.splitlines()[-1]).named)

{'pts': 5074.885, 'c': 1694612423.31, 'h': 1.0, 's': 15.0, 'pos': '116.390356,40.006868,0.0'}


In [30]:
print(gt_gnss_unit_list[0].gps_tow)
print(gt_gnss_unit_list[0])

280758.0
21.931,116.390449,40.006868,182.000,0.000


## 评估

In [31]:
ts_estimate_loc_map = {}

r = requests.get(server + trialname + '/estimates')
for i, line in enumerate(r.text.splitlines()):
    if i in {0, 1}:
        continue
    s = parse(estfmt, line)
    ts = int(float(s.named['pts']))
    loc_str = s.named['pos']
    lng, lat = float(loc_str.split(',')[0]), float(loc_str.split(',')[1])
    lng, lat = geo_util.wgs84_to_gcj02(lng, lat)
    ts_estimate_loc_map[ts] = [lng, lat]
    
ts_list = sorted(ts_estimate_loc_map.keys(), reverse=False)
with open('offline_estimate2.csv', 'w') as fd:
    fd.write('ts,lng,lat\n')
    for i, ts in enumerate(ts_list):
        lng, lat = ts_estimate_loc_map[ts]
        if i > 0:
            ts = ts - ts_list[0]
        fd.write('%d,%.6f,%.6f\n' % (ts, lng, lat))

In [32]:
ts_gnss_loc_map = {}

for i, gnss in enumerate(gt_gnss_unit_list):
    ts = int(round(gnss.app_ts, 0))
    lng, lat = gnss.lng_gcj02, gnss.lat_gcj02
    ts_gnss_loc_map[ts] = [lng, lat]
    
ts_list = sorted(ts_gnss_loc_map.keys(), reverse=False)
with open('offline_gnss2.csv', 'w') as fd:
    fd.write('ts,lng,lat\n')
    for i, ts in enumerate(ts_list):
        lng, lat = ts_gnss_loc_map[ts]
        if i > 0:
            ts = ts - ts_list[0]
        fd.write('%d,%.6f,%.6f\n' % (ts, lng, lat))

In [33]:
len(ts_gnss_loc_map)

2083

In [34]:
dist_err_list = []

for ts, (lng, lat) in ts_estimate_loc_map.items():
    if ts not in ts_gnss_loc_map and ts + 1 in ts_gnss_loc_map:
        dist_err = geo_util.distance(lng, lat, 
                                     ts_gnss_loc_map[ts + 1][0], ts_gnss_loc_map[ts + 1][1])
        dist_err_list.append(dist_err)

In [35]:
print(np.percentile(dist_err_list, 25))
print(np.percentile(dist_err_list, 50))
print(np.percentile(dist_err_list, 75))

5.402017
25.768239
42.24721025


In [36]:
dist_err_list

[0.0, 12.717748, 55.174156, 2.96344, 38.81873, 43.390037]

In [34]:
cnt = 0
init_gps_tow = gt_gnss_unit_list[0].gps_tow

tow_estimate_locs = []

r = requests.get(server + trialname + '/estimates')
for i, line in enumerate(r.text.splitlines()):
    if i in {0, 1}:
        continue
    s = parse(estfmt, line)
    ts = float(s.named['pts'])
    lng, lat, altitude = [float(x) for x in s.named['pos'].split(',')]
    
    if cnt == 0:
        ts = init_gps_tow
    else:
        ts = init_gps_tow + cnt
    cnt += 1
    tow_estimate_locs.append([lng, lat, ts])
    
gt_gps_tows = set([u.gps_tow for u in gt_gnss_unit_list])

In [35]:
posi_unit_list = []
posi_unit_map = {}

with open('./trials2023/POSI_testingData.txt', 'r') as fd:
    for line in fd:
        if line.startswith('POSI'):
            unit = posiUnit(line.strip())
            posi_unit_list.append(unit)
            posi_unit_map[unit.gps_tow] = unit
            
print(len(posi_unit_list))

5494


In [36]:
errs = []
for posi_unit, (lng, lat, ts) in zip(posi_unit_list, tow_estimate_locs):
    if ts in gt_gps_tows:
        continue
    err = geo_util.distance(posi_unit.lng, posi_unit.lat, lng, lat)
    errs.append(err)
    
print(len(errs))
print(np.percentile(errs, 25))
print(np.percentile(errs, 50))
print(np.percentile(errs, 75))

2526
27.6935805
54.8975455
73.80828825


In [37]:
errs = []
for posi_unit, (lng, lat, ts) in zip(posi_unit_list, tow_estimate_locs):
    err = geo_util.distance(posi_unit.lng, posi_unit.lat, lng, lat)
    errs.append(err)
    
print(len(errs))
print(np.percentile(errs, 25))
print(np.percentile(errs, 50))
print(np.percentile(errs, 75))

4847
2.9857665
8.735258
57.664604


In [37]:
print(len(errs))
print(np.percentile(errs, 25))
print(np.percentile(errs, 50))
print(np.percentile(errs, 75))

2526
27.6935805
54.8975455
73.80828825
