In [1]:
# Load Image
# Editor appearance set up & Load image

# Extend width of Jupyter Notebook Cell to the size of browser
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

# Import packages needed
import gc
from math import cos, sin, radians
import pickle
import platform
from tkinter import Tk, simpledialog, messagebox
from tkinter.filedialog import askopenfilename, asksaveasfilename

from ipywidgets import widgets
from math import atan, degrees
import matplotlib
from matplotlib.lines import Line2D
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.widgets import RectangleSelector, PolygonSelector
from matplotlib.ticker import AutoMinorLocator, MultipleLocator
import numpy as np
from osgeo import gdal
import pandas as pd
from scipy import ndimage
from skimage import io

import general_funcs


# OS related settings
if platform.system() == 'Windows':
    print('Windows')
#     %matplotlib tk
    %matplotlib qt
elif platform.system() == 'Darwin':
    print('macOS')
    Tk().withdraw()
    %matplotlib osx
elif platform == 'linux' or platform == 'linux2':
    print('Linux')
# This line of "print" must exist right after %matplotlib command, otherwise JN will hang on the first import statement after this.
print('Interactive plot activated')


# Load image & pre-process
# Tk().withdraw()
image_file = askopenfilename(title='Load image file', initialdir='./data/field_image')
img = io.imread(image_file)
print('Image File:', image_file)
print("Image Shape:", img.shape)

# Extract layers from the multilayer tiff file and create low res counter part for faster plotting
scale_factor_RGB = 4
scale_factor_IR = 1

h, w, d = img.shape
layer_RGB, layer_IR, layer_mask = general_funcs.extract_layers(img)
if d == 5:
    layer_RGB_low_res = general_funcs.low_res(layer_RGB, scale_factor_RGB)
#     layer_mask_low_res = general_funcs.low_res(layer_mask, scale_factor_RGB)
    scale_factor = scale_factor_RGB
elif d == 2:
#     layer_IR_low_res = general_funcs.low_res(layer_IR, scale_factor_IR)
#     layer_mask_low_res = general_funcs.low_res(layer_mask, scale_factor_IR)
    layer_IR_low_res = layer_IR
    scale_factor = scale_factor_IR
else:
    print('Unsupported layer number')
    

# Load geo info & calculate pixel size
ds = gdal.Open(image_file)
gt = ds.GetGeoTransform()
lon_meter_per_pix, lat_meter_per_pix = general_funcs.meter_per_pix(gt)
lon_meter_per_pix_low_res = lon_meter_per_pix * scale_factor
lat_meter_per_pix_low_res = lat_meter_per_pix * scale_factor

# Reduce memory usage
del(img)

Windows
Interactive plot activated
Image File: D:/rgb-ir_field_image_processing/data/field_image/BRC/20190828_140117.tif
Image Shape: (12586, 40599, 5)


In [2]:
# Temporary adjustment for plot size

def on_click_size_adjustment(event):
    global verts_size_adjustment
    if tb.mode == '':
        x, y = event.xdata, event.ydata
        verts_size_adjustment.append([x, y])
        if len(verts_size_adjustment) == 2:
            
            actual_length_in_meter = float(textbox_actual_length.value)
            
            verts_size_adjustment = np.asarray(verts_size_adjustment)
            length_in_pix = np.linalg.norm(verts_size_adjustment[0, :]-verts_size_adjustment[1, :])
            length_in_meter = length_in_pix * lon_meter_per_pix_low_res
            
            length_adjustment_coef = np.round(length_in_meter / actual_length_in_meter, 1)
            
            print('Lenght ajustment coefficient:', length_adjustment_coef)
            print('In the process below, when asked for length, time it by *', length_adjustment_coef, '* before input.')
            print('e.g. for 2 (meter), input ', 2*length_adjustment_coef)
            plt.close('all')
    
verts_size_adjustment = []
fig, ax = plt.subplots(figsize=(9, 9))
fig.canvas.set_window_title('Draw a line and tell me it\'s actual length')
tb = fig.canvas.toolbar
if d == 5:
    myax = ax.imshow(layer_RGB_low_res)
elif d == 2:
    # Set vmin & vmax to avoid colorgrading 0 values
    mask_not_0_inds = np.where(layer_mask == 1)
    vmin, vmax = general_funcs.cal_vmin_vmax(layer_IR, layer_mask)
    myax = ax.imshow(layer_IR_low_res, cmap='gist_gray', vmin=vmin, vmax=vmax)
