In [112]:
import threading as th
import ipywidgets as w
import multiprocessing as mp
import traitlets as tr
import time
from enum import Enum

In [113]:
debug_output_widget = w.Output(layout=w.Layout(border='solid'))
debug_clear_button = w.Button(description="clear debug")
debug_clear_button.on_click(lambda self: debug_output_widget.clear_output())
def dprint(*args, **kwargs):
    with debug_output_widget:
        print(*args, **kwargs, flush=True)

class Output:
    """
    Container class for multi-modal node outputs
    """

    def __init__(self, value=None, display=None):
        self.display = display
        self.value = value

In [114]:
class Pending: pass

In [115]:
class Supervisor:

    def __init__(self):
        
        # the "queue" to consume from
        self.ready_nodes = set()
        
        # condition to notify connsumers
        self.ready_nodes_cond = th.Condition()

        self.work_lock = th.Lock()

        # node ids to workers
        self.work = {}
        
        self.start_worker_threads()


    def handle_node_to_ready(self, node):
        """The node is ready for a new computation"""

        with self.ready_nodes_cond:
            self.ready_nodes.add(node)
            self.ready_nodes_cond.notify()

    def handle_node_from_ready(self, node):
        """If node is coming from a ready state,
        there might be a worker on it. Cancel all node-related activity"""

        dprint("Canceling work...")
        
        # check if was ready for computation
        with self.ready_nodes_cond:
            dprint("  ready_nodes lock obtained...")
            if node in self.ready_nodes:
                dprint("  Node discovered in ready_nodes; deleting.")
                self.ready_nodes.remove(node)
                return

            # check if needs to be interrupted
            with self.work_lock:
                dprint("  Work lock obtained")
                if node in self.work:
                    dprint("  Node discovered in work set; interrupting and deleting.")
                    curr_worker = self.work[id(node)]
                    curr_worker.interrupt()
                    del self.work[id(node)]

    def start_worker_threads(self):
        """Start worker threads which will consume from ready node set"""
        worker_1 = Worker(self)
        worker_2 = Worker(self)
        dprint("Worker 1:", id(worker_1))
        dprint("Worker 2:", id(worker_2))
        self.workers = {worker_1, worker_2}
        
    def get_work(self, worker):
         with self.ready_nodes_cond:
            # Wait until there's a ready node available.
            # When there is, lock the ready nodes set
            # and move the node to the work dict

            while not self.ready_nodes:
                self.ready_nodes_cond.wait()

            # pop a ready node
            node = self.ready_nodes.pop()

            self.work[id(node)] = worker
            
            return node
                
    def return_work(self, worker, node, output):
        with self.work_lock:
            if self.work[id(node)] == worker:
                del self.work[id(node)]

            else:
                return
        
        node.handle_computation_end(output)
        
        

In [116]:
class Worker:

    def __init__(self, supervisor):
        self.supervisor = supervisor
        self.consumer_thread = th.Thread(target=self.consume)
        self.consumer_thread.start()
        self.stale = False
        
    def interrupt(self):
        self.stale = True
        dprint("Worker", id(self), "has become stale")

    def consume(self):
        """Consume tasks forever"""

        while True:
            
            # wait for work--blocks thread until there's something to work on
            node = self.supervisor.get_work(self)
            
            dprint("Worker", id(self), "picked up task from node", id(node))

            # straight-up run the function--in the future we'll subprocessify this
            inputs = [arg.value for arg in node.args]
            output = node.f(*inputs)
            
            if self.stale:
                dprint("Result from worker", id(self), "discarded due to staleness")
                self.stale = False
                continue
            else:
                dprint("Result from worker", id(self), "returned")
                self.supervisor.return_work(self, node, output)

