In [None]:
# Grab imports
import numpy as np
import tensorflow as tf
from tensorflow import data
import cv2

In [None]:
from data_utils import data_preparer

In [None]:
# Set the configs for the data
class Configs():
    def __init__(self):
        self.data_dir = "../dataset"
        self.num_frames_sequence = 9
        self.original_image_shape = (1920,1080)
        self.processed_image_shape = (320, 128)

In [None]:
configs = Configs()
events_infor, events_labels = data_preparer(configs=configs)

In [24]:
class TTNetDataset():
    def __init__(self, events_infor, org_size, input_size):
        self.events_infor = events_infor
        self.w_org = org_size[0]
        self.h_org = org_size[1]
        self.w_input = input_size[0]
        self.h_input = input_size[1]
        self.w_resize_ratio = self.w_org / self.w_input
        self.h_resize_ratio = self.h_org / self.h_input

    def parse_images(self, images: np.ndarray):
        """Open and perform operations on all images.
        
        Parameters:
            images (np.ndarray): Array of image filepaths
        Returns:
            image_stack (np.array): Stack of processed images
        """
        # Processing if the image is a group of images.
        image_stack = []
        for image_path in images:
            image_path = tf.compat.as_str_any(image_path)
            image = cv2.imread(image_path)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image = cv2.resize(image, (self.w_input, self.h_input))
            image_stack.append(image)
        image_stack = np.dstack(image_stack)
        return image_stack

    def parse_masks(self, mask_path: np.ndarray):
        """Open and perform operations on the masks."""
        mask = cv2.imread(tf.compat.as_str_any(mask_path))
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
        mask = np.asarray(mask).astype(np.int8)
        return mask

    def coordinate_adjustment(self, ball_position: np.ndarray):
        """Change the position coordinates of the ball to scale to training."""
        print(ball_position)
        ball_position[0] = ball_position[0]/self.w_resize_ratio
        ball_position[1] = ball_position[1]/self.h_resize_ratio
        ball_position = np.asarray(ball_position, dtype=np.int32)
        print(ball_position)
        return ball_position

    def get_dataset(self):
        """Creates and zips the dataset."""
        # Separate the data and convert into lists
        events_infor = np.asarray(self.events_infor)
        image_fps = events_infor[:,0].tolist()
        ball_position = events_infor[:,1].tolist()
        target_events = events_infor[:,2].tolist()
        segmentation_fp = events_infor[:,3].tolist()
        # Convert all of the data into tensor slices
        image_ds = data.Dataset.from_tensor_slices(image_fps)
        position_ds = data.Dataset.from_tensor_slices(ball_position)
        mask_ds = data.Dataset.from_tensor_slices(segmentation_fp)
        events_ds = data.Dataset.from_tensor_slices(target_events)
        # Map the associated function to the tensor slices
        print("-------------------------------------- Mapping element spec --------------------------------------")
        # print(list(mask_ds.as_numpy_iterator()))
        print(position_ds.element_spec)

        position_ds = position_ds.map(
            lambda x: tf.numpy_function(
                self.coordinate_adjustment, inp=[x], Tout=tf.int32), 
            num_parallel_calls=data.experimental.AUTOTUNE)
        
        print(list(position_ds.as_numpy_iterator()))
        print(position_ds.element_spec)
        print("-------------------------------------- End of element spec --------------------------------------")

        """
        image_ds = image_ds.map(
            lambda x: tf.numpy_function(self.parse_images, inp=[x], Tout=[tf.uint8]),
            num_parallel_calls=data.experimental.AUTOTUNE)
        mask_ds = mask_ds.map(
            lambda x: tf.numpy_function(
                self.parse_masks, inp=[x], Tout=[tf.int8]))
        """

        # 
        # ds = data.Dataset.zip((image_ds, position_ds, mask_ds, events_ds))
        # return ds

In [25]:
ttnet_data = TTNetDataset(
    events_infor=events_infor, 
    org_size=configs.original_image_shape, 
    input_size=configs.processed_image_shape)

ttnet_dataset = ttnet_data.get_dataset()

  events_infor = np.asarray(self.events_infor)


-------------------------------------- Mapping element spec --------------------------------------
TensorSpec(shape=(2,), dtype=tf.int32, name=None)
[578 539][611 526]
[101  62]
[648 514]
[859 522]
[143  61]
[680 559]
[113  66]
[635 518]
[105  61]
[786 536]
[131  63]

[624 522]
[104  61]
[600 530]
[100  62]
[644 567]
[107  67]
[897 516]
[149  61]
[96 63]
[566 544][825 528]
[137  62]
[659 511]
[109  60]
[715 551]
[119  65]
[108  60]
[751 543][589 535]
[98 63]

[125  64]
[94 64]

[607 577]
[101  68]
[500 606][469 604]
[78 71]

[83 71]
[436 599]
[72 70]
[407 595]
[67 70]
[370 590]
[61 69]
[338 586]
[56 69]
[308 583]
[51 69]
[272 581][500 606]
[83 71]

[242 578]
[40 68]
[45 68][469 604]
[78 71]
[407 595]
[67 70]

[436 599]
[72 70]
[308 583][370 590]
[61 69]

[338 586]
[56 69]
[886 302]
[147  35]
[272 581]
[45 68]
[903 303]
[150  35]
[51 69][242 578]
[40 68]

[917 306]
[152  36]
[932 309][948 312]
[158  36]

[155  36]
[962 315]
[160  37]
[977 319]
[162  37]
[991 322]
[165  38]
[1005  326]
[

KeyboardInterrupt: 