else:
    print('Unsupported layer number')

ps = PolygonSelector(ax, general_funcs.do_nothing, useblit=True)

text_layout = widgets.Layout(width='95%')
textbox_actual_length = widgets.Text(placeholder='What\'s the acutal length of the line in meter?', layout=text_layout)
display(textbox_actual_length)

cid_size_adjustment = fig.canvas.mpl_connect('button_press_event', on_click_size_adjustment)

Text(value='', layout=Layout(width='95%'), placeholder="What's the acutal length of the line in meter?")

Traceback (most recent call last):
  File "C:\Users\lj_ji\anaconda3\envs\hyperspect\lib\site-packages\matplotlib\cbook\__init__.py", line 224, in process
    func(*args, **kwargs)
  File "<ipython-input-2-e218fbef7ae8>", line 10, in on_click_size_adjustment
    actual_length_in_meter = float(textbox_actual_length.value)
ValueError: could not convert string to float: ''


In [3]:
# Rotate Image


def on_click_rot(event):
    global rot_degree, verts_rot_line
    if tb.mode == '':
        x, y = event.xdata, event.ydata
        verts_rot_line.append((x, y))
        if len(verts_rot_line) == 2:
            rot_degree = calc_rot(verts_rot_line)
            apply_rot(rot_degree)
            print('Rotation degree:', rot_degree)
            plt.close('all')
        
def calc_rot(verts):
    verts = np.array(verts)
    vert1 = verts[0, :]
    vert2 = verts[1, :]
    diff = vert2 - vert1
    rot_rad = atan(diff[1]/diff[0])
    rot_degree = degrees(rot_rad)
    return rot_degree
    
    
def apply_rot(rot_degree):
    global layer_RGB_low_res_rot, layer_IR_low_res_rot, rot_center
    if d == 5:
        layer_RGB_low_res_rot = ndimage.rotate(layer_RGB_low_res, rot_degree, reshape=False)
        rot_center = np.flip(np.asarray(layer_RGB_low_res_rot.shape)[0:2] / 2)
    elif d == 2:
        layer_IR_low_res_rot = ndimage.rotate(layer_IR_low_res, rot_degree, reshape=False)
        rot_center = np.flip(np.asarray(layer_IR_low_res_rot.shape)[0:2] / 2)
    else:
        print('Unsupported layer number')
        
def key_press_horizontal_line(event):
    global verts_rot_line
    if event.key == 'escape':
        verts_rot_line = []

rot_degree = 0
verts_rot_line = []

plt.close('all')
fig_size = general_funcs.fig_size(h, w)
fig, ax = plt.subplots(figsize=fig_size)
fig.canvas.set_window_title('Set horizontal line')
fig.canvas.mpl_connect('key_press_event', key_press_horizontal_line)
plt.tight_layout()
tb = fig.canvas.toolbar
if d == 5:
    myax = ax.imshow(layer_RGB_low_res)
elif d == 2:
    # Set vmin & vmax to avoid colorgrading 0 values
    mask_not_0_inds = np.where(layer_mask == 1)
    vmin, vmax = general_funcs.cal_vmin_vmax(layer_IR, layer_mask)
    myax = ax.imshow(layer_IR_low_res, cmap='gist_gray', vmin=vmin, vmax=vmax)
else:
    print('Unsupported layer number')

    

ps = PolygonSelector(ax, general_funcs.do_nothing, useblit=True)

cid_horizotal = fig.canvas.mpl_connect('button_press_event', on_click_rot)

Rotation degree: -12.586027176993067


In [4]:
# Show Image
    
def save_area(button):
    global interested_area, ul_x, ul_y, ul_x_low_res_rot, ul_y_low_res_rot, original_extent, original_extent_low_res, grid_extent, interested_area_low_res, interested_area_low_res_rot
    
    if rs.center == (0, 0.5):
        # Ask user to crop the image if they haven't done it
        print('Please select area')
        return
    else:
        interested_area_low_res_rot = np.round(np.asarray(rs.corners)).astype(int)
        interested_area_low_res = general_funcs.undo_rotation(interested_area_low_res_rot, rot_degree, rot_center)
        interested_area = interested_area_low_res * scale_factor
        interested_area = interested_area.astype(int)
        
        ul_x_low_res_rot, ul_y_low_res_rot = interested_area_low_res_rot[:, 0]
        br_x_low_res_rot, br_y_low_res_rot = interested_area_low_res_rot[:, 2]
        
        ul_x, ul_y = interested_area[:, 0]
        br_x, br_y = interested_area[:, 2]
    
    button_set1 = widgets.HBox(children=[radio_plot_form, button_load, button_create_plot, button_remove, button_done], layout=box_layout)
