- DOING
    - 使用`./trial`目录下的txt文件模拟流程（可以按GNSS时间戳判断是否缺数据）
    - 使用提供的**trialname**测试
        - `T62288821d9d01`
        - `T62288821d9d02`

In [1]:
import requests
import numpy as np
import torch
import time
import copy
from parse import parse

from spd_dnn import *
from data_struct import *
from ftr_generator import *
from coord_utils import *


# 算法流程

## 航向推算

In [2]:
def infer_bearing(sec_gyro_unit_list, crnt_brng, collect_hz=250):
    gy_ys = [u.sen_y for u in sec_gyro_unit_list]
    bearing_diff = np.sum(gy_ys) / collect_hz * -1.0
    return (crnt_brng + bearing_diff) % 360.0

## 速度推算

In [3]:
ftr_hz = 25
collect_hz = 250

model = load_spd_dnn(trained_model='./spd_dnn_weight_ftrHz%d.pt' % ftr_hz,
                     ftr_hz=ftr_hz, collect_hz=collect_hz)
model.eval()

ftr_gen = ftrGenerator(ftr_hz=ftr_hz, collect_hz=collect_hz)

Load trained model: ./spd_dnn_weight_ftrHz25.pt


In [4]:
def calibrate_sample_freq(imu_unit_list, collect_hz=250):
    if len(imu_unit_list) < collect_hz:
        imu_unit_list.extend([copy.deepcopy(imu_unit_list[-1])] * (collect_hz - len(imu_unit_list)))
    elif len(imu_unit_list) > collect_hz:
        imu_unit_list = imu_unit_list[-collect_hz:]
    else:
        pass
    return imu_unit_list


def infer_speed(sec_acc_unit_list, sec_gyro_unit_list, crnt_spd,
                model=None, ftr_gen=None, collect_hz=250):
    sec_acc_unit_list = calibrate_sample_freq(sec_acc_unit_list, collect_hz=collect_hz)
    sec_gyro_unit_lis = calibrate_sample_freq(sec_gyro_unit_list, collect_hz=collect_hz)
    acc_ftr_list = ftr_gen.calc_imu_sec_ftr(sec_acc_unit_list)
    gy_ftr_list = ftr_gen.calc_imu_sec_ftr(sec_gyro_unit_lis)

    acc = torch.tensor([[acc_ftr_list]], dtype=torch.float).permute(0, 2, 1)
    gy = torch.tensor([[gy_ftr_list]], dtype=torch.float).permute(0, 2, 1)
    init_spd = torch.tensor([[[crnt_spd]]], dtype=torch.float)
    pred = model(acc, gy, init_spd=init_spd)
    spd_diff = pred.cpu().detach().squeeze().numpy().tolist()

    return max(0.0, min(33.0, crnt_spd + spd_diff))

# 服务端配置及请求

- 正式比赛配置

In [5]:
server = 'https://evaal.aaloa.org/evaalapi/'
trialname = 'S622team708_02'

estfmt = "{pts:.3f},{c:.3f},{h:.3f},{s:.3f},{pos}"
statefmt = "{trialts:.3f},{rem:.3f},{V:.3f},{S:.3f},{p:.3f},{h:.3f},{pts:.3f},{pos}"

- 使用官方服务器测试

In [6]:
# server = 'https://evaal.aaloa.org/evaalapi/'
# trialname = 'T62288821d9d01'

# estfmt = "{pts:.3f},{c:.3f},{h:.3f},{s:.3f},{pos}"
# statefmt = "{trialts:.3f},{rem:.3f},{V:.3f},{S:.3f},{p:.3f},{h:.3f},{pts:.3f},{pos}"

- 使用本地服务器测试

In [7]:
# server = 'http://127.0.0.1:5000/evaalapi/'
# trialname = 'demo2'

# estfmt = "{pts:.3f},{c:.3f},{h:.3f},{s:.3f},{pos}"
# statefmt = "{trialts:.3f},{rem:.3f},{V:.3f},{S:.3f},{p:.3f},{h:.3f},{pts:.3f},{pos}"

In [8]:
def do_state():
    r = requests.get(server + trialname + '/state')
    print(r.text)
    s = parse(statefmt, r.text)
    return s.named
    
    
def do_reload():
    r = requests.get(server + trialname + '/reload')
    s = parse(statefmt, r.text)
    print(s.named)

In [9]:
def parse_response(text):
    sen_str_list = text.splitlines()
    acc_unit_list, gy_unit_list = [], []
    gnss_unit = None
    for sen_str in sen_str_list:
        if sen_str.startswith('ACCE'):
            acc_unit_list.append(acceUnit(sen_str))
        elif sen_str.startswith('GYRO'):
            gy_unit_list.append(gyroUnit(sen_str))
        elif sen_str.startswith('GNSS'):
            gnss_unit = gnssUnit(sen_str)
        else:
            continue
    return acc_unit_list, gy_unit_list, gnss_unit


