# Explore the dataset


In this notebook, we will perform an EDA (Exploratory Data Analysis) on the processed Waymo dataset (data in the `processed` folder). In the first part, you will create a function to display 

In [17]:
from utils import get_dataset
import IPython.display as display
import matplotlib
# Necessary to see matplotlib outside of container
import tkinter
matplotlib.use('TKAgg')

In [18]:
# dataset = get_dataset("/mnt/data/processed/*.tfrecord")

## Write a function to display an image and the bounding boxes

Implement the `display_instances` function below. This function takes a batch as an input and display an image with its corresponding bounding boxes. The only requirement is that the classes should be color coded (eg, vehicles in red, pedestrians in blue, cyclist in green).

In [19]:
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import IPython.display as display
import glob
from helpers.visualization import visualize_tf_record_dataset
from helpers.exploratory_analysis import display_structure_of_dataset_item, show_dataset_basics


def display_instances(batch):
    """
    This function takes a batch from the dataset and display the image with 
    the associated bounding boxes.
    """
    # ADD CODE HERE
    for idx, sample in enumerate(batch):
        print(f"Printing image {idx}")        
        plt.imshow(sample['image'].numpy())


def parse_record(record):
    '''Function to parse one record.'''
    image_feature_description = {
        'image/height': tf.io.FixedLenFeature([], tf.int64),
        'image/width': tf.io.FixedLenFeature([], tf.int64),
        'image/filename': tf.io.FixedLenFeature([], tf.string),
        'image/source_id': tf.io.FixedLenFeature([], tf.string),
        'image/encoded': tf.io.FixedLenFeature([], tf.string),
        'image/format': tf.io.FixedLenFeature([], tf.string),
        'image/object/bbox/xmin': tf.io.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/xmax': tf.io.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/ymin': tf.io.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/ymax': tf.io.VarLenFeature(dtype=tf.float32),
        'image/object/class/text': tf.io.VarLenFeature(dtype=tf.string),
        'image/object/class/label': tf.io.VarLenFeature(dtype=tf.int64),
    }
    return tf.io.parse_single_example(record, image_feature_description)


def transform_record(record):
    ret_dict = {
        'image': tf.image.decode_image(record['image/encoded']).numpy(),
        'filename': record['image/filename'].numpy(),
        'width': record['image/width'].numpy(),
        'height': record['image/height'].numpy(),
        'classes_text': record['image/object/class/text'].values.numpy(),
        'classes': record['image/object/class/label'].values.numpy()
    }
    # boxes in data as [0,1], has to be multiplied with width / height
    boxes_odd = np.array([
        record['image/object/bbox/ymin'].values.numpy()*ret_dict['height'],
        record['image/object/bbox/xmin'].values.numpy()*ret_dict['width'],
        record['image/object/bbox/ymax'].values.numpy()*ret_dict['height'],
        record['image/object/bbox/xmax'].values.numpy()*ret_dict['width'],
    ])
    ret_dict['boxes'] = [boxes_odd[:, idx]
                         for idx in range(boxes_odd.shape[1])]
    return ret_dict


def project1_visualize_inspect(tf_record_path_array):
    '''Function to visualize and inspect dataset according to project1'''
    raw_image_dataset = tf.data.TFRecordDataset(tf_record_path_array)
    display_structure_of_dataset_item(raw_image_dataset)

    parsed_image_dataset = raw_image_dataset.map(parse_record)
    # Transform to numpy/python if wanted
    transformed_dataset = [transform_record(
        element) for element in parsed_image_dataset]

    show_dataset_basics(transformed_dataset)

    # Debug Visu
    if False:
        for image_element in transformed_dataset[0:2]:
            plt.imshow(image_element['image'])
            plt.show()
            # display.display(display.Image(data=image_element['image']))

    # Visualize
    visualize_tf_record_dataset(
        transformed_dataset,
        n_show=100,
        x_max=3, y_max=4,
        show_gt_class_names=True,
        class_names=['', 'car', 'pedestrian', '', 'bike'])

## Display 10 images 

Using the dataset created in the second cell and the function you just coded, display 10 random images with the associated bounding boxes. You can use the methods `take` and `shuffle` on the dataset.

In [None]:
## STUDENT SOLUTION HERE
all_tf_records = glob.glob('/mnt/data/processed/*.tfrecord')
project1_visualize_inspect(all_tf_records)  # all_tf_records[0:5]

## Additional EDA

In this last part, you are free to perform any additional analysis of the dataset. What else would like to know about the data?
For example, think about data distribution. So far, you have only looked at a single file...