In [20]:
import json
import pathlib
import os

import numpy as np
import plotly.express as px
from plotly import graph_objects as go
from torchvision import io
import pandas as pd
from PIL import Image

import ipywidgets as widgets
from ipywidgets import interact, fixed

In [90]:
path_to_file = pathlib.Path("data", "resnet_errors.json")

In [91]:
with open(path_to_file, "r", encoding="utf-8") as dump:
    error_per_image = json.load(dump)

In [92]:
px.histogram(tuple(error_per_image.values()))

In [93]:
errors = np.array(tuple(error_per_image.values()))

errors

In [94]:
max_error = 0.0007 #np.quantile(errors, 0.96)

In [95]:
max_error

0.0007

In [96]:
images = [image for image in error_per_image if error_per_image[image] >= max_error]

In [97]:
images.sort(key=lambda x: error_per_image[x], reverse=True)

In [98]:
px.histogram(tuple(error_per_image[image] for image in images))

In [99]:
image_dir = pathlib.Path("data", "test-train", "contest01_data", "train", "images")

In [100]:
landmarks = pd.read_csv(image_dir.parent / "landmarks.csv", index_col="file_name", sep="\t", engine="c")

In [101]:
image_list = widgets.Dropdown(options=images)

In [102]:
def show_landmarks(image_name, image_dir, landmarks, error_per_image):
    rgb_image = Image.open(os.path.join(image_dir, image_name))
    fig = go.FigureWidget(data=go.Image(z=rgb_image))
    landmarks = landmarks.loc[image_name, :].to_numpy().reshape(-1, 2)
    fig.add_trace(go.Scatter(x=landmarks[:, 0], y=landmarks[:, 1], mode='markers'))
    fig.update_layout(title=error_per_image[image_name])
    return fig


In [103]:
interact(show_landmarks, image_name=image_list, image_dir=fixed(image_dir), landmarks=fixed(landmarks), error_per_image=fixed(error_per_image))

interactive(children=(Dropdown(description='image_name', options=('10d949376b4324574a6b560efeaf110b.jpg', '133…

<function __main__.show_landmarks(image_name, image_dir, landmarks, error_per_image)>

In [104]:
len(images)

815

In [105]:
with open(path_to_file.parent / "filter_images.txt", "w", encoding="utf-8") as file:
    file.write("\n".join(images))