In [13]:
import threading as th
import ipywidgets as w
import multiprocessing as mp
import traitlets as tr
import time
from enum import Enum

In [14]:
class Output:
    """
    Container class for multi-modal node outputs
    """

    def __init__(self, value=None, display=None):
        self.display = display
        self.value = value

In [15]:
class Pending: pass

In [16]:
class Supervisor:

    def __init__(self):
        
        # the "queue" to consume from
        self.ready_nodes = set()
        
        # condition to notify connsumers
        self.ready_nodes_cond = th.Condition()

        self.work_lock = th.Lock()

        # node ids to workers
        self.work = {}
        
        self.start_worker_threads()


    def handle_node_to_ready(self, node):
        """The node is ready for a new computation"""

        with self.ready_nodes_cond:
            self.ready_nodes.add(node)
            self.ready_nodes_cond.notify()

    def handle_node_away_from_ready(self, node):
        """If node reverts to a non-ready state,
        cancel all worker-related activity for that node"""

        # check if was ready for computation
        with self.ready_nodes_cond:
            if node in self.ready_nodes:
                self.ready_nodes.remove(node)
                return

            # check if needs to be interrupted
            with self.work_lock:
                if node in self.work:
                    # TODO
                    # interrupt worker at self.work[id(node)]
                    del self.work[id(node)]

    def start_worker_threads(self):
        """Start worker threads which will consume from ready node set"""
        worker_1 = Worker(self)
        worker_2 = Worker(self)
        self.workers = {worker_1, worker_2}

In [17]:
class Worker:

    def __init__(self, supervisor):
        self.supervisor = supervisor
        self.consumer_thread = th.Thread(target=self.consume)
        self.consumer_thread.start()

    def consume(self):
        """Consume tasks forever"""

        while True:
            
#             self.supervisor.request_work()

            with self.supervisor.ready_nodes_cond:
                # Wait until there's a ready node available.
                # When there is, lock the ready nodes set
                # and move the node to the work dict

                while not self.supervisor.ready_nodes:
                    self.supervisor.ready_nodes_cond.wait()

                # pop a ready node
                node = self.supervisor.ready_nodes.pop()

                with self.supervisor.work_lock:
                    self.supervisor.work[id(node)] = self

            # straight-up run the function--in the future we'll subprocessify this
            inputs = [arg.value for arg in node.args]
            output = node.f(*inputs)

            with self.supervisor.work_lock:
                node.handle_computation_end(output)
                del self.supervisor.work[id(node)]


In [18]:
class Node(tr.HasTraits):
    
    value = tr.Any()
    
    @tr.default("value")
    def _set_default_value(self):
        return Pending
    
    class State(Enum):
        DONE = 0
        PENDING = 1
        READY = 2

    def __init__(self, args=None, f=lambda: None, display_widget=None):
        print("Created Node with ID", id(self))
        for arg in args:
            if isinstance(arg, Node):
                # subscribe to traitlet
                arg.observe(self.handle_parent_change, names=["value"])
            elif isinstance(arg, w.DOMWidget):
                # subscribe to traitlet
                arg.observe(self.handle_parent_change, names=["value"])
            else:
                raise Exception

        self.args = args
        self.f = f
        
        if display_widget:
            assert isinstance(display_widget, w.widgets.widget_output.Output)
            
        self.display_widget = display_widget
        # current inputs given by [arg.value for arg in args]
        self.old_inputs = [arg.value for arg in args]
        self.old_value = None
        self.state = Node.State.DONE
        
    def handle_parent_change(self, change):
#         # Todo: use these to speed up comparison of current inputs to old_inputs
#         parent = change.owner
#         p_value = parent.value
        
        # recompute state of current node        
        current_inputs = [arg.value for arg in self.args]

        # State machine
        #                    : on node(+ state upd8) : on supervisor
        #
        # ready   -> ready   : do nothing            : cancel and then add
        # ready   -> done    : value <- old_value    : cancel
        # done    -> ready   : value <- pending      : add to ready
        
        # pending -> ready   : do nothing            : add to ready
        # pending -> done    : value <- old_value    : do nothing
        
        # Partial summary:
        # if from ready, call handle node away from ready
        
        # if to ready, call handle node to ready and set to pending
        # if to done, set value <- old_value
        
        
        # ready   -> pending : do nothing            : cancel
        # done    -> pending : value <- pending      : do nothing
        # pending -> pending : do nothing            : do nothing
        
        # Case to DONE or to READY
        if Pending not in current_inputs:
            # from READY
            if self.state == Node.State.READY:
                supervisor.handle_node_away_from_ready(self)
                
            #---
                
            # to READY
            if not current_inputs == self.old_inputs:
                supervisor.handle_node_to_ready(self)
                self.value = Pending
                self.state = Node.State.READY
                
            # to DONE
            else:
                self.value = self.old_value
                self.state = Node.State.DONE
                
        # Case to PENDING
        else:
            # from READY
            if self.state == Node.State.READY:
                supervisor.handle_node_away_from_ready(self)
                
            # in any case, make sure we're pending                
            self.value = Pending
            self.state = Node.State.PENDING

    def handle_computation_end(self, value):
        if isinstance(value, Output):
            display_value = value.display
            value = value.value
            
        else:
            display_value = value
            value = value

        self.value = value
        self.state = Node.State.DONE

        # Handle 
        if self.display_widget:
            self.display_widget.layout = w.Layout(border="solid 3px rgba(0,0,0,0)")
            self.display_widget.clear_output()

            if isinstance(display_value, str):
                self.display_widget.append_stdout(display_value)
            else:
                try:
                    self.display_widget.append_display_data(display_value)
                except Exception as e:
                    self.display_widget.append_stdout(display_value)
                    

In [19]:
def f1(a):
    time.sleep(1)
    return a+a

def f2(a, b):
    a = str(a)
    b = str(b)
    time.sleep(1)
    return a+b

In [20]:
in_widget1 = w.Text()
output_widget1 = w.Output()

first_node = Node(
    args=[in_widget1],
    f=f1,
    display_widget=output_widget1
)

container_widget = w.VBox([in_widget1, output_widget1])

supervisor = Supervisor()

display(container_widget)

Created Node with ID 140164781762208


VBox(children=(Text(value=''), Output()))

In [None]:
first_node.value = "ho"