In [431]:
import cv2
import numpy as np
import mediapipe as mp
import tensorflow as tf
from tensorflow.keras.models import load_model
import time
import datetime
import json
from pathlib import Path
from ipywidgets import Output, Button, Layout, HBox
from IPython.display import Image, display, clear_output
from matplotlib import pyplot as plt

class LabelingTool:
    def __init__(self, classes: list, path: str, position = 0) -> None:
        self.classes = classes
        self.path = Path(path)
        self.images = [f.name for f in self.path.glob("*.jpg")]
        self.labeled_data = {}
        if (0 <= position < len(self.images)):
            self.position = position
        else:
            print("Position out of bound. Setting to 0 " + "number of images: " + str(len(self.images)))
            self.position = 0
            
    def _add_landmarks(self) -> None:
        mp_face_mesh = mp.solutions.face_mesh
        mpDraw = mp.solutions.drawing_utils
        mp_drawing_styles = mp.solutions.drawing_styles
        
        with mp_face_mesh.FaceMesh(
            max_num_faces=1,
            refine_landmarks=True,
            min_detection_confidence=0.5,
            min_tracking_confidence=0.5) as face_mesh:
            
#             print(str(self.path)+ "\\" + current_image)
            
            image = cv2.imread(str(self.path)+ "\\" + self.images[self.position])
#             image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            resultFace = face_mesh.process(image)
            if resultFace.multi_face_landmarks:
                for face_landmarks in resultFace.multi_face_landmarks:
                    mpDraw.draw_landmarks(
                        image=image,
                        landmark_list=face_landmarks,
                        connections=mp_face_mesh.FACEMESH_TESSELATION,
                        landmark_drawing_spec=None,
                        connection_drawing_spec=mp_drawing_styles
                        .get_default_face_mesh_tesselation_style())
                    mpDraw.draw_landmarks(
                        image=image,
                        landmark_list=face_landmarks,
                        connections=mp_face_mesh.FACEMESH_CONTOURS,
                        landmark_drawing_spec=None,
                        connection_drawing_spec=mp_drawing_styles
                        .get_default_face_mesh_contours_style())
                    mpDraw.draw_landmarks(
                        image=image,
                        landmark_list=face_landmarks,
                        connections=mp_face_mesh.FACEMESH_IRISES,
                        landmark_drawing_spec=None,
                        connection_drawing_spec=mp_drawing_styles
                        .get_default_face_mesh_iris_connections_style())
        cv2.imwrite("temp.jpg", image)
        
    
    def _next_image(self, *args) -> None:
        """Select the next image and update the displays."""
        self.position += 1
        if self.position == len(self.images):
            self.position = 0

        # refresh display
        self._add_landmarks()
        image = Image("temp.jpg")
        o_image = Image(self.path / self.images[self.position])
        
        clear_output(wait=True)
        with self.o_frame:
            clear_output(wait=True)
            display(o_image)
        with self.frame:
            clear_output(wait=True)
            display(image)
        
        display(self.o_frame)
        display(self.frame)
        display(HBox(self.navigation_buttons))
        display(HBox(self.class_buttons))
        if (str(self.path)+"\\"+self.images[self.position]) in self.labeled_data:
            self.existing_label = []
            existing_label_button = Button(description="Label: " + self.labeled_data.get(str(self.path)+"\\"+self.images[self.position])[1])
            self.existing_label.append(existing_label_button)
            display(HBox(self.existing_label))

    def _previous_image(self, *args) -> None:
        """Select the previous image and update the displays."""
        self.position -= 1
        if self.position == -1:
            self.position = len(self.images) - 1

        # refresh display
        self._add_landmarks()
        image = Image("temp.jpg")
        o_image = Image(self.path / self.images[self.position])
        with self.o_frame:
            clear_output(wait=True)
            display(o_image)
        
        clear_output(wait=True)
        with self.frame:
            clear_output(wait=True)
            display(image)
            
        display(self.o_frame)
        display(self.frame)
        display(HBox(self.navigation_buttons))
        display(HBox(self.class_buttons))
        if (str(self.path)+"\\"+self.images[self.position]) in self.labeled_data:
            self.existing_label = []
            existing_label_button = Button(description="Label: " + self.labeled_data.get(str(self.path)+"\\"+self.images[self.position])[1])
            self.existing_label.append(existing_label_button)
            display(HBox(self.existing_label))

    def _store_labeled_data(self, button: Button) -> None:
        """Annotates the current image with the button's description."""
        # store label
        current_image = self.images[self.position]
#         display(current_image)
        mp_face_mesh_d = mp.solutions.face_mesh
        
        with mp_face_mesh_d.FaceMesh(
            max_num_faces=1,
            refine_landmarks=True,
            min_detection_confidence=0.5,
            min_tracking_confidence=0.5) as face_mesh:
            
#             print(str(self.path)+ "\\" + current_image)
            
            image = cv2.imread(str(self.path)+ "\\" + current_image)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            resultFace = face_mesh.process(image)
            if resultFace.multi_face_landmarks:
                for face_landmarks in resultFace.multi_face_landmarks:
                    if face_landmarks:
                        #                         current_time = datetime.datetime.now(datetime.timezone.utc)
#                         if not (str(self.path)+"\\"+current_image) in self.labeled_data:
                            xyz = [(lm.x, lm.y, lm.z) for lm in face_landmarks.landmark]
                            self.labeled_data.update({(str(self.path)+"\\"+current_image) : (xyz, button.description)})
    #                         self.labeled_data[str(self.path)+ "\\" + current_image] = (xyzt, button.description)
