# This Notebook is for converting Waymo TF-Records to json files that represent a scene standalone. 

In [8]:
WAYMO_TF_RECORD_FILE_PATH = '/home/yz8733/Github/isaac-rl/data'

In [9]:
import math
import os
import uuid
import time

from matplotlib import cm
import matplotlib.animation as animation
import matplotlib.pyplot as plt

import numpy as np
from IPython.display import HTML
import itertools
import tensorflow as tf
from pathlib import Path

waymo_record_folder = Path(WAYMO_TF_RECORD_FILE_PATH)



from google.protobuf import text_format
from waymo_open_dataset.metrics.ops import py_metrics_ops
from waymo_open_dataset.metrics.python import config_util_py as config_util
from waymo_open_dataset.protos import motion_metrics_pb2

In [5]:
num_map_samples = 30000

# Example field definition
roadgraph_features = {
    'roadgraph_samples/dir': tf.io.FixedLenFeature(
        [num_map_samples, 3], tf.float32, default_value=None
    ),
    'roadgraph_samples/id': tf.io.FixedLenFeature(
        [num_map_samples, 1], tf.int64, default_value=None
    ),
    'roadgraph_samples/type': tf.io.FixedLenFeature(
        [num_map_samples, 1], tf.int64, default_value=None
    ),
    'roadgraph_samples/valid': tf.io.FixedLenFeature(
        [num_map_samples, 1], tf.int64, default_value=None
    ),
    'roadgraph_samples/xyz': tf.io.FixedLenFeature(
        [num_map_samples, 3], tf.float32, default_value=None
    ),
}
# Features of other agents.
state_features = {
    'state/id':
        tf.io.FixedLenFeature([128], tf.float32, default_value=None),
    'state/type':
        tf.io.FixedLenFeature([128], tf.float32, default_value=None),
    'state/is_sdc':
        tf.io.FixedLenFeature([128], tf.int64, default_value=None),
    'state/tracks_to_predict':
        tf.io.FixedLenFeature([128], tf.int64, default_value=None),
    'state/current/bbox_yaw':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/height':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/length':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/timestamp_micros':
        tf.io.FixedLenFeature([128, 1], tf.int64, default_value=None),
    'state/current/valid':
        tf.io.FixedLenFeature([128, 1], tf.int64, default_value=None),
    'state/current/vel_yaw':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/velocity_x':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/velocity_y':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/width':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/x':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/y':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/z':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/future/bbox_yaw':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/future/height':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/future/length':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/future/timestamp_micros':
        tf.io.FixedLenFeature([128, 80], tf.int64, default_value=None),
    'state/future/valid':
        tf.io.FixedLenFeature([128, 80], tf.int64, default_value=None),
    'state/future/vel_yaw':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/future/velocity_x':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/future/velocity_y':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/future/width':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/future/x':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/future/y':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/future/z':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/past/bbox_yaw':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/past/height':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/past/length':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/past/timestamp_micros':
        tf.io.FixedLenFeature([128, 10], tf.int64, default_value=None),
    'state/past/valid':
        tf.io.FixedLenFeature([128, 10], tf.int64, default_value=None),
    'state/past/vel_yaw':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/past/velocity_x':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/past/velocity_y':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/past/width':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/past/x':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/past/y':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/past/z':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
}

traffic_light_features = {
    'traffic_light_state/current/state':
        tf.io.FixedLenFeature([1, 16], tf.int64, default_value=None),
    'traffic_light_state/current/valid':
        tf.io.FixedLenFeature([1, 16], tf.int64, default_value=None),
    'traffic_light_state/current/x':
        tf.io.FixedLenFeature([1, 16], tf.float32, default_value=None),
    'traffic_light_state/current/y':
        tf.io.FixedLenFeature([1, 16], tf.float32, default_value=None),
    'traffic_light_state/current/z':
        tf.io.FixedLenFeature([1, 16], tf.float32, default_value=None),
    'traffic_light_state/past/state':
        tf.io.FixedLenFeature([10, 16], tf.int64, default_value=None),
    'traffic_light_state/past/valid':
        tf.io.FixedLenFeature([10, 16], tf.int64, default_value=None),
    'traffic_light_state/past/x':
        tf.io.FixedLenFeature([10, 16], tf.float32, default_value=None),
    'traffic_light_state/past/y':
        tf.io.FixedLenFeature([10, 16], tf.float32, default_value=None),
    'traffic_light_state/past/z':
        tf.io.FixedLenFeature([10, 16], tf.float32, default_value=None),
}