In [117]:
class Node(tr.HasTraits):
    
    value = tr.Any()
    
    @tr.default("value")
    def _set_default_value(self):
        return Pending
    
    class State(Enum):
        DONE = 0
        PENDING = 1
        READY = 2

    def __init__(self, args=None, f=lambda: None, display_widget=None):
        dprint("Created Node with ID", id(self))
        for arg in args:
            if isinstance(arg, Node):
                # subscribe to traitlet
                arg.observe(self.handle_parent_change, names=["value"])
            elif isinstance(arg, w.DOMWidget):
                # subscribe to traitlet
                arg.observe(self.handle_parent_change, names=["value"])
            else:
                raise Exception

        self.args = args
        self.f = f
        
        if display_widget:
            assert isinstance(display_widget, w.widgets.widget_output.Output)
            
        self.display_widget = display_widget
        # current inputs given by [arg.value for arg in args]
        self.old_inputs = [arg.value for arg in args]
        self.old_value = None
        self.state = Node.State.DONE
        
    def handle_parent_change(self, change):
#         # Todo: use these to speed up comparison of current inputs to old_inputs
#         parent = change.owner
#         p_value = parent.value
        
        current_inputs = [arg.value for arg in self.args]

        # State machine
        #                    : on node(+ state upd8) : on supervisor
        #
        # ready   -> ready   : do nothing            : cancel and then add
        # ready   -> done    : value <- old_value    : cancel
        # done    -> ready   : value <- pending      : add to ready
        
        # pending -> ready   : do nothing            : add to ready
        # pending -> done    : value <- old_value    : do nothing
        
        # Partial summary:
        # if from ready, call handle node away from ready
        
        # if to ready, call handle node to ready and set to pending
        # if to done, set value <- old_value
        
        
        # ready   -> pending : do nothing            : cancel
        # done    -> pending : value <- pending      : do nothing
        # pending -> pending : do nothing            : do nothing
        
        # from READY
        if self.state == Node.State.READY:        
            supervisor.handle_node_from_ready(self)
            
        # Case to DONE or 
        # Case to READY
        if Pending not in current_inputs:                
            # to READY
            if not current_inputs == self.old_inputs:
                supervisor.handle_node_to_ready(self)
                self.value = Pending
                self.state = Node.State.READY
                
            # to DONE
            else:
                self.value = self.old_value
                self.state = Node.State.DONE
                
        # Case to PENDING
        else:                
            # in any case, make sure we're pending                
            self.value = Pending
            self.state = Node.State.PENDING

    def handle_computation_end(self, value):
        """Update new and old values, draw as needed"""
        # todo--only erase and redraw when output is actually different
        if isinstance(value, Output):
            display_value = value.display
            value = value.value
            
        else:
            display_value = value
            value = value

        self.value = value
        self.state = Node.State.DONE
        
        # todo: update the old values
        # Question: can we guarantee this only gets called by "good" updates?
        
        if self.display_widget:
            self.display_widget.clear_output()

            if isinstance(display_value, str):
                self.display_widget.append_stdout(display_value)
            else:
                try:
                    self.display_widget.append_display_data(display_value)
                except Exception as e:
                    self.display_widget.append_stdout(display_value)
                    

In [118]:
def f1(a):
    time.sleep(2)
    return a+a

def f2(a, b):
    a = str(a)
    b = str(b)
    time.sleep(2)
    return a+b

In [119]:
in_widget1 = w.Text()
in_widget2 = w.Text()

output_widget1 = w.Output()

first_node = Node(
    args=[in_widget1],
    f=f1
)
second_node = Node(
    args=[in_widget2],
    f=f1
)
third_node = Node(
    args=[first_node, second_node],
    f=f2,
    display_widget=output_widget1
)

container_widget = w.VBox([in_widget1, in_widget2, output_widget1, debug_output_widget, debug_clear_button])

supervisor = Supervisor()

display(container_widget)

VBox(children=(Text(value=''), Text(value=''), Output(), Output(layout=Layout(border='solid')), Button(descrip…

Result from worker 140293903539616 returned
