In [None]:
import numpy as np
import os
import dlib
import albumentations as A
import cv2
from imutils import face_utils
import tensorflow as tf
from sklearn.utils import shuffle
from tensorflow.python.data import AUTOTUNE
import yaml
import wandb

In [None]:
wandb.login(key=open('../secrets/wandb_key.txt', 'r').read(), relogin=True)
config = yaml.safe_load(open('../config.yaml', 'r'))

## Create DS

### Funcs

In [None]:
def create_dataset(dataset_path: str, image_size: int, maxi=np.Infinity):
    images = np.empty([0, image_size, image_size, 3], dtype=np.uint8)
    keypoints = np.empty([0, 68, 2], dtype=np.int16)
    p = "../shape_predictor_68_face_landmarks.dat"
    detector = dlib.get_frontal_face_detector()
    predictor = dlib.shape_predictor(p)
    directory = os.fsencode(dataset_path)

    co = 0

    transform = A.Compose(
        [A.Rotate(p=0.6, limit=15),
         #  A.RandomCrop(height=750, width=750, p=0.2),
         A.HorizontalFlip(p=0.5),
         A.ImageCompression(quality_lower=20, quality_upper=70, p=1),
         A.GaussianBlur(blur_limit=(3, 13), sigma_limit=0, p=0.8),
         A.RandomBrightnessContrast(p=0.4)
         ],
        keypoint_params=A.KeypointParams(
            format='xy', remove_invisible=False)
    )

    for file in os.listdir(directory):
        filename = os.fsdecode(file)
        image = cv2.imread(dataset_path + filename)
        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        rects = detector(gray, 0)
        for (_, rect) in enumerate(rects):
            shape = predictor(gray, rect)
            shape = face_utils.shape_to_np(shape)

            transformed = transform(image=image, keypoints=shape)
            image = transformed['image']
            image = np.array(image, dtype=np.uint8)
            shape = transformed['keypoints']
            shape = np.array(shape, dtype=np.int16) / 1024 * image_size
            shape = shape.astype(dtype=np.uint8)

            image = cv2.resize(image, (image_size, image_size),
                               interpolation=cv2.INTER_AREA)
            # image = image / 255

            image = np.expand_dims(image, axis=0)
            images = np.append(images, image, axis=0)

            shape = np.expand_dims(shape, axis=0)
            keypoints = np.append(keypoints, shape, axis=0)

            break
        co += 1
        if co > maxi:
            break
    return images, keypoints


def compress_splits(X, Y, dir):
    np.savez_compressed(dir + 'Xvalues.npz', X)
    np.savez_compressed(dir + 'Yvalues.npz', Y)


def uncompress_splits(dir: str):
    X = np.load(dir + 'Xvalues.npz')['arr_0']
    Y = np.load(dir + 'Yvalues.npz')['arr_0']
    return X, Y


def split_dataset(X, Y, test_ratio: float = 0.20):
    size = int(len(X) * test_ratio)
    return X[size:], X[:size], Y[size:], Y[:size]


def normalize_images(images):
    images /= 255
    return images


def normalize_keypoints(keypoints, image_size):
    keypoints /= image_size
    return keypoints


def preprocess(images, keypoints, image_size):
    images = normalize_images(images)
    keypoints = normalize_keypoints(keypoints, image_size)
    return images, keypoints


def fetch_ds(config, op_type='train'):
    # load dataset
    images, keypoints = uncompress_splits(config['dataset']['compressed_dir'])

    # preprocess ds
    images, keypoints = preprocess(images, keypoints, config['img_shape'])

    # split ds
    images, keypoints = shuffle(images, keypoints, random_state=0)
    train_x, test_x, train_y, test_y = split_dataset(
        images, keypoints, config['dataset']['split_ratio'])

    # put into tf.ds
    train_dataset = tf.data.Dataset.from_tensor_slices(
        (train_x, train_y))
    test_dataset = tf.data.Dataset.from_tensor_slices((test_x, test_y))

    # visualization
    # log_image_artifacts_to_wandb(data=train_ds, metadata=metadata)

    train_dataset = train_dataset.batch(config[op_type]['batch_size'])
    train_dataset = train_dataset.cache()
    train_dataset = train_dataset.prefetch(buffer_size=AUTOTUNE)

    test_dataset = test_dataset.batch(config[op_type]['batch_size'])
    test_dataset = test_dataset.cache()
    test_dataset = test_dataset.prefetch(buffer_size=AUTOTUNE)

    return train_dataset, test_dataset


