In [64]:
import cv2
import numpy as np
import mediapipe as mp
import tensorflow as tf
from tensorflow.keras.models import load_model
import time
import datetime
import json
from pathlib import Path
from ipywidgets import Output, Button, Layout, HBox
from IPython.display import Image, display, clear_output

class LabelingTool:
    def __init__(self, classes: list, path: str, position = 0) -> None:
        self.classes = classes
        self.path = Path(path)
        self.images = [f.name for f in self.path.glob("*.jpg")]
        self.labeled_data = {}
        if (0 <= position < len(self.images)):
            self.position = position
        else:
            print("Position out of bound. Setting to 0 " + "number of images: " + str(len(self.images)))
            self.position = 0

    def _next_image(self, *args) -> None:
        """Select the next image and update the displays."""
        self.position += 1
        if self.position == len(self.images):
            self.position = 0

        # refresh display
        image = Image(self.path / self.images[self.position])
        clear_output(wait=True)
        with self.frame:
            clear_output(wait=True)
            display(image)
        display(self.frame)
        display(HBox(self.navigation_buttons))
        display(HBox(self.class_buttons))

    def _previous_image(self, *args) -> None:
        """Select the previous image and update the displays."""
        self.position -= 1
        if self.position == -1:
            self.position = len(self.images) - 1

        # refresh display
        image = Image(self.path / self.images[self.position])
        clear_output(wait=True)
        with self.frame:
            clear_output(wait=True)
            display(image)
        display(self.frame)
        display(HBox(self.navigation_buttons))
        display(HBox(self.class_buttons))

    def _store_labeled_data(self, button: Button) -> None:
        """Annotates the current image with the button's description."""
        # store label
        current_image = self.images[self.position]
#         display(current_image)
        mp_face_mesh = mp.solutions.face_mesh
        
        with mp_face_mesh.FaceMesh(
            max_num_faces=1,
            refine_landmarks=True,
            min_detection_confidence=0.5,
            min_tracking_confidence=0.5) as face_mesh:
            
#             print(str(self.path)+ "\\" + current_image)
            
            image = cv2.imread(str(self.path)+ "\\" + current_image)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            resultFace = face_mesh.process(image)
            if resultFace.multi_face_landmarks:
                for face_landmarks in resultFace.multi_face_landmarks:
                    if resultFace.multi_face_landmarks:
                        #                         current_time = datetime.datetime.now(datetime.timezone.utc)
                        if not (str(self.path)+"\\"+current_image) in self.labeled_data:
                            xyz = [(lm.x, lm.y, lm.z) for lm in face_landmarks.landmark]
                            self.labeled_data.update({(str(self.path)+"\\"+current_image) : (xyz, button.description)})
    #                         self.labeled_data[str(self.path)+ "\\" + current_image] = (xyzt, button.description)
                            self._next_image()
            else:
                clear_output(wait=True)
                self.no_face = []
                label_button = Button(description="No face detected")
                self.no_face.append(label_button)
                display(self.frame)
                display(HBox(self.navigation_buttons))
                display(HBox(self.no_face))
        # move on
        
    def load_labeled_data_to_file(self) -> None:
        with open((str(self.path)+"\\"+"labaled_data.json"), "r") as outfile:
            self.labeled_data = json.load(outfile)
        
    def save_labeled_data_to_file(self) -> None:
        labeled_data_json = json.dumps(self.labeled_data, indent=4)
#         print((str(self.path)+"\\"+"labaled_data.json"))
        with open((str(self.path)+"\\"+"labaled_data.json"), "w") as outfile:
            outfile.write(labeled_data_json)
        
    def start(self) -> None:
        try:
            self.load_labeled_data_to_file()
        except:
            print("No saved data")
        image = Image(self.path / self.images[self.position])
        self.frame = Output(layout=Layout(height="300px", max_width="300px"))
        with self.frame:
            display(image)

        # navigation buttons
        backward_button = Button(description="< go back")
        backward_button.on_click(self._previous_image)
        forward_button = Button(description="next >")
        forward_button.on_click(self._next_image)
        self.navigation_buttons = [backward_button, forward_button]

        # class label buttons
        self.class_buttons = []
        for label in self.classes:
            label_button = Button(description=label)
            label_button.on_click(self._store_labeled_data)
            self.class_buttons.append(label_button)

        # display contents
        display(self.frame)
        display(HBox(self.navigation_buttons))
        display(HBox(self.class_buttons))

In [65]:
tool = LabelingTool(path="wiki_crop/00", classes = ["open_mouth", "closed_mouth"])

tool.start()

Output(layout=Layout(height='300px', max_width='300px'), outputs=({'output_type': 'display_data', 'data': {'im…

HBox(children=(Button(description='< go back', style=ButtonStyle()), Button(description='next >', style=Button…

HBox(children=(Button(description='open_mouth', style=ButtonStyle()), Button(description='closed_mouth', style…

In [67]:
tool.save_labeled_data_to_file()