In [1]:
import sys
import h5py

import pandas as pd
import numpy as np


import matplotlib.pyplot as plt

In [3]:
from pathlib import Path

from rosbags.rosbag2 import Reader
from rosbags.serde import deserialize_cdr
from rosbags.typesys import get_types_from_msg, register_types

In [4]:
from skimage.transform import resize
from utils import rgba2gray, rgba2rgb, storetable

In [8]:
class DataExtractor():
    def __init__(self, pathname, cmd_pathname = None, rgb_h=600, rgb_w=800, manual_control = False,
                 depth_h=70, depth_w=400, batch_size=200, steering_time = 3,
                 rgb_cropsize=(224,224,3), depth_cropsize=(70,200),
                 save_dir = "./train_data", save_data = False):
        self.pathname = pathname
        self.cmd_pathname = cmd_pathname
        self.rgb_h = rgb_h
        self.rgb_w = rgb_w
        self.depth_h = depth_h
        self.depth_w = depth_w
        self.manual_control = manual_control
        self.batch_size = batch_size
        self.rgb_cropsize = rgb_cropsize
        self.depth_cropsize = depth_cropsize
        self.save_dir = save_dir
        self.save_data = save_data
        self.steering_time = steering_time
        
        if self.save_data:
            Path(self.save_dir).mkdir(parents=True, exist_ok=True)
        
        self.data = {}
        
        if self.manual_control:
            self.adjust_timing()
            self.cmd()
            
        self.get_topics()
        self.initialize()
    
    def cmd(self):
        self.cmd_df = pd.read_excel(self.cmd_pathname)
        self.cmd_ptr = 0
        
    # get the first timestamp to move all the timestamps afterwards
    def adjust_timing(self):
        with Reader(self.pathname) as reader:
            for connection, timestamp, rawdata in reader.messages():
                self.ts_first = pd.to_datetime(timestamp)
                break
        
    def get_topics(self):
        self.topics = []
        self.msg_count = sys.maxsize
        with Reader(self.pathname) as reader:
            for connection in reader.connections:
                self.topics.append(connection.topic)
                if connection.msgcount < self.msg_count:
                    self.msg_count = connection.msgcount
        self.num_topics = len(self.topics)

    def initialize(self):
        self.data['labels'] = np.empty((0, 3))
        self.data['command'] = []
        for t in self.topics:
            k = t.split('ego_vehicle/')[-1]
            if k == 'rgb_front/image':
                self.data[k] = np.empty((0, self.rgb_cropsize[0], self.rgb_cropsize[1], 3))
            elif k == 'depth_front/image':
                self.data[k] = np.empty((0, self.depth_cropsize[0], self.depth_cropsize[1]))
            elif k == 'imu':
                self.data[k] = np.empty((0, 10))
            elif k == 'speedometer':
                self.data[k] = []
    
    def read_data(self):
        with Reader(self.pathname) as reader:
            count = 0
            for connection, timestamp, rawdata in reader.messages():
                k = connection.topic.split('ego_vehicle/')[-1]
                if k == 'rgb_front/image':
                    msg = self.read_rgb(rawdata, connection.msgtype)
                    msg = np.expand_dims(msg, axis=0)
                    self.data[k] = np.append(self.data[k], msg, axis=0)
                elif k == 'depth_front/image':
                    msg = self.read_depth(rawdata, connection.msgtype)
                    msg = np.expand_dims(msg, axis=0)
                    self.data[k] = np.append(self.data[k], msg, axis=0)
                elif k == 'speedometer':
                    msg = self.read_speed(rawdata, connection.msgtype)
                    self.data[k].append(msg)
                elif k == 'imu':
                    msg = self.read_imu(rawdata, connection.msgtype)
                    msg = np.expand_dims(msg, axis=0)
                    self.data[k] = np.append(self.data[k], msg, axis=0)
                elif k == 'vehicle_status':
                    labels, cmd = self.read_status(rawdata, connection.msgtype, timestamp)
                    labels = np.expand_dims(labels, axis=0)
                    self.data['labels'] = np.append(self.data['labels'], labels, axis=0)
                    self.data['command'].append(cmd)   

                count+=1
                
                if (count%(self.batch_size*self.num_topics)==0 or count==(self.msg_count*self.num_topics)):
