In [24]:
import os
import time
import tensorflow.compat.v1 as tf
import pickle

tf.enable_eager_execution() # No need for session to be created. Function instances are run immediately. 

from waymo_open_dataset import dataset_pb2 as open_dataset
from google.cloud import storage

import concurrent.futures as concurr

# CONFIG
project = "Waymo3DObjectDetection"
bucket_name = 'waymo_open_dataset_v_1_2_0_individual_files'
suffix = '.tfrecord'
data_destination = os.getcwd() + "/data/"
download_batch_size = 1

def download_blob(blob, c):
    """
    blob = single file name
    c = file counter
    """
    fname = f"{data_destination}blob_{c}{suffix}"
    blob.download_to_filename(fname)
    return fname

def strip_frame(frame, idx, blob_idx):
    """Strip frame from garbage such as LIDAR data"""
    
    cam_dict = {}
    for i, camera in enumerate(["FRONT", "FRONT_LEFT", "SIDE_LEFT", "FRONT_RIGHT", "SIDE_RIGHT"]):
        cam_dict[camera] = {}
        cam_dict[camera]['image'] = frame.images[i].image
        cam_dict[camera]['velocity'] = frame.images[i].velocity
        cam_dict[camera]['labels'] = frame.camera_labels[i]
        
        cam_dict[camera]['context']={'stats':frame.context.stats, 
                           'name': frame.context.name, 
                           'blob_idx':blob_idx,
                           'time_frame_idx':idx}
    return cam_dict

def save_frames(frames, blob_idx, dataset='training'):
    """Save frames into pickle format. To preprocess later"""
    for frame_idx, frame in enumerate(frames):
        for camera, camera_dict in frame.items():
            with open(f'{data_destination}{dataset}/{camera}/blob_{blob_idx}_frame_{frame_idx}.pickle', 'wb') as f:
                # Pickle the 'data' dictionary using the highest protocol available.
                pickle.dump(camera_dict, f, pickle.HIGHEST_PROTOCOL)
    return None

def load_frame(frame_idx, blob_idx, dataset='training'):
    with open(f'{data_destination}{dataset}/blob_{blob_idx}.pickle', 'rb') as f:
        # Load the 'data' dictionary using the highest protocol available.
        return pickle.load(f, pickle.HIGHEST_PROTOCOL)


# Retrieve frames from selected files to download
def get_and_strip_frames_from_one_blob(downloaded_blob, blob_idx):
    # Load into tf record dataset
    dataset = tf.data.TFRecordDataset(downloaded_blob, compression_type='')
    frames = []
    for idx, data in enumerate(dataset):
        frame = open_dataset.Frame()
        frame.ParseFromString(bytearray(data.numpy()))
        # Function to strip away LIDAR and other garbage from frame
        frame = strip_frame(frame, idx, blob_idx)
        frames.append(frame)
    return frames

def download_process_save_1_blob(blob, blob_idx, dataset='training'):
    """Like dem descriptive func names eh?"""

    print(f"Downloading blob_{blob_idx}")
    blob_fname = download_blob(blob, blob_idx)
    
#     print(f"Getting and stripping all frames from blob_{blob_idx}")
    frames = get_and_strip_frames_from_one_blob(blob_fname, blob_idx)

#     print(f"Saving frames for blob {blob_idx}")
    save_frames(frames, blob_idx, dataset)

#     print(f'No longer need tfrecord blob_{blob_idx}. Deleting now.')
    os.remove(f'data/blob_{blob_idx}.tfrecord')

    return f"blob_{blob_idx}"
          

In [28]:
# Initialise a client
storage_client = storage.Client(project= project) #storage.Client(project= "Waymo3DObjectDetection", credentials=credentials)
# Create a bucket object for our bucket
bucket = storage_client.get_bucket(bucket_name)
# Get blob files in bucket
blobs = [blob for blob in storage_client.list_blobs(bucket_name, prefix='validation/')]

n_blobs = len(blobs) # Number of blobs in the training dataset
print(f'Total number of blobs is {n_blobs}')




Total number of blobs is 202


In [29]:
# TRAINING
start = time.time()
downloaded_blobs = []

# Start from idx 610
thread_iterable = [(blob,blob_idx, 'validation') for blob_idx, blob in enumerate(blobs)]

# thread_iterable = ((blob,blob_idx, 'training') for blob_idx, blob in enumerate(blobs))

with concurr.ThreadPoolExecutor(max_workers = 2) as executor:

    results = executor.map(lambda args: download_process_save_1_blob(*args), thread_iterable)
    for r in results:
        print(f'\n Time elapsed {time.time() - start}')
        downloaded_blobs.append(r)
        
end = time.time()
print(f'Total time taken {end - start}')


