In [None]:
import argparse
import math
import os
import pickle
import warnings

import data_generator
from util.MT3DataConvertor import MT3DataConvertor
import PerformanceEval
import matplotlib.pyplot as plt
import numpy as np
import torch
from util.load_config_files import load_yaml_into_dotdict
from util.misc import super_load
import scipy.io as scio
import PerformanceEval

os.environ['CUDA_VISIBLE_DEVICES']='1'

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument('-rp', '--result_filepath', default='/home/weixinwei/data/MT3/Results/2023-05-03_022739')
parser.add_argument('-tp', '--task_params', default='/home/weixinwei/data/MT3/Results/2023-05-03_022739/code_used/task_params.yaml')
parser.add_argument('-mp', '--model_params', default='/home/weixinwei/data/MT3/Results/2023-05-03_022739/code_used/model_params.yaml')
args = parser.parse_known_args()[0]
print(f'Evaluating results from folder: {args.result_filepath}...')

model, params = super_load(args.result_filepath, verbose=True)

# Test that the model was trained in the task chosen for evaluation
if args.task_params is not None:
    task_params = load_yaml_into_dotdict(args.task_params)
    for k, v in task_params.data_generation.items():
        if k not in params.data_generation:
            warnings.warn(f"Key '{k}' not found in trained model's hyperparameters")
        elif params.data_generation[k] != v:
            warnings.warn(f"Different values for key '{k}'. Task: {v}\tTrained: {params.data_generation[k]}")
    # Use task params, not the ones from the trained model
    params.recursive_update(task_params)  # note: parameters specified only on trained model will remain untouched
else:
    warnings.warn('Evaluation task was not specified; inferring it from the task specified in the results folder.')

eval_params = load_yaml_into_dotdict('/home/weixinwei/study/MT3-test/configs/eval/default.yaml')
params.recursive_update(eval_params)

GetSeqBatch = data_generator.GetSeqBatch(params)

In [None]:
def SeqPredPlot(dataGenerator, model, timeStep, existanceThreshold = 0.9):
	'''
	序列化预测结果绘制
	'''
	plt.figure(figsize=(6, 6), dpi=300)
	plt.rcParams['font.sans-serif'] = ['SimHei']
	plt.rcParams['axes.unicode_minus'] = False
	xLim = []; yLim = []
	for stepID in range(timeStep):
		batch, labels, unique_ids = next(dataGenerator)
		output, _, _, _, _ = model.forward(batch, unique_ids)
		output_state = output['state'].detach()
		output_logits = output['logits'].sigmoid().detach()
		bs, num_queries = output_state.shape[:2]
		for batchID in range(bs):
			if stepID == 0:
				falseMeas = batch.tensors[batchID][(unique_ids[batchID] == -1)].cpu()
				if falseMeas.shape[0] != 0:
					plt.scatter(falseMeas.T[0][:-1], falseMeas.T[1][:-1], color='k', marker='+', alpha=falseMeas.T[2][:-1]/falseMeas.T[2].max()/1.2)
					plt.scatter(falseMeas.T[0][-1], falseMeas.T[1][-1], color='k', marker='+', alpha=1/1.2, label='Clutter')
				# end if
			# end if
			alive_idx = output_logits[batchID, :].squeeze(-1) > existanceThreshold
			alive_output = output_state[batchID, alive_idx, :].cpu()
			current_targets = labels[batchID].cpu()
			pointAlpha = 1 - stepID / timeStep
			if stepID == 0 and batchID == 0:
				plt.scatter(current_targets.T[0], current_targets.T[1], color='b', marker='+', alpha=pointAlpha/2, label = '目标真实位置')
				plt.scatter(alive_output.T[0], alive_output.T[1], color='r', marker='x', alpha=pointAlpha/1.2, label = '算法预测位置')
			else:
				plt.scatter(current_targets.T[0], current_targets.T[1], color='b', marker='+', alpha=pointAlpha/2)
				plt.scatter(alive_output.T[0], alive_output.T[1], color='r', marker='x', alpha=pointAlpha/1.2)
			# end if
			if stepID == 0 or stepID == timeStep - 1:
				xLim.append(current_targets.T[0])
				yLim.append(current_targets.T[1])
			# end if
		# end for
	# end for
	plt.xlabel('X / km')
	plt.ylabel('Y / km')
	plt.legend(loc = 1)
	plt.grid(True, linestyle="--", color="k", linewidth=0.5, alpha=0.3)
	plt.show()

In [None]:
SeqPredPlot(GetSeqBatch, model, 19, 0.4)