# Explore the dataset


In this notebook, we will perform an EDA (Exploratory Data Analysis) on the processed Waymo dataset (data in the `processed` folder). In the first part, you will create a function to display 

In [None]:
from utils import get_dataset
import matplotlib as plt
import matplotlib.patches as patches
import numpy as np
import tensorflow as tf

%matplotlic inline

In [None]:
dataset = get_dataset("./data/train/*.tfrecord")

## Write a function to display an image and the bounding boxes

Implement the `display_instances` function below. This function takes a batch as an input and display an image with its corresponding bounding boxes. The only requirement is that the classes should be color coded (eg, vehicles in red, pedestrians in blue, cyclist in green).

In [None]:
def display_instances(batch):
    """
    This function takes a batch from the dataset and display the image with 
    the associated bounding boxes.
    """
    # ADD CODE HERE
    # get the width, height and depth of the images
    width, height, depth = batch['image'].shape

    # get the groundtruth classes and the boxes
    gt_class = batch['groundtruth_classes'].numpy()
    gt_box = batch['groundtruth_boxes']     

    # plot the image using matplotlib
    fig, ax = plt.subplots(figsize=(10, 10))
    ax.imshow(batch['image'].numpy().astype('uint8'))

    # assign color for each class
    color = {
        1 : 'red',
        2 : 'blue',
        4 : 'green'
    }

    for i in range(len(gt_box)):
        ymin, xmin, ymax, xmax = gt_box[i]
        classes = gt_class[i]

        box = patches.Rectangle((xmin * width, ymin * height), (xmax - xmin)*width, 
                                (ymax - ymin)*height, edgecolor=color[classes], facecolor='none')

        ax.add_patch(box)

    plt.show()

## Display 10 images 

Using the dataset created in the second cell and the function you just coded, display 10 random images with the associated bounding boxes. You can use the methods `take` and `shuffle` on the dataset.

In [None]:
## STUDENT SOLUTION HERE
# just display only 10 images in dataset
for data in dataset.take(10):
    display_instances(data)

## Additional EDA

In this last part, you are free to perform any additional analysis of the dataset. What else would like to know about the data?
For example, think about data distribution. So far, you have only looked at a single file...

In [None]:
# I would like to know how much the number of car, pedestrian and cyclist in 10,000 images from the train file
# counting the number of car, pedestrian and cyclist

counting = {}

for data in dataset.take(10000):
    for i in range(len(data['groundtruth_classes'])):
        gt_classes = data['groundtruth_classes'].numpy()

        if gt_classes[i] in counting.keys():
            counting[gt_classes[i]] += 1
        else:
            counting[gt_classes[i]] = 1

counting