In [None]:
import os
os.environ['SM_FRAMEWORK'] = 'tf.keras'

In [None]:
%load_ext autoreload
%autoreload 2
from pathlib import Path
from typing import List
import tensorflow as tf
import cv2
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from watch_recognition.utilities import BBox


from PIL.Image import BICUBIC

In [None]:
kp_image_size = (224, 224)
classes = ["clock-face"]
COLORS = [(255, 0, 255)]
detection_model = (
    "./models/detection/efficientdet_lite0/run_1633100188.371347/model.tflite"
)

temp_file = "/tmp/test-image.png"
detector = tf.lite.Interpreter(model_path=detection_model)
detector.allocate_tensors()
# 'models/keypoint/efficientnetb0/run_1633543680.356453.h5'


keypoint_model = 'models/keypoint/efficientnetb0-unet-sigmoid/run_1635001687.225322'
kp_model = tf.keras.models.load_model(keypoint_model, compile=False)
rotation_model = (
    "./rotation-model-cls-4-fc-4.h5"
)
rotation_model = tf.keras.models.load_model(rotation_model, compile=False)
frames_dir = Path("./frames/")
frames_dir.mkdir(exist_ok=True, parents=True)

In [None]:
from functools import lru_cache

In [None]:
help(lru_cache)

In [None]:
from watch_recognition.utilities import Point


def detect_objects(interpreter, image, threshold):
    """Returns a list of detection results, each a dictionary of object info."""
    # Feed the input image to the model
    set_input_tensor(interpreter, image)
    interpreter.invoke()

    # Get all outputs from the model
    boxes = get_output_tensor(interpreter, 0)
    classes = get_output_tensor(interpreter, 1)
    scores = get_output_tensor(interpreter, 2)
    count = int(get_output_tensor(interpreter, 3))

    results = []
    for i in range(count):
        if scores[i] >= threshold:
            result = {
                "bounding_box": boxes[i],
                "class_id": classes[i],
                "score": scores[i],
            }
            results.append(result)
    return results


# functions to run object detector in tflite from object detector model maker


def set_input_tensor(interpreter, image):
    """Set the input tensor."""
    tensor_index = interpreter.get_input_details()[0]["index"]
    input_tensor = interpreter.tensor(tensor_index)()[0]
    input_tensor[:, :] = image


def get_output_tensor(interpreter, index):
    """Retur the output tensor at the given index."""
    output_details = interpreter.get_output_details()[index]
    tensor = np.squeeze(interpreter.get_tensor(output_details["index"]))
    return tensor


def preprocess_image(image_path, input_size):
    """Preprocess the input image to feed to the TFLite model"""
    img = tf.io.read_file(image_path)
    img = tf.io.decode_image(img, channels=3)
    img = tf.image.convert_image_dtype(img, tf.uint8)
    original_image = img
    resized_img = tf.image.resize(img, input_size)
    resized_img = resized_img[tf.newaxis, :]
    return resized_img, original_image

def _run_detector_tflite(image_path) -> List[BBox]:
    """Run object detection on the input image and draw the detection results"""
    if isinstance(image_path, (str, Path)):
        im = Image.open(image_path)
    elif isinstance(image_path, np.ndarray):
        im = Image.fromarray(image_path)
    else:
        raise ValueError(f"unrecognized image type {type(image_path)}")
    im.thumbnail((512, 512), Image.ANTIALIAS)
    # TODO skip temp file?
    im.save(temp_file, "PNG")

    # Load the input shape required by the model
    _, input_height, input_width, _ = detector.get_input_details()[0]["shape"]

    # Load the input image and preprocess it
    preprocessed_image, original_image = preprocess_image(
        temp_file, (input_height, input_width)
    )

    # Run object detection on the input image
    results = detect_objects(detector, preprocessed_image, threshold=0.5
                             )

    bboxes = []
    for obj in results:
        ymin, xmin, ymax, xmax = obj["bounding_box"]
        # Find the class index of the current object
        class_id = int(obj["class_id"])
        class_name = classes[class_id]
        score = float(obj["score"])
        bboxes.append(
            BBox(
                x_min=xmin,
                y_min=ymin,
                x_max=xmax,
                y_max=ymax,
                name=class_name,
                score=score,
            )
        )

    return bboxes