features_description = {}
features_description.update(roadgraph_features)
features_description.update(state_features)
features_description.update(traffic_light_features)

In [10]:
filenames = [
    p.name for p in waymo_record_folder.iterdir()
    if p.is_file() and p.name.startswith("uncompressed_tf_example")
]

print(filenames)

['uncompressed_tf_example_training_training_tfexample.tfrecord-00000-of-01000']


In [25]:
parsed = []

for file in filenames:
    path = os.path.join(WAYMO_TF_RECORD_FILE_PATH, file)
    ds = tf.data.TFRecordDataset(path, compression_type="")

    n = 0
    for raw in ds:  # raw is a tf.Tensor of type string (bytes)
        ex = tf.io.parse_single_example(raw, features_description)
        parsed.append(ex)
        n += 1

    print(file, "Number of Scenes:", n)


uncompressed_tf_example_training_training_tfexample.tfrecord-00000-of-01000 Number of Scenes: 455


In [34]:
import os, json, hashlib
import numpy as np
import tensorflow as tf

def ensure_dir(p):
    os.makedirs(p, exist_ok=True)

def _to_1d(x):
    return np.asarray(x).reshape(-1)

def _safe_float(x):
    return float(x)  # keep exact float32->python float (no rounding)

def order_polyline_pca_xy(xyz: np.ndarray) -> np.ndarray:
    """Order points along dominant XY axis (helps make a polyline)."""
    n = xyz.shape[0]
    if n <= 2:
        return np.arange(n)
    p = xyz[:, :2].astype(np.float64)
    p0 = p.mean(axis=0, keepdims=True)
    X = p - p0
    _, _, vt = np.linalg.svd(X, full_matrices=False)
    axis = vt[0]
    s = X @ axis
    return np.argsort(s)

def _hash_bytes(b: bytes) -> str:
    return hashlib.md5(b).hexdigest()

# -----------------------------
# Extract everything Waymo provides (no filtering, no downsampling)
# -----------------------------
def extract_road_polylines_all(parsed_ex, order_points=True):
    """
    Keep ALL roadgraph_samples groups by (type,id), with ALL points.
    If order_points=False, the points are kept in the TFRecord order (not recommended for polylines).
    """
    rg_valid = _to_1d(parsed_ex["roadgraph_samples/valid"].numpy()).astype(bool)
    rg_xyz   = parsed_ex["roadgraph_samples/xyz"].numpy()          # (30000,3)
    rg_dir   = parsed_ex["roadgraph_samples/dir"].numpy()          # (30000,3)
    rg_id    = _to_1d(parsed_ex["roadgraph_samples/id"].numpy())   # (30000,)
    rg_type  = _to_1d(parsed_ex["roadgraph_samples/type"].numpy()) # (30000,)

    # only remove invalid points (that's not "downsampling"; it's required)
    xyz   = rg_xyz[rg_valid]
    direc = rg_dir[rg_valid]
    ids   = rg_id[rg_valid].astype(np.int64)
    types = rg_type[rg_valid].astype(np.int64)

    # group by (type,id)
    keys = np.stack([types, ids], axis=1)  # (M,2)
    uniq, inv = np.unique(keys, axis=0, return_inverse=True)

    polylines = []
    for gi, (t, i) in enumerate(uniq):
        m = (inv == gi)
        pts = xyz[m]
        d   = direc[m]

        if order_points:
            idx = order_polyline_pca_xy(pts)
            pts = pts[idx]
            d   = d[idx]

        polylines.append({
            "type": int(t),
            "id": int(i),
            "n": int(pts.shape[0]),
            "xyz": pts.tolist(),   # keep full precision lists
            "dir": d.tolist(),
        })

    stats = {
        "road_valid_points": int(xyz.shape[0]),
        "num_groups_total": int(uniq.shape[0]),
        "num_polylines_saved": int(len(polylines)),
    }
    return polylines, stats