#     button_set2 = widgets.HBox(children=[button_modify, ], layout=box_layout)
#     button_set3 = widgets.HBox(children=[], layout=box_layout)
    all_widgets.children = [textbox_plot_name, textbox_notes, button_set1, textlabel_info, textbox_info]
    fig.canvas.mpl_disconnect(cid_save_area)
    fig.canvas.mpl_connect('key_press_event', key_press_plot_selection)
    
    if d == 5 or d == 4:
        ax.imshow(layer_RGB_low_res_rot[ul_y_low_res_rot:br_y_low_res_rot, ul_x_low_res_rot:br_x_low_res_rot])
    elif d == 2:
        ax.imshow(layer_IR_low_res_rot[ul_y_low_res_rot:br_y_low_res_rot, ul_x_low_res_rot:br_x_low_res_rot], cmap='gist_gray', vmin=vmin, vmax=vmax)

    rs.set_active(False)
    fig.canvas.set_window_title('Set Plots')

    
def choose_plot_form(sender):
    global plot_form_ind, standard_plot_size_meter, standard_plot_size_pix_low_res
    plot_form_ind = sender.owner.options.index(sender.new)
    
    if plot_form_ind == 0 or plot_form_ind == 2:
        if standard_plot_size_meter == [0, 0]:
            ROOT = Tk()
            ROOT.withdraw()
            user_input = simpledialog.askstring(title="Plot size (m)", prompt="e.g. For a 2m(Horizontal) by 1m(Vertical) size plot, input \"2 1\"")
            standard_plot_size_meter = [float(side) for side in user_input.split()]
            standard_plot_size_pix_low_res = [i/lon_meter_per_pix_low_res for i in standard_plot_size_meter]

    if plot_form_ind == 0:
        free_form_off()
        batch_off()
        standard_form_on()
            
    elif plot_form_ind == 1:
        standard_form_off(remove_lines=True)
        batch_off()
        free_form_on()
        
    elif plot_form_ind == 2:
        standard_form_off(remove_lines=True)
        free_form_off()
        batch_on()
        
        
def standard_form_on():
    global ps_standard_form, cid_ps_standard
    ps_standard_form = PolygonSelector(ax, on_click_ps_standard, useblit=True, lineprops=dict(linewidth=0), markerprops=dict(markersize=5, mec='r', mfc='y', alpha=0.5))
    cid_ps_standard = fig.canvas.mpl_connect('button_press_event', on_click_ps_standard)
    
def standard_form_off(remove_lines):
    if 'ps_standard_form' in globals():
        ps_standard_form.set_active(False)
    if 'cid_ps_standard' in globals():
        fig.canvas.mpl_disconnect(cid_ps_standard)
    if remove_lines:
        remove_from_ax_component(ax.lines)
        fig.canvas.draw()

def remove_from_ax_component(ax_component):
    for i in range(len(ax_component)):
        ax_component[0].remove();
    fig.canvas.draw()
    
def free_form_on():
    global ps_free_form
    ps_free_form = PolygonSelector(ax, ps_free_form_onselect, useblit=True, lineprops=dict(color=matplotlib.colors.to_rgba('orange', 0.5), linestyle='-', linewidth=2),  markerprops=dict(markersize=5, mec='r', mfc='y', alpha=0.5))

def free_form_off():
    if 'ps_free_form' in globals():
        ps_free_form.set_active(False)

def batch_on():
    global rs_batch
    rs_batch = RectangleSelector(ax, set_batch_plot, drawtype='box', useblit=True, button=[1], minspanx=50, minspany=50,
                      rectprops=dict(facecolor='red', edgecolor='black', alpha=0, fill=True), spancoords='pixels', interactive=True)
    ax.patches[-1].remove()
    
def batch_off():
    if 'rs_batch' in globals():
        rs_batch.set_active(False)
    remove_from_ax_component(ax.lines)
        
def set_batch_plot(eclick, erelease):
    global batch_plot_num, batch_mode, top_verts, bot_verts
