In [None]:
import json
import uuid
from multiprocessing import Pool
from tqdm.auto import tqdm
import random
import cv2
import os
from timeout_decorator import timeout, TimeoutError
import matplotlib.pyplot as plt

In [None]:
with open("/output/output.json", "r") as fp:
    data = json.load(fp)
len(data)

In [None]:
originals = random.sample([os.path.join("/input", path) for path in os.listdir("/input")],25)
plt.figure(figsize=(48,48))
for i in range(25):
    plt.subplot(5,5,i+1)
    img = cv2.imread(originals[i])
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    plt.imshow(img)

In [None]:
n = 16
plt.imshow(cv2.imread(originals[n]))

In [None]:
os.remove(originals[n])

In [None]:
@timeout(600)
def _trim(item):
    size = 512
    key, values = item
    img = cv2.imread(key)
    if img is None:
        print("error loading", key)
        return
    if img.shape[0]<size or img.shape[1] < size:
        return
    for value in values:
        try:
            score = value["score"]
            if score<0.8:
                continue
            bbox = value["bbox"]
            min_x, min_y, max_x, max_y = bbox
            min_x = int(min_x)
            min_y = int(min_y)
            max_x = int(max_x)
            max_y = int(max_y)
            
            if min(max_y-min_y, max_x-min_x) < size*0.7:
                continue
            _size = random.randint(max(size, max_y-min_y, max_x-min_x), min(img.shape[0], img.shape[1], max(max_y-min_y, max_x-min_x)*3))
            
            left = random.randint(max(0, max_x-_size), min(min_x, img.shape[1]-_size))
            right = left + _size
            
            up = random.randint(max(0, max_y-_size), min(min_y, img.shape[0]-_size))
            down = up + _size
            
            assert max_x<=right and right <= img.shape[1]
            assert 0<=left and left<=min_x
            assert max_y<=down and down<=img.shape[0]
            assert 0<=up and up<=min_y
            
            img_trim = img.copy()[up:down, left:right]
            img_trim = cv2.resize(img_trim, (size, size))
            cv2.imwrite("/faces/{0}-{1}.jpg".format(str(uuid.uuid4()), str(uuid.uuid4())), img_trim)
        except Exception as e:
            print(e)
            continue
    
def trim(item):
    try:
        _trim(item)
    except TimeoutError:
        pass

In [None]:
for i in range(5):
    with Pool() as p:
        imap = p.imap(trim, list(data.items()))
        list(tqdm(imap, total=len(data)))

In [None]:
faces = random.sample([os.path.join("/faces", path) for path in os.listdir("/faces")],25)
plt.figure(figsize=(48,48))
for i in range(25):
    plt.subplot(5,5,i+1)
    img = cv2.imread(faces[i])
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    plt.imshow(img)