# Explore the dataset


In this notebook, we will perform an EDA (Exploratory Data Analysis) on the processed Waymo dataset (data in the `processed` folder). In the first part, you will create a function to display 

In [None]:
from utils import get_dataset
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import numpy as np
import tensorflow as tf
import seaborn as sns
import pandas as pd
%matplotlib inline

In [None]:
data_dir = "data/train/*.tfrecord" 
# PS: The function get_dataset was changed in order to not shuffle the data in loading.
dataset = get_dataset(data_dir, shuffle = False, label_map='./label_map.pbtxt')

## Write a function to display an image and the bounding boxes

Implement the `display_instances` function below. This function takes a batch as an input and display an image with its corresponding bounding boxes. The only requirement is that the classes should be color coded (eg, vehicles in red, pedestrians in blue, cyclist in green).

In [None]:
def display_instances(batch, ax):
    """
    This function takes a batch from the dataset and display the image with 
    the associated bounding boxes.
    """
   
    # color mapping of classes
    colormap = {1: [1, 0, 0], 2: [0, 1, 0], 4: [0, 0, 1]}
    img = batch["image"].numpy()
    gt_bboxes = batch['groundtruth_boxes'].numpy()
    gt_classes = batch['groundtruth_classes'].numpy()    
    
    ax.imshow(img)
    
    # draw boxes
    for cl, bb in zip(gt_classes, gt_bboxes):
        y1, x1, y2, x2 = bb
        x1 *= img.shape[1]
        x2 *= img.shape[1]
        y1 *= img.shape[0]
        y2 *= img.shape[0]
        rec = Rectangle((x1, y1), x2-x1, y2-y1, facecolor='none', edgecolor=colormap[cl])
        ax.add_patch(rec)

## Display 10 images 

Using the dataset created in the second cell and the function you just coded, display 10 random images with the associated bounding boxes. You can use the methods `take` and `shuffle` on the dataset.

In [None]:
## STUDENT SOLUTION HERE## STUDENT SOLUTION HERE
shuffle_seed = 1
shuffle_value = 100

plt.figure()
layout_plt = [2, 5]
number_images = layout_plt[0] * layout_plt[1]
f, ax = plt.subplots(*layout_plt, figsize=(20, 20))
for idx, batch in enumerate(dataset.shuffle(shuffle_value, seed=shuffle_seed).take(number_images)):
    x = idx % layout_plt[0]
    y = idx % layout_plt[1]
    display_instances(batch, ax[x, y])
    ax[x ,y].axis('off')
    plt.tight_layout()

plt.show()

## Additional EDA

In this last part, you are free to perform any additional analysis of the dataset. What else would like to know about the data?
For example, think about data distribution. So far, you have only looked at a single file...

#### Classes distribution

In [None]:
# For each frame, count the classes
n_frames = 10000

vehicle_count = np.zeros(n_frames)
pedestrian_count = np.zeros(n_frames)
cyclist_count = np.zeros(n_frames)

for i, frame in enumerate(dataset.take(n_frames)):
    frame_classes = frame['groundtruth_classes'].numpy()
    values, count = np.unique(frame_classes, return_counts=True)
    for j, cl in enumerate(values):
        if cl == 1:
            vehicle_count[i] += count[j]
        if cl == 2:
            pedestrian_count[i] += count[j]
        if cl == 4:
            cyclist_count[i] += count[j]

d = {'vehicle': vehicle_count, 'pedestrian': pedestrian_count, 'cyclist': cyclist_count}
df = pd.DataFrame(data=d)

# Plot histograms of Classes distribution
fig, ax = plt.subplots(1,3, figsize=(20, 5))
fig.suptitle('Distribution of Classes per frame')
sns.histplot(ax=ax[0], data=df, x='vehicle' , kde=True, discrete=True, stat='percent')
sns.histplot(ax=ax[1], data=df, x='pedestrian' , kde=True, discrete=True, stat='percent')
sns.histplot(ax=ax[2], data=df, x='cyclist' , kde=True, discrete=True, stat='percent')

#### Light Distribution
The code below plots a histogram of the median value, V of the HSV model, of images. Notice that images with a median of less than 80 were probably gotten at night.

In [None]:
# For each frame, compute the mean of the image value channel 
n_frames = 3000
value_mean = np.zeros(n_frames)

for i, frame_ in enumerate(dataset.take(n_frames)):
    frame = frame_['image'].numpy()
    frame_hsv = colors.rgb_to_hsv(frame)
    V = frame_hsv[..., 2].flatten().tolist()
    value_mean[i] = np.median(V)

fig, ax = plt.subplots(1,1, figsize=(20, 5))
ax.set(xlabel='V')
sns.histplot(ax=ax, data=value_mean, kde=True, discrete=True, stat='count')