# Future Object Detection

This purpose of this notebook is to explore the capabilities of our Future Object Detection model. The model is trained on the [Nuscenes](https://www.nuscenes.org/) or [NuImages](https://www.nuscenes.org/) dataset, which sequences of images with bounding box annotations. The model takes 2 images at times (T-1, T), together with the corresponding ego-motion information, and predicts the 2D bounding boxes at time T+1.

For full information, please see the paper ["Future Object Detection with Spatiotemporal Transformers"](https://arxiv.org/abs/2204.10321).

In [None]:
import argparse
import sys

sys.path.append(".")
sys.path.append("./ConditionalDETR")

import cv2
import gdown
import matplotlib.pyplot as plt
import numpy as np
import torch
from einops import rearrange

from future_od.datasets import nu_scenes
from future_od.models.st_detr import SpatioTemporalDETRArgs
from future_od.utils.recursive_functions import recursive_to
from future_od.utils.visualization import revert_imagenet_normalization, draw_boxes, COLOURS

from runs._loader import get_nuim_loaders
from runs._model import build_model
from config import config

In [None]:
# Maybe download weights
CHECKPOINT_PATH = gdown.cached_download(
    url="https://drive.google.com/file/d/1BkKvCfrJYORvRtPRAr5Uonltc4Nf4IGa",
    path="checkpoints/nuim_spatiotemporal_imu.pth.tar",
    quiet=True
)

In [None]:
DEVICE = "cpu"  # "cuda"
DEBUG = False  # Disable debug mode to run on the full dataset (recommended)

args = argparse.Namespace(
    device=DEVICE,
    distributed=False, 
    debug=DEBUG,
    night=False,
    short_train=True,
    num_workers=0,
)
detr_args = SpatioTemporalDETRArgs(
    pretrained_backbone=False,
    num_classes=len(nu_scenes.CATEGORY_DICT),
    num_queries=128,
    lr_backbone=1e-4,
)
model = build_model(args, detr_args)
checkpoint_dict = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
model.load_state_dict(checkpoint_dict["net"])
model.to(DEVICE)
# Enable attention storage
for layer in model._model.detector.decoder.layers:
    for slot_to_image_attention in layer.image_attend:
        slot_to_image_attention.store_attention = True
pass
train_loader, val_loaders = get_nuim_loaders(
    (896, 1600), offsets=[-2, -1, 0], config=config, args=args, train_batch_size=1
)
val_loader = val_loaders['val0']

In [None]:
def process_sample(sample_idx, train=True):
    data = train_loader.dataset[sample_idx] if train else val_loader.dataset[sample_idx]
    for key in data:
        if isinstance(data[key], torch.Tensor):
            data[key] = data[key].unsqueeze(0)
        else:
            data[key] = [data[key]]
    data = recursive_to(data, DEVICE)
    with torch.no_grad():
        outputs, state, loss, stats, od_map_stuffs = model(data)
    class_scores, boxes, video = outputs['class_scores'][0,0,0], outputs['boxes'][0,0,0], data['video'][0]
    return outputs, class_scores, boxes, video

## Visualize Future Object Detection(s)

In [None]:
def visualize(image, classes, boxes, labels=None):
    """
    Args:
        video (Tensor): Of size (3, H, W)
        classes (Tensor or LongTensor): Of size (M, C) if Tensor or (M,) if LongTensor. 0 if background
        boxes (Tensor): Of size (M, 4), encoded as (x1, y1, x2, y2)
    """
    _, BACKGROUND_CLASS = classes.size()
    vis = revert_imagenet_normalization(image)
    if boxes is not None:
        if isinstance(classes, (torch.FloatTensor, torch.cuda.FloatTensor)):  # We get logits
            scores, classes = classes.max(dim=1)
            classes[scores < 0.1] = BACKGROUND_CLASS
        boxes = boxes[classes != BACKGROUND_CLASS]
        colours = COLOURS[classes[classes != BACKGROUND_CLASS]]
        vis = draw_boxes(vis, boxes, colours).permute(1, 2, 0).cpu().numpy().copy()
        if labels is not None:
            for label, box in zip(labels, boxes):
                # Draw the idx on top of the box
                cv2.putText(vis, str(label), (int(box[0]), int(box[1])-10), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
    return vis

In [None]:
# Process random sample from the training set
SAMPLE_IDX = np.random.randint(len(train_loader.dataset))
CONFIDENCE_THRESHOLD = 0.3
outputs, class_scores, boxes, video = process_sample(SAMPLE_IDX)

# Threshold bounding boxes
idxs = (class_scores[:,-1] > CONFIDENCE_THRESHOLD).nonzero().squeeze(1)
score, box = class_scores[idxs], boxes[idxs]
labels = idxs.cpu().numpy()

# Plot Future Object Detections
plt.figure(dpi=200, frameon=False)
plt.axis('off')
img = visualize(video[-1], score, box, labels)
plt.imshow(img)
plt.show()

## Visualize attention

In [None]:
def plot_attention(class_scores, boxes, video, obj_idx) -> None:
    plt.figure(frameon=False,figsize=(16*3, 9), dpi=50)
    # Plot past with attention
    for frame_idx in range(video.shape[0]-1):
        plt.subplot(131 + frame_idx)
        att = model._model.detector.decoder.layers[-1].image_attend[1-frame_idx].stored_attention[0]
        img = rearrange(revert_imagenet_normalization(video[frame_idx]), 'c h w -> h w c')
        featuremap_shape = np.array((896, 1600))/32
        att = rearrange(att, 'c (h w) -> c h w', h = int(featuremap_shape[0]), w = int(featuremap_shape[1]))
        att_zoomed = torch.nn.functional.interpolate(att[None,:], scale_factor=32, mode='bilinear')[0].cpu().numpy()
        plt.imshow(img.cpu().numpy(), interpolation='nearest')
        plt.imshow(att_zoomed[obj_idx], alpha=np.clip(att_zoomed[obj_idx]*50, 0, 0.4), interpolation='nearest', cmap='coolwarm')
        plt.axis('off')
    # Plot future with bounding box
    plt.subplot(133)
    score = class_scores[obj_idx:obj_idx+1]
    box = boxes[obj_idx:obj_idx+1]
    img = visualize(video[-1], score, box)
    plt.subplots_adjust(wspace=0.01, hspace=0)
    plt.imshow(img)
    plt.axis('off')
    plt.show()

In [None]:
SAMPLE_IDX = 6
outputs, class_scores, boxes, video = process_sample(SAMPLE_IDX)

In [None]:
OBJ_RANK_IDX = 2
OBJ_IDX_OVERRIDE = None

obj_idxs = torch.topk(class_scores[:,-1], 10).indices.cpu().numpy()
obj_idx = obj_idxs[OBJ_RANK_IDX]
if OBJ_IDX_OVERRIDE is not None:
    obj_idx = OBJ_IDX_OVERRIDE
plot_attention(class_scores, boxes, video, obj_idx)