In [1]:
import os
import random
import math
import numpy as np
import itertools
import time
import tensorflow.compat.v1 as tf
import torch

tf.enable_eager_execution() # No need for session to be created. Function instances are run immediately. 

from waymo_open_dataset.utils import range_image_utils
from waymo_open_dataset.utils import transform_utils
from waymo_open_dataset.utils import  frame_utils
from waymo_open_dataset import dataset_pb2 as open_dataset
from google.cloud import storage

import concurrent.futures as concurr

# CONFIG
project = "Waymo3DObjectDetection"
bucket_name = 'waymo_open_dataset_v_1_2_0_individual_files'
suffix = '.tfrecord'
data_destination = os.getcwd() + "/data/"
download_batch_size = 1

def partition(list_in, n):
    random.shuffle(list_in)
    return [list_in[i::n] for i in range(n)]

# def download_batch(blobs, batch_num=0):
#     fnames = []
#     for c, blob in enumerate(blobs):
#         fname = f"{data_destination}batch_{batch_num}file_{c}{suffix}"
#         print(blob)
#         blob.download_to_filename(fname)
#         fnames.append(fname)
#     return fnames

def download_blob(blob, c, batch_num=0):
    """
    blob = single file name
    c = file counter
    """
    fname = f"{data_destination}batch_{batch_num}file_{c}{suffix}"
    blob.download_to_filename(fname)
    print(f'File {c} of batch {batch_num} has downloaded')
    return fname

# Initialise a client
storage_client = storage.Client(project= project) #storage.Client(project= "Waymo3DObjectDetection", credentials=credentials)
# Create a bucket object for our bucket
bucket = storage_client.get_bucket(bucket_name)
# Get blob files in bucket
blobs = [blob for blob in storage_client.list_blobs(bucket_name, prefix='training/')]
# Partition files into batches
# batch_num = int(len(blobs)/download_batch_size)
# batches = partition(blobs, batch_num)

# Eventually will look like this
# for c, batch in enumerate(batches):
#     fnames = download_batch(batch, c)
#     dataset = tf.data.TFRecordDataset(fname, compression_type='')
    # DO SOMETHING




In [3]:
import pickle

def _strip_frame(frame, idx, blob_idx):
    """Strip frame from garbage such as LIDAR data"""
    
    cam_dict = {}
    for i, camera in enumerate(["FRONT", "FRONT_LEFT", "SIDE_LEFT", "FRONT_RIGHT", "SIDE_RIGHT"]):
        cam_dict[camera] = {}
#         cam_dict[camera]['image'] = torch.tensor((tf.image.decode_jpeg(frame.images[i].image)).numpy())
#         cam_dict[camera]['image'] = tf.image.decode_jpeg(frame.images[i].image)
        cam_dict[camera]['image'] = frame.images[i].image
        cam_dict[camera]['velocity'] = frame.images[i].velocity
        cam_dict[camera]['labels'] = frame.camera_labels[i]
        
    cam_dict['context']={'stats':frame.context.stats, 
                       'name': frame.context.name, 
                       'blob_idx':blob_idx,
                       'time_frame_idx':idx}
    return cam_dict

def _save_frames(frames):
    """Save frames into pickle format. To preprocess later"""
    blob_idx = frames[0]['context']['blob_idx']
    with open(f'{data_destination}pickled/blob_{blob_idx}.pickle', 'wb') as f:
        # Pickle the 'data' dictionary using the highest protocol available.
        pickle.dump(frames, f, pickle.HIGHEST_PROTOCOL)
    return None

def _load_frame(frame_idx, blob_idx):
    with open(f'{data_destination}pickled/blob_{blob_idx}.pickle', 'rb') as f:
        # Load the 'data' dictionary using the highest protocol available.
        return pickle.load(f, pickle.HIGHEST_PROTOCOL)

######################################################
def process_frame(data, idx, blob_idx):
    frame = open_dataset.Frame()
    frame.ParseFromString(bytearray(data.numpy()))
    # Function to strip away LIDAR and other garbage from frame
    return _strip_frame(frame, idx, blob_idx)
    