In [None]:
image_size = 192
data_dir = '../data/'

BATCH_SIZE = 64
SHUFFLE_BUFFER_SIZE = 100


### Create dataset from scratch

In [None]:
with wandb.init(project=config['wandb']['project'],
           name='Dataset',
           config=config):
    i, k = create_dataset(dataset_path=data_dir, image_size=image_size)
    compress_splits(i,k, '../data/')    

### Load ds

In [None]:
train_dataset, test_dataset = fetch_ds(config, 'train')

## Model training


In [None]:
import tensorflow as tf
from keras.callbacks import EarlyStopping, ReduceLROnPlateau
from wandb.keras import WandbCallback

from src.data_tests import pass_tests_before_fitting
from src.model import compile_model
from src.dataset import fetch_ds

In [None]:
with wandb.init(project=config['wandb']['project'],
           name=config['wandb']['name'],
           config=config):
    model = compile_model(
        input_shape=config['img_shape'], output_shape=config['kp_shape'])

    callbacks = [EarlyStopping(**config['callbacks']['EarlyStopping']),
                ReduceLROnPlateau(**config['callbacks']['ReduceLROnPlateau']),
                WandbCallback(**config['callbacks']['WandbCallback'])]

    # data tests (pre-fitting)
    pass_tests_before_fitting(
        data=train_dataset, img_shape=config['img_shape'], keypoint_shape=config['kp_shape'])
    pass_tests_before_fitting(
        data=test_dataset, img_shape=config['img_shape'],  keypoint_shape=config['kp_shape'])

    # training
    history = model.fit(
        train_dataset, epochs=config['train']['epochs'], validation_data=test_dataset, callbacks=callbacks)


## Model optimization

In [None]:
import tempfile
import tensorflow as tf
import tensorflow_model_optimization as tfmot
import numpy as np
from keras.callbacks import EarlyStopping, ReduceLROnPlateau
from wandb.keras import WandbCallback


from src.data_tests import pass_tests_before_fitting
from src.model import compile_model
from src.dataset import fetch_ds


In [None]:
with wandb.init(project=config['wandb']['project'],
                name='Optimization',
                config=config):

    prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

    # Compute end step to finish pruning after 2 epochs.
    tf.random.set_seed(config['random_seed'])

    train_dataset, test_dataset = fetch_ds(config, 'optimize')

    # Define model for pruning.
    # TODO: WHAT ARE THOOOOOSE
    end_step = np.ceil(
        config['amount'] / config['optimize']['batch_size']).astype(np.int32) * config['optimize']['epochs']
    pruning_params = {
        'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
                                                                 final_sparsity=0.80,
                                                                 begin_step=0,
                                                                 end_step=end_step)
    }

    model_for_pruning = prune_low_magnitude(model, **pruning_params)

    # `prune_low_magnitude` requires a recompile.
    model_for_pruning = compile_model(
        input_shape=config['img_shape'], output_shape=config['kp_shape'], model=model)

    # model_for_pruning.summary()
    logdir = tempfile.mkdtemp()
    callbacks = [
        tfmot.sparsity.keras.UpdatePruningStep(),
        tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
    ]

    #  TODO: wandbcallback with tfds???
    callbacks = [tfmot.sparsity.keras.UpdatePruningStep(),
                 tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
                 EarlyStopping(**config['callbacks']['EarlyStopping']),
                 ReduceLROnPlateau(**config['callbacks']['ReduceLROnPlateau']),
                 WandbCallback(**config['callbacks']['WandbCallback'])]

    # data tests (pre-fitting)
    pass_tests_before_fitting(
        data=train_dataset, img_shape=config['img_shape'], keypoint_shape=config['kp_shape'])
    pass_tests_before_fitting(
        data=test_dataset, img_shape=config['img_shape'],  keypoint_shape=config['kp_shape'])

    # training
    history = model_for_pruning.fit(
        train_dataset, epochs=config['optimize']['epochs'], validation_data=test_dataset, callbacks=callbacks)

    