def estimate_next_loc(acc_unit_list, gy_unit_list, gnss_unit, 
                      model=None, ftr_gen=None, ftr_hz=50, collect_hz=250):
    crnt_brng = gnss_unit.bearing
    crnt_spd = gnss_unit.loc_speed
    if crnt_spd < 0:
        crnt_spd = gnss_unit.speed
    
    next_brng = infer_bearing(gy_unit_list, crnt_brng, collect_hz=collect_hz)
    next_spd = infer_speed(acc_unit_list, gy_unit_list, crnt_spd, 
                           model=model, ftr_gen=ftr_gen, collect_hz=collect_hz)
    
    next_lng, next_lat = geo_util.position_from_angle_and_dist(gnss_unit.lng, gnss_unit.lat, 
                                                               next_brng, next_spd)
    return next_lng, next_lat, next_spd, next_brng


def estimate_next_loc_by_pred(acc_unit_list, gy_unit_list, pred_loc, 
                              model=None, ftr_gen=None, ftr_hz=25, collect_hz=250):
    crnt_lng, crnt_lat, crnt_spd, crnt_brng = pred_loc
    next_brng = infer_bearing(gy_unit_list, crnt_brng, collect_hz=collect_hz)
    next_spd = infer_speed(acc_unit_list, gy_unit_list, crnt_spd, 
                           model=model, ftr_gen=ftr_gen, collect_hz=collect_hz)
    next_lng, next_lat = geo_util.position_from_angle_and_dist(crnt_lng, crnt_lat, 
                                                               next_brng, next_spd)
    return next_lng, next_lat, next_spd, next_brng
    
    
def estimate_next_loc_by_dft(crnt_lng, crnt_lat, crnt_spd, crnt_brng):
    next_brng, next_spd = crnt_brng, crnt_spd
    next_lng, next_lat = geo_util.position_from_angle_and_dist(crnt_lng, crnt_lat, 
                                                               next_brng, next_spd)
    return next_lng, next_lat, next_spd, next_brng
    

In [10]:
# do_reload()

In [11]:
# s = do_state()
# print(s)

In [10]:
# r = requests.get(server + trialname + '/estimates')
# print(r.text)

In [11]:
# r = requests.get(server + trialname + '/nextdata?horizon=1.0&position=0,0,0')
# print(r.text)

# r = requests.get(server + trialname + '/log')
# print(r.text)

In [15]:
init_pos = None  # 车辆初始位置
altitude = 0.0  # 高程
pre_gnss = None  # 上一时刻GNSS
pred_loc = None
next_position = None
gt_gnss_unit_list = []

sec_cnt = 0
request_arr = []

# do_reload()  # 正式比赛需要注释
while True:
    s = do_state()
    if s['trialts'] == 0.0:
        print('testing start...')
        try:
            lat, lng, altitude = [float(x) for x in s['pos'].split(',')]
        except ValueError as e:
            lat, lng, altitude = [float(x) for x in s['pos'].split(';')]
        init_pos = [lng, lat]
    elif s['trialts'] < 0:
        print('testing timeout...')
        break
    else:
        pass
    
    if next_position is None:
        r = requests.get(server + trialname + '/nextdata?horizon=1.0')
    else:
        r = requests.get(server + trialname + \
                         '/nextdata?horizon=1.0&position=%s,%.1f' % (next_position, altitude))
    
    acc_unit_list, gy_unit_list, gnss_unit = parse_response(r.text)
    
#     request_arr.append(r.text.splitlines())
    if gnss_unit is None:
        print(len(acc_unit_list))
        print(len(gy_unit_list))
        
    if gnss_unit is not None:
        gt_gnss_unit_list.append(gnss_unit)
        
    if len(acc_unit_list) < 10 and gnss_unit is None:
        if pre_gnss is not None:
            lng, lat, spd, brng = pre_gnss.lng, pre_gnss.lat, pre_gnss.loc_speed, pre_gnss.bearing
            if spd < 0:
                spd = pre_gnss.speed
            next_lng, next_lat, \
            next_spd, next_brng = estimate_next_loc_by_dft(lng, lat, spd, brng)
            next_position = '%.6f,%.6f' % (next_lat, next_lng)
            pred_loc = [next_lng, next_lat, next_spd, next_brng]
        elif pred_loc is not None:
            next_lng, next_lat, \
            next_spd, next_brng = estimate_next_loc_by_dft(pred_loc[0], pred_loc[1], 
                                                           pred_loc[2], pred_loc[3])
            next_position = '%.6f,%.6f' % (next_lat, next_lng)
            pred_loc = [next_lng, next_lat, next_spd, next_brng]
        else:
            pass
        time.sleep(1.001)
        sec_cnt += 1
        continue
    
    if gnss_unit is not None:
        gnss_unit.pre_gnss = pre_gnss
        pre_gnss = gnss_unit
        next_position = '%.6f,%.6f' % (gnss_unit.lat, gnss_unit.lng)
    elif pre_gnss is not None:
        next_lng, next_lat, \
        next_spd, next_brng = estimate_next_loc(acc_unit_list, gy_unit_list, pre_gnss, 
                                                model=model, ftr_gen=ftr_gen, 
                                                ftr_hz=ftr_hz, collect_hz=collect_hz)
        next_position = '%.6f,%.6f' % (next_lat, next_lng)
        pred_loc = [next_lng, next_lat, next_spd, next_brng]
        pre_gnss = None
    elif pred_loc is not None:
        next_lng, next_lat, \
        next_spd, next_brng = estimate_next_loc_by_pred(acc_unit_list, gy_unit_list, pred_loc, 
                                                        model=model, ftr_gen=ftr_gen, 
                                                        ftr_hz=ftr_hz, collect_hz=collect_hz)
        next_position = '%.6f,%.6f' % (next_lat, next_lng)
        pred_loc = [next_lng, next_lat, next_spd, next_brng]
    else:
        pass
        
    time.sleep(1.001)
    sec_cnt += 1