def plot_detection_on_image(original_image, results: List[BBox]):
    # Plot the detection results on the input image
    original_image_np = original_image.astype(np.uint8)
    for obj in results:
        # Convert the object bounding box from relative coordinates to absolute
        # coordinates based on the original image resolution
        xmin, ymin, xmax, ymax = obj.as_coordinates_tuple
        xmin = int(xmin * original_image_np.shape[1])
        xmax = int(xmax * original_image_np.shape[1])
        ymin = int(ymin * original_image_np.shape[0])
        ymax = int(ymax * original_image_np.shape[0])

        # Draw the bounding box and label on the image
        color = [int(c) for c in COLORS[0]]
        cv2.rectangle(original_image_np, (xmin, ymin), (xmax, ymax), color, 2)
        # Make adjustments to make the label visible for all objects
        y = ymin - 15 if ymin - 15 > 15 else ymin + 15
        label = obj.name
        cv2.putText(original_image_np, label, (xmin, y),
                    cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2)

    # Return the final image
    original_uint8 = original_image_np.astype(np.uint8)
    return original_uint8


def plot_kp_on_image(original_image, point: Point):
    # Plot the detection results on the input image
    original_image_np = original_image.astype(np.uint8)

    x, y = point.as_coordinates_tuple
    x = int(x * original_image_np.shape[1])
    y = int(y * original_image_np.shape[0])

    # Draw the bounding box and label on the image
    cv2.drawMarker(original_image_np, (x, y), (255, 0, 0), cv2.MARKER_CROSS, thickness=3)

    # Return the final image
    original_uint8 = original_image_np.astype(np.uint8)
    return original_uint8

In [None]:
detections_cache = {}

In [None]:
from PIL import ImageOps
from watch_recognition.targets_encoding import convert_mask_outputs_to_keypoints
from skimage.transform import rotate
from tqdm import tqdm
import dataclasses
from watch_recognition.models import points_to_time

file = Path('../IMG_1200.MOV')
assert file.exists()
cap = cv2.VideoCapture(str(file))

