In [None]:
%load_ext autoreload
%autoreload 2
from pathlib import Path
from watch_recognition.data_preprocessing import load_keypoints_data
import tensorflow as tf

In [None]:
import matplotlib.pyplot as plt

In [None]:
%matplotlib inline

In [None]:
from watch_recognition.data_preprocessing import load_keypoints_data_2

image_size = (224, 224)
X, y = load_keypoints_data_2(
    Path("../download_data/keypoints/validation"),
    mask_size=(14,14),
    extent=(3,3)
)
X.shape, y.shape

In [None]:
dataset = tf.data.Dataset.from_tensor_slices((X, y))


In [None]:
from functools import partial

In [None]:
AUTOTUNE = tf.data.experimental.AUTOTUNE

In [None]:
def process_image(image, masks, img_size):
    # cast and normalize image
    # image = tf.image.convert_image_dtype(image, tf.float32)
    # # apply simple augmentations
    # image = tf.image.random_flip_left_right(image)
    # image = tf.image.resize(image, [img_size, img_size])
    return image, masks


ds_tf = dataset.map(partial(process_image, img_size=120),
                    num_parallel_calls=AUTOTUNE).batch(30).prefetch(AUTOTUNE)
ds_tf

In [None]:
def view_image(ds):
    # entry = next(iter(ds)) # extract 1 batch from the dataset
    image, mask = next(iter(ds))
    image = image.numpy()
    mask = mask.numpy()

    #     fig = plt.figure(figsize=(22, 22))
    fig, axarr = plt.subplots(5, 5, figsize=(22, 22))
    for i in range(5):
        ax = axarr[i]
        print(image[i].max())
        img = image[i]
        ax[0].imshow(img.astype('uint8'))
        ax[0].set_xticks([])
        ax[0].set_yticks([])
        ax[0].set_title("Image")
        for j, tag in enumerate(["Center", "Top", "Hour", "Minute"]):
            ax_idx = j + 1
            ax[ax_idx].imshow(mask[i, :, :, j].astype('uint8'))
            ax[ax_idx].set_xticks([])
            ax[ax_idx].set_yticks([])
            ax[ax_idx].set_title(f"Point: {tag}")



In [None]:
view_image(ds_tf)

In [None]:
from albumentations import (
    Compose, RandomBrightness, JpegCompression, HueSaturationValue, RandomContrast,
    HorizontalFlip,
    Rotate
)
from albumentations import RandomBrightnessContrast
import random
import albumentations as A

In [None]:
transforms = A.Compose([
    A.OneOf([
        A.RandomSizedCrop(min_max_height=(50, 101), height=224, width=224, p=0.5),
        A.PadIfNeeded(min_height=224, min_width=224, p=0.5)
    ], p=0.5),
    A.VerticalFlip(p=0.5),
    # A.RandomRotate90(p=0.5),
    A.Rotate(p=0.5),
    A.OneOf([
        A.GridDistortion(p=0.5),
    ], p=0.8)
    ],
    additional_targets={
        'mask0': 'mask',
        'mask1': 'mask',
        'mask2': 'mask',
        'mask3': 'mask',
    })

random.seed(11)

In [None]:
from watch_recognition.augmentations import process_data

ds_alb = dataset.map(partial(process_data, mask_size=(14, 14), image_size=(224, 224)),
                     num_parallel_calls=AUTOTUNE).prefetch(AUTOTUNE)
ds_alb

In [None]:
from watch_recognition.augmentations import set_shapes

ds_alb = ds_alb.map(
    partial(set_shapes, img_shape=(224, 224, 3), masks_shape=(14,14, 4)),
    num_parallel_calls=AUTOTUNE).batch(32).prefetch(AUTOTUNE)
ds_alb

In [None]:
view_image(ds_alb)

In [None]:
view_image(ds_alb)