#                     flush()
                    if self.save_data:
                        filname = f'{self.save_dir}/data_{count//self.num_topics}.h5'
                        storetable(filname, self.data)
                        self.initialize()     
                
    
    def read_rgb(self, request, msgtype):
        msg = deserialize_cdr(request, msgtype)
        img = np.reshape(msg.data, (self.rgb_h, self.rgb_w, 4))
        img = rgba2rgb(img)
        img = resize(img, self.rgb_cropsize)
        return img
        
    def read_depth(self, request, msgtype):
        msg = deserialize_cdr(request, msgtype)
        img = np.reshape(msg.data, (self.depth_h, self.depth_w, 4))
        img = rgba2gray(img)
        img = resize(img, self.depth_cropsize)
        return img
        
    def read_speed(self, request, msgtype):
        msg = deserialize_cdr(request, msgtype)
        return msg.data
    
    def read_imu(self, request, msgtype):
        msg = deserialize_cdr(request, msgtype)
        cur = np.array([
            msg.orientation.x,
            msg.orientation.y,
            msg.orientation.z,
            msg.orientation.w,
            msg.angular_velocity.x,
            msg.angular_velocity.y,
            msg.angular_velocity.z,
            msg.linear_acceleration.x,
            msg.linear_acceleration.y,
            msg.linear_acceleration.z
        ])
        return cur
    
    def read_status(self, request, msgtype, timestamp):
        msg = deserialize_cdr(request, msgtype)
        cur = np.array([
            msg.control.throttle,
            msg.control.steer, 
            msg.control.brake
        ])
        # going straight everytime
        if not self.manual_control:
            cmd = 0
        else:
            ts = pd.to_datetime(timestamp)
            sec = (ts - self.ts_first).total_seconds()
            if sec > self.cmd_df.loc[self.cmd_ptr].time - self.steering_time \
                and sec < self.cmd_df.loc[self.cmd_ptr].time + self.steering_time:
                cmd = self.cmd_df.loc[self.cmd_ptr].cmd
            else:
                cmd = 0
            # update pointer to next entry
            if self.cmd_df.loc[self.cmd_ptr].time + self.steering_time < sec:
                self.cmd_ptr += 1
                self.cmd_ptr = min(self.cmd_ptr, len(self.cmd_df) - 1)
        return cur, cmd

#     def flush(self):
        # DO SOMETHING
        # Can call this function after n-entries of data
        # Can make batch_sizes

    def get_rgb(self):
        return self.data['rgb_front/image']
    def get_depth(self):
        return self.data['depth_front/image']
    def get_imu(self):
        return self.data['imu']
    def get_speedometer(self):
        return self.data['speedometer']
    def get_cmd(self):
        return self.data['command']
    def get_label(self):
        return self.data['labels']

In [6]:
control_text = Path('ros-msg/CarlaEgoVehicleControl.msg').read_text()
status_text = Path('ros-msg/CarlaEgoVehicleStatus.msg').read_text()

add_types = {}
add_types.update(get_types_from_msg(control_text, 'carla_msgs/msg/CarlaEgoVehicleControl'))
add_types.update(get_types_from_msg(status_text, 'carla_msgs/msg/CarlaEgoVehicleStatus'))
register_types(add_types)

In [9]:
for i in range(6,9):
    print(i)
    a = DataExtractor(f'training_data/Data_{i}_10min',\
                  manual_control=False, save_data=True, save_dir=f'train_data/dataset{i}')
    a.read_data()

6
7
8
