In [None]:
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import networkx as nx
from enum import Enum

In [None]:
def visualize_graph(n_nodes, positions, children):
    fig = go.Figure()

    # Add lines from parent to children
    for i in range(n_nodes):
        for child in children[i]:
            fig.add_trace(go.Scatter(
                x=[positions[i, 0], positions[child, 0]],
                y=[positions[i, 1], positions[child, 1]],
                mode='lines',
                line=dict(color='gray', width=2),
                showlegend=False
            ))

    # Add nodes on top
    scatter = go.Scatter(
        x=positions[:, 0], 
        y=positions[:, 1], 
        mode='markers',
        text=[str(i) for i in range(n_nodes)],
        textposition='middle center',
        marker=dict(size=5, color='lightblue', line=dict(color='black', width=2))
    )
    fig.add_trace(scatter)

    fig.update_layout(
        title="Random Tree Visualization",
        xaxis_title="X",
        yaxis_title="Y",
        showlegend=False
    )
    return fig

def animate_trajectories(fig, node_positions, trajectories, node_colors=None, highlight=None):        
    trajectories = np.asarray(trajectories)

    n_nodes = node_positions.shape[0]
    n_trajectories = trajectories.shape[0]
    n_frames = trajectories.shape[1]
    frames = []
    trace = len(fig.data)

    # Use provided colors or default to trajectory index
    if node_colors is None:
        node_colors = np.full([n_frames, n_nodes], 'rgba(0, 0, 0, 0)', dtype=object)
    else:
        node_colors = np.asarray(node_colors, dtype=object)

    n_colors = len(px.colors.qualitative.Set3)
    for i in range(n_trajectories):
        color = px.colors.qualitative.Set3[i % n_colors]
        node_colors[np.arange(n_frames), trajectories[i]] = color

    for i in range(n_frames):
        x = node_positions[:, 0]
        y = node_positions[:, 1]

        node_size = np.full([n_nodes], 8).astype(int)
        if highlight is not None:
            node_size[trajectories[highlight, i]] = 14
        
        frame_colors = node_colors[i]
        frames.append(
            go.Frame(
                data=go.Scatter(
                    x=x, y=y,
                    hoverinfo='none',
                    mode='markers',
                    marker=dict(size=node_size, color=frame_colors, opacity=1)
                ),
                traces=[trace]
            )
        )

    fig.add_trace(frames[0].data[0])
    fig.update(frames=frames)

    fig.update_layout(updatemenus=[dict(
            type="buttons",
            showactive=False,
            buttons=[dict(label="Play",
                          method="animate",
                          args=[None, {"frame": {"duration": 500, "redraw": True}, 
                                        "fromcurrent": True, 
                                        "transition": {"duration": 0}}])]
        )])

In [None]:
class TreeNode:
    def __init__(self, value, parent=None):
        self.value = value
        self.parent = parent
        self.children = []
    
    def add_child(self, child_node):
        child_node.parent = self
        self.children.append(child_node)
    
    def __repr__(self):
        return f"Node({self.value})"

