In [None]:
import torch
import numpy as np
from skimage import io
from PIL import Image, ImageDraw
from torchvision import transforms
from pathlib import Path
from nets.ROAR import ROAR
import pandas as pd
import tqdm
import matplotlib.pyplot as plt

Configuration

In [None]:
csv_file = Path('./data/test_set/labeled_data_test.csv')
image_folder = Path('./data/test_set/images_test/')
network_weights = Path('./experiments/roar/model.pth')
pretrained_image_weights = Path('./weights/VisionNavNet_state_hd.pth.tar')
sequence = 1
device = torch.device('cuda')

In [None]:
network = ROAR(device, True, pretrained_image_weights, 10)
network.to(device)
network.load_state_dict(torch.load(network_weights))

In [None]:
csv = pd.read_csv(csv_file)
csv = csv[csv['sequence'] == sequence].reset_index(drop=True)
print("Number of samples: {}".format(csv.shape[0]))

Functions for reading and visualizing

In [None]:
def get_ego_trajectory(pred_traj, f, cam_height, image_width, image_height):
    predXs = pred_traj[:,0]
    predYs = pred_traj[:,1]
    xs = f * predYs / predXs
    ys = f * cam_height / predXs
    xs = -xs + image_width / 2
    ys = -ys + image_height / 2
    ego_traj = [(x,y) for x,y in zip (xs, ys)]
    return ego_traj


def read_datapoint(datapoint):
    # Constants
    lidar_clip = 1.85
    f = 460
    image_width, image_height = 320, 240
    cam_height = -0.23

    # Helper lambdas
    map_to_float = lambda x: np.array(list(map(float, x)))
    map_to_int = lambda x: np.array(list(map(int, x)))
    
    # Read image
    image_name = image_folder / datapoint['image_name']
    image = io.imread(image_name)
    image = Image.fromarray(image)
    
    # Create data for traj image
    predX = datapoint['pred_traj_x'][1:-1].split(',')
    predY = datapoint['pred_traj_y'][1:-1].split(',')
    predX = abs(map_to_float(predX))
    predY = map_to_float(predY)
    pred_traj = torch.tensor(np.stack([predX, predY])).to(torch.float32).permute(1,0)

    # Get LiDAR data
    lidar_scan = datapoint['lidar_scan'][1:-1].split(',')
    lidar_scan = map_to_float(lidar_scan)
    # Clip LiDAR
    lidar_scan_clipped = np.clip(lidar_scan, a_min=0, a_max=lidar_clip) / lidar_clip
    lidar_scan_clipped = torch.as_tensor(lidar_scan_clipped, dtype=torch.float32)

    # Get label
    label = datapoint['label'][1:-1].split(',')
    label = map_to_int(label)
    label = torch.as_tensor(label, dtype=torch.float32)
    
    # Get occlusion labels and sequence number
    image_occluded = int(datapoint['image_occluded'])
    lidar_occluded = int(datapoint['lidar_occluded'])
    sequence_number = int(datapoint['sequence'])

    # Process the data
    image_tx = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    image_tensor = image_tx(image)

    # Create trajectory data
    ego_traj = get_ego_trajectory(pred_traj, f, cam_height, image_width, image_height)
    traj_image = Image.new(mode="L", size=(image_width, image_height))
    traj_draw = ImageDraw.Draw(traj_image)
    traj_draw.line(ego_traj, fill="white", width=6, joint="curve")
    traj_image = traj_image.crop((0, 112, 320, 240))
    traj_image = transforms.ToTensor()(traj_image)
    
    return {
        'image': image,
        'image_tensor': image_tensor,
        'lidar_scan': lidar_scan,
        'lidar_scan_clipped': lidar_scan_clipped,
        'traj_image_tensor': traj_image,
        'image_occluded': image_occluded,
        'lidar_occluded': lidar_occluded,
        'label': label
    }

def predict(network, sequence, reset_state_each_frame=False):
    # Predict
    predictions = []
    network.eval()
    with torch.no_grad():
        state = None
        for entry in tqdm.tqdm(sequence):
            p = network(entry['image_tensor'].to(device).unsqueeze(0).unsqueeze(0),
                        entry['traj_image_tensor'].to(device).unsqueeze(0).unsqueeze(0),
                        entry['lidar_scan_clipped'].to(device).unsqueeze(0).unsqueeze(0),
                        initial_state=state)
            if reset_state_each_frame:
                state = None
            else:
                state = p['state']
            pred_scores = list(p['pred_inv_score'].flatten(0,1).cpu().numpy()[0])
            pred_images = p['pred_img_score'].flatten(0,1).cpu().numpy()[0][0]
            pred_lidars = p['pred_lidar_score'].flatten(0,1).cpu().numpy()[0][0]
            predictions.append({
                'predicted_label': pred_scores,
                'predicted_lidar': pred_lidars,
                'predicted_camera': pred_images
            })
    return predictions

def visualize(frame_data, frame_predictions, title="", show_gts=False):
    plt.rcParams.update({
        'font.size': 14
    })
    plt.rcParams['font.family'] = 'serif'
    plt.rcParams['font.serif'] = ['Times New Roman'] + plt.rcParams['font.serif']
    fig = plt.figure(figsize=(14,4))
    if title is not None:
        fig.suptitle(title)
    axs = fig.subplots(1,3, gridspec_kw={'width_ratios': [2, 1, 1]})

    # Get predicted trajectory
    traj = torch.cat([torch.zeros(1,240-128,320), frame_data['traj_image_tensor']], dim=1)[0]
    render = transforms.ToPILImage()(transforms.ToTensor()(frame_data['image']) * (1.0 - traj) + torch.stack([torch.zeros_like(traj), torch.zeros_like(traj), traj]))
    # Get lidar
    theta = np.linspace(-0.25*np.pi, 1.25*np.pi, 1081)
    lidar_xs = frame_data['lidar_scan'] * np.cos(theta)
    lidar_ys = frame_data['lidar_scan'] * np.sin(theta)

    axs[0].imshow(render)
    axs[0].axis('off')

    axs[1].plot(lidar_xs, lidar_ys, ls='None', color='white', marker='.', markersize=5)
    axs[1].plot(0, 0, color='tab:blue', marker='^', markersize=15)
    axs[1].axis([-1, 1, -0.75, 2.0])
    axs[1].set_xticks([])
    axs[1].set_yticks([])
    axs[1].set_facecolor('black')

    axs[2].plot(range(1,11), frame_predictions['predicted_label'])
    axs[2].set_xlabel('Timesteps ahead')
    axs[2].set_ylabel('Probability of Failure')
    axs[2].set_ylim(-0.1,1.1)
    axs[2].set_facecolor('white')
    axs[2].plot(range(1,11), [0.5]*10, color='red')
    if show_gts:
        axs[2].scatter(range(1,11), frame_data['label'], color='orange')

Run prediction

In [None]:
sequence = [read_datapoint(csv.iloc[i]) for i in range(csv.shape[0])]
predictions = predict(network=network, sequence=sequence, reset_state_each_frame=False)

Visualize

In [None]:
frames_to_show = list(range(0,len(sequence)))

for idx in frames_to_show:
    visualize(sequence[idx], predictions[idx], show_gts=False)