In [51]:
import matplotlib 
import matplotlib.pyplot as plt
import cv2
from matplotlib.widgets import Button
import pandas as pd
import os
import tkinter as tk
import glob

%matplotlib qt

def get_image_status(image_dir = 'images', image_ext = '.png'):

    image_glob = os.path.join(image_dir, f'*{image_ext}')
    images = glob.glob(image_glob, recursive=True)
    print(f'{len(images)} {image_ext} images found')

    csv_glob = os.path.join(image_dir, f'*.csv')
    csvs = glob.glob(csv_glob, recursive=True)
    print(f'{len(csvs)} annotation .csv files found')

    def process_images(fullpath):
        basename = os.path.basename(fullpath)
        has_annotations =  os.path.isfile(f'{fullpath}.csv')
        d = {'fullpath':fullpath, 'basename': basename, 'has_annotations': has_annotations}
        return d

    d = [process_images(x) for x in images]

    df_images = pd.DataFrame(d)

    return df_images


image_dir = 'images'
image_ext = '.png'
df_images = get_image_status(image_dir = image_dir,  image_ext = image_ext )

print(df_images)

2 .png images found
1 annotation .csv files found
           fullpath   basename  has_annotations
0   images\demo.png   demo.png             True
1  images\demo1.png  demo1.png            False


In [52]:


def plot_points(df_xy):
    hpoints = ax.plot(df_xy['x'], df_xy['y'],  marker = 'o', markersize = 20, c = 'r', fillstyle = 'none', linestyle = '--')
    return hpoints


def message(msg):
    global ax
    print(msg)
    ax.set_title(msg)




def onclick(event):
    global hpoints
    global df_xy
    global fig
    global ax

    print(event)
    ix, iy = event.xdata, event.ydata

    if event.inaxes != ax:
        return None

    if (ix ==None) or (iy ==None):
        return None

    if event.button ==1: # left click, add point
        ix = int(ix)
        iy = int(iy)

        msg = f'add point: x = {ix}, y = {iy}'
        message(msg)

        df_xy = df_xy._append( {'x':ix, 'y':iy} , ignore_index=True)

    if event.button ==3 and len(df_xy)>0: # right click, remove point
        d = (ix-df_xy['x'])**2 + (iy-df_xy['y'])**2
        idx = d.idxmin()
        df_xy = df_xy.drop(idx)
        msg = f'remove point: index = {idx}'
        message(msg)

    # sort points and redraw
    df_xy=df_xy.sort_values('x').reset_index(drop = True)
    try:
        hpoints.pop(0).remove()
    except:
        pass
    hpoints = plot_points(df_xy)
    print(df_xy)
    fig.canvas.draw()
    return df_xy

def save_data(event):
    global ax
    global fig
    global save_fn

    if len(df_xy)>2:
        df_xy.to_csv(save_fn)
        msg = f'saved {save_fn}'
        message(msg)
        fig.canvas.draw()

def load_image(file, ax):
    global fig
    global hpoints 
    global save_fn
    ax.clear()
    save_fn = f'{file}.csv'
    img = cv2.imread(file)

    plt.suptitle(file)

    if os.path.isfile(save_fn):
        df_xy = pd.read_csv(save_fn, header = 0, index_col = 0)
    else:
        df_xy = pd.DataFrame(columns = ['x', 'y', 'user', 'add_datetime', 'filename'])

    ax.imshow(img)
    hpoints = plot_points(df_xy)
    ax.set_title(file)
    print(f'loaded {file}')
    fig.canvas.draw()
    return df_xy

def load_image_idx(idx, ax):
    print(idx)
    file = df_images.loc[df_images.index[image_index], 'fullpath']
    df_xy = load_image(file, ax)
    return df_xy

def previous_image(idx, ax):
    global fig
    global image_index
    global hpoints 
    hpoints = []
    image_index = image_index-1
    image_index = image_index % len(df_images)
    load_image_idx(image_index, ax)
    print('previous image')


plt.close() 


spec = matplotlib.gridspec.GridSpec(ncols=2, nrows=1,
                         width_ratios=[2, 1], wspace=0.1,
                         hspace=0.1)


fig = matplotlib.pyplot.figure() 

ax = fig.add_subplot(spec[0])

ax.set_xlabel('[left click]: add point, [right click]: remove point')



image_index = 0
#df_xy = load_image_idx(image_index)
file = df_images.loc[df_images.index[image_index], 'fullpath']
df_xy = load_image(file, ax)
cid = fig.canvas.mpl_connect('button_press_event', onclick)


ax_confirm = fig.add_axes([0.7, 0.3, 0.1, 0.075])
ax_next = fig.add_axes([0.7, 0.2, 0.1, 0.075])
ax_previous = fig.add_axes([0.7, 0.1, 0.1, 0.075])

button_confirm = Button(ax_confirm, 'save')
button_previous = Button(ax_previous, 'previous')
button_next = Button(ax_next, 'next')

button_confirm.on_clicked(save_data)
button_previous.on_clicked(lambda x: previous_image(x, ax = ax))



loaded images\demo.png


0

button_press_event: xy=(251, 197) xydata=(932.3636088709677, 679.9869153225807) button=1 dblclick=False inaxes=Axes(0.125,0.319713;0.492063x0.350574)
add point: x = 932, y = 679
      x    y  user  add_datetime  filename
0   359  679   NaN           NaN       NaN
1   681  505   NaN           NaN       NaN
2   932  679   NaN           NaN       NaN
3  1368  259   NaN           NaN       NaN
button_press_event: xy=(254, 255) xydata=(948.729637096774, 363.57703629032267) button=1 dblclick=False inaxes=Axes(0.125,0.319713;0.492063x0.350574)
add point: x = 948, y = 363
      x    y  user  add_datetime  filename
0   359  679   NaN           NaN       NaN
1   681  505   NaN           NaN       NaN
2   932  679   NaN           NaN       NaN
3   948  363   NaN           NaN       NaN
4  1368  259   NaN           NaN       NaN
button_press_event: xy=(500, 97) xydata=(0.8125000000000009, 0.02777777777777768) button=1 dblclick=False inaxes=Axes(0.7,0.2;0.1x0.075)
button_press_event: xy=(493, 73) x