#                             self._next_image()
            else:
                clear_output(wait=True)
                self.no_face = []
                label_button = Button(description="No face detected")
                self.no_face.append(label_button)
                display(self.frame)
                display(HBox(self.navigation_buttons))
                display(HBox(self.no_face))
        # move on
    def plot_scatter(self) -> None:
        %matplotlib notebook
        data_list = list(tool.labeled_data.keys())
        # display(data_list)
        current_image=str(tool.path)+"\\"+tool.images[tool.position]
        points = np.array(tool.labeled_data.get(current_image)[0])
        fig = plt.figure(figsize=(10, 10))
        ax = fig.add_subplot(111, projection="3d")
        ax.scatter(points[:, 0], points[:, 1], points[:, 2])
        ax.set_title(
                "image: {:}, label: {:}".format(
                    data_list[-1], tool.labeled_data.get(current_image)[1]
                )
        )
        ax.set_axis_off()
        plt.show()
    
    def load_labeled_data_to_file(self) -> None:
        with open(("labaled_data.json"), "r") as outfile:
            self.labeled_data = json.load(outfile)
        
    def save_labeled_data_to_file(self) -> None:
        labeled_data_json = json.dumps(self.labeled_data, indent=4)
#         print((str(self.path)+"\\"+"labaled_data.json"))
        with open(("labaled_data.json"), "w") as outfile:
            outfile.write(labeled_data_json)
        
    def start(self) -> None:
        try:
            self.load_labeled_data_to_file()
        except:
            print("No saved data")
        o_image = Image(self.path / self.images[self.position])
        self.frame = Output(layout=Layout(height="300px", max_width="300px"))
        self.o_frame = Output(layout=Layout(height="300px", max_width="300px"))
        
        with self.o_frame:
            clear_output(wait=True)
            display(o_image)
        
        self._add_landmarks()
        image = Image("temp.jpg")
        with self.frame:
            clear_output(wait=True)
            display(image)

        # navigation buttons
        backward_button = Button(description="< go back")
        backward_button.on_click(self._previous_image)
        forward_button = Button(description="next >")
        forward_button.on_click(self._next_image)
        self.navigation_buttons = [backward_button, forward_button]

        # class label buttons
        self.class_buttons = []
        for label in self.classes:
            label_button = Button(description=label)
            label_button.on_click(self._store_labeled_data)
            self.class_buttons.append(label_button)
        display(self.o_frame)
        display(self.frame)
        # display contents
        display(HBox(self.navigation_buttons))
        display(HBox(self.class_buttons))
        if (str(self.path)+"\\"+self.images[self.position]) in self.labeled_data:
            self.existing_label = []
            existing_label_button = Button(description="Label: " + self.labeled_data.get(str(self.path)+"\\"+self.images[self.position])[1])
            self.existing_label.append(existing_label_button)
            display(HBox(self.existing_label))

In [432]:
tool = LabelingTool(path="wiki_crop/00", classes = ["open_mouth", "closed_mouth"], position=278)

tool.start()


Output(layout=Layout(height='300px', max_width='300px'), outputs=({'output_type': 'display_data', 'data': {'im…

Output(layout=Layout(height='300px', max_width='300px'), outputs=({'output_type': 'display_data', 'data': {'im…

HBox(children=(Button(description='< go back', style=ButtonStyle()), Button(description='next >', style=Button…

HBox(children=(Button(description='open_mouth', style=ButtonStyle()), Button(description='closed_mouth', style…

In [458]:
# display(tool.labeled_data.get("wiki_crop\\00\\12170800_1982-12-15_2013.jpg")[1])
display(tool.position)

315

In [522]:
display(tool.position)
tool.save_labeled_data_to_file()
tool.plot_scatter()

420

<IPython.core.display.Javascript object>

In [408]:
tool.save_labeled_data_to_file()

In [407]:
%matplotlib notebook
data_list = list(tool.labeled_data.keys())
# display(data_list)
current_image=str(tool.path)+"\\"+tool.images[tool.position]
points = np.array(tool.labeled_data.get(current_image)[0])
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection="3d")
ax.scatter(points[:, 0], points[:, 1], points[:, 2])
ax.set_title(
        "image: {:}, label: {:}".format(
            data_list[-1], tool.labeled_data.get(current_image)[1]
        )
)
ax.set_axis_off()
plt.show()

# display(points)
# display(str(tool.path)+ "\\" + tool.images[tool.position-1])
# display(data_list.index(str(tool.path)+ "\\" + tool.images[tool.position-1]))

263

<IPython.core.display.Javascript object>

In [202]:
tool.labeled_data.pop(current_image)

([(0.3752659559249878, 0.6305263042449951, -0.02217859774827957),
  (0.31992027163505554, 0.5875664949417114, -0.056693900376558304),
  (0.35709500312805176, 0.5961862206459045, -0.02466323785483837),
  (0.3107072710990906, 0.5202952027320862, -0.0372917577624321),
  (0.3118478059768677, 0.5695884823799133, -0.06252408027648926),
  (0.3125709891319275, 0.541977047920227, -0.06098701059818268),
  (0.3244013786315918, 0.46655312180519104, -0.041284073144197464),
  (0.27751022577285767, 0.46571600437164307, 0.09326519072055817),
  (0.32009878754615784, 0.4184405207633972, -0.0412781797349453),
  (0.3112500309944153, 0.3950541019439697, -0.04729628562927246),
  (0.30046072602272034, 0.29384490847587585, -0.052521053701639175),
  (0.377672016620636, 0.6379831433296204, -0.019771777093410492),
  (0.3820529580116272, 0.6446865200996399, -0.01486701238900423),
  (0.3871069550514221, 0.6488887071609497, -0.006835169158875942),
  (0.39202436804771423, 0.6652513146400452, -0.004633735399693251),


<IPython.core.display.Javascript object>