In [None]:
# Load Image
# Editor appearance set up & Load image

# Extend width of Jupyter Notebook Cell to the size of browser
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

# Import packages needed
import gc
import pickle
import platform
from tkinter import Tk
from tkinter.filedialog import askopenfilename, asksaveasfilename

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.widgets import RectangleSelector, PolygonSelector
from matplotlib.ticker import AutoMinorLocator, MultipleLocator
import numpy as np
import pandas as pd
from skimage import io
from ipywidgets import widgets
from osgeo import gdal

import general_funcs


# OS related settings
if platform.system() == 'Windows':
    print('Windows')
    # %matplotlib nbagg
    # Sometimes tk/qt will not let cells rerun after an ERROR occurs
#     %matplotlib tk
    %matplotlib qt
elif platform.system() == 'Darwin':
    print('macOS')
    Tk().withdraw()
    %matplotlib osx
elif platform == 'linux' or platform == 'linux2':
    print('Linux')
# This line of "print" must exist right after %matplotlib command, 
# otherwise JN will hang on the first import statement after this.
print('Interactive plot activated')


# Load image and print size & pre-process

# Use skimage to load multi-layer tiff file
# Tk().withdraw()
image_file = askopenfilename(title='Load image file', initialdir='./data/field_image')
img = io.imread(image_file)
print('Image File:', image_file)
print("Image Shape:", img.shape)

# Extract layers from the multilayer tiff file and do some adjustments

scale_factor = 4

h, w, d = img.shape
layer_RGB, layer_IR, layer_mask = general_funcs.extract_layers(img)
if d != 2:
    layer_RGB_low_res = general_funcs.low_res(layer_RGB, scale_factor)


ds = gdal.Open(image_file)
gt = ds.GetGeoTransform()
lon_meter_per_pix, lat_meter_per_pix = general_funcs.meter_per_pix(gt)
lon_meter_per_pix_low_res = lon_meter_per_pix * scale_factor
lat_meter_per_pix_low_res = lat_meter_per_pix * scale_factor
print(lon_meter_per_pix, lat_meter_per_pix)
# Remove object from memory
del(img)


In [None]:
# Show Image

def line_select_callback(eclick, erelease):
    ;

def save_area(button):
    global interested_area, ps, ul_x, ul_y, original_extent, original_extent_low_res, grid_extent
    
    rs.set_active(False)
    fig.canvas.set_window_title('Set Plots')
    
    if d == 5 or d == 4:
        interested_area_low_res = np.asarray(rs.corners).astype(int)
        interested_area = interested_area_low_res * scale_factor
        
        ul_x_low_res, ul_y_low_res = interested_area_low_res[:, 0]
        br_x_low_res, br_y_low_res = interested_area_low_res[:, 2]
        
#         w_low_res = br_x_low_res - ul_x_low_res + 1
#         h_low_res = br_y_low_res - ul_y_low_res + 1
        
#         original_extent_low_res = [0, w_low_res, h_low_res, 0]
#         grid_extent_low_res = [0, w_low_res*lon_meter_per_pix_low_res, h_low_res*lon_meter_per_pix_low_res, 0]
    elif d == 2:
        interested_area = np.asarray(rs.corners).astype(int)
    ul_x, ul_y = interested_area[:, 0]
    br_x, br_y = interested_area[:, 2]
    
#     w = br_x - ul_x + 1
#     h = br_y - ul_y + 1
#     original_extent = [0, w, h, 0]
#     grid_extent = [0, w*lon_meter_per_pix, h*lon_meter_per_pix, 0]
    
    button_set1 = widgets.HBox(children=[button_load, button_save_plot, button_add_grid, button_remove, button_done], layout=box_layout)
#     button_set2 = widgets.HBox(children=[button_modify, ], layout=box_layout)
#     button_set3 = widgets.HBox(children=[], layout=box_layout)
    all_widgets.children = [textbox_plot_name, textbox_notes, textlabel_info, textbox_info, button_set1]
    
    
    if d == 5 or d == 4:
        ax.imshow(layer_RGB_low_res[ul_y_low_res:br_y_low_res, ul_x_low_res:br_x_low_res])
    elif d == 2:
        myax = ax.imshow(layer_IR[ul_y:br_y, ul_x:br_x], cmap='gist_gray', vmin=vmin, vmax=vmax)
#         cbar = fig.colorbar(myax)
    ps = PolygonSelector(ax, onselect, useblit=True)
    fig.canvas.mpl_connect('key_press_event', save_when_enter_pressed)
    
    
def low_res(button):
    ;
    
    
    
def onselect(vert):
    global one_plot_vertices_low_res, one_plot_vertices
    if d != 2:
        one_plot_vertices_low_res = np.array(vert)
        one_plot_vertices = np.round(one_plot_vertices_low_res * scale_factor)
    else:
        one_plot_vertices = np.round(np.array(vert))
    
    