#     ROOT = Tk()
# #     ROOT.withdraw()
#     simpledialog.askstring(title="Batch plot num", prompt="How many plots in the batch?")
#     print('????')
#     batch_plot_num = int(simpledialog.askstring(title="Batch plot num", prompt="How many plots in the batch?", initialvalue=batch_plot_num))
    
    x1, y1 = eclick.xdata, eclick.ydata
    x2, y2 = erelease.xdata, erelease.ydata
    
    ul = [min(x1, x2), min(y1, y2)]
    br = [max(x1, x2), max(y1, y2)]
    
    remove_from_ax_component(ax.lines)
    
    if br[0] - ul[0] > br[1] - ul[1]:
        batch_mode = 'horizontal'
        
        top_verts = [[ul[0], ul[1]]]
        bot_verts = [[ul[0], br[1]]]
        step = (br[0] - ul[0])/batch_plot_num
        for i in range(batch_plot_num):
            top_verts.append([ul[0]+(i+1)*step, ul[1]])
            bot_verts.append([ul[0]+(i+1)*step, br[1]])
            
        lineprops = dict(color=matplotlib.colors.to_rgba('orange', 0.5), linestyle='-', linewidth=2)

        for i in range(batch_plot_num+1):
            line = Line2D([top_verts[i][0], bot_verts[i][0]], [top_verts[i][1], bot_verts[i][1]], **lineprops)
            ax.add_line(line)
        fig.canvas.draw()
    else:
        batch_mode = 'vertical'
        left_verts = [[ul[0], ul[1]]]
        right_verts = [[br[0], ul[1]]]
        step = (br[1] - ul[1])/batch_plot_num
        for i in range(batch_plot_num):
            left_verts.append([ul[0], ul[1]+(i+1)*step])
            right_verts.append([br[0], ul[1]+(i+1)*step])
            
        lineprops = dict(color=matplotlib.colors.to_rgba('orange', 0.5), linestyle='-', linewidth=2)

        for i in range(batch_plot_num+1):
            line = Line2D([left_verts[i][0], right_verts[i][0]], [left_verts[i][1], right_verts[i][1]], **lineprops)
            ax.add_line(line)
        fig.canvas.draw()
        
def calc_corners(ul_corner, sides):
    corners = [[ul_corner[0], ul_corner[1]],
               [ul_corner[0] + sides[0], ul_corner[1]],
               [ul_corner[0] + sides[0], ul_corner[1] + sides[1]], 
               [ul_corner[0], ul_corner[1] + sides[1]],
               [ul_corner[0], ul_corner[1]],]
    return corners
        
def on_click_ps_standard(event):
    global tb, standard_plot_size_meter, standard_plot_size_pix_low_res, one_plot_vertices_low_res_rot_crop, one_plot_vertices, ul_x_low_res_rot, ul_y_low_res_rot
    if tb.mode == '':
        ul_corner = np.asarray([event.xdata, event.ydata])
        
        corners_low_res_rot = calc_corners(ul_corner, standard_plot_size_pix_low_res)
        corners_low_res_rot = np.asarray(corners_low_res_rot)

        lineprops = dict(color=matplotlib.colors.to_rgba('orange', 0.5), linestyle='-', linewidth=2)
        for i in range(4):
            line = Line2D(corners_low_res_rot[i:i+2, 0], corners_low_res_rot[i:i+2, 1], **lineprops)
            ax.add_line(line)
            
        one_plot_vertices_low_res_rot_crop = corners_low_res_rot[0:4, :]
        one_plot_vertices_low_res_rot = one_plot_vertices_low_res_rot_crop + np.asarray([ul_x_low_res_rot, ul_y_low_res_rot])
        one_plot_vertices_low_res = general_funcs.undo_rotation(one_plot_vertices_low_res_rot.transpose(), rot_degree, rot_center)
        one_plot_vertices = one_plot_vertices_low_res.transpose() * scale_factor
        one_plot_vertices = one_plot_vertices.astype(int)
        fig.canvas.draw()
        
        standard_form_off(remove_lines=False)

        
def ps_free_form_onselect(verts):
    global one_plot_vertices_low_res_rot_crop, one_plot_vertices, ul_x_low_res_rot, ul_y_low_res_rot
    
    one_plot_vertices_low_res_rot_crop = np.asarray(verts).astype(int)
    one_plot_vertices_low_res_rot = one_plot_vertices_low_res_rot_crop + np.asarray([ul_x_low_res_rot, ul_y_low_res_rot])
    one_plot_vertices_low_res = general_funcs.undo_rotation(one_plot_vertices_low_res_rot.transpose(), rot_degree, rot_center)
    one_plot_vertices = one_plot_vertices_low_res.transpose() * scale_factor
    one_plot_vertices = one_plot_vertices.astype(int)
    
    