#     if sec_cnt == 2500:
#         break

-1.000,-338.390,9.000,45.000,1662123281.621,1.000,116.303,40.093923,116.283254,0.0
testing timeout...


In [18]:
est_fd = open('./ScoringResult02.csv', 'w')

r = requests.get(server + trialname + '/estimates')
for i, line in enumerate(r.text.splitlines()):
    if i == 0:
        continue
    s = parse(estfmt, line)
    ts = int(s.named['pts'])
    lat, lng, altitude = [float(x) for x in s.named['pos'].split(',')]
    est_fd.write('%d,%.6f,%.6f,%.1f\n' % (ts, lat, lng, altitude))

In [19]:
r = requests.get(server + trialname + '/estimates')
print(len(r.text.splitlines()))

4979


In [20]:
print(parse(estfmt, r.text.splitlines()[-1]).named)

{'pts': 5001.518, 'c': 1662121141.554, 'h': 1.0, 's': 45.0, 'pos': '40.094209,116.282839,0.0'}


## 评估

In [21]:
ts_estimate_loc_map = {}

r = requests.get(server + trialname + '/estimates')
for i, line in enumerate(r.text.splitlines()):
    if i in {0, 1}:
        continue
    s = parse(estfmt, line)
    ts = int(float(s.named['pts']))
    loc_str = s.named['pos']
    lat, lng = float(loc_str.split(',')[0]), float(loc_str.split(',')[1])
    lng, lat = geo_util.wgs84_to_gcj02(lng, lat)
    ts_estimate_loc_map[ts] = [lng, lat]
    
ts_list = sorted(ts_estimate_loc_map.keys(), reverse=False)
with open('offline_estimate.csv', 'w') as fd:
    fd.write('ts,lng,lat\n')
    for i, ts in enumerate(ts_list):
        lng, lat = ts_estimate_loc_map[ts]
        if i > 0:
            ts = ts - ts_list[0]
        fd.write('%d,%.6f,%.6f\n' % (ts, lng, lat))

In [22]:
ts_gnss_loc_map = {}

for i, gnss in enumerate(gt_gnss_unit_list):
    ts = int(round(gnss.app_ts, 0))
    lng, lat = gnss.lng_gcj02, gnss.lat_gcj02
    ts_gnss_loc_map[ts] = [lng, lat]
    
ts_list = sorted(ts_gnss_loc_map.keys(), reverse=False)
with open('offline_gnss.csv', 'w') as fd:
    fd.write('ts,lng,lat\n')
    for i, ts in enumerate(ts_list):
        lng, lat = ts_gnss_loc_map[ts]
        if i > 0:
            ts = ts - ts_list[0]
        fd.write('%d,%.6f,%.6f\n' % (ts, lng, lat))

In [23]:
len(ts_gnss_loc_map)

3642

In [24]:
dist_err_list = []

for ts, (lng, lat) in ts_estimate_loc_map.items():
    if ts not in ts_gnss_loc_map and ts + 1 in ts_gnss_loc_map:
        dist_err = geo_util.distance(lng, lat, 
                                     ts_gnss_loc_map[ts + 1][0], ts_gnss_loc_map[ts + 1][1])
        dist_err_list.append(dist_err)

In [25]:
print(np.percentile(dist_err_list, 25))
print(np.percentile(dist_err_list, 50))
print(np.percentile(dist_err_list, 75))

9.39333925
14.1145405
29.604348


In [27]:
dist_err_list

[0.783204,
 7.565693,
 11.24752,
 15.515795,
 12.713286,
 9.542579,
 22.544692,
 84.965147,
 12.418362,
 133.85211,
 18.113812,
 119.167889,
 18.11654,
 0.9869,
 8.94562,
 50.783316]