# Explore the dataset


In this notebook, we will perform an EDA (Exploratory Data Analysis) on the processed Waymo dataset (data in the `processed` folder). In the first part, you will create a function to display 

In [None]:
%matplotlib inline
from utils import get_dataset
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import Rectangle
import glob
import seaborn as sns
import pandas as pd

In [None]:
dataset = get_dataset("/data/waymo/processed/train/segment-10072231702153043603_5725_000_5745_000_with_camera_labels.tfrecord", label_map="label_map.pbtxt")

## Write a function to display an image and the bounding boxes

Implement the `display_instances` function below. This function takes a batch as an input and display an image with its corresponding bounding boxes. The only requirement is that the classes should be color coded (eg, vehicles in red, pedestrians in blue, cyclist in green).

In [None]:
def display_instances(batch):
    """
    This function takes a batch from the dataset and display the image with
    the associated bounding boxes.
    """
    image = batch["image"].numpy()
    h, w, _ = image.shape
    _, ax = plt.subplots(1, figsize=(10, 10))
    ax.imshow(image.astype(np.uint8))
    gt_boxes = batch["groundtruth_boxes"].numpy()
    gt_boxes[:, [0, 2]] *= w
    gt_boxes[:, [1, 3]] *= h
    for box in gt_boxes:
        y1, x1, y2, x2 = box
        rect = Rectangle((x1, y1), x2 - x1, y2 - y1, 
                         facecolor="none", edgecolor="r", linewidth=2)
        ax.add_patch(rect)
    
    plt.show()


## Display 10 images 

Using the dataset created in the second cell and the function you just coded, display 10 random images with the associated bounding boxes. You can use the methods `take` and `shuffle` on the dataset.

In [None]:
## STUDENT SOLUTION HERE
for _ in range(10):
    new_dataset = dataset.shuffle(64)
    for sample in new_dataset:
        display_instances(sample) # displays 1 image
        break

## Additional EDA

In this last part, you are free to perform any additional analysis of the dataset. What else would like to know about the data?
For example, think about data distribution. So far, you have only looked at a single file...

In [None]:
def display_instances_with_detail(batch):
    """
    This function takes a batch from the dataset and display the image with
    the associated bounding boxes and the corresponding class labels
    """
    label_map = {1: "vehicle", 2: "pedestrian", 3: "Signs", 4: "Cyclists"}
    color_map = {1: "xkcd:fresh green", 2: "xkcd:cherry red", 3: "xkcd:azure", 4:"xkcd:butter yellow"}
    image = batch["image"]
    h, w, _ = image.shape
    _, ax = plt.subplots(1, figsize=(10, 10))
    ax.imshow(image.astype(np.uint8))
    gt_boxes = batch["groundtruth_boxes"].copy()
    gt_boxes[:, [0, 2]] *= w
    gt_boxes[:, [1, 3]] *= h
    for box, label in zip(gt_boxes, batch["groundtruth_classes"]):
        y1, x1, y2, x2 = box
        rect = Rectangle((x1, y1), x2 - x1, y2 - y1, facecolor="none", edgecolor=color_map[label], linewidth=2)
        ax.text(x1, y1, label_map[label], color=color_map[label], size="large")
        ax.add_patch(rect)

    plt.show()


In [None]:
for _ in range(10):
    new_dataset = dataset.shuffle(64)
    sample = next(dataset.as_numpy_iterator())
    display_instances_with_detail(sample)  # displays 1 image


In [None]:
def get_dataset_statistics(batch):
    """This function returns different statistics for the entire dataset.
    For example the height and width of each bounding box
    scaled to image size"""
    image = batch["image"].numpy()
    h, w, _ = image.shape
    gt_boxes = batch["groundtruth_boxes"].numpy()
    gt_boxes[:, [0, 2]] *= w
    gt_boxes[:, [1, 3]] *= h
    dims = []
    for box in gt_boxes:
        y1, x1, y2, x2 = box
        dims.append([x2 - x1, y2 - y1])

    class_counts = {"pedestrian":0, "cyclist":0, "sign":0, "vehicle":0}
    label_map = {1: "vehicle", 2: "pedestrian", 3: "sign", 4: "cyclist"}
    gt_classes = batch["groundtruth_classes"].numpy()
    for label in gt_classes:
        class_counts[label_map[label]] += 1

    return {'dims':dims, 'classes':class_counts}


In [None]:
dim_list = []
num_images = 0
pedestrians = 0
cyclists = 0
signs = 0
vehicles = 0
# get statistics from first n frames of each tfrecord
for ds_path in glob.glob("/data/waymo/processed/test/segment*.tfrecord"):
    dataset = get_dataset(ds_path, label_map="label_map.pbtxt")
    for idx, sample in enumerate(dataset):
        num_images += 1
        stats = get_dataset_statistics(sample)
        dim_list.extend(stats['dims'])
        cyclists += stats['classes']['cyclist']
        pedestrians += stats['classes']['pedestrian']
        signs += stats['classes']['sign']
        vehicles += stats['classes']['vehicle']
        if (idx+1)%25000 == 0:
            break


## Distribution of bounding box dimensions

In [None]:
print("Total number of bounding boxes: ", len(dim_list))
_, ax = plt.subplots()
dim_list = np.array(dim_list)
sns.set_theme(style="whitegrid")
sns.scatterplot(dim_list[:, 0], dim_list[:, 1], ax=ax)
plt.xlabel("Width")
plt.ylabel("Height")
plt.title("Distribution of Bounding Box Dimensions in Test Split")
plt.show()

In [None]:
areas = dim_list[:, 0]* dim_list[:, 1]
sns.set_theme(style="ticks")
f, ax = plt.subplots(figsize=(7, 5))
sns.despine(f)
sns.histplot(
    areas,
    palette="rocket",
    edgecolor=".3",
    linewidth=7,
    log_scale=True,
)
plt.xlabel("Area")
plt.show()

## Statistics on Class Distribution

In [None]:
data = pd.DataFrame(
    {"Number of Annotations": [cyclists, pedestrians, signs, vehicles], "label": ["cyclists", "pedestrians", "signs", "vehicles"]}
)
print(cyclists)
print(pedestrians)
print(signs)
print(vehicles)
sns.color_palette("rocket", as_cmap=True)
sns.catplot("label", "Number of Annotations", kind="bar", data=data, legend=True)


In [None]:
def display_instances_show_only_area(batch, area_thresh):
    """
    This function takes a batch from the dataset and display the image with
    the associated bounding boxes but shows only the bounding boxes with areas below the threshold
    """
    plt.clf()
    image = batch["image"]
    h, w, _ = image.shape
    _, ax = plt.subplots(1, figsize=(10, 10))
    ax.imshow(image.astype(np.uint8))
    gt_boxes = batch["groundtruth_boxes"].copy()
    gt_boxes[:, [0, 2]] *= w
    gt_boxes[:, [1, 3]] *= h
    exists_bbs = True
    for box in gt_boxes:
        y1, x1, y2, x2 = box
        if (x2-x1)*(y2-y1) < area_thresh:
            rect = Rectangle((x1, y1), x2 - x1, y2 - y1, facecolor="none", edgecolor="r", linewidth=2)
            ax.add_patch(rect)
            exists_bbs = False
    if exists_bbs:
        plt.show()

for _ in range(10):
    new_dataset = dataset.shuffle(64)
    sample = next(dataset.as_numpy_iterator())
    display_instances_show_only_area(sample, 200)  # displays 1 image