<a href="https://colab.research.google.com/github/SIshikawa1106/planner/blob/master/kShorterPath.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import numpy as np
import heapq

class Graph(object):
  def __init__(self):
    self.node_list = []
    self.link = {}
    self.route_links = {}
    self.route_link_costs = {}
    
  def add_node_edge(self, node, links, costs):
    assert len(links) == len(costs)
    self.node_list.append(node)
    index = len(self.node_list) - 1
    
    self.link[index] = {}
    
    for n, idx in enumerate(links):
      assert idx in self.link
      assert index not in self.link[idx]
      self.link[idx][index] = costs[n]
    
    return index

  def get_link_node_idx_list(self, index):
    assert index in self.link
    node_list = []
    for n in range(len(self.node_list)):
      assert n in self.link
      if n < index:
        if index in self.link[n]:
          node_list.append(n)
      elif n in self.link[index]:
        node_list.append(n)
    return node_list

  def get_link_cost(self, idx1, idx2):
    pre_idx = idx1 if idx1<idx2 else idx2
    nxt_idx = idx2 if idx1<idx2 else idx1

    assert pre_idx in self.link
    assert nxt_idx in self.link[pre_idx]

    return self.link[pre_idx][nxt_idx]

  def get_heuristic_cost(self, target_idx, goal_idx):
    assert max([target_idx, goal_idx]) < len(self.node_list)
    if target_idx in self.heuristic_costs:
      return self.heuristic_costs[target_idx]
    n1 = self.node_list[target_idx]
    n2 = self.node_list[goal_idx]
    cost = np.linalg.norm(n1 - n2)
    self.heuristic_costs[target_idx] = cost
    return cost

  def add_node_edge_and_find_shorter_path(self, node, links, costs):
    target_index = self.add_node_edge(node, links, costs)
    goal_index = self.goal_index
    for link_node_idx, link_cost in zip(links, costs):
      if link_node_idx not in self.route_links:
        assert link_node_idx not in self.route_link_costs
        continue
      assert link_node_idx in self.route_link_costs
      assert len(self.route_links) == len(self.route_link_costs)
      route_list = self.route_links[link_node_idx]
      route_costs = self.route_link_costs[link_node_idx]

      for n in range(len(route_list)):
        current_route = route_list[n]
        current_cost = route_costs[n]
        heuristic_cost = 0
        if link_node_idx != goal_index:
          heuristic_cost = self.get_heuristic_cost(link_node_idx, goal_index)

        new_route = current_route + [link_node_idx]
        new_route_cost = current_cost + link_cost
        new_route_est_cost = new_route_cost + heuristic_cost
        print("new_route={}".format(new_route))
        if link_node_idx not in self.route_links:
          assert link_node_idx not in self.route_link_costs
          self.route_links[link_node_idx] = []
          self.route_link_costs[link_node_idx] = []
        self.route_links[link_node_idx].append(new_route)
        self.route_link_costs[link_node_idx].append(new_route_cost)

        heapq.heappush(self.heap,
                   (new_route_est_cost,
                    link_node_idx,
                    new_route,
                    new_route_cost)
                   )

    found_route_index_last = len(self.route_links[goal_index])
    #find
    while len(heap)>0:
      # get min cost index
      est_cost, target_index, current_route, current_cost = heapq.heappop(self.heap)
      print("\n\nest_cost={}\ntarget_index={}\ncurrent_route={}".format(est_cost, target_index, current_route))
      #input()
      if est_cost > th_cost:
        continue
      if target_index == goal_index:
        continue
      
      # getting next node list
      link_node_idx_list = self.get_link_node_idx_list(target_index)

      for link_node_idx in link_node_idx_list:
        # skip when current route is included next node. 
        if link_node_idx in current_route:
          continue
        link_cost = self.get_link_cost(target_index, link_node_idx)
        heuristic_cost = 0
        if link_node_idx != goal_index:
          heuristic_cost = self.get_heuristic_cost(link_node_idx, goal_index)

        new_route = current_route + [link_node_idx]
        new_route_cost = current_cost + link_cost
        new_route_est_cost = new_route_cost + heuristic_cost
        print("new_route={}".format(new_route))
        if link_node_idx not in self.route_links:
          assert link_node_idx not in self.route_link_costs
          self.route_links[link_node_idx] = []
          self.route_link_costs[link_node_idx] = []
        self.route_links[link_node_idx].append(new_route)
        self.route_link_costs[link_node_idx].append(new_route_cost)

        heapq.heappush(self.heap,
                   (new_route_est_cost,
                    link_node_idx,
                    new_route,
                    new_route_cost)
                   )

    return self.route_links[goal_index][found_route_index_last:],
     self.route_link_costs[goal_index][found_route_index_last:]

  def find_shorter_path(self, start_index, goal_index, th_cost):
    # init
    self.heap = []
    self.route_links = {}
    self.route_link_costs = {}
    self.route_links[start_index] = []
    self.heuristic_costs = {}
    self.route_links[start_index] = []
    self.route_link_costs[start_index] = []
    self.start_index = start_index
    self.goal_index = goal_index
    current_route = [start_index]
    current_cost = 0
    est_cost = 0
    # (est_cost_to_goal, own index, route_index_from_start, cost_from_start)
    heapq.heappush(self.heap, (est_cost, start_index, current_route, current_cost))

    while len(heap)>0:
      # get min cost index
      est_cost, target_index, current_route, current_cost = heapq.heappop(self.heap)
      print("\n\nest_cost={}\ntarget_index={}\ncurrent_route={}".format(est_cost, target_index, current_route))
      #input()
      if est_cost > th_cost:
        continue
      if target_index == goal_index:
        continue
      
      # getting next node list
      link_node_idx_list = self.get_link_node_idx_list(target_index)

      for link_node_idx in link_node_idx_list:
        # skip when current route is included next node. 
        if link_node_idx in current_route:
          continue
        link_cost = self.get_link_cost(target_index, link_node_idx)
        heuristic_cost = 0
        if link_node_idx != goal_index:
          heuristic_cost = self.get_heuristic_cost(link_node_idx, goal_index)

        new_route = current_route + [link_node_idx]
        new_route_cost = current_cost + link_cost
        new_route_est_cost = new_route_cost + heuristic_cost
        print("new_route={}".format(new_route))
        if link_node_idx not in self.route_links:
          assert link_node_idx not in self.route_link_costs
          self.route_links[link_node_idx] = []
          self.route_link_costs[link_node_idx] = []
        self.route_links[link_node_idx].append(new_route)
        self.route_link_costs[link_node_idx].append(new_route_cost)

        heapq.heappush(self.heap,
                   (new_route_est_cost,
                    link_node_idx,
                    new_route,
                    new_route_cost)
                   )

    return self.route_links[goal_index], self.route_link_costs[goal_index]

