In [3]:
from hydra import core, initialize, compose
from omegaconf import OmegaConf


# CHANGE ME
DATASET_DIR = '/shares/datasets/nuscenes/v1.0-trainval'
LABELS_DIR = '/shares/datasets/cvt_labels_nuscenes'


core.global_hydra.GlobalHydra.instance().clear()        # required for Hydra in notebooks

initialize(config_path='../config')

# Add additional command line overrides
cfg = compose(
    config_name='config',
    overrides=[
        'experiment.save_dir=../logs/',                 # required for Hydra in notebooks
        'data=nuscenes',
        f'data.dataset_dir={DATASET_DIR}',
        f'data.labels_dir={LABELS_DIR}',
        'data.version=v1.0-trainval',
        'loader.batch_size=1',
    ]
)

# resolve config references
OmegaConf.resolve(cfg)

print(list(cfg.keys()))

['experiment', 'loader', 'optimizer', 'scheduler', 'trainer', 'data', 'loss', 'metrics']


The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  initialize(config_path='../config')


In [11]:
import torch
import numpy as np

from cross_view_transformer.common import setup_data_module


# Additional splits can be added to cross_view_transformer/data/splits/nuscenes/
SPLIT = 'val_qualitative_000'
SUBSAMPLE = 10


data = setup_data_module(cfg)

dataset = data.get_split(SPLIT, loader=False)
dataset = torch.utils.data.ConcatDataset(dataset)
dataset = torch.utils.data.Subset(dataset, range(0, len(dataset), SUBSAMPLE))

loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=2)

print(len(dataset))

25


In [5]:
from pathlib import Path

from cross_view_transformer.common import load_backbone


# Download a pretrained model (13 Mb)
VEHICLE_MODEL_URL = 'https://www.cs.utexas.edu/~bzhou/cvt/cvt_nuscenes_vehicles_50k.ckpt'
VEHICLE_CHECKPOINT_PATH = '../logs/cvt_nuscenes_vehicles_50k.ckpt'

ROAD_MODEL_URL = 'https://www.cs.utexas.edu/~bzhou/cvt/cvt_nuscenes_road_75k.ckpt'
ROAD_CHECKPOINT_PATH = '../logs/cvt_nuscenes_road_75k.ckpt'

!mkdir -p $(dirname ${VEHICLE_CHECKPOINT_PATH})
!wget $VEHICLE_MODEL_URL -O $VEHICLE_CHECKPOINT_PATH
!wget $ROAD_MODEL_URL -O $ROAD_CHECKPOINT_PATH


vehicle_network = load_backbone(VEHICLE_CHECKPOINT_PATH)
road_network = load_backbone(ROAD_CHECKPOINT_PATH)

--2023-06-14 15:21:50--  https://www.cs.utexas.edu/~bzhou/cvt/cvt_nuscenes_vehicles_50k.ckpt
Loaded CA certificate '/etc/ssl/certs/ca-certificates.crt'
Resolving www.cs.utexas.edu (www.cs.utexas.edu)... 128.83.120.48
Connecting to www.cs.utexas.edu (www.cs.utexas.edu)|128.83.120.48|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 13407137 (13M)
Saving to: ‘../logs/cvt_nuscenes_vehicles_50k.ckpt’


2023-06-14 15:21:50 (83.6 MB/s) - ‘../logs/cvt_nuscenes_vehicles_50k.ckpt’ saved [13407137/13407137]

--2023-06-14 15:21:50--  https://www.cs.utexas.edu/~bzhou/cvt/cvt_nuscenes_road_75k.ckpt
Loaded CA certificate '/etc/ssl/certs/ca-certificates.crt'
Resolving www.cs.utexas.edu (www.cs.utexas.edu)... 128.83.120.48
Connecting to www.cs.utexas.edu (www.cs.utexas.edu)|128.83.120.48|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 13402145 (13M)
Saving to: ‘../logs/cvt_nuscenes_road_75k.ckpt’


2023-06-14 15:21:51 (83.5 MB/s) - ‘../logs/cvt_nuscene

In [8]:
# load IOU metric
from cross_view_transformer.metrics import IoUMetric

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# get information about metrics from config files
iou_vehicle_config = OmegaConf.load('/home/dxa230028/Documents/cross_view_transformer/config/data/nuscenes_vehicle.yaml')
iou_road_config = OmegaConf.load('/home/dxa230028/Documents/cross_view_transformer/config/data/nuscenes_road.yaml')

iou_vehicle_metric = IoUMetric(label_indices=iou_vehicle_config['data']['label_indices']).to(device)
iou_road_metric = IoUMetric(label_indices=iou_road_config['data']['label_indices']).to(device)



In [9]:
%load_ext autoreload
%autoreload 2

import torch
import time
import imageio
import ipywidgets as widgets

from cross_view_transformer.visualizations.nuscenes_stitch_viz import NuScenesStitchViz


GIF_PATH = './predictions.gif'

# Show more confident predictions, note that if show_images is True, GIF quality with be degraded.
viz = NuScenesStitchViz(vehicle_threshold=0.6, road_threshold=0.6, show_images=False)

vehicle_network.to(device)
vehicle_network.eval()

road_network.to(device)
road_network.eval()

images = list()

iou_vehicle_metric.reset()
iou_road_metric.reset()

total_vehicle_040_iou = 0
total_vehicle_050_iou = 0
total_road_040_iou = 0
total_road_050_iou = 0

with torch.no_grad():
    for batch in loader:
        batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}

        
        # update vehicle iou
        vehicle_pred = vehicle_network(batch)['bev'].to(device)
        iou_vehicle_metric.update(vehicle_pred, batch)
        total_vehicle_040_iou += iou_vehicle_metric.compute()['@0.40']
        total_vehicle_050_iou += iou_vehicle_metric.compute()['@0.50']
        
        # update road iou
        road_pred = road_network(batch)['bev'].to(device)
        iou_road_metric.update(road_pred, batch)
        total_road_040_iou += iou_road_metric.compute()['@0.40']
        total_road_050_iou += iou_road_metric.compute()['@0.50']
        
        

        # visualization = np.vstack(viz(batch, road_pred, vehicle_pred))

        # images.append(visualization)


average_vehicle_040_iou = total_vehicle_040_iou / len(dataset)
average_vehicle_050_iou = total_vehicle_050_iou / len(dataset)
average_road_040_iou = total_road_040_iou / len(dataset)
average_road_050_iou = total_road_050_iou / len(dataset)

print(f"Vehicle IOU Metric Value for threshold 0.40: {average_vehicle_040_iou * 100}")
print(f"Vehicle IOU Metric Value for threshold 0.50: {average_vehicle_050_iou * 100}")
print(f"Road IOU Metric Value for threshold 0.40: {average_road_040_iou * 100}")
print(f"Road IOU Metric Value for threshold 0.50: {average_road_050_iou * 100}")


# Save a gif
# duration = [0.5 for _ in images[:-1]] + [2 for _ in images[-1:]]
# imageio.mimsave(GIF_PATH, images, duration=duration)

# html = f'''
# <div align="center">
# <img src="{GIF_PATH}?modified={time.time()}" width="80%">
# </div>
# '''

# display(widgets.HTML(html))

Vehicle IOU Metric Value for threshold 0.40: 38.937167167663574
Vehicle IOU Metric Value for threshold 0.50: 34.8044193983078
Road IOU Metric Value for threshold 0.40: 82.86128187179565
Road IOU Metric Value for threshold 0.50: 82.68776774406433