class ExploringAgent:
    def __init__(self, id, starting_node, true_map, all_responsibilities, taken_responsibilities, responsibility=None, root=0):
        self.id = id
        self.root = root
        self.current_node = starting_node
        self.node_map = {}
        self.finished_exploring = []
        self.true_map = true_map
        self.all_responsibilities = all_responsibilities
        self.taken_responsibilities = taken_responsibilities
        self.update_map()
        self.node_map[self.root].update({'responsibilities': all_responsibilities.copy()})
        self.intercept_responsibility = None
        self.responsibility = responsibility if responsibility else self.id
        self.all_intercepts = []
    
    def visit_node(self, node):
        """Visit a node and update the map with its structure"""
        node = self.true_map[node]
        current_entry = self.node_map.get(node.value, {'visited': False})
        if not current_entry['visited']:
            current_entry.update({
                'parent': node.parent.value if node.parent else None,
                'children': [child.value for child in node.children],
                'visited': True
            })
            self.node_map[node.value] = current_entry
    
    def current_children(self):
        return self.node_map[self.current_node]['children']
    
    def current_parent(self):
        return self.node_map[self.current_node].get('parent', None)
    
    def assign_responsibilities(self):
        if self.node_map[self.current_node]['parent'] is None:
            agents = self.all_responsibilities
        else:
            # Non-root node case: get responsibilities from current node
            agents = self.node_map[self.current_node]['responsibilities']
        
        # Calculate responsibilities for children
        responsibilities = self.calculate_responsibilities(self.current_node, agents)
        
        # Save responsibilities in node_map for each child
        for child_value in self.node_map[self.current_node]['children']:
            if child_value not in self.node_map:
                self.node_map[child_value] = {
                    'responsibilities': responsibilities[child_value],
                    'visited': False
                }
            else:
                # Update the child's responsibilities in the node_map
                self.node_map[child_value]['responsibilities'] = responsibilities[child_value]
    
    def _is_subtree_fully_explored(self, node_value, responsibility):
        """
        Check if all nodes in the subtree rooted at node_value that are this agent's 
        responsibility have been visited.
        """
        if node_value not in self.node_map:
            return False
            
        node_info = self.node_map[node_value]
        
        # Check if this node is our responsibility and if it's been visited
        responsibilities = node_info.get('responsibilities', [])
        if responsibility not in responsibilities:
            return True

        if not node_info.get('visited', False):
            return False
        
        # Recursively check all children
        for child_value in node_info.get('children', []):
            if not self._is_subtree_fully_explored(child_value, responsibility):
                return False
            
        return True
    
    def step(self):
        if self.responsibility is not None:
            return self.open_explore_step()
        elif self.intercept_responsibility is not None:
            return self.intercept_step()
        return self.current_parent()

    def open_explore_step(self):
        self.assign_responsibilities()

        for child_value in self.current_children():
            if not self._is_subtree_fully_explored(child_value, self.responsibility):
                return child_value
        if not self._is_subtree_fully_explored(self.root, self.responsibility):
            return self.current_parent()
        
        self.finished_exploring.append(self.responsibility)
        self.responsibility = None
        return self.step()

    def intercept_step(self):
        self.assign_responsibilities()

        for child_value in self.current_children()[::-1]:
            if not self._is_subtree_fully_explored(child_value, self.intercept_responsibility):
                return child_value
        if not self._is_subtree_fully_explored(self.root, self.intercept_responsibility):
            return self.current_parent()
        
        self.intercept_responsibility = None
        return self.step()
    
    def update_map(self):
        self.visit_node(self.current_node)
    
    def calculate_responsibilities(self, node, agents):
        """
        Distribute agents across the children of a given node.
        
        Args:
            node: TreeNode whose children will be assigned to agents
            agents: List of ExploringAgent instances to distribute
        
        Returns:
            dict: Mapping of child nodes to lists of assigned agents
        """

        children = self.node_map[node].get('children', None)
        if not children:
            return {}
        
        responsibilities = {child: [] for child in children}
        n_agents = len(agents)
        n_children = len(children)

        if n_agents >= n_children:
            agents_per_child = n_agents // n_children
            extra_agents = n_agents % n_children

            remaining_agents = agents.copy()
            for i, child in enumerate(children):
                n_assigned = agents_per_child + (1 if i < extra_agents else 0)
                responsibilities[child].extend(remaining_agents[:n_assigned])
                remaining_agents = remaining_agents[n_assigned:]
        else:
            children_per_agent = n_children // n_agents
            extra_children = n_children % n_agents
            
            remaining_children = children.copy()
            for i, agent in enumerate(agents):
                n_assigned = children_per_agent + (1 if i < extra_children else 0)
                for child in remaining_children[:n_assigned]:
                    responsibilities[child].append(agent)
                remaining_children = remaining_children[n_assigned:]

        return responsibilities

In [None]:
def synchronize_agent_maps(agents):
    synchronized_map = {}
    
    for agent in agents:
        for node_value, info in agent.node_map.items():
            synced_entry = synchronized_map.get(node_value, info.copy())
            visited = synced_entry.get('visited', False) or info.get('visited', False)
            synced_entry.update(info)
            synced_entry['visited'] = visited
            synchronized_map[node_value] = synced_entry
    
    for agent in agents:
        agent.node_map = synchronized_map.copy()