Downloading blob_0
Downloading blob_1
Downloading blob_2
 Time elapsed 19.75239658355713

Downloading blob_3
 Time elapsed 21.364221811294556

Downloading blob_4
Downloading blob_5
 Time elapsed 41.881749868392944


 Time elapsed 41.884921073913574
Downloading blob_6
 Time elapsed 58.90697193145752

Downloading blob_7

 Time elapsed 61.08712816238403
Downloading blob_8
 Time elapsed 76.28627276420593

Downloading blob_9

 Time elapsed 85.01444435119629
Downloading blob_10
 Time elapsed 97.5887930393219

Downloading blob_11

 Time elapsed 101.1943039894104
Downloading blob_12
 Time elapsed 115.45931386947632

Downloading blob_13
 Time elapsed 128.10792994499207

Downloading blob_14
 Time elapsed 130.71124076843262

Downloading blob_15
Downloading blob_16

 Time elapsed 145.56687951087952

 Time elapsed 145.56696963310242
Downloading blob_17
 Time elapsed 163.14707493782043

Downloading blob_18
 Time elapsed 171.65074944496155

Downloading blob_19
 Time elapsed 176.26040601730347

Downlo

Downloading blob_155
 Time elapsed 1461.321121931076

Downloading blob_156
 Time elapsed 1469.8529164791107

Downloading blob_157
Downloading blob_158

 Time elapsed 1487.2735691070557

 Time elapsed 1487.2736513614655
Downloading blob_159
 Time elapsed 1502.7554144859314

Downloading blob_160
 Time elapsed 1505.1350963115692

Downloading blob_161
Downloading blob_162
 Time elapsed 1525.766099691391


 Time elapsed 1525.7668397426605
Downloading blob_163
 Time elapsed 1540.658563375473

Downloading blob_164

 Time elapsed 1551.5301611423492
Downloading blob_165
 Time elapsed 1559.035404920578

Downloading blob_166
 Time elapsed 1568.6766300201416

Downloading blob_167
 Time elapsed 1578.9145395755768

Downloading blob_168
 Time elapsed 1590.2959082126617

Downloading blob_169
 Time elapsed 1600.2218596935272

Downloading blob_170
 Time elapsed 1607.4462540149689

Downloading blob_171
Downloading blob_172

 Time elapsed 1624.5390865802765

 Time elapsed 1624.5392315387726
Downloading bl

In [None]:
# VALIDATION
# val blobs
blobs = [blob for blob in storage_client.list_blobs(bucket_name, prefix='validation/')]

start = time.time()
downloaded_blobs = []

thread_iterable = ((blob,blob_idx, 'validation') for blob_idx, blob in enumerate(blobs))

with concurr.ThreadPoolExecutor(max_workers = 2) as executor:

    results = executor.map(lambda args: download_process_save_1_blob(*args), thread_iterable)
    for r in results:
        print(f'\n Time elapsed {time.time() - start}')
        downloaded_blobs.append(r)
        
end = time.time()
print(f'Total time taken {end - start}')


In [12]:
import os
x = sorted(os.listdir('data/training/SIDE_LEFT'))
set(i[:9] for i in x)

{'blob_0_fr',
 'blob_100_',
 'blob_101_',
 'blob_102_',
 'blob_103_',
 'blob_104_',
 'blob_105_',
 'blob_106_',
 'blob_107_',
 'blob_108_',
 'blob_109_',
 'blob_10_f',
 'blob_110_',
 'blob_111_',
 'blob_112_',
 'blob_113_',
 'blob_114_',
 'blob_115_',
 'blob_116_',
 'blob_117_',
 'blob_118_',
 'blob_119_',
 'blob_11_f',
 'blob_120_',
 'blob_121_',
 'blob_122_',
 'blob_123_',
 'blob_124_',
 'blob_125_',
 'blob_126_',
 'blob_127_',
 'blob_128_',
 'blob_129_',
 'blob_12_f',
 'blob_130_',
 'blob_131_',
 'blob_132_',
 'blob_133_',
 'blob_134_',
 'blob_135_',
 'blob_136_',
 'blob_137_',
 'blob_138_',
 'blob_139_',
 'blob_13_f',
 'blob_140_',
 'blob_141_',
 'blob_142_',
 'blob_143_',
 'blob_144_',
 'blob_145_',
 'blob_146_',
 'blob_147_',
 'blob_148_',
 'blob_149_',
 'blob_14_f',
 'blob_150_',
 'blob_151_',
 'blob_152_',
 'blob_153_',
 'blob_154_',
 'blob_155_',
 'blob_156_',
 'blob_157_',
 'blob_158_',
 'blob_159_',
 'blob_15_f',
 'blob_160_',
 'blob_161_',
 'blob_162_',
 'blob_163_',
 'blob