In [1]:
%cd ..

/home/romet/projects/ut/milrem/waypoint_planner


In [2]:
import cv2
import yaml
from pathlib import Path
import matplotlib.pyplot as plt

from viz.gnm_visualizer import GNMVisualizer
from viz.nomad_visualizer import NomadVisualizer

from data.dataset import MilremVizDataset

from model.util import load_model

from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

from viz.base_visualizer import BaseVisualizer, GREEN, RED
from viz.util_camera import to_camera_frame

In [3]:
model_type = 'vint'

# GNM
if model_type == 'gnm':
    model_config = "config/gnm.yaml"
    model_path = "/home/romet/projects/ut/milrem/waypoint_planner/model_weights/gnm_large.pth"
    viz = GNMVisualizer(False)

# GNM finetuned
if model_type == 'gnm-finetuned':
    model_config = "config/gnm.yaml"
    model_path = "/home/romet/projects/ut/milrem/models/gnm-finetuned.ckpt"
    viz = GNMVisualizer(False)

if model_type == 'vint':
    model_config = "config/vint.yaml"
    model_path = "/home/romet/projects/ut/milrem/models/vint.pth"
    viz = GNMVisualizer(False)

if model_type == 'nomad':
    model_config = "config/nomad.yaml"
    model_path = "/home/romet/projects/ut/milrem/models/nomad.pth"
    viz = NomadVisualizer(False)


In [4]:
with open(model_config, 'r') as file:
    config = yaml.safe_load(file)

model = load_model(model_path, config)
model.eval();



In [5]:
dataset_path = Path("/home/romet/projects/ut/milrem/test-data/2023-07-28-14-08-06")
dataset = MilremVizDataset(dataset_path, **config)

In [7]:
frame_id_slider = widgets.IntSlider(min=0, max=len(dataset), step=1, value=0, continuous_update=True)
goal_id_slider = widgets.IntSlider(min=0, max=180, step=1, value=0, continuous_update=True)
@interact(frame_id=frame_id_slider, goal_id=goal_id_slider)
def draw_predictions(frame_id, goal_id):
    obs_tensor, _, labels, data, obs_img, _ = dataset[frame_id]
    rectified_image = viz.rectify_image(obs_img)

    last_obs_tensor, _, _, last_data, waypoint_img, _ = dataset[frame_id + goal_id]
    waypoint_tensor = last_obs_tensor[-3:]
    data["wp_idx"] = last_data["idx"]

    if model_type == 'nomad':
        predictions = model(obs_tensor.unsqueeze(dim=0), waypoint_tensor.unsqueeze(dim=0), True)
    else:
        predictions = model(obs_tensor.unsqueeze(dim=0), waypoint_tensor.unsqueeze(dim=0))
    
    predicted_actions = predictions[1].squeeze().detach().numpy()
    predicted_dist = predictions[0][0].item()

    print(predicted_actions.shape)
    
    to_camera_frame(rectified_image, labels[0], GREEN)
    if model_type == 'nomad':
        for i in range(len(predicted_actions)):
            to_camera_frame(rectified_image, predicted_actions[i], RED)
    else:
        to_camera_frame(rectified_image, predicted_actions, RED)
    viz.draw_info_overlay(rectified_image, data, predicted_dist)
    viz.draw_top_town_overlay(rectified_image, None, predicted_actions.squeeze(), labels[0])
    viz.draw_waypoint_img(rectified_image, waypoint_img)

    rectified_image = cv2.cvtColor(rectified_image, cv2.COLOR_BGR2RGB)
    plt.figure(figsize=(12, 12))
    plt.imshow(rectified_image);

interactive(children=(IntSlider(value=0, description='frame_id', max=2338), IntSlider(value=0, description='go…