# Read until video is completed
frame_width = int(cap.get(3))
frame_height = int(cap.get(4))
out = cv2.VideoWriter(
    "outpy_12.mov",
    cv2.VideoWriter_fourcc("M", "J", "P", "G"),
    cap.get(cv2.CAP_PROP_FPS) - 10,
    (1080, 1080),
)
use_angle_model = False
frame_id = 0
# TODO profile this loop and make it faster
# TODO can we know upfront how many frames are there?
with tqdm(total=393) as pbar:
    while cap.isOpened():
        # Capture frame-by-frame
        ret, frame = cap.read()
        if ret:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            # TODO crop rectangle from the center of the frame

            x = frame.shape[1]
            y = frame.shape[0]
            x_min = int((y - x) / 2)
            frame = frame[x_min - 100 : -x_min - 100, :, :]
            if frame_id in detections_cache:
                results = detections_cache[frame_id]
            else:
                results = _run_detector_tflite(frame)
                detections_cache[frame_id] = results
            results = sorted(results, key=lambda x: x.x_min)
            new_results = []
            with Image.fromarray(frame) as pil_img:
                for i, bbox in enumerate(results):
                    try:
                        scaled_bbox = bbox.scale(pil_img.width, pil_img.height)
                        max_dim = max(scaled_bbox.width, scaled_bbox.height)
                        scaled_bbox = scaled_bbox.center_scale(
                                max_dim / scaled_bbox.width,
                                max_dim / scaled_bbox.height,
                        )
                        crop_box = tuple(map(int, scaled_bbox.as_coordinates_tuple))

                        crop = pil_img.crop(crop_box)
                        crop.save(frames_dir / f"{frame_id}_{i}.jpg")
                        # rotation
                        crop_resized = ImageOps.pad(
                            crop,
                            tuple(kp_model.inputs[0].shape[1:3]), BICUBIC
                        )
                        # crop_resized = crop.resize(
                        #     rotation_model.inputs[0].shape[1:3], BICUBIC
                        # )
                        crop_resized_np = tf.keras.preprocessing.image.img_to_array(
                            crop_resized
                        )
                        pred_angle = 0
                        if use_angle_model:
                            pred_angle = rotation_model.predict(
                                np.expand_dims(crop_resized_np, 0)
                            )
                            pred_angle = pred_angle.argmax(axis=1) * 90
                            pred_angle = pred_angle[0]

                            if pred_angle:
                                crop_np = tf.keras.preprocessing.image.img_to_array(crop)
                                rotated_crop = rotate(crop_np, -pred_angle).astype(
                                    "float32"
                                )
                                rotated_crop = tf.keras.preprocessing.image.array_to_img(
                                    rotated_crop
                                )
                                crop_rotated_resized = rotated_crop.resize(
                                    kp_model.inputs[0].shape[1:3], BICUBIC
                                )
                                crop_rotated_resized_np = (
                                    tf.keras.preprocessing.image.img_to_array(
                                        crop_rotated_resized
                                    )
                                )
                            else:
                                crop_rotated_resized_np = crop.resize(
                                    kp_model.inputs[0].shape[1:3], BICUBIC
                                )
                        else:
                            crop_rotated_resized_np = crop.resize(
                                    kp_model.inputs[0].shape[1:3], BICUBIC
                                )

                        # keypoints
                        predicted = kp_model.predict(
                            np.expand_dims(crop_rotated_resized_np, 0)
                        )[0]
                        scale_x = crop.width / predicted.shape[0]
                        scale_y = crop.height / predicted.shape[1]
                        outputs = convert_mask_outputs_to_keypoints(predicted)
                        if use_angle_model:
                            print(outputs)
                            crop_keypoints = [
                                p.rotate_around_origin_point(
                                    Point(predicted.shape[0] / 2, predicted.shape[1] / 2),
                                    -pred_angle,
                                )
                                .scale(scale_x, scale_y)
                                .translate(scaled_bbox.x_min, scaled_bbox.y_min)
                                .scale(1 / pil_img.width, 1 / pil_img.height)
                                for p in outputs
                            ]
                        else:
                            crop_keypoints = [
                                p
                                .scale(scale_x, scale_y)
                                .translate(scaled_bbox.x_min, scaled_bbox.y_min)
                                .scale(1 / pil_img.width, 1 / pil_img.height)
                                for p in outputs
                            ]
                        for kp in crop_keypoints:
                            frame = plot_kp_on_image(frame, kp)
                        mean_score = int(
                            np.round(np.mean([o.score for o in outputs]), 2) * 100
                        )
                        center, top, hour, minute = [
                            np.array(p.as_coordinates_tuple).astype(float)
                            for p in outputs
                        ]
                        read_hour, read_minute = points_to_time(
                            center, hour, minute, top
                        )
                        if use_angle_model:
                            time = f"{read_hour:.0f}:{read_minute:.0f}[{int(pred_angle)}]"
                        else:
                            time = f"{read_hour:.0f}:{read_minute:.0f}"

                        new_results.append(dataclasses.replace(bbox, name=time))
                    except Exception as e:
                        print(e, frame_id)
                        raise e
            frame = plot_detection_on_image(frame, new_results)
            #             plt.imshow(frame)
            # Write the frame into the file 'output.avi'
            frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
            out.write(frame)
        # Break the loop
        else:
            break
        frame_id += 1
        pbar.update(1)
        # break


cap.release()
out.release()
cv2.destroyAllWindows()
frame_id