# Classification

## N-Ways

In [None]:
BATCH_SIZE = 5
N_CLASSES = 102

def split_label(ways=2):
    def split_fn(x, y):
        zero_mask = tf.zeros(tf.shape(y), dtype=y.dtype)

        def label_fn(slot):
            slot_size = tf.constant(N_CLASSES//ways, dtype=y.dtype)
            start, end = slot*slot_size, (slot+1)*slot_size
            start_cond = tf.math.greater_equal(y, start)
            end_cond = tf.math.less(y, end)
            slot_y = tf.where(tf.logical_and(start_cond, end_cond), y-start+1, zero_mask)

            return slot_y
        
        y = tf.map_fn(label_fn, tf.range(ways, dtype=y.dtype), dtype=y.dtype)
        y = tf.unstack(y, axis=0)
        
        return (x, tuple(y))
    
    return split_fn

tds = train_ds.batch(BATCH_SIZE).map(split_label(3))
itr = iter(tds)
next(itr)[1]

# Images and Bounding Boxes

## Bounding boxes are transformed to HW grid. HW locations are based on YX_MIN

In [None]:
IMG_SIZE = 256
# MAX_BOXES = 2
# MAX_BOXES = 5
# MAX_INPUT_BOXES = 2
MAX_BOXES = 10
MAX_INPUT_BOXES = 6

@tf.function
def preprocess_as_grid(item):
    image, boxes = item['image'], item['faces']['bbox']

    # Resize image to IMG_SIZE
    image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE])
    image = tf.cast(image, tf.float32) / 255.0

    box_grid = yxyx_to_hw_grid(boxes, IMG_SIZE)

    return image, box_grid

@tf.function
def filter_empty_bboxes(item):
    _, bboxes = item['image'], item['faces']['bbox']

    return tf.shape(bboxes)[0] != 0

# train_prep_ds = train_ds.filter(filter_empty_bboxes).map(preprocess_as_grid, num_parallel_calls=None)
# # train_prep_ds = train_ds.filter(filter_empty_bboxes).map(preprocess_as_grid, num_parallel_calls=tf.data.AUTOTUNE)
# val_prep_ds = val_ds.filter(filter_empty_bboxes).map(preprocess_as_grid, num_parallel_calls=tf.data.AUTOTUNE)
# test_prep_ds = test_ds.map(preprocess_as_grid, num_parallel_calls=tf.data.AUTOTUNE)

# # bitr = iter(train_prep_ds.batch(2))
# # images, bboxes = next(bitr)
# # images.shape, bboxes.shape

# itr = iter(train_prep_ds)
# image, y_true = next(itr)

# # display('image', image)
# display('y_true', y_true)
# itr = iter(train_ds)
# item = next(itr)
# image, boxes = item['image'], item['faces']['bbox']

## Bounding boxes are transformed to HW grid. HW locations are based on CYCX

In [None]:
IMG_SIZE = 256
# MAX_BOXES = 2
# MAX_BOXES = 5
# MAX_INPUT_BOXES = 2
MAX_BOXES = 10
MAX_INPUT_BOXES = 6

def preprocess_as_center_grid(item):
    image, boxes = item['image'], item['faces']['bbox']

    # tf.print('image: ', image.shape)
    # tf.print('boxes: ', boxes.shape)

    # Resize image to IMG_SIZE
    image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE])
    image = tf.cast(image, tf.float32) / 255.0

    # Scatter boxes in a grid of (IMG_SIZE, IMG_SIZE)
    box_grid = yxyx_to_cycxhw_grid(boxes, IMG_SIZE)

    # tf.print('box_grid: ', box_grid, box_grid.shape)

    return image, box_grid

@tf.function
def filter_empty_bboxes(item):
    _, bboxes = item['image'], item['faces']['bbox']

    return tf.shape(bboxes)[0] != 0

# train_prep_ds = train_ds.filter(filter_empty_bboxes).map(preprocess_as_center_grid, num_parallel_calls=None)
# val_prep_ds = val_ds.filter(filter_empty_bboxes).map(preprocess_as_center_grid, num_parallel_calls=tf.data.AUTOTUNE)
# test_prep_ds = test_ds.map(preprocess_as_center_grid, num_parallel_calls=tf.data.AUTOTUNE)

# bitr = iter(train_prep_ds.batch(2))
# images, y_true = next(bitr)
# images.shape, y_true.shape

# itr = iter(train_prep_ds)
# image, y_true = next(itr)

# display('image', image)
# display('y_true', y_true), tf.math.reduce_sum(mask, axis=-2)
# itr = iter(train_ds)
# item = next(itr)
# image, boxes = item['image'], item['faces']['bbox']
# image, mask = tf.function(preprocess_as_box_center_mask)(item)
# image, y_true = preprocess_as_center_grid(item)