In [None]:
import json
from collections import defaultdict
from pathlib import Path

import pandas as pd
from PIL import Image
from tqdm.notebook import tqdm
from watch_recognition.utilities import Point, BBox


In [None]:
from pprint import pprint


def generate_kp_dataset(image_dir, tags_path, save_dir):
    with tags_path.open('r') as f:
        tag_data = json.load(f)

    save_dir.mkdir(exist_ok=True, parents=True)

    records = []
    # TODO make sure that there's only a single tag data for every filename
    for tag in tqdm(tag_data):

        filename = tag['file_upload']
        # TODO correct this later
        # label-studio (Django) added hash to images that were reuploaded
        if 'IMG' not in filename and '_' in filename:
            filename = filename.split('_')[0] + '.jpg'
        if 'IMG' in filename and '-' in filename:
            filename = filename.split('-')[1]
        # if 'a1dcd9a1997a0082' not in filename:
        #     continue
        # print(filename)
        watch_faces = []
        watch_keypoints = []
        for annotation in tag['annotations']:
            # print(len(annotation['result']))
            for result in annotation['result']:
                original_width = result['original_width']
                original_height = result['original_height']
                value_ = result['value']
                if result['type'] == 'keypointlabels':
                    label = value_['keypointlabels'][0]
                    x = value_['x']
                    y = value_['y']
                    pixel_x = x / 100.0 * original_width
                    pixel_y = y / 100.0 * original_height
                    point = Point(pixel_x, pixel_y, label)
                    watch_keypoints.append(point)


                elif result['type'] == 'rectanglelabels':
                    label = value_['rectanglelabels'][0]
                    x = value_['x']
                    y = value_['y']
                    width = value_['width']
                    height = value_['height']
                    pixel_x = x / 100.0 * original_width
                    pixel_y = y / 100.0 * original_height
                    pixel_width = width / 100.0 * original_width
                    pixel_height = height / 100.0 * original_height
                    bbox = BBox(
                        pixel_x ,
                        pixel_y,
                        pixel_x + pixel_width,
                        pixel_y + pixel_height,
                        label
                    )
                    watch_faces.append(bbox)

                else:
                    print(f"unknown annotation type {result['type']}")
        box_to_keypoints = defaultdict(list)
        # print('watch_faces', len(watch_faces))
        # print('watch_keypoints', len(watch_keypoints))
        for box in watch_faces:
            for keypoint in watch_keypoints:
                if box.contains(keypoint):
                    box_to_keypoints[box].append(keypoint)
        image_id = Path(filename).stem
        with Image.open(image_dir / filename) as img:
            for i, (box, keypoints) in enumerate(box_to_keypoints.items()):
                crop_box = tuple(map(int, box.as_coordinates_tuple))
                crop_id = f"{image_id}_{i}"
                crop_file = save_dir / f"{crop_id}.jpg"
                crop = img.crop(crop_box)
                crop.save(crop_file)
                for kp in keypoints:
                    kp = kp.translate(-box.x_min, -box.y_min)
                    records.append({
                        'image_id': image_id,
                        'image_file': filename,
                        'crop_file': crop_file.name,
                        'crop_id': crop_id,
                        'label': kp.name,
                        'x': kp.x / crop.width,
                        'y': kp.y / crop.height,
                    })

    df = pd.DataFrame(records)
    print(len(df['crop_id'].unique()))
    df.to_csv(save_dir / f"tags.csv", index=False)

for split in ['train', 'validation']:
    image_dir = Path(f"../download_data/{split}")
    save_dir = Path(f"../download_data/keypoints/{split}")
    tags_path = Path(f"../download_data/{split}-tags.json")
    generate_kp_dataset(image_dir, tags_path, save_dir)

print("done")