def load_plot_name_from_file(button):
    global plot_name_loaded, load_plot_name_from_file_flag
    plot_name_file = askopenfilename(title='Load plot name file', initialdir='./data/plot_name_csv')
    plot_name_loaded = pd.read_csv(plot_name_file, header=None).astype(str).to_numpy().flatten('F')
    textbox_plot_name.value = plot_name_loaded[0]
    load_plot_name_from_file_flag = true
#     info_box.value = 'Failed'
    
    
def save_plot(button):
    global one_plot_vertices, one_plot_vertices_low_res, ps
    if textbox_plot_name.value == '':
        textbox_info.value = 'Failed! Please input plot name!'
        return
    elif textbox_plot_name.value in plot_vertices.keys():
        textbox_info.value = 'Failed! Plot already EXISTS!'
        return
    elif one_plot_vertices is None:
        textbox_info.value = 'Failed! Please select vertices!'
        return
    else:
#         plot_names.append(plot_name_box.value)
#         plot_vertices.append(one_plot_vertices)
        plot_vertices[textbox_plot_name.value] = one_plot_vertices
        plot_notes[textbox_plot_name.value] = textbox_notes.value
                
        if d == 5 or d == 4:
            polygon = patches.Polygon(one_plot_vertices_low_res, True, facecolor = matplotlib.colors.to_rgba('red', 0.1), edgecolor=matplotlib.colors.to_rgba('orange', 0.5))
            text_loc = np.mean(one_plot_vertices_low_res, 0)
        elif d == 2:
            polygon = patches.Polygon(one_plot_vertices, True, facecolor = matplotlib.colors.to_rgba('red', 0.1), edgecolor=matplotlib.colors.to_rgba('orange', 0.5))
            text_loc = np.mean(one_plot_vertices, 0)
            
        ax.add_patch(polygon)
        ax.text(text_loc[0], text_loc[1], textbox_plot_name.value, ha='center', va='center')
        
        ps.set_active(False)
        ps = PolygonSelector(ax, onselect, useblit=True)
        fig.canvas.draw()
        
        one_plot_vertices = None
        textbox_info.value = 'Success! Plot ' + textbox_plot_name.value + ' saved!'
        if load_plot_name_from_file_flag:
            plot_name_index = np.where(plot_name_loaded == textbox_plot_name.value)[0][0]
            textbox_plot_name.value = plot_name_loaded[plot_name_index + 1]
        
        
# def modify_plot(button):
#     global one_plot_vertices, ax
#     if one_plot_vertices is None:
#         textbox_info.value = 'Failed! Please select vertices!'
#         return
    
#     elif not textbox_plot_name.value in plot_vertices.keys():
#         textbox_info.value = 'Failed! Plot ' + textbox_plot_name.value + ' does not exist!'
#     else:
#         keys = list(plot_vertices.keys())
#         patch_ind = keys.index(textbox_plot_name.value)
#         patch = ax.patches[patch_ind]
#         patch.remove()
#         polygon = patches.Polygon(one_plot_vertices, True, facecolor = matplotlib.colors.to_rgba('red', 0.1), edgecolor=matplotlib.colors.to_rgba('orange', 0.5))
#         ax.add_patch(polygon)
#         text_loc = np.mean(one_plot_vertices, 0)
#         text_patch = ax.texts[patch_ind]
#         text_patch.remove()
#         ax.text(text_loc[0], text_loc[1], textbox_plot_name.value, ha='center', va='center')
        
        
#         plot_vertices.pop(textbox_plot_name.value)
#         plot_notes.pop(textbox_plot_name.value)
#         plot_vertices[textbox_plot_name.value] = one_plot_vertices
#         plot_notes[textbox_plot_name.value] = textbox_notes.value
#         textbox_info.value = 'Success! Plot ' + textbox_plot_name.value + ' modified!'
        
#         one_plot_vertices = None
        
def remove_plot(button):
#     global ax
    if not textbox_plot_name.value in plot_vertices.keys():
        textbox_info.value = 'Failed! Plot ' + textbox_plot_name.value + ' does not exist!'
    else:
        keys = list(plot_vertices.keys())
        patch_ind = keys.index(textbox_plot_name.value)
        patch = ax.patches[patch_ind]
        patch.remove()
        text_patch = ax.texts[patch_ind]
        text_patch.remove()
        
        plot_vertices.pop(textbox_plot_name.value)
        plot_notes.pop(textbox_plot_name.value)
        textbox_info.value = 'Success! Plot ' + textbox_plot_name.value + ' removed!'
        
def save_plot_loc_to_file(button):
    # Save plot info coordinates to file
    global ul_x, ul_y

    ds = gdal.Open(image_file)
    gt = ds.GetGeoTransform()

    plot_vertices_gps = {}
    
    for plot_name in plot_vertices.keys():
        one_plot_vertices = plot_vertices[plot_name]
        one_plot_vertices_gps = []
        for vertex in one_plot_vertices:
            geo_loc = general_funcs.pix2geo(vertex+[ul_x, ul_y], gt)
            one_plot_vertices_gps.append(geo_loc)
        one_plot_vertices_gps = np.array(one_plot_vertices_gps)
        plot_vertices_gps[plot_name] = one_plot_vertices_gps