def synchronize_taken_responsibilities(agents):
    synchronized = list(set(sum([agent.taken_responsibilities for agent in agents], [])))
    for agent in agents:
        agent.taken_responsibilities = synchronized.copy()

def synchronize_finished_exploring(agents):
    synchronized = list(set(sum([agent.finished_exploring for agent in agents], [])))
    for agent in agents:
        agent.finished_exploring = synchronized.copy()

def synchronize_intercepts(agents):
    synchronized = list(set(sum([agent.all_intercepts for agent in agents], [])))
    for agent in agents:
        agent.all_intercepts = synchronized.copy()
    

def synchronize_information(agents):
    if len(agents) <= 1:
        return
    synchronize_agent_maps(agents)
    synchronize_finished_exploring(agents)
    synchronize_intercepts(agents)

In [None]:
def find_groups(agents, edges):
    def is_connected(i, j):
        i = agents[i].current_node
        j = agents[j].current_node
        return (i == j) or np.all([i, j] == edges, axis=1).any() or np.all([j, i] == edges, axis=1).any()

    roots = np.arange(len(agents)).tolist()
    for i in range(0, len(agents)):
        for j in range(i+1, len(agents)):
            if is_connected(i, j):
                new_root = min(roots[i], roots[j])
                roots[i] = new_root
                roots[j] = new_root

    groups = {}
    for agent, root in zip(agents, roots):
        if root in groups:
            groups[root].append(agent)
        else:
            groups[root] = [agent]
    return list(groups.values())

In [None]:
def update_tasks(agents, root):
    list_filter = lambda agent: (agent.responsibility is None) and (agent.intercept_responsibility is None)
    remaining_agents = [agent for agent in agents if list_filter(agent)]
    remaining_agents = sorted(remaining_agents, key=lambda agent: agent.id)
    coordinator = None
    
    for agent in remaining_agents:
        if agent.current_node == root:
            coordinator = agent
            remaining_agents.remove(agent)
            
    if (coordinator is None):
        return        
    
    all_intercepts = coordinator.all_intercepts.copy()
    taken_responsibilities = coordinator.taken_responsibilities.copy()
    open_intercepts = (set(coordinator.all_responsibilities) - set(coordinator.finished_exploring)) - set(coordinator.all_intercepts)
    open_responsibilities = set(coordinator.all_responsibilities) - set(coordinator.taken_responsibilities)
    
    for responsibility in open_responsibilities:
        if len(remaining_agents) == 0:
            break
        remaining_agents.pop().responsibility = responsibility
        taken_responsibilities.append(responsibility)

    for intercept in open_intercepts:
        if len(remaining_agents) == 0:
            break
        remaining_agents.pop().intercept_responsibility = intercept
        all_intercepts.append(intercept)
    
    for agent in agents:
        agent.all_intercepts = all_intercepts.copy()
        agent.taken_responsibilities = taken_responsibilities.copy()

In [None]:
def generate_random_tree(n_nodes=10):
    end = np.arange(n_nodes)
    start = (end * np.random.uniform(size=n_nodes)).astype(int)
    edges = np.stack([start, end], axis=1)[1:]
    
    children = [edges[edges[:, 0] == i, 1].tolist() for i in range(n_nodes)]
    
    # Create nodes
    nodes = {i: TreeNode(i) for i in range(n_nodes)}
    
    # Build tree structure
    for i, child_indices in enumerate(children):
        for child_idx in child_indices:
            nodes[i].add_child(nodes[child_idx])
    
    G = nx.DiGraph()
    for parent, local_children in enumerate(children):
        for child in local_children:
            G.add_edge(parent, child)
            
    positions = nx.nx_agraph.graphviz_layout(G, prog='dot')
    positions = np.asarray([value for key, value in sorted(positions.items())])

    return nodes, edges, positions, children

In [None]:
def find_inactive_agents(agents, target_nodes):
    taken_targets = []
    remaining_agents = []
    inactive_positions = {}

    for agent in agents:
        target = target_nodes[agent.id]
        if ((agent.responsibility is None) and (agent.intercept_responsibility is None)) or (target is None):
            inactive_positions[agent.current_node] = agent
            continue
        if target in taken_targets:
            continue
        remaining_agents.append(agent)
        taken_targets.append(target)
    
    return remaining_agents, inactive_positions