if __name__=="__main__":
  np.random.seed(0)
  node_list = [np.random.rand(2) for _ in range(10)]
  print(node_list)

  """
  0-1-2--
  |   |  \
  3---4-5-9
  |   |\ /
  8   6-7

  """
  graph = Graph()
  graph.add_node_edge(node_list[0], [], [])
  graph.add_node_edge(node_list[1], [0], [np.random.rand(1)])
  graph.add_node_edge(node_list[2], [1], [np.random.rand(1)])
  graph.add_node_edge(node_list[3], [0], [np.random.rand(1)])
  graph.add_node_edge(node_list[4], [2, 3], [np.random.rand(1) for _ in range(2)])
  graph.add_node_edge(node_list[5], [4], [np.random.rand(1)])
  graph.add_node_edge(node_list[6], [4], [np.random.rand(1)])
  graph.add_node_edge(node_list[7], [4, 6], [np.random.rand(1) for _ in range(2)])
  graph.add_node_edge(node_list[8], [3], [np.random.rand(1)])
  graph.add_node_edge(node_list[9], [2, 5, 7], [np.random.rand(1) for _ in range(3)])

  result = graph.find_shortest(0, 1, 100)
  print("\n".join(map(lambda x: str(x), result[0])))
  print("\n".join(map(lambda x: str(x), result[1])))

[array([0.5488135 , 0.71518937]), array([0.60276338, 0.54488318]), array([0.4236548 , 0.64589411]), array([0.43758721, 0.891773  ]), array([0.96366276, 0.38344152]), array([0.79172504, 0.52889492]), array([0.56804456, 0.92559664]), array([0.07103606, 0.0871293 ]), array([0.0202184 , 0.83261985]), array([0.77815675, 0.87001215])]


est_cost=0
target_index=0
current_route=[0]
new_route=[0, 1]
new_route=[0, 3]


est_cost=[0.84568726]
target_index=3
current_route=[0, 3]
new_route=[0, 3, 4]
new_route=[0, 3, 8]


est_cost=[0.97511663]
target_index=4
current_route=[0, 3, 4]
new_route=[0, 3, 4, 2]
new_route=[0, 3, 4, 5]
new_route=[0, 3, 4, 6]
new_route=[0, 3, 4, 7]


est_cost=[0.97861834]
target_index=1
current_route=[0, 1]


est_cost=[1.10540033]
target_index=6
current_route=[0, 3, 4, 6]
new_route=[0, 3, 4, 6, 7]


est_cost=[1.40931166]
target_index=5
current_route=[0, 3, 4, 5]
new_route=[0, 3, 4, 5, 9]


est_cost=[1.52587282]
target_index=8
current_route=[0, 3, 8]


est_cost=[1.56591149]
tar

In [0]:
ｘｚ