def extract_agents_start_end_all_valid(parsed_ex, end_mode="last_valid"):
    """
    Save start pose for every valid agent; save end pose from future (last valid).
    """
    cur_valid = _to_1d(parsed_ex["state/current/valid"].numpy()).astype(bool)
    is_sdc    = _to_1d(parsed_ex["state/is_sdc"].numpy()).astype(bool)

    # (optional) type/id are float32 in your parse; keep them but cast safely
    a_type = _to_1d(parsed_ex["state/type"].numpy())
    a_id   = _to_1d(parsed_ex["state/id"].numpy())

    cx   = _to_1d(parsed_ex["state/current/x"].numpy())
    cy   = _to_1d(parsed_ex["state/current/y"].numpy())
    cz   = _to_1d(parsed_ex["state/current/z"].numpy())
    cyaw = _to_1d(parsed_ex["state/current/bbox_yaw"].numpy())

    f_valid = parsed_ex["state/future/valid"].numpy().astype(bool)   # (128,80)
    fx   = parsed_ex["state/future/x"].numpy()
    fy   = parsed_ex["state/future/y"].numpy()
    fz   = parsed_ex["state/future/z"].numpy()
    fyaw = parsed_ex["state/future/bbox_yaw"].numpy()

    agents = []
    for i in np.where(cur_valid)[0]:
        start = {
            "x": _safe_float(cx[i]),
            "y": _safe_float(cy[i]),
            "z": _safe_float(cz[i]),
            "yaw": _safe_float(cyaw[i]),
        }

        end = None
        js = np.where(f_valid[i])[0]
        if js.size > 0:
            j = int(js[-1] if end_mode == "last_valid" else js[0])
            end = {
                "x": _safe_float(fx[i, j]),
                "y": _safe_float(fy[i, j]),
                "z": _safe_float(fz[i, j]),
                "yaw": _safe_float(fyaw[i, j]),
                "t_idx": int(j),
            }

        agents.append({
            "track_idx": int(i),
            "is_sdc": bool(is_sdc[i]),
            "agent_type": int(a_type[i]) if np.isfinite(a_type[i]) else None,
            "agent_id": float(a_id[i]) if np.isfinite(a_id[i]) else None,
            "start": start,
            "end": end,
        })

    sdc = [a for a in agents if a["is_sdc"]]
    return agents, (sdc[0] if sdc else None)

def save_scene_json_full(parsed_ex, out_path, source_file, scene_index_global,
                         order_points=True, end_mode="last_valid"):
    polylines, road_stats = extract_road_polylines_all(parsed_ex, order_points=order_points)
    agents, sdc = extract_agents_start_end_all_valid(parsed_ex, end_mode=end_mode)

    payload = {
        "meta": {
            "source_file": source_file,
            "scene_index_global": int(scene_index_global),
        },
        "road": {
            "stats": road_stats,
            "polylines": polylines,
        },
        "agents": {
            "count_valid": int(len(agents)),
            "sdc": sdc,          # includes start/end
            "items": agents,     # all valid agents
        }
    }

    with open(out_path, "w") as f:
        json.dump(payload, f)

def export_tfrecords_to_json_full(filenames, tfrecord_dir, out_dir,
                                 features_description, compression_type="",
                                 out_name_mode="sequential",
                                 order_points=True,
                                 end_mode="last_valid"):
    """
    Writes one JSON per TFRecord record (scene). No downsampling, no type filtering.
    """
    ensure_dir(out_dir)
    scene_idx = 0

    for file in filenames:
        path = os.path.join(tfrecord_dir, file)
        ds = tf.data.TFRecordDataset(path, compression_type=compression_type)

        for raw in ds:
            parsed_ex = tf.io.parse_single_example(raw, features_description)

            if out_name_mode == "hash":
                name = f"scene_{_hash_bytes(raw.numpy())}.json"
            else:
                name = f"scene_{scene_idx:06d}.json"

            out_path = os.path.join(out_dir, name)
            save_scene_json_full(
                parsed_ex,
                out_path=out_path,
                source_file=file,
                scene_index_global=scene_idx,
                order_points=order_points,
                end_mode=end_mode,
            )
            scene_idx += 1

    print("Done. Total scenes written:", scene_idx)


