# Env preparation

In [None]:
import torch
from torch.cuda.amp import autocast
from torch.utils.data import DataLoader

from data.dataset import split_dataset
from utils.vis import *
from utils.tools import load_model
from utils.trainer import _set_seed

%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0)  # set default size of plots

# life save magic code
%load_ext autoreload
%autoreload 2

In [None]:
print(f"torch version: {torch.__version__}")
use_cuda = torch.cuda.is_available()
if use_cuda:
    GPU_nums = torch.cuda.device_count()
    GPU = torch.cuda.get_device_properties(0)
    print(f"There are {GPU_nums} GPUs in total.\nThe first GPU: {GPU}")
    print(f"CUDA version: {torch.version.cuda}")
device = torch.device(f"cuda:0" if use_cuda else "cpu")
print(f"Using {device} now!")

# Load Model & Data

In [None]:
# Fill your run name and log dir!
run_name = None
log_dir = None
model = load_model(run_name, log_dir).to(device)

In [None]:
_set_seed(seed=0, deterministic=True)

train_dataset, val_dataset = split_dataset(
    dataset_root='', # Fill your dataset root!
    train_ratio=0.8,
    route_len=250,
    total_len=250)

loader_kwargs = {
    'batch_size' : 6,
    'num_workers': 4,
    'pin_memory': True,
    'prefetch_factor': 4,
    'persistent_workers': True
}
train_loader = DataLoader(train_dataset, **loader_kwargs)
iter_train_loader = iter(train_loader)
val_loader = DataLoader(val_dataset, **loader_kwargs)
iter_val_loader = iter(val_loader)

# Test & Plot!

In [None]:
frames, gt_routes, map_sizes = next(iter_val_loader)
frames.shape, gt_routes.shape

In [None]:
with torch.no_grad():
    with autocast():
        pred_routes = model.vis_forward((frames, gt_routes))
for idx, (gt, pred) in enumerate(zip(gt_routes.cpu().numpy(), pred_routes.cpu().numpy())):
    draw_routes(routes=(gt, pred))