# Retrieve frames from selected files to download
def get_frames_from_one_blob_m_thread(downloaded_blob, blob_idx):
    # Load into tf record dataset
    frames = []
    dataset = tf.data.TFRecordDataset(downloaded_blob, compression_type='')
    dset_list = [data for data in dataset]
    idx_list = [idx for idx in range(len(dset_list))]
    with concurr.ThreadPoolExecutor(max_workers=None) as executor:
        results = executor.map(process_frame, dset_list, idx_list, blob_idx)
#             frames.append(frame)
    return results    
######################################################

# Retrieve frames from selected files to download
def get_frames_from_one_blob(downloaded_blob, blob_idx):
    # Load into tf record dataset
    dataset = tf.data.TFRecordDataset(downloaded_blob, compression_type='')
    frames = []
    for idx, data in enumerate(dataset):
        frame = open_dataset.Frame()
        frame.ParseFromString(bytearray(data.numpy()))
        # Function to strip away LIDAR and other garbage from frame
        frame = _strip_frame(frame, idx, blob_idx)
        frames.append(frame)
    return frames

In [20]:
n_blobs = len(blobs) # Number of blobs in the training dataset
print(f'Number of blobs is {n_blobs}')

now = time.time()
downloaded_blobs = []

with concurr.ThreadPoolExecutor(max_workers = 2) as executor:
    n1 = 0
    n2 = 1
    idx_list = [i for i in range(n1,n2)]
    results = executor.map(download_blob, blobs[n1:n2], idx_list)
    for r in results:
        print(f'\n Time is {time.time() - now}')
        downloaded_blobs.append(r)
        
then = time.time()
print(f'Elapsed time is {then - now}')

Number of blobs is 798
File 0 of batch 0 has downloaded

 Time is 16.485257148742676
Elapsed time is 16.486326694488525


In [23]:
# TODO
# Now we just need to do in the same loop (multi-threaded):
# 1 - download a blob
# 2 - process the frames and save that blob
# 3 - discard the blob and move to the next blob (memory efficient)

In [21]:
# SAVE THE FUCKING FRAMES, we just need to loop over blobidx and that is it 

blobidx = 0
frames = get_frames_from_one_blob(downloaded_blobs[blobidx], blobidx)
_save_frames(frames)

In [225]:
# Visualise images
import matplotlib.pyplot as plt
import matplotlib.patches as patches

def show_camera_image(camera_image, camera_labels, layout, cmap=None):
    """Show a camera image and the given camera labels."""

    ax = plt.subplot(*layout)

    # Draw the camera labels.
    for camera_label in camera_labels:
        # Ignore camera labels that do not correspond to this camera.
        if camera_label.name != camera_image.name:
            continue
            
        for label in camera_label.labels:
            # Draw the object bounding box.
            ax.add_patch(patches.Rectangle(
            xy=(label.box.center_x - 0.5 * label.box.length,
                label.box.center_y - 0.5 * label.box.width),
            width=label.box.length,
            height=label.box.width,
            linewidth=1,
            edgecolor='red',
            facecolor='none'))

    # Show the camera image.
    plt.imshow(tf.image.decode_jpeg(camera_image.image), cmap=cmap)
    plt.title(open_dataset.CameraName.Name.Name(camera_image.name))
    plt.grid(False)
    plt.axis('off')

    plt.figure(figsize=(25, 20))

for index, image in enumerate(frames[20].images):
    show_camera_image(image, frames[20].camera_labels, [3, 3, index+1])

In [None]:
# Multithread tryout

def do_something(file, idx):
    time.sleep(idx)
    return f'Slept for {idx} seconds'

now = time.time()

with concurr.ThreadPoolExecutor(max_workers=None) as executor:
    downloaded_blobs = []
    n1 = 0
    n2 = 10
    idx_list = [i for i in range(n1,n2)]
    results = executor.map(do_something, blobs[n1:n2], idx_list)
    for r in results:
        print(r)
        downloaded_blobs.append(r)
        
then = time.time()
print(f'Elapsed time is {then - now}')