In [35]:
OUT_DIR = "/home/yz8733/Github/isaac-rl/data/processed/waymo_scenes_json"
export_tfrecords_to_json_full(
    filenames=filenames,
    tfrecord_dir=WAYMO_TF_RECORD_FILE_PATH,
    out_dir=OUT_DIR,
    features_description=features_description,
    compression_type="",
    out_name_mode="sequential",
    order_points=True,   # set False if you truly want raw point order
    end_mode="last_valid",
)

Done. Total scenes written: 455


#### Debugging blocks

In [32]:
import numpy as np

# -----------------------------
# Helpers: road polylines + start/end poses
# -----------------------------

def _to_1d(x):
    a = np.asarray(x)
    return a.reshape(-1)

def _to_2d(x, last_dim):
    a = np.asarray(x)
    a = a.reshape(-1, last_dim)
    return a

def order_polyline_pca_xy(xyz: np.ndarray) -> np.ndarray:
    """
    Returns indices that order points along the dominant 2D axis (XY PCA).
    Good enough to turn a grouped point cloud into a polyline.
    """
    n = xyz.shape[0]
    if n <= 2:
        return np.arange(n)
    p = xyz[:, :2].astype(np.float64)
    p0 = p.mean(axis=0, keepdims=True)
    X = p - p0
    # principal axis
    _, _, vt = np.linalg.svd(X, full_matrices=False)
    axis = vt[0]  # (2,)
    s = X @ axis  # (n,)
    return np.argsort(s)

def extract_road_polylines(parsed_ex, min_points=5, downsample_step=1, order_points=True):
    """
    Build map as grouped polylines from roadgraph_samples/{xyz,dir,id,type,valid}.
    Returns a dict with:
      - 'polylines': list of dicts: {id, type, xyz(N,3), dir(N,3)}
      - 'stats': basic counts
    """
    rg_valid = _to_1d(parsed_ex["roadgraph_samples/valid"].numpy()).astype(bool)
    rg_xyz   = parsed_ex["roadgraph_samples/xyz"].numpy()            # (30000,3)
    rg_dir   = parsed_ex["roadgraph_samples/dir"].numpy()            # (30000,3)
    rg_id    = _to_1d(parsed_ex["roadgraph_samples/id"].numpy())     # (30000,)
    rg_type  = _to_1d(parsed_ex["roadgraph_samples/type"].numpy())   # (30000,)

    xyz = rg_xyz[rg_valid]
    direc = rg_dir[rg_valid]
    ids = rg_id[rg_valid].astype(np.int64)
    types = rg_type[rg_valid].astype(np.int64)

    # group by (type, id)
    keys = np.stack([types, ids], axis=1)  # (M,2)
    uniq, inv = np.unique(keys, axis=0, return_inverse=True)

    polylines = []
    for gi, (t, i) in enumerate(uniq):
        m = (inv == gi)
        pts = xyz[m]
        d   = direc[m]
        if pts.shape[0] < min_points:
            continue

        if order_points:
            idx = order_polyline_pca_xy(pts)
            pts = pts[idx]
            d   = d[idx]

        if downsample_step > 1:
            pts = pts[::downsample_step]
            d   = d[::downsample_step]

        polylines.append({
            "type": int(t),
            "id": int(i),
            "xyz": pts.astype(np.float32),
            "dir": d.astype(np.float32),
            "n_points": int(pts.shape[0]),
        })

    stats = {
        "road_valid_points": int(xyz.shape[0]),
        "num_groups_total": int(uniq.shape[0]),
        "num_polylines_kept": int(len(polylines)),
    }
    return {"polylines": polylines, "stats": stats}

