Import libraries

In [11]:
import os
import time

import torch
import numpy as np

import torch.optim as optim
from torch_geometric.loader import DataLoader
import torch.nn.functional as F

from utils import config
from utils.dataset import GraphDataset, GraphData, get_sequential_edge_index
from vectornet.vectornet import VectornetGNN

import matplotlib.pyplot as plt
from utils.data_utils import draw_trajectory, create_ego_raster, create_agents_raster, create_map_raster, decoding_features, increment_to_trajectories

Use training set to test the output, as well as the plot function

In [12]:
# Set seed
np.random.seed(config.SEED)
torch.manual_seed(config.SEED)

# Get validation set
validate_data = GraphDataset(config.TRAIN_PATH)

# Load validation data
validate_loader = DataLoader(
    validate_data,
    batch_size=1
)

# Create predictor
device = device = torch.device(
    'cuda:0' if torch.cuda.is_available() else 'cpu'
)

model = VectornetGNN(
    in_channels=config.IN_CHANNELS,
    out_channels=config.OUT_CHANNELS,
).to(device)

model.load_state_dict(
    torch.load(config.WEIGHT_PATH + '/model_epoch_100.pth')
)


<All keys matched successfully>

In [13]:
# 确保结果目录存在
if not os.path.exists('./result/'):
    os.makedirs('./result/')

model.eval()
with torch.no_grad():
    for i, data in enumerate(validate_loader):
        data = data.to(device)
        observations = data.x
        model_out = model(data)
        
        ego_past, agent_past_list, \
        lane_list, crosswalk_list, \
        route_lane_list, \
        agent_current_pose_list = decoding_features(observations)        
        predicted_path_list = increment_to_trajectories(model_out.cpu(), agent_current_pose_list)
        
        # 绘制轨迹和地图
        draw_trajectory(ego_past, agent_past_list)
        create_map_raster(lane_list, crosswalk_list, route_lane_list)
        draw_trajectory(predicted_path_list[0, :, :], predicted_path_list[1:, :, :], alpha=0.5, linewidth=2)
        plt.axis('equal')