<a href="https://colab.research.google.com/github/RO-AD/waymo-od-motion-pred/blob/main/tutorial/2_waymo_official_tutorial/hj-waymo_official_tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

tutorial/2_waymo_official_tutorial/hj-waymo_official_tutorial.ipynb

# Waymo Open Dataset Motion Tutorial

https://github.com/waymo-research/waymo-open-dataset/blob/master/tutorial/tutorial_motion.ipynb


## 튜토리얼 목차
- 데이터를 디코딩하고 해석하는 방법
- Tensorflow로 간단한 모델을 훈련시키는 방법

## 데이터셋 다운로드
- https://waymo.com/open/



## 패키지 설치
- https://github.com/waymo-research/waymo-open-dataset/blob/master/tutorial/tutorial.ipynb

In [1]:
!rm -rf waymo-od > /dev/null
!git clone https://github.com/waymo-research/waymo-open-dataset.git waymo-od
!cd waymo-od && git branch -a
!cd waymo-od && git checkout remotes/origin/master

Cloning into 'waymo-od'...
remote: Enumerating objects: 2404, done.[K
remote: Counting objects: 100% (2404/2404), done.[K
remote: Compressing objects: 100% (689/689), done.[K
remote: Total 2404 (delta 1707), reused 2352 (delta 1688), pack-reused 0[K
Receiving objects: 100% (2404/2404), 86.01 MiB | 25.29 MiB/s, done.
Resolving deltas: 100% (1707/1707), done.
* [32mmaster[m
  [31mremotes/origin/HEAD[m -> origin/master
  [31mremotes/origin/master[m
  [31mremotes/origin/om2[m
  [31mremotes/origin/r1.0[m
  [31mremotes/origin/r1.0-tf1.15[m
  [31mremotes/origin/r1.0-tf2.0[m
  [31mremotes/origin/r1.2[m
  [31mremotes/origin/r1.3[m
Note: switching to 'remotes/origin/master'.

You are in 'detached HEAD' state. You can look around, make experimental
changes and commit them, and you can discard any commits you make in this
state without impacting any branches by switching back to a branch.

If you want to create a new branch to retain commits you create, you may
do so (now or 

In [2]:
%%capture
!pip3 install --upgrade pip
!pip3 install waymo-open-dataset-tf-2-6-0

## 데이터셋 로드
데이터셋의 크기가 커서 구글 드라이브에 업로드하였고, 해당 데이터를 로드해야 한다.

In [3]:
# 내 드라이브를 현재 위치에 가져오기
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Import

In [4]:
FILENAME = '/content/drive/MyDrive/waymo-od-dataset/tutorial/uncompressed_tf_example_validation_validation_tfexample.tfrecord-00000-of-00150'

In [5]:
import math
import os
import uuid
import time

from matplotlib import cm
import matplotlib.animation as animation
import matplotlib.pyplot as plt

import numpy as np
from IPython.display import HTML
import itertools
import tensorflow as tf

from google.protobuf import text_format
from waymo_open_dataset.metrics.ops import py_metrics_ops
from waymo_open_dataset.metrics.python import config_util_py as config_util
from waymo_open_dataset.protos import motion_metrics_pb2


In [6]:
# Example field definition
roadgraph_features = {
    'roadgraph_samples/dir':
        tf.io.FixedLenFeature([30000, 3], tf.float32, default_value=None),
    'roadgraph_samples/id':
        tf.io.FixedLenFeature([30000, 1], tf.int64, default_value=None),
    'roadgraph_samples/type':
        tf.io.FixedLenFeature([30000, 1], tf.int64, default_value=None),
    'roadgraph_samples/valid':
        tf.io.FixedLenFeature([30000, 1], tf.int64, default_value=None),
    'roadgraph_samples/xyz':
        tf.io.FixedLenFeature([30000, 3], tf.float32, default_value=None),
}

# Features of other agents.
state_features = {
    'state/id':
        tf.io.FixedLenFeature([128], tf.float32, default_value=None),
    'state/type':
        tf.io.FixedLenFeature([128], tf.float32, default_value=None),
    'state/is_sdc':
        tf.io.FixedLenFeature([128], tf.int64, default_value=None),
    'state/tracks_to_predict':
        tf.io.FixedLenFeature([128], tf.int64, default_value=None),
    'state/current/bbox_yaw':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/height':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/length':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/timestamp_micros':
        tf.io.FixedLenFeature([128, 1], tf.int64, default_value=None),
    'state/current/valid':
        tf.io.FixedLenFeature([128, 1], tf.int64, default_value=None),
    'state/current/vel_yaw':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/velocity_x':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/velocity_y':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/width':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/x':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/y':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/z':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/future/bbox_yaw':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/future/height':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/future/length':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/future/timestamp_micros':
        tf.io.FixedLenFeature([128, 80], tf.int64, default_value=None),
    'state/future/valid':
        tf.io.FixedLenFeature([128, 80], tf.int64, default_value=None),
    'state/future/vel_yaw':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/future/velocity_x':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/future/velocity_y':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/future/width':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/future/x':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/future/y':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/future/z':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/past/bbox_yaw':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/past/height':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/past/length':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/past/timestamp_micros':
        tf.io.FixedLenFeature([128, 10], tf.int64, default_value=None),
    'state/past/valid':
        tf.io.FixedLenFeature([128, 10], tf.int64, default_value=None),
    'state/past/vel_yaw':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/past/velocity_x':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/past/velocity_y':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/past/width':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/past/x':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/past/y':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/past/z':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
}

traffic_light_features = {
    'traffic_light_state/current/state':
        tf.io.FixedLenFeature([1, 16], tf.int64, default_value=None),
    'traffic_light_state/current/valid':
        tf.io.FixedLenFeature([1, 16], tf.int64, default_value=None),
    'traffic_light_state/current/x':
        tf.io.FixedLenFeature([1, 16], tf.float32, default_value=None),
    'traffic_light_state/current/y':
        tf.io.FixedLenFeature([1, 16], tf.float32, default_value=None),
    'traffic_light_state/current/z':
        tf.io.FixedLenFeature([1, 16], tf.float32, default_value=None),
    'traffic_light_state/past/state':
        tf.io.FixedLenFeature([10, 16], tf.int64, default_value=None),
    'traffic_light_state/past/valid':
        tf.io.FixedLenFeature([10, 16], tf.int64, default_value=None),
    'traffic_light_state/past/x':
        tf.io.FixedLenFeature([10, 16], tf.float32, default_value=None),
    'traffic_light_state/past/y':
        tf.io.FixedLenFeature([10, 16], tf.float32, default_value=None),
    'traffic_light_state/past/z':
        tf.io.FixedLenFeature([10, 16], tf.float32, default_value=None),
}

features_description = {}
features_description.update(roadgraph_features)
features_description.update(state_features)
features_description.update(traffic_light_features)
     


## Visualize TF Example sample

### Create Dataset

In [7]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [8]:
dataset = tf.data.TFRecordDataset(FILENAME, compression_type='')
data = next(dataset.as_numpy_iterator())
parsed = tf.io.parse_single_example(data, features_description)

### Generate visualization images

In [11]:
def create_figure_and_axes(size_pixels) :
  """
  Initializes a unique figure and axes for plotting.

  Args:
    size_pixels: 출력 이미지의 픽셀 크기

  Returns:
    설정된 Figure와 Axes 객체
  """
  fig, ax = plt.subplots(1, 1, num=uuid.uuid4()) # 고유한 식별자 부여
  
  # Sets output image to pixel resolution
  dpi = 100  # 출력 이미지 해상도
  size_inches = size_pixels / dpi  # 출력 이미지의 크기를 인치 단위로 변환
  fig.set_size_inches([size_inches, size_inches])
  fig.set_dpi(dpi)
  fig.set_facecolor('white')
  ax.set_facecolor('white')
  ax.xaxis.label.set_color('black')
  ax.tick_params(axis='x', colors='black')
  ax.yaxis.label.set_color('black')
  ax.tick_params(axis='y', colors='black')
  fig.set_tight_layout(True)  # Figure 객체 여백 최소화
  ax.grid(False)  # 그리드 제거
  
  return fig, ax

def fig_canvas_image(fig) :
  """
  Returns a [H, W, 3] uint8 np.array image from fig.canvas.tostring_rgb().
  Matplotlib을 사용하여 그린 그래프를 이미지 데이터로 변환하는 함수
  Figure 객체의 캔버스를 RGB 형식의 바이트 문자열로 변환한 다음, 이를 NumPy 배열로 변환
  NumPy 배열의 형태를 `fig.canvas.get_width_height()[::-1] + (3,)`로 지정하여 [H, W, 3] 형태의 이미지로 변환

  Args:
    fig: Matplotlib의 Figure 객체

  Returns:
    이미지 데이터
  """
  # Just enough margin in the figure to display xticks and yticks.
  # Figure 객체의 여백 설정
  fig.subplots_adjust(
      left=0.08, bottom=0.08, right=0.98, top=0.98, wspace=0.0, hspace=0.0) 
  fig.canvas.draw()
  data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)

  return data.reshape(fig.canvas.get_width_height()[::-1] + (3,))

def get_colormap(num_agents) :
  """
  Compute a color map array of shape [num_agents, 4].
  Matplotlib의 cm 모듈을 사용하여 컬러맵 생성

  Args:
    num_agents: 정수. 컬러맵 생성할 때 사용할 색상의 개수

  Returns:
    컬러맵 배열. 배열의 크기는 [num_agents, 4]. 각 행은 하나의 색상, 네 번째 열은 알파값(투명도).
  """
  # jet 컬러맵 생성
  colors = cm.get_cmap('jet', num_agents)
  # num_agents 개수의 색상을 추출
  colors = colors(range(num_agents))
  # 색상 무작위로 섞기. 색상의 순서는 무작위화 되며, 각 에이전트가 고유한 색상을 가짐
  np.random.shuffle(colors)
  return colors

def get_viewport(all_states, all_states_mask) :
  """Gets the region containing the data.
  주어진 데이터 영역의 중심과 너비를 계산하는 함수

  Args: 
    all_states: 모든 에이전트의 상태를 나타내는 배열. shape은 [num_agents, num_steps, 2],
                num_agents는 에이전트의 수, num_steps는 각 에이전트의 시간 스텝 수
    all_states_mask: all_states와 같이 [num_agents, num_steps] shape의 바이너리 마스크 배열.
                     마스크의 각 원소는 all_states의 해당 상태가 유효한 상태인지 아닌지를 나타냄
  
  Returns:
    center_y: float. 데이터의 중심 y 좌표.
    center_x: float. 데이터의 중심 x 좌표.
    width: float. 데이터의 너비.
  """
  # 유효한 상태만 선택. valid_states의 shape은 [num_valid_states, 2]. num_valid_states는 유효한 상태의 수
  valid_states = all_states[all_states_mask]
  # 모든 y 좌표와 x 좌표를 각각 추출
  all_y = valid_states[..., 1]
  all_x = valid_states[..., 0]

  # y, x 좌표의 최대값과 최소값의 평균 저장
  center_y = (np.max(all_y) + np.min(all_y)) / 2
  center_x = (np.max(all_x) + np.min(all_x)) / 2

  # np.ptp 함수는 배열의 최대값과 최소값의 차이 계산. y 좌표와 x 좌표의 범위를 각각 계산. 
  range_y = np.ptp(all_y)
  range_x = np.ptp(all_x)

  # 이 범위 중 더 큰 값을 width로 저장.
  width = max(range_y, range_x)

  return center_y, center_x, width


def visualize_one_step(states,
                       mask,
                       roadgraph,
                       title,
                       center_y,
                       center_x,
                       width,
                       color_map,
                       size_pixels=1000):
  """Generate visualization for a single step.
  하나의 스텝(시간)에 대한 시각화 생성

  Args: 
    states: 에이전트의 위치 정보를 담은 배열. 
            shape은 [num_agents, num_steps, 2]. num_agents는 차량 수. num_steps는 시간 단계 수.
            첫 번째 차원은 차량. 두 번째 차원은 시간. 세 번째 차원은 차량의 위치.
    mask: states에서 해당 스텝에서 활동중인 에이전트를 나타내는 바이너리 마스크. shape은 [num_agents].
    roadgraph: 도로 그래프 정보를 담은 배열. 형태는 [num_points, 2].
    title: 그래프 제목.
    center_y: 데이터의 중심 y 좌표.
    center_x: 데이터의 중심 x 좌표.
    width: 데이터의 너비.
    color_map: 색상 맵 배열. 형태는 [num_agents, 4].
    size_pixels: 출력 이미지의 픽셀 크기.

  Returns:
    시각화된 이미지를 나타내는 [H, W, 3] 형태의 numpy 배열
  """

  # Create figure and axes. 새로운 이미지 생성
  fig, ax = create_figure_and_axes(size_pixels=size_pixels)

  # 도로 그래프를 점(Plot)으로 그리기
  rg_pts = roadgraph[:, :2].T
  ax.plot(rg_pts[0, :], rg_pts[1, :], 'k.', alpha=1, ms=2)


  masked_x = states[:, 0][mask]
  masked_y = states[:, 1][mask]
  colors = color_map[mask]

  # 각 에이전트의 현재 위치를 산점도로 그리기
  ax.scatter(
      masked_x,
      masked_y,
      marker='o',
      linewidths=3,
      color=colors,
  )

  # Title.
  ax.set_title(title)

  # 축 정의. 축은 적어도 한 쪽은 10m 이상이어야 하며, 모든 에이전트가 표시되도록 축 범위가 조정됨
  size = max(10, width * 1.0)
  ax.axis([
      -size / 2 + center_x, size / 2 + center_x, -size / 2 + center_y,
      size / 2 + center_y
  ])
  ax.set_aspect('equal')

  # 이미지를 Numpy 배열 형태로 반환
  image = fig_canvas_image(fig)
  plt.close(fig)

  return image

In [12]:
def visualize_all_agents_smooth(
    decoded_example,
    size_pixels=1000,
):
  """Visualizes all agent predicted trajectories in a serie of images.
  주어진 예측된 에이전트들의 모든 경로 시각화

  Args:
    decoded_example: 예측된 에이전트들의 정보를 가진 딕셔너리. 예측된 경로를 시각화하기 위해 필요한 모든 정보.
    size_pixels: 출력 이미지의 크기의 픽셀.

  Returns:
    이미지 리스트 반환. [H, W, 3].
  """
  # decoded_example에서 에이전트의 과거, 현재, 미래 위치와 유효한 마스크 가져오기
  # [num_agents, num_past_steps, 2] float32.
  past_states = tf.stack(
      [decoded_example['state/past/x'], decoded_example['state/past/y']],
      -1).numpy()
  past_states_mask = decoded_example['state/past/valid'].numpy() > 0.0

  # [num_agents, 1, 2] float32.
  current_states = tf.stack(
      [decoded_example['state/current/x'], decoded_example['state/current/y']],
      -1).numpy()
  current_states_mask = decoded_example['state/current/valid'].numpy() > 0.0

  # [num_agents, num_future_steps, 2] float32.
  future_states = tf.stack(
      [decoded_example['state/future/x'], decoded_example['state/future/y']],
      -1).numpy()
  future_states_mask = decoded_example['state/future/valid'].numpy() > 0.0

  # [num_points, 3] float32.
  roadgraph_xyz = decoded_example['roadgraph_samples/xyz'].numpy()

  num_agents, num_past_steps, _ = past_states.shape
  num_future_steps = future_states.shape[1]

  color_map = get_colormap(num_agents)

  # all_states, all_states_mask를 생성하여 모든 상태와 해당 마스크 연결
  # [num_agens, num_past_steps + 1 + num_future_steps, depth] float32.
  all_states = np.concatenate([past_states, current_states, future_states], 1)

  # [num_agens, num_past_steps + 1 + num_future_steps] float32.
  all_states_mask = np.concatenate(
      [past_states_mask, current_states_mask, future_states_mask], 1)

  # viewport 위치와 크기 계산
  center_y, center_x, width = get_viewport(all_states, all_states_mask)

  images = []

  # past_states와 past_states_mask에 대한 이미지 생성
  # Generate images from past time steps.
  for i, (s, m) in enumerate(
      zip(
          np.split(past_states, num_past_steps, 1),
          np.split(past_states_mask, num_past_steps, 1))):
    im = visualize_one_step(s[:, 0], m[:, 0], roadgraph_xyz,
                            'past: %d' % (num_past_steps - i), center_y,
                            center_x, width, color_map, size_pixels)
    images.append(im) # 각 시간 단계마다 생성된 이미지를 리스트에 추가

  # current_states, current_states_mask에 대한 이미지 생성
  # Generate one image for the current time step.
  s = current_states
  m = current_states_mask

  im = visualize_one_step(s[:, 0], m[:, 0], roadgraph_xyz, 'current', center_y,
                          center_x, width, color_map, size_pixels)
  images.append(im) # 각 시간 단계마다 생성된 이미지를 리스트에 추가

  # future_states, future_states_mask에 대한 이미지 생성
  # Generate images from future time steps.
  for i, (s, m) in enumerate(
      zip(
          np.split(future_states, num_future_steps, 1),
          np.split(future_states_mask, num_future_steps, 1))):
    im = visualize_one_step(s[:, 0], m[:, 0], roadgraph_xyz,
                            'future: %d' % (i + 1), center_y, center_x, width,
                            color_map, size_pixels)
    images.append(im) # 각 시간 단계마다 생성된 이미지를 리스트에 추가

  return images


images = visualize_all_agents_smooth(parsed)

  colors = cm.get_cmap('jet', num_agents)


### Display animation.

In [13]:
def create_animation(images):
  """ Creates a Matplotlib animation of the given images.
  Matplotlib 애니메이션 생성

  Args:
    images: A list of numpy arrays representing the images. 이미지 리스트.

  Returns:
    A matplotlib.animation.Animation. 애니메이션

  Usage:
    .avi 확장자로 저장하거나 HTML5 동영상으로 출력 가능
    anim = create_animation(images)
    anim.save('/tmp/animation.avi')
    HTML(anim.to_html5_video())
  """

  plt.ioff()
  fig, ax = plt.subplots()
  dpi = 100
  size_inches = 1000 / dpi
  fig.set_size_inches([size_inches, size_inches])
  plt.ion()

  def animate_func(i):
    # 애니메이션의 각 프레임마다 수행할 작업.
    ax.imshow(images[i])
    ax.set_xticks([])
    ax.set_yticks([])
    ax.grid('off')

  # 시각화된 이미지들을 애니메이션으로 생성.
  # frames: 프레임 수. interval: 애니메이션 속도.
  anim = animation.FuncAnimation(
      fig, animate_func, frames=len(images) // 2, interval=100)
  plt.close(fig)
  return anim


anim = create_animation(images[::5])
HTML(anim.to_html5_video())


## Simple MLP model with TF

이것은 입력 구문 분석 및 메트릭 계산을 보여주는 매우 간단한 예제 모델이다. 전혀 경쟁력이 없다.

에러가 발생해서 keras 버전 지정하여 재설치

```
!pip install keras==2.6
```


In [9]:
!pip install keras==2.6

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting keras==2.6
  Downloading keras-2.6.0-py2.py3-none-any.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m13.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: keras
  Attempting uninstall: keras
    Found existing installation: keras 2.12.0
    Uninstalling keras-2.12.0:
      Successfully uninstalled keras-2.12.0
Successfully installed keras-2.6.0
[0m

In [10]:
def _parse(value):
  decoded_example = tf.io.parse_single_example(value, features_description)

  past_states = tf.stack([
      decoded_example['state/past/x'], decoded_example['state/past/y'],
      decoded_example['state/past/length'], decoded_example['state/past/width'],
      decoded_example['state/past/bbox_yaw'],
      decoded_example['state/past/velocity_x'],
      decoded_example['state/past/velocity_y']
  ], -1)

  cur_states = tf.stack([
      decoded_example['state/current/x'], decoded_example['state/current/y'],
      decoded_example['state/current/length'],
      decoded_example['state/current/width'],
      decoded_example['state/current/bbox_yaw'],
      decoded_example['state/current/velocity_x'],
      decoded_example['state/current/velocity_y']
  ], -1)

  input_states = tf.concat([past_states, cur_states], 1)[..., :2]

  future_states = tf.stack([
      decoded_example['state/future/x'], decoded_example['state/future/y'],
      decoded_example['state/future/length'],
      decoded_example['state/future/width'],
      decoded_example['state/future/bbox_yaw'],
      decoded_example['state/future/velocity_x'],
      decoded_example['state/future/velocity_y']
  ], -1)

  gt_future_states = tf.concat([past_states, cur_states, future_states], 1)

  past_is_valid = decoded_example['state/past/valid'] > 0
  current_is_valid = decoded_example['state/current/valid'] > 0
  future_is_valid = decoded_example['state/future/valid'] > 0
  gt_future_is_valid = tf.concat(
      [past_is_valid, current_is_valid, future_is_valid], 1)

  # If a sample was not seen at all in the past, we declare the sample as
  # invalid.
  sample_is_valid = tf.reduce_any(
      tf.concat([past_is_valid, current_is_valid], 1), 1)

  inputs = {
      'input_states': input_states,
      'gt_future_states': gt_future_states,
      'gt_future_is_valid': gt_future_is_valid,
      'object_type': decoded_example['state/type'],
      'tracks_to_predict': decoded_example['state/tracks_to_predict'] > 0,
      'sample_is_valid': sample_is_valid,
  }
  return inputs


def _default_metrics_config():
  config = motion_metrics_pb2.MotionMetricsConfig()
  config_text = """
  track_steps_per_second: 10
  prediction_steps_per_second: 2
  track_history_samples: 10
  track_future_samples: 80
  speed_lower_bound: 1.4
  speed_upper_bound: 11.0
  speed_scale_lower: 0.5
  speed_scale_upper: 1.0
  step_configurations {
    measurement_step: 5
    lateral_miss_threshold: 1.0
    longitudinal_miss_threshold: 2.0
  }
  step_configurations {
    measurement_step: 9
    lateral_miss_threshold: 1.8
    longitudinal_miss_threshold: 3.6
  }
  step_configurations {
    measurement_step: 15
    lateral_miss_threshold: 3.0
    longitudinal_miss_threshold: 6.0
  }
  max_predictions: 6
  """
  text_format.Parse(config_text, config)

  return config


class SimpleModel(tf.keras.Model):
  """A simple one-layer regressor."""

  def __init__(self, num_agents_per_scenario, num_states_steps,
               num_future_steps):
    super(SimpleModel, self).__init__()
    self._num_agents_per_scenario = num_agents_per_scenario
    self._num_states_steps = num_states_steps
    self._num_future_steps = num_future_steps
    self.regressor = tf.keras.layers.Dense(num_future_steps * 2)

  def call(self, states):
    states = tf.reshape(states, (-1, self._num_states_steps * 2))
    pred = self.regressor(states)
    pred = tf.reshape(
        pred, [-1, self._num_agents_per_scenario, self._num_future_steps, 2])
    return pred


class MotionMetrics(tf.keras.metrics.Metric):
  """Wrapper for motion metrics computation."""

  def __init__(self, config):
    super().__init__()
    self._prediction_trajectory = []
    self._prediction_score = []
    self._ground_truth_trajectory = []
    self._ground_truth_is_valid = []
    self._prediction_ground_truth_indices = []
    self._prediction_ground_truth_indices_mask = []
    self._object_type = []
    self._metrics_config = config

  def reset_state(self):
    self._prediction_trajectory = []
    self._prediction_score = []
    self._ground_truth_trajectory = []
    self._ground_truth_is_valid = []
    self._prediction_ground_truth_indices = []
    self._prediction_ground_truth_indices_mask = []
    self._object_type = []

  def update_state(self, prediction_trajectory, prediction_score,
                   ground_truth_trajectory, ground_truth_is_valid,
                   prediction_ground_truth_indices,
                   prediction_ground_truth_indices_mask, object_type):
    self._prediction_trajectory.append(prediction_trajectory)
    self._prediction_score.append(prediction_score)
    self._ground_truth_trajectory.append(ground_truth_trajectory)
    self._ground_truth_is_valid.append(ground_truth_is_valid)
    self._prediction_ground_truth_indices.append(
        prediction_ground_truth_indices)
    self._prediction_ground_truth_indices_mask.append(
        prediction_ground_truth_indices_mask)
    self._object_type.append(object_type)

  def result(self):
    # [batch_size, num_preds, 1, 1, steps, 2].
    # The ones indicate top_k = 1, num_agents_per_joint_prediction = 1.
    prediction_trajectory = tf.concat(self._prediction_trajectory, 0)
    # [batch_size, num_preds, 1].
    prediction_score = tf.concat(self._prediction_score, 0)
    # [batch_size, num_agents, gt_steps, 7].
    ground_truth_trajectory = tf.concat(self._ground_truth_trajectory, 0)
    # [batch_size, num_agents, gt_steps].
    ground_truth_is_valid = tf.concat(self._ground_truth_is_valid, 0)
    # [batch_size, num_preds, 1].
    prediction_ground_truth_indices = tf.concat(
        self._prediction_ground_truth_indices, 0)
    # [batch_size, num_preds, 1].
    prediction_ground_truth_indices_mask = tf.concat(
        self._prediction_ground_truth_indices_mask, 0)
    # [batch_size, num_agents].
    object_type = tf.cast(tf.concat(self._object_type, 0), tf.int64)

    # We are predicting more steps than needed by the eval code. Subsample.
    interval = (
        self._metrics_config.track_steps_per_second //
        self._metrics_config.prediction_steps_per_second)
    prediction_trajectory = prediction_trajectory[...,
                                                  (interval - 1)::interval, :]

    return py_metrics_ops.motion_metrics(
        config=self._metrics_config.SerializeToString(),
        prediction_trajectory=prediction_trajectory,
        prediction_score=prediction_score,
        ground_truth_trajectory=ground_truth_trajectory,
        ground_truth_is_valid=ground_truth_is_valid,
        prediction_ground_truth_indices=prediction_ground_truth_indices,
        prediction_ground_truth_indices_mask=prediction_ground_truth_indices_mask,
        object_type=object_type)


model = SimpleModel(128, 11, 80)
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
loss_fn = tf.keras.losses.MeanSquaredError()
metrics_config = _default_metrics_config()
motion_metrics = MotionMetrics(metrics_config)
metric_names = config_util.get_breakdown_names_from_motion_config(
    metrics_config)


def train_step(inputs):
  with tf.GradientTape() as tape:
    # [batch_size, num_agents, D]
    states = inputs['input_states']

    # Predict. [batch_size, num_agents, steps, 2].
    pred_trajectory = model(states, training=True)

    # Set training target.
    prediction_start = metrics_config.track_history_samples + 1

    # [batch_size, num_agents, steps, 7]
    gt_trajectory = inputs['gt_future_states']
    gt_targets = gt_trajectory[..., prediction_start:, :2]

    # [batch_size, num_agents, steps]
    gt_is_valid = inputs['gt_future_is_valid']
    # [batch_size, num_agents, steps]
    weights = (
        tf.cast(inputs['gt_future_is_valid'][..., prediction_start:],
                tf.float32) *
        tf.cast(inputs['tracks_to_predict'][..., tf.newaxis], tf.float32))

    loss_value = loss_fn(gt_targets, pred_trajectory, sample_weight=weights)
  grads = tape.gradient(loss_value, model.trainable_weights)
  optimizer.apply_gradients(zip(grads, model.trainable_weights))

  # [batch_size, num_agents, steps, 2] ->
  # [batch_size, num_agents, 1, 1, steps, 2].
  # The added dimensions are top_k = 1, num_agents_per_joint_prediction = 1.
  pred_trajectory = pred_trajectory[:, :, tf.newaxis, tf.newaxis]

  # Fake the score since this model does not generate any score per predicted
  # trajectory.
  pred_score = tf.ones(shape=tf.shape(pred_trajectory)[:3])

  # [batch_size, num_agents].
  object_type = inputs['object_type']

  # [batch_size, num_agents].
  batch_size = tf.shape(inputs['tracks_to_predict'])[0]
  num_samples = tf.shape(inputs['tracks_to_predict'])[1]

  pred_gt_indices = tf.range(num_samples, dtype=tf.int64)
  # [batch_size, num_agents, 1].
  pred_gt_indices = tf.tile(pred_gt_indices[tf.newaxis, :, tf.newaxis],
                            (batch_size, 1, 1))
  # [batch_size, num_agents, 1].
  pred_gt_indices_mask = inputs['tracks_to_predict'][..., tf.newaxis]

  motion_metrics.update_state(pred_trajectory, pred_score, gt_trajectory,
                              gt_is_valid, pred_gt_indices,
                              pred_gt_indices_mask, object_type)

  return loss_value


dataset = tf.data.TFRecordDataset(FILENAME)
dataset = dataset.map(_parse)
dataset = dataset.batch(32)

epochs = 2
num_batches_per_epoch = 10

for epoch in range(epochs):
  print('\nStart of epoch %d' % (epoch,))
  start_time = time.time()

  # Iterate over the batches of the dataset.
  for step, batch in enumerate(dataset):
    loss_value = train_step(batch)

    # Log every 10 batches.
    if step % 10 == 0:
      print('Training loss (for one batch) at step %d: %.4f' %
            (step, float(loss_value)))
      print('Seen so far: %d samples' % ((step + 1) * 64))

    if step >= num_batches_per_epoch:
      break

  # Display metrics at the end of each epoch.
  train_metric_values = motion_metrics.result()
  for i, m in enumerate(
      ['min_ade', 'min_fde', 'miss_rate', 'overlap_rate', 'map']):
    for j, n in enumerate(metric_names):
      print('{}/{}: {}'.format(m, n, train_metric_values[i, j]))


Start of epoch 0
Training loss (for one batch) at step 0: 2847022.7500
Seen so far: 64 samples
Training loss (for one batch) at step 10: 1037661.0625
Seen so far: 704 samples
min_ade/TYPE_VEHICLE_5: 5824.185546875
min_ade/TYPE_VEHICLE_9: 6199.0654296875
min_ade/TYPE_VEHICLE_15: 6473.201171875
min_ade/TYPE_PEDESTRIAN_5: 5711.22314453125
min_ade/TYPE_PEDESTRIAN_9: 6084.462890625
min_ade/TYPE_PEDESTRIAN_15: 6340.5634765625
min_ade/TYPE_CYCLIST_5: 4747.6552734375
min_ade/TYPE_CYCLIST_9: 4969.04248046875
min_ade/TYPE_CYCLIST_15: 5171.52197265625
min_fde/TYPE_VEHICLE_5: 6121.1044921875
min_fde/TYPE_VEHICLE_9: 5386.2001953125
min_fde/TYPE_VEHICLE_15: 3629.50390625
min_fde/TYPE_PEDESTRIAN_5: 6004.4443359375
min_fde/TYPE_PEDESTRIAN_9: 5286.3935546875
min_fde/TYPE_PEDESTRIAN_15: 3336.2470703125
min_fde/TYPE_CYCLIST_5: 4783.87890625
min_fde/TYPE_CYCLIST_9: 4592.359375
min_fde/TYPE_CYCLIST_15: 3965.519775390625
miss_rate/TYPE_VEHICLE_5: 1.0
miss_rate/TYPE_VEHICLE_9: 1.0
miss_rate/TYPE_VEHICLE_15:

# Senario 데이터셋

### 데이터셋 구조 파일(proto)을 이용하여 로드
- tfrecord는 압축된 형태로 존재한다. `TFRecordDataset`을 통해 데이터를 읽을 수 있지만, 그 전에 압축된 데이터를 해제할 수 있도록 구조 정의가 필요하다.

In [None]:
type(data)

In [None]:
data.numpy()[:100]

In [None]:
from waymo_open_dataset.protos import scenario_pb2 
# Protocol Buffers (PB)
# 구글에서 개발한 경량 데이터 교환 형식 중 하나로, 프로그래밍 언어와 플랫폼 간의 데이터 통신에 사용

dataset = tf.data.TFRecordDataset(FILENAME, compression_type='')

for _, data in enumerate(dataset) :
  scenario = scenario_pb2.Scenario()
  scenario.ParseFromString(data.numpy())
  break

In [None]:
scenario.scenario_id

- past : scenario.tracks[0].states[:10] # 10개
- current : scenario.tracks[0].states[10] # 1개
- future : scenario.tracks[0].states[11:] # 80개
- state/id => scenario.tracks[0].states.id
- traffic_light_state/current/state => scenario.dynamic_map_states[10].lane_states[0].state

In [None]:
dataset = tf.data.TFRecordDataset(FILENAME, compression_type='')
for num_data, data in enumerate(dataset) :
  pass

num_data

- 해당 tfrecord 내부에는 269개의 데이터가 있다.

## 데이터셋 구조

- scenario.proto
  - message ObjectState
  - message Track
  - message DynamicMapState
  - message RequiredPrediction
  - message Scenario
- map.proto
  - message Map
  - message DynamicState
  - message TrafficSignalLaneState
  - message MapFeature
  - message MapPoint
  - message BoundarySegment
  - message LaneNeighbor
  - message LaneCenter
  - message RoadEdge
  - message RoadLine
  - message StopSign
  - message Crosswalk
  - message SpeedBump

## Senario TF(TensorFlow 데이터 포맷) Example sample 시각화

### 정적 map 데이터 출력 

In [None]:
def poly2position(poly) :
  x, y, z = [], [], []
  point(poly)
  for point in poly :
    x.append(point.x)
    y.append(point.y)
    z.append(point.z)
  return x, y, z

In [None]:
def visualizationMap(ax, map_features) :
  for feature in map_features :
    lane_x, lane_y, lane_z = poly2position(feature.lane.polyline)
    ax.plot(lane_x, lane_y, '-', c='#d6d6d6', lw=3) # 도로부분이라서 연한 회색
    
    road_line_x, road_line_y, road_line_z = poly2position(feature.road_line.polyline)
    ax.plot(road_line_x, road_line_y, '--', c='gray') # 차선이라서 일단 회색 대시 선
    
    road_edge_x, road_edge_y, road_edge_z = poly2position(feature.road_edge.polyline)
    ax.plot(road_edge_x, road_edge_y, '-', c='black') # 엣지 부분은 검은 실선
    
    for point in [feature.stop_sign.position]:
        pass #ax.plot(point.x, point.y, marker='o', c='red', ms=10) # 데이터가 좀 이상함,,,
    
    crosswalk_x, crosswalk_y, crosswalk_z = poly2position(feature.crosswalk.polygon)
    ax.plot(crosswalk_x, crosswalk_y, '-', c='#f1f289', lw=3)
    
    speed_bump_x, speed_bump_y, speed_bump_z = poly2position(feature.speed_bump.polygon)
    ax.plot(speed_bump_x, speed_bump_y, '-', c='#fab6e6', lw=3)

In [None]:
for data in dataset:
  proto_string = data.numpy()
  scenario = scenario_pb2.Scenario()
  scenario.ParseFromString(proto_string)

  break

fig, ax = plt.subplots(figsize=(10,10))
ax.title.set_text("Scenario ID : " + scenario.scenario_id)

visualizationMap(ax, scenario.map_features)