def extract_agent_start_end(parsed_ex, use_gt_future_end=True, end_mode="last_valid"):
    """
    Returns list of agents with:
      - track_idx
      - is_sdc
      - start: (x,y,z,yaw)
      - end:   (x,y,z,yaw)  (if available)
    end_mode: "last_valid" or "first_valid" over future valid mask.
    """
    cur_valid = _to_1d(parsed_ex["state/current/valid"].numpy()).astype(bool)
    is_sdc    = _to_1d(parsed_ex["state/is_sdc"].numpy()).astype(bool)

    cx = _to_1d(parsed_ex["state/current/x"].numpy())
    cy = _to_1d(parsed_ex["state/current/y"].numpy())
    cz = _to_1d(parsed_ex["state/current/z"].numpy())
    cyaw = _to_1d(parsed_ex["state/current/bbox_yaw"].numpy())

    # optional: future (128,80)
    if use_gt_future_end:
        f_valid = parsed_ex["state/future/valid"].numpy().astype(bool)
        fx = parsed_ex["state/future/x"].numpy()
        fy = parsed_ex["state/future/y"].numpy()
        fz = parsed_ex["state/future/z"].numpy()
        fyaw = parsed_ex["state/future/bbox_yaw"].numpy()

    agents = []
    idxs = np.where(cur_valid)[0]
    for i in idxs:
        start = (float(cx[i]), float(cy[i]), float(cz[i]), float(cyaw[i]))
        end = None

        if use_gt_future_end:
            fv = f_valid[i]  # (80,)
            js = np.where(fv)[0]
            if js.size > 0:
                j = int(js[-1] if end_mode == "last_valid" else js[0])
                end = (float(fx[i, j]), float(fy[i, j]), float(fz[i, j]), float(fyaw[i, j]))

        agents.append({
            "track_idx": int(i),
            "is_sdc": bool(is_sdc[i]),
            "start": start,
            "end": end,
        })

    return agents

# -----------------------------
# “Inspect” one parsed scene
# -----------------------------

def inspect_scene(parsed_ex, *, min_points=8, downsample_step=5, order_points=True, max_show=3):
    road = extract_road_polylines(
        parsed_ex,
        min_points=min_points,
        downsample_step=downsample_step,
        order_points=order_points,
    )
    agents = extract_agent_start_end(parsed_ex, use_gt_future_end=True, end_mode="last_valid")

    print("=== Road stats ===")
    print(road["stats"])
    print("Example polylines:")
    for pl in road["polylines"][:max_show]:
        xyz0 = pl["xyz"][0]
        xyz1 = pl["xyz"][-1]
        print(f"  (type={pl['type']}, id={pl['id']}) points={pl['n_points']} "
              f"first=({xyz0[0]:.2f},{xyz0[1]:.2f}) last=({xyz1[0]:.2f},{xyz1[1]:.2f})")

    print("\n=== Agent start/end ===")
    sdc = [a for a in agents if a["is_sdc"]]
    print(f"valid agents: {len(agents)} | sdc count: {len(sdc)}")
    if sdc:
        a = sdc[0]
        print("SDC start (x,y,z,yaw):", a["start"])
        print("SDC end   (x,y,z,yaw):", a["end"])
    else:
        print("No SDC found in this scene (unexpected, but possible in some subsets).")
    non_sdc = [a for a in agents if not a["is_sdc"]]
    if non_sdc:
        a = non_sdc[0]
        print(a["start"],a["end"])

    return road, agents

# -----------------------------
# Usage
# -----------------------------
# parsed is your list of parsed examples
# road, agents = inspect_scene(parsed[0])


In [33]:
road, agents = inspect_scene(parsed[0])

=== Road stats ===
{'road_valid_points': 20933, 'num_groups_total': 343, 'num_polylines_kept': 327}
Example polylines:
  (type=2, id=96) points=57 first=(783.39,1452.30) last=(504.18,1454.57)
  (type=2, id=97) points=58 first=(783.72,1455.51) last=(500.70,1457.70)
  (type=2, id=99) points=41 first=(782.72,1449.13) last=(584.00,1450.62)

=== Agent start/end ===
valid agents: 9 | sdc count: 1
SDC start (x,y,z,yaw): (571.58447265625, 1450.7574462890625, -147.02467346191406, 3.134042978286743)
SDC end   (x,y,z,yaw): (410.5989685058594, 1452.0958251953125, -146.0395050048828, 3.1409289836883545)
(587.8990478515625, 1457.5389404296875, -147.53790283203125, 3.1334471702575684) (451.7677307128906, 1458.6298828125, -146.79776000976562, 3.127753496170044)
