In [None]:
import numpy as np
from ipywidgets import Image
from ipywidgets import ColorPicker, IntSlider, link, AppLayout, HBox, VBox, HTML, Label
from ipycanvas import MultiCanvas, hold_canvas
from ipyevents import Event
import asyncio
from time import time

def throttle(wait):
    """ Decorator that prevents a function from being called
        more than once every wait period. """
    def decorator(fn):
        time_of_last_call = 0
        scheduled = False
        new_args, new_kwargs = None, None
        def throttled(*args, **kwargs):
            nonlocal new_args, new_kwargs, time_of_last_call, scheduled
            def call_it():
                nonlocal new_args, new_kwargs, time_of_last_call, scheduled
                time_of_last_call = time()
                fn(*new_args, **new_kwargs)
                scheduled = False
            time_since_last_call = time() - time_of_last_call
            new_args = args
            new_kwargs = kwargs
            if not scheduled:
                new_wait = max(0, wait - time_since_last_call)
                Timer(new_wait, call_it)
                scheduled = True
        return throttled
    return decorator

class Timer:
    def __init__(self, timeout, callback):
        self._timeout = timeout
        self._callback = callback
        self._task = asyncio.ensure_future(self._job())

    async def _job(self):
        await asyncio.sleep(self._timeout)
        self._callback()

    def cancel(self):
        self._task.cancel()

def debounce(wait):
    """ Decorator that will postpone a function's
        execution until after `wait` seconds
        have elapsed since the last time it was invoked. """
    def decorator(fn):
        timer = None
        def debounced(*args, **kwargs):
            nonlocal timer
            def call_it():
                fn(*args, **kwargs)
            if timer is not None:
                timer.cancel()
            timer = Timer(wait, call_it)
        return debounced
    return decorator


def get_pixel_row_col(x,y):
    return x//pixel_size, y//pixel_size


def get_pixel_boundaries(x,y):
    row, col = get_pixel_row_col(x, y)
    
    x1,y1 = row*pixel_size, col*pixel_size
    x2,y2 = x1+pixel_size, y1+pixel_size
    return x1,y1,x2,y2


def plot_grid(layer, pixel_size):
    with hold_canvas(layer):
        for i in range(1, layer.width//pixel_size):
            draw_line(layer, (i*pixel_size, 0), (i*pixel_size, layer.height))            
        for i in range(1, layer.height//pixel_size):
            draw_line(layer, (0, i*pixel_size), (layer.width, i*pixel_size))
        
    
def draw_line(layer, start, end):
    layer.line_width = 0.4
    layer.begin_path()
    layer.move_to(start[0], start[1])
    layer.line_to(end[0], end[1])
    layer.stroke()
    layer.close_path()
    

def draw_coords(canvas, x,y):
    canvas.shadow_offset_x = 2
    canvas.shadow_offset_y = 2
    canvas.shadow_blur = 3
    canvas.shadow_color = "rgba(44,44,44,0.3)"        
    x1,y1,x2,y2 = get_pixel_boundaries(x,y)
    canvas.fill_text('({0},{1}), ({2},{3})'.format(int(x1),int(y1), int(x2),int(y2)), x,y)

    
def highlight_pixel(layer, x,y):
    global data
    row, col = get_pixel_row_col(x,y)
    x1,y1,x2,y2 = get_pixel_boundaries(x,y)
    layer.stroke_rect(x1, y1, pixel_size, pixel_size)
    
    
def switch_pixel(layer, x, y, state='On'):
    row, col = get_pixel_row_col(x, y)
    x1,y1,x2,y2 = get_pixel_boundaries(x,y)
    
    data[row][col] = 1 if state=='On' else 0
    
    with hold_canvas(canvas):
        if state=='On':
            layer.fill_rect(x1 , y1, pixel_size, pixel_size)
        else:
            layer.clear_rect(x1, y1, pixel_size, pixel_size)    
        
    
@debounce(0.009)
@throttle(0.009)
def on_mouse_move(x, y):
    if drawing:            
        switch_pixel(interaction_layer, x, y, 'On')                
    if erasing:
        switch_pixel(interaction_layer, x, y, 'Off')
    
    with hold_canvas(canvas):            
        drawing_layer.clear()
        #draw_coords(drawing_layer, x,y)
        highlight_pixel(drawing_layer, x,y)
        
    
def on_mouse_down(x, y): 
    global drawing
    drawing = True
    
        
def on_mouse_up(x, y):
    global drawing 
    drawing = False
    
    
# keyboard modifiers handler
def handle_keyboard(event):
    # Mouse_events:
    # -------------
    # 'click', 'auxclick', 'dblclick', 
    # 'mouseenter', 'mouseleave', 'mousedown', 'mouseup', 'mousemove', 
    # 'wheel', 'contextmenu', 
    # 'dragstart', 'drag', 'dragend', 'dragenter', 'dragover', 'dragleave', 'drop'
    #
    # Keyboard events:
    # ----------------
    #'keydown','keyup'

    global key_states
    key_states = {k:v for k, v in event.items() if k.endswith('Key')}
        
    global drawing
    drawing = True if event.get('shiftKey', False) else False
    
    
    global erasing
    erasing = True if not event.get('shiftKey', False) and event.get('ctrlKey', False) else False    
    
    if drawing:
        l.value = 'MODE: Drawing (Shift pressed)'
    elif erasing:
        l.value = 'MODE: Erasing (Ctrl pressed)'
    else:
        l.value = l.default_text #'MODE: Drawing disabled (Press Shift or Ctrl)'


drawing = False
erasing = False
height, width = 320, 320
pixel_size = 40
line_width = 0.4

canvas = MultiCanvas(3, width=width, height = height)
data = np.zeros((height//pixel_size, width//pixel_size), dtype=int)
key_states = {'is':'empty'}   

background_layer = canvas[0]
plot_grid(background_layer,pixel_size)

drawing_layer = canvas[1]
drawing_layer.font = '14px "serif"'

interaction_layer = canvas[2]
interaction_layer.on_mouse_move(on_mouse_move)

#mouse = Event(source = canvas, watched_events=['mousedown', 'mouseup', 'mousemove'])
#mouse.on_dom_event(handle_mouse)

keyboard = Event(source = canvas, watched_events=['keydown','keyup'])
keyboard.on_dom_event(handle_keyboard)

l = HTML(value='MODE: Drawing disabled (Press Shift or Ctrl)')
l.default_text = l.value # ненуачо
#---------------------------------------
AppLayout(center=VBox((l, canvas)))