In [1]:
import fog_x 
import os
from logging import getLogger
import numpy as np
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # or any {'0', '1', '2'}
import tensorflow as tf

class BaseLoader():
    def __init__(self, data_path):
        super(BaseLoader, self).__init__()
        self.data_dir = data_path
        self.logger = getLogger(__name__)


    def __len__(self):
        raise NotImplementedError

    def __getitem__(self, idx):
        raise NotImplementedError

    def __iter___(self):
        raise NotImplementedError
    
class RTXLoader(BaseLoader):
    def __init__(self, data_path, split = 'train[:50]'):
        super(RTXLoader, self).__init__(data_path)
        import tensorflow_datasets as tfds

        builder = tfds.builder_from_directory(data_path)

        self.ds = builder.as_dataset(split=split)
        # https://www.determined.ai/blog/tf-dataset-the-bad-parts
        # data_source = builder.as_data_source()
        # print(data_source)


    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        return idx
    
    def __iter__(self):
        return self.ds.__iter__()


In [2]:

class BaseExporter():
    def __init__(self):
        super(BaseExporter, self).__init__()
        self.logger = getLogger(__name__)

    def export(self, loader: BaseLoader, output_path: str):
        raise NotImplementedError
        


In [3]:
import av
import pickle 

class MKVExporter(BaseExporter):
    def __init__(self):
        super(MKVExporter, self).__init__()

    # Function to create a frame from numpy array
    def create_frame(self, image_array, stream):
        frame = av.VideoFrame.from_ndarray(np.array(image_array), format='rgb24')
        frame.pict_type = 'NONE'
        frame.time_base = stream.time_base
        return frame
    
    # Function to create a frame from numpy array
    def create_frame_depth(self, image_array, stream):
        image_array = np.array(image_array)
        # if float, convert to uint8
        if image_array.dtype == np.float32:
            image_array = (image_array * 255).astype(np.uint8)
        # if 3 dim, convert to 2 dim
        if len(image_array.shape) == 3:
            image_array = image_array[:,:,0]
        frame = av.VideoFrame.from_ndarray(image_array, format='gray')
        frame.pict_type = 'NONE'
        frame.time_base = stream.time_base
        return frame

    def export(self, loader: BaseLoader, output_path: str):
        # Create an output container
        i = -1
        for traj_tensor in loader:
            i += 1
            trajectory = dict(traj_tensor)
            output = av.open(f'{output_path}/output_{i}.mkv', mode='w')
            # Define video streams (assuming images are 640x480 RGB)
            video_stream_1 = output.add_stream('libx264', rate=1)
            video_stream_1.width = 640
            video_stream_1.height = 480
            video_stream_1.pix_fmt = 'yuv420p'

            video_stream_2 = output.add_stream('libx264', rate=1)
            video_stream_2.width = 640
            video_stream_2.height = 480
            video_stream_2.pix_fmt = 'yuv420p'

            # Define custom data stream for vectors
            depth_stream = output.add_stream('libx264', rate=1)

            data_stream = output.add_stream('rawvideo', rate=1)

            ts = 0
            # convert step data to stream
            for step_tensor in trajectory["steps"]:
                step = dict(step_tensor)
                obesrvation = step_tensor["observation"].copy()
                obesrvation.pop("image")
                obesrvation.pop("hand_image")
                obesrvation.pop("image_with_depth")
                non_image_data_step = step.copy()
                non_image_data_step["observation"] = obesrvation

                non_image_data_bytes = pickle.dumps(non_image_data_step)
                packet = av.Packet(non_image_data_bytes)
                packet.stream = data_stream
                packet.pts = ts
                output.mux(packet)


                image =np.array(step["observation"]["image"])
                # Create a frame from the numpy array
                frame = self.create_frame(image, video_stream_1)
                frame.pts = ts
                packet = video_stream_1.encode(frame)
                
                output.mux(packet)

                hand_image =np.array(step["observation"]["hand_image"])
                # Create a frame from the numpy array
                frame = self.create_frame(hand_image, video_stream_2)
                frame.pts = ts
                packet = video_stream_2.encode(frame)
                output.mux(packet)

                # # Create a frame from the numpy array
                frame = self.create_frame_depth(step["observation"]["image_with_depth"], depth_stream)
                # frame.pts = ts
                # Encode the frame
                packet = depth_stream.encode(frame)
                # Write the packet to the output file
                output.mux(packet)

                ts += 1

            
            # Flush the remaining frames
            for packet in video_stream_1.encode():
                output.mux(packet)

            for packet in video_stream_2.encode():
                output.mux(packet)

            for packet in depth_stream.encode():
                output.mux(packet)
            output.close()