class Dummy_Event:
    def __init__(self, xdata, ydata):
        self.xdata = xdata
        self.ydata = ydata
        

def create_plot(button):
    global one_plot_vertices, plot_form_ind, batch_mode, top_verts, bot_verts, left_verts, right_verts, standard_plot_size_pix_low_res
    if textbox_plot_name.value == '':
        textbox_info.value = 'Failed! Please input plot name!'
        return
    elif textbox_plot_name.value in plot_vertices.keys():
        textbox_info.value = 'Failed! Plot already EXISTS!'
        return
    elif one_plot_vertices is None and plot_form_ind !=2:
        textbox_info.value = 'Failed! Please select vertices!'
        return
    else:
        if plot_form_ind == 0:
            save_one_plot()
            standard_form_off(remove_lines=True)
            standard_form_on()
        elif plot_form_ind == 1:
            save_one_plot()
            free_form_off()
            free_form_on()
        elif plot_form_ind == 2:
            if batch_mode == 'horizontal':
                for i in range(batch_plot_num):
                    plot_center = [(top_verts[i][0]+bot_verts[i+1][0])/2, (top_verts[i][1]+bot_verts[i+1][1])/2]
                    ul_corner = np.asarray(plot_center) - np.asarray(standard_plot_size_pix_low_res)/2
                    dummy_event = Dummy_Event(ul_corner[0],ul_corner[1])
                    on_click_ps_standard(dummy_event)
                    save_one_plot()
                remove_from_ax_component(ax.lines)
            elif batch_mode == 'vertical':
                for i in range(batch_plot_num):
                    plot_center = [(left_verts[i][0]+right_verts[i+1][0])/2, (left_verts[i][1]+right_verts[i+1][1])/2]
                    ul_corner = np.asarray(plot_center) - np.asarray(standard_plot_size_pix_low_res)/2
                    dummy_event = Dummy_Event(ul_corner[0],ul_corner[1])
                    on_click_ps_standard(dummy_event)
                    save_one_plot()
                remove_from_ax_component(ax.lines)
        
            
            
def save_one_plot():
    global one_plot_vertices, one_plot_vertices_low_res_rot_crop, plot_form_ind
    plot_vertices[textbox_plot_name.value] = one_plot_vertices
    plot_notes[textbox_plot_name.value] = textbox_notes.value
    
#     if d == 5:
    polygon = patches.Polygon(one_plot_vertices_low_res_rot_crop, True, facecolor = matplotlib.colors.to_rgba('red', 0.1), edgecolor=matplotlib.colors.to_rgba('orange', 0.5))
    text_loc = np.mean(one_plot_vertices_low_res_rot_crop, 0)
#     elif d == 2:
#         polygon = patches.Polygon(one_plot_vertices, True, facecolor = matplotlib.colors.to_rgba('red', 0.1), edgecolor=matplotlib.colors.to_rgba('orange', 0.5))
#         text_loc = np.mean(one_plot_vertices, 0)
        
    ax.add_patch(polygon)
    ax.text(text_loc[0], text_loc[1], textbox_plot_name.value, ha='center', va='center')
            
    one_plot_vertices = None
    textbox_info.value = 'Success! Plot ' + textbox_plot_name.value + ' created!'
    if load_plot_name_from_file_flag:
        if textbox_plot_name.value == plot_name_loaded[-1]:
            textbox_info.value = 'Last plot!'
        else:
            plot_name_index = np.where(plot_name_loaded == textbox_plot_name.value)[0][0]
            textbox_plot_name.value = plot_name_loaded[plot_name_index + 1]
    
    