def transfer_responsibilities(old_agent, new_agent):
    if old_agent.intercept_responsibility is None:
        tmp = old_agent.responsibility
        old_agent.responsibility = new_agent.responsibility
        new_agent.responsibility = tmp
    else:
        new_agent.intercept_responsibility = old_agent.intercept_responsibility
        old_agent.intercept_responsibility = None
    
def move_active_agents_if_possible(active_agents, blocked_nodes, target_nodes):
    moved_agents = []

    n_remaining = -1
    while len(active_agents) != n_remaining:
        n_remaining = len(active_agents)
        for agent in active_agents.copy():
            if target_nodes[agent.id] not in blocked_nodes.values():
                moved_agents.append(agent)
                del blocked_nodes[agent.id]
                active_agents.remove(agent)
    
    for agent in moved_agents:
        agent.current_node = target_nodes[agent.id]

def move_inactive_agents_if_possible(inactive_agents, all_agents, target_nodes):
    inactive_agents = sorted(inactive_agents, key=lambda agent: agent.id)
    for agent in inactive_agents:
        target = target_nodes[agent.id]
        if target is None:
            continue
        if target not in [agent.current_node for agent in all_agents]:
            agent.current_node = target

def no_collision_move(agents):
    target_nodes = {agent.id: agent.step() for agent in agents}
    agents = sorted(agents, key=lambda agent: agent.id)
    current_nodes = {agent.id: agent.current_node for agent in agents}
    active_agents, inactive_positions = find_inactive_agents(agents, target_nodes)
    
    for agent_iter in active_agents.copy():
        old_agent = agent_iter
        target = target_nodes[old_agent.id]
        while target in inactive_positions:
            new_agent = inactive_positions[target]

            transfer_responsibilities(old_agent, new_agent)
            active_agents.remove(old_agent)
            inactive_positions[old_agent.current_node] = old_agent
            
            new_target = new_agent.step()
            if new_target is None:
                break

            active_agents.append(new_agent)
            del inactive_positions[target]

            old_agent = new_agent
            target_nodes[new_agent.id] = new_target
            target = new_target

    move_active_agents_if_possible(active_agents, current_nodes.copy(), target_nodes)
    move_inactive_agents_if_possible(inactive_positions.values(), agents, target_nodes)

In [None]:
n_nodes = 30
true_map, edges, positions, children = generate_random_tree(n_nodes=n_nodes)

In [None]:
n_agents = 2
n_responsibilities = 3
root = 0
agents = [ExploringAgent(i, root, true_map, root=root, all_responsibilities=list(range(n_responsibilities)), taken_responsibilities=list(range(n_agents))) for i in range(n_agents)]
trajectories = [[0] for _ in range(n_agents)]
explored_nodes = [[[]] for _ in range(n_agents)]

for i in range(150):
    connected_agent_groups = find_groups(agents, edges)
    for group in connected_agent_groups:
        synchronize_information(group)
        update_tasks(group, root)
        no_collision_move(group)
        
    for agent in agents:
        agent.update_map()
        trajectories[agent.id].append(agent.current_node)
        
    for idx, agent in enumerate(agents):
        explored_nodes[idx].append([k for k, v in agent.node_map.items() if v.get('visited', False)])

# Todo:
replace self.all_agents with all_resposibilities, add anohter list open_responsibilities. If there are still open ones, agentent that have finished their responsibility will be dealt a new one before intercepts are handed out.

In [None]:
def get_node_color_for_agent(i):
    n_timesteps = len(explored_nodes[i])
    node_colors = np.full((n_timesteps, n_nodes), 'rgba(0, 0, 0, 255)', dtype=object)

    for t, explored_list in enumerate(explored_nodes[i]):
        for node_id in explored_list:
            node_colors[t, node_id] = 'rgba(100, 255, 100, 255)'
    return node_colors

In [None]:
fig = visualize_graph(n_nodes, positions, children)
animate_trajectories(fig, positions, trajectories, node_colors=get_node_color_for_agent(1), highlight=1)

fig.show()