In [4]:
class MKVLoader(BaseLoader):
    def __init__(self, data_path):
        super(MKVLoader, self).__init__(data_path)
        self.files = [data_path + f for f in os.listdir(data_path) if f.endswith('.mkv')]
        self.index = 0

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        return idx
    
    def __iter__(self):
        return self

    def __next__(self):
        if self.index < len(self.files):
            result = self.files[self.index]
            self.index += 1
            return self._parse_mkv_file(result)
        else:
            raise StopIteration
    
    def _parse_mkv_file(self, filename):
        print(filename)
        input_container = av.open(filename)
        video_stream1 = input_container.streams.video[0] 
        video_stream1.thread_type = 'AUTO'
        video_stream2 = input_container.streams.video[1] 
        video_stream2.thread_type = 'AUTO'
        depth_stream = input_container.streams.video[2] 
        depth_stream.thread_type = 'AUTO'
        data_stream = input_container.streams[3] 

        decoded_stream_1 = []
        decoded_stream_2 = []
        decoded_stream_depth = []
        decoded_stream_data = []

        pkt_counter = 0
        for packet in input_container.demux(video_stream1, video_stream2, depth_stream, data_stream):
            pkt_counter += 1
            if packet.stream.index == video_stream1.index: 
                frame = packet.decode()
                if frame:
                    for f in frame:
                        image = f.to_ndarray(format='rgb24')
                        decoded_stream_1.append(image)
            elif packet.stream.index == video_stream2.index:
                frame = packet.decode()
                if frame:
                    for f in frame:
                        image = f.to_ndarray(format='rgb24')
                        decoded_stream_2.append(image)
            elif packet.stream.index == depth_stream.index:
                frame = packet.decode()
                if frame:
                    for f in frame:
                        image = f.to_ndarray(format='gray')
                        decoded_stream_depth.append(image)
            elif packet.stream.index == data_stream.index:
                packet_in_bytes = bytes(packet)
                if packet_in_bytes:
                    non_dict = pickle.loads(packet_in_bytes)
                    decoded_stream_data.append(non_dict)
            else:
                print("Unknown stream")
        print(pkt_counter, len(decoded_stream_1), len(decoded_stream_2), len(decoded_stream_depth), len(decoded_stream_data))
        input_container.close()

        


In [5]:
number_of_samples = 50
def iterate_dataset(loader: BaseLoader, number_of_samples = 50):
    for i, data in enumerate(loader): 
        list(dict(data)["steps"])
        if i == number_of_samples:
            break

In [12]:

rtx_loader = RTXLoader(os.path.expanduser("~/datasets/berkeley_autolab_ur5/0.1.0"), split = f'train[:{number_of_samples}]')
iterate_dataset(rtx_loader)

I 2024-06-22 13:57:09,758 dataset_info.py:617] Load dataset info from /home/kych/datasets/berkeley_autolab_ur5/0.1.0
I 2024-06-22 13:57:09,797 reader.py:261] Creating a tf.data.Dataset reading 26 files located in folders: /home/kych/datasets/berkeley_autolab_ur5/0.1.0.
I 2024-06-22 13:57:09,915 logging_logger.py:49] Constructing tf.data.Dataset berkeley_autolab_ur5 for split train[:50], from /home/kych/datasets/berkeley_autolab_ur5/0.1.0


In [None]:

exporter = MKVExporter()
output_path = os.path.expanduser("~") + "/fog_x/examples/dataloader/mkv_output/"
%timeit exporter.export(rtx_loader, output_path)

In [11]:
mkv_loader = MKVLoader(output_path)
def iterate_dataset(loader: BaseLoader, number_of_samples = 50):
    for i, data in enumerate(loader): 
        if i == number_of_samples:
            break
iterate_dataset(mkv_loader, number_of_samples)

/home/kych/fog_x/examples/dataloader/mkv_output/output_48.mkv
424 105 105 105 105
/home/kych/fog_x/examples/dataloader/mkv_output/output_10.mkv
484 120 120 120 120
/home/kych/fog_x/examples/dataloader/mkv_output/output_35.mkv
488 121 121 121 121
/home/kych/fog_x/examples/dataloader/mkv_output/output_14.mkv
412 102 102 102 102
/home/kych/fog_x/examples/dataloader/mkv_output/output_8.mkv
340 84 84 84 84
/home/kych/fog_x/examples/dataloader/mkv_output/output_49.mkv
512 127 127 127 127
/home/kych/fog_x/examples/dataloader/mkv_output/output_16.mkv
436 108 108 108 108
/home/kych/fog_x/examples/dataloader/mkv_output/output_12.mkv
464 115 115 115 115
/home/kych/fog_x/examples/dataloader/mkv_output/output_24.mkv
328 81 81 81 81
/home/kych/fog_x/examples/dataloader/mkv_output/output_11.mkv
408 101 101 101 101
/home/kych/fog_x/examples/dataloader/mkv_output/output_17.mkv
388 96 96 96 96
/home/kych/fog_x/examples/dataloader/mkv_output/output_4.mkv
496 123 123 123 123
/home/kych/fog_x/examples/data