def load_plot_name_from_file(button):
    global plot_name_loaded, load_plot_name_from_file_flag
    plot_name_file = askopenfilename(title='Load plot name file', initialdir='./data/plot_name_csv')
    
    ROOT = Tk()
    ROOT.withdraw()
    # the input dialog
    starting_plot = simpledialog.askstring(title="Reading Format", prompt="Starting with?:")
    direction = int(simpledialog.askstring(title="Reading Format", prompt="Direction?(1 for vertical, 2 for horizontal):"))

    plot_name_matrix = pd.read_csv(plot_name_file, header=None).astype(str).to_numpy()
    for i in range(3):
        if np.where(plot_name_matrix==starting_plot) != (0, 0):
            plot_name_matrix = np.rot90(plot_name_matrix)
            
    if np.where(plot_name_matrix==starting_plot) != (0, 0):
        print('Please input starting plot at the corners')
        return
    
    if direction == 1:
        plot_name_loaded = plot_name_matrix.flatten('F')
    elif direction == 2:
        plot_name_loaded = plot_name_matrix.flatten('C')
    textbox_plot_name.value = plot_name_loaded[0]
    load_plot_name_from_file_flag = True
#     info_box.value = 'Failed'

        
def remove_plot(button):
#     global ax
    if not textbox_plot_name.value in plot_vertices.keys():
        textbox_info.value = 'Failed! Plot ' + textbox_plot_name.value + ' does not exist!'
    else:
        keys = list(plot_vertices.keys())
        patch_ind = keys.index(textbox_plot_name.value)
        patch = ax.patches[patch_ind]
        patch.remove()
        text_patch = ax.texts[patch_ind]
        text_patch.remove()
        
        plot_vertices.pop(textbox_plot_name.value)
        plot_notes.pop(textbox_plot_name.value)
        textbox_info.value = 'Success! Plot ' + textbox_plot_name.value + ' removed!'
        
def save_plot_loc_to_file(button):
    # Save plot info coordinates to file
#     global ul_x, ul_y

    ds = gdal.Open(image_file)
    gt = ds.GetGeoTransform()

    plot_vertices_gps = {}
    
    for plot_name in plot_vertices.keys():
        one_plot_vertices = plot_vertices[plot_name]
        one_plot_vertices_gps = []
        for vertex in one_plot_vertices:
            geo_loc = general_funcs.pix2geo(vertex, gt)
            one_plot_vertices_gps.append(geo_loc)
        one_plot_vertices_gps = np.array(one_plot_vertices_gps)
        plot_vertices_gps[plot_name] = one_plot_vertices_gps

#     fn = image_file.split('/')[-2] + '_' + image_file.split('/')[-1].split('.')[0]
    fn = image_file.split('/')[-1].split('.')[0]
    file_name = asksaveasfilename(filetypes=[('pickle', '*.pkl')], title='Save plot locations', initialfile=fn+'_plot_loc', initialdir='./data/plot_location')
    if not file_name:
        return
    if not file_name.endswith('.pkl'):
        file_name += '.pkl'
        
    try:
        with open(file_name, 'wb') as f:
            pickle.dump(interested_area, f)
            pickle.dump(plot_vertices_gps, f)
            pickle.dump(plot_notes, f)
        print('GPS coordinates saved to', file_name)
    except Exception as e:
        showerror(type(e).__name__, str(e))
        
        
# def add_grid(button):
#     if button.description == 'Add Grid (1 meter)':

#         five_meter_in_low_res = 5 / lon_meter_per_pix_low_res
#         ax.xaxis.set_major_locator(MultipleLocator(five_meter_in_low_res))
#         ax.yaxis.set_major_locator(MultipleLocator(five_meter_in_low_res))

#         ax.xaxis.set_minor_locator(AutoMinorLocator(5))
#         ax.yaxis.set_minor_locator(AutoMinorLocator(5))
        
# #         ax.grid(which='both', axis='both', linestyle='-')
#         ax.grid(which='minor', color='red', alpha=0.5)
#         ax.grid(which='major', color='red', alpha=1)
#         button.description = 'Remove Grid'
#     elif button.description == 'Remove Grid':
#         ax.grid(False, which='both')
#         button.description = 'Add Grid (1 meter)'
        
        
# def create_auxiliary_grid(button):
#     ps.set_active(False)
#     rs_grid = RectangleSelector(ax, line_select_callback, drawtype='box', useblit=True, button=[1], minspanx=50, minspany=50,
#                                 rectprops=dict(facecolor='red', edgecolor='black', alpha=0.1, fill=True), spancoords='pixels', interactive=True)


# def auto_generate_plot(button):
#     ;

    
def key_press_area_selection(event):
    if event.key == 'enter':
        save_area(event)
    fig.canvas.draw()
    
    
def key_press_plot_selection(event):
    if event.key == 'enter':
        create_plot(event)
    elif event.key == 'escape':
        if plot_form_ind == 0:
            standard_form_off(remove_lines=True)
            standard_form_on()
        if plot_form_ind == 2:
            batch_off()
            batch_on()
    fig.canvas.draw()