#     fn = image_file.split('/')[-2] + '_' + image_file.split('/')[-1].split('.')[0]
    fn = image_file.split('/')[-1].split('.')[0]
    file_name = asksaveasfilename(filetypes=[('pickle', '*.pkl')], title='Save plot locations', initialfile=fn+'_plot_loc', initialdir='./data/plot_location')
    if not file_name:
        return
    if not file_name.endswith('.pkl'):
        file_name += '.pkl'
        
    try:
        with open(file_name, 'wb') as f:
            pickle.dump(interested_area, f)
            pickle.dump(plot_vertices_gps, f)
            pickle.dump(plot_notes, f)
        print('GPS coordinates saved to', file_name)
    except Exception as e:
        showerror(type(e).__name__, str(e))
        
        
def add_grid(button):
    if button.description == 'Add Grid (1 meter)':

        five_meter_in_low_res = 5 / lon_meter_per_pix_low_res
        ax.xaxis.set_major_locator(MultipleLocator(five_meter_in_low_res))
        ax.yaxis.set_major_locator(MultipleLocator(five_meter_in_low_res))

        ax.xaxis.set_minor_locator(AutoMinorLocator(5))
        ax.yaxis.set_minor_locator(AutoMinorLocator(5))
        
#         ax.grid(which='both', axis='both', linestyle='-')
        ax.grid(which='minor', color='red', alpha=0.5)
        ax.grid(which='major', color='red', alpha=1)
        button.description = 'Remove Grid'
    elif button.description == 'Remove Grid':
        ax.grid(False, which='both')
        button.description = 'Add Grid (1 meter)'
    
    
    
def save_when_enter_pressed(event):
    if event.key == 'enter':
        save_plot(event)
    fig.canvas.draw()


plot_vertices = {}
plot_notes = {}
one_plot_vertices = None
load_plot_name_from_file_flag = False

# Text widgets
text_layout = widgets.Layout(width='95%')
textbox_plot_name = widgets.Text(placeholder='Input plot name here', layout=text_layout)
textbox_notes = widgets.Textarea(placeholder='Put plot notes here', layout=widgets.Layout(height='50px', width='100%'))
textlabel_info = widgets.Label(value='ERROR/Info will be shown below', layout=text_layout)
textbox_info = widgets.Label(layout=text_layout)

# Button widgets
button_layout = widgets.Layout(width='auto')
button_save_area = widgets.Button(description="Save Area", layout=button_layout)
button_low_res = widgets.Button(description="Low Resolution: On", layout=button_layout)
button_load = widgets.Button(description="Load plot name From File", layout=button_layout)
button_save_plot = widgets.Button(description="Save Plot (Press Enter)", layout=button_layout)
# button_modify = widgets.Button(description='Modify Plot', layout=button_layout, button_style='danger')
button_remove = widgets.Button(description='Remove Plot', layout=button_layout, button_style='danger')
button_done = widgets.Button(description='Done', layout=button_layout, button_style='success')
button_add_grid = widgets.Button(description='Add Grid (1 meter)', layout=button_layout)

button_save_area.on_click(save_area)
button_low_res.on_click(low_res)
button_load.on_click(load_plot_name_from_file)
button_save_plot.on_click(save_plot)
# button_modify.on_click(modify_plot)
button_remove.on_click(remove_plot)
button_done.on_click(save_plot_loc_to_file)
button_add_grid.on_click(add_grid)
textbox_plot_name.on_submit(save_plot)

# Box widgets
box_layout = widgets.Layout(width='auto')
all_widgets = widgets.VBox(children=[button_save_area], layout=box_layout)
# button_set1 = widgets.HBox(children=[button_load, button_save], layout = box_layout)
# button_set2 = widgets.HBox(children=[button_modify, button_remove], layout = box_layout)
# all_widgets = widgets.VBox(children=[plot_name_box, notes_box, button_set1, button_set2, info_label, info_box], layout=box_layout)
display(all_widgets)

out = widgets.Output()
display(out)
with out:
    fig_size = general_funcs.fig_size(h, w)
    plt.close('all')
    fig, ax = plt.subplots(figsize=fig_size)
    fig.canvas.set_window_title('Select area of interest')
    # plt.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95)
    plt.tight_layout()
    if d == 5 or d == 4:
        myax = ax.imshow(layer_RGB_low_res)
    elif d == 2:
        mask_not_0_inds = np.where(layer_mask > 0)
        vmin, vmax = general_funcs.cal_vmin_vmax(layer_IR, layer_mask)
        myax = ax.imshow(layer_IR, cmap='gist_gray', vmin=vmin, vmax=vmax)
#         cbar = fig.colorbar(myax)

    rs = RectangleSelector(ax, line_select_callback, drawtype='box', useblit=True, button=[1], minspanx=50, minspany=50,
                      rectprops=dict(facecolor='red', edgecolor='black', alpha=0.1, fill=True), spancoords='pixels', interactive=True)
    ax.patches[0].remove()
    
    plt.show()