plot_vertices = {}
plot_notes = {}
one_plot_vertices = None
load_plot_name_from_file_flag = False
standard_plot_size_meter = [0, 0]
batch_plot_num = 33


# Text widgets
text_layout = widgets.Layout(width='95%')
textbox_plot_name = widgets.Text(placeholder='Input plot name here', layout=text_layout)
textbox_notes = widgets.Textarea(placeholder='Put plot notes here', layout=widgets.Layout(height='50px', width='100%'))
textlabel_info = widgets.Label(value='ERROR/Info will be shown below', layout=text_layout)
textbox_info = widgets.Label(layout=text_layout)

textbox_plot_name.on_submit(create_plot)

# Button widgets
button_layout = widgets.Layout(width='auto')
button_save_area = widgets.Button(description="Save Area (Press Enter)", layout=button_layout)
# button_low_res = widgets.Button(description="Low Resolution: On", layout=button_layout)
button_load = widgets.Button(description="Load plot name From File", layout=button_layout)
# button_plot_mode = widgets.Button(description="Use")
button_create_plot = widgets.Button(description="Create Plot (Press Enter)", layout=button_layout)
# button_modify = widgets.Button(description='Modify Plot', layout=button_layout, button_style='danger')
button_remove = widgets.Button(description='Remove Plot', layout=button_layout, button_style='danger')
button_done = widgets.Button(description='Save and Exit', layout=button_layout, button_style='success')
# button_add_grid = widgets.Button(description='Add Grid (1 meter)', layout=button_layout)
# button_create_auxiliary_grid = widgets.Button(description='Create Auxiliary Grid', layout=button_layout)
# button_auto_generate_plot = widgets.Button(description='Auto Generate Plot', layout=button_layout)
radio_plot_form = widgets.RadioButtons(options=['Standard plot', 'Free plot', 'Batch'], value=None, layout=button_layout)

button_save_area.on_click(save_area)
# button_low_res.on_click(low_res)
button_load.on_click(load_plot_name_from_file)
button_create_plot.on_click(create_plot)
# button_modify.on_click(modify_plot)
button_remove.on_click(remove_plot)
button_done.on_click(save_plot_loc_to_file)
# button_add_grid.on_click(add_grid)
# button_create_auxiliary_grid.on_click(create_auxiliary_grid)
# button_auto_generate_plot.on_click(auto_generate_plot)
radio_plot_form.observe(choose_plot_form, names='value')

# Box widgets
box_layout = widgets.Layout(width='auto')
all_widgets = widgets.VBox(children=[button_save_area], layout=box_layout)
# button_set1 = widgets.HBox(children=[button_load, button_save], layout = box_layout)
# button_set2 = widgets.HBox(children=[button_modify, button_remove], layout = box_layout)
# all_widgets = widgets.VBox(children=[plot_name_box, notes_box, button_set1, button_set2, info_label, info_box], layout=box_layout)
display(all_widgets)

out = widgets.Output()
display(out)
with out:
    plt.close('all')
    fig_size = general_funcs.fig_size(h, w)
    fig, ax = plt.subplots(figsize=fig_size)
    fig.canvas.set_window_title('Select area of interest')
    plt.tight_layout()
    tb = fig.canvas.toolbar
    if d == 5:
        myax = ax.imshow(layer_RGB_low_res_rot)
    elif d == 2:
        # Set vmin & vmax to avoid colorgrading 0 values
        mask_not_0_inds = np.where(layer_mask == 1)
        vmin, vmax = general_funcs.cal_vmin_vmax(layer_IR, layer_mask)
        
        myax = ax.imshow(layer_IR_low_res_rot, cmap='gist_gray', vmin=vmin, vmax=vmax)
    else:
        print('Unsupported layer number')

    rs = RectangleSelector(ax, general_funcs.do_nothing, drawtype='box', useblit=True, button=[1], minspanx=50, minspany=50,
                      rectprops=dict(facecolor='red', edgecolor='black', alpha=0.1, fill=True), spancoords='pixels', interactive=True)
    ax.patches[-1].remove()
    cid_save_area = fig.canvas.mpl_connect('key_press_event', key_press_area_selection)
    
    plt.show()

VBox(children=(Button(description='Save Area (Press Enter)', layout=Layout(width='auto'), style=ButtonStyle())…

Output()