In [None]:
import copy
from pathlib import Path
from collections import deque
from typing import List, Callable

import numpy as np


from PIL import Image, ImageDraw
import numpy as np
import matplotlib.pyplot as plt
import math
from heapq import heappop, heappush
from typing import Tuple, List, Iterable, Callable, Type, Dict, Union, Optional
import numpy.typing as npt

In [None]:
class Map:
    """
    Square grid map class represents the environment for our moving agent.

    Attributes
    ----------
    _width : int
        The number of columns in the grid
        
    _height : int
        The number of rows in the grid
        
    _cells : ndarray[int, ndim=2]
        The binary matrix, that represents the grid. 0 - cell is traversable, 1 - cell is blocked
    """

    def __init__(self, cells: npt.NDArray):
        """
        Initialization of map by 2d array of cells.
        
        Parameters
        ----------
        cells : ndarray[int, ndim=2]
            The binary matrix, that represents the grid. 0 - cell is traversable, 1 - cell is blocked.
        """
        self._width = cells.shape[1]
        self._height = cells.shape[0]
        self._cells = cells


    def in_bounds(self, i: int, j: int) -> bool:
        """
        Check if the cell (i, j) is on a grid.
        
        Parameters
        ----------
            i : int
                Number of the cell row in grid
            j : int
                Number of the cell column in grid
        Returns
        ----------
             bool
                Is the cell inside grid.
        """
        return (0 <= j < self._width) and (0 <= i < self._height)
    

    def traversable(self, i: int, j: int) -> bool:
        """
        Check if the cell (i, j) is not an obstacle.
        
        Parameters
        ----------
            i : int
                Number of the cell row in grid
            j : int
                Number of the cell column in grid
        Returns
        ----------
             bool
                Is the cell traversable.
        """
        return not self._cells[i, j]


    def get_neighbors(self, i: int, j: int) -> List[Tuple[int, int]]:
        """
        Get a list of neighbouring cells as (i,j) tuples.
        It's assumed that grid is 4-connected (i.e. only moves into cardinal directions are allowed)
                
        Parameters
        ----------
            i : int
                Number of the cell row in grid
            j : int
                Number of the cell column in grid
        Returns
        ----------
            neighbors : List[Tuple[int, int]]
                List of neighbouring cells.
        """ 
        neighbors = []
        # delta = [[0, 1], [1, 0], [0, -1], [-1, 0]]
        delta = [[0, 1], [1, 0], [0, -1], [-1, 0], 
                 [1, 1], [-1, 1], [1, -1], [-1, -1]]
        for d in delta:
            if self.in_bounds(i + d[0], j + d[1]) and self.traversable(i + d[0], j + d[1]):
                neighbors.append((i + d[0], j + d[1]))
        return neighbors

    
    def get_size(self) -> Tuple[int, int]:
        """
        Returns size of grid in cells.
        
        Returns
        ----------
            (height, widht) : Tuple[int, int]
                Number of rows and columns in grid
        """
        return (self._height, self._width)

In [None]:
class Node:

    def __init__(self, 
                 i: int, j: int, 
                 level: Union[float, int] = 0, 
                 parent: 'Node' = None):
        """
        Initialization of search node.
        
        Parameters
        ----------
        i, j : int, int
            Coordinates of corresponding grid element.
        g : float | int
            g-value of the node.
        h : float | int 
            h-value of the node // always 0 for Dijkstra.
        f : float | int 
            f-value of the node // always equal to g-value for Dijkstra.
        parent : Node 
            Pointer to the parent-node.
        """
        self.i = i
        self.j = j
        self.level = level    
        self.parent = parent

    
    def __eq__(self, other):
        """
        Estimating where the two search nodes are the same,
        which is needed to detect dublicates in the search tree.
        """
        
        return (self.i == other.i) and (self.j == other.j)

    def __hash__(self):
        """
        To implement CLOSED as set/dict of nodes we need Node to be hashable.
        """
        return hash((self.i, self.j))

    
    def __str__(self) -> str:
        return f"({self.i}, {self.j}) -> level={self.level}"

In [None]:
from scipy.spatial import distance

def compute_cost(x1, y1, x2, y2):
    return distance.euclidean([x1, y1], [x2, y2])

In [None]:
def convert_string_to_cells(cell_str: str) -> npt.NDArray:
    """
    Converting a string (with '#' representing obstacles and '.' representing free cells) to a grid

    Parameters
    ----------
    cell_str : str
        String, which contains information about grid map ('#' representing obstacles 
        and '.' representing free cells).
        
    Returns
    ----------
        cells : ndarray[np.int8, ndim=2]
            Grid map representation as matrix.
    """
    
    cells_list = []
    cells_row = []
    cell_lines = cell_str.split("\n")
    row = 0
    for line in cell_lines:
        cells_row = []
        column = 0
        if len(line) == 0:
            continue
        for char in line:
            if char == '.':
                cells_row.append(0)
            elif char == '#':
                cells_row.append(1)
            else:
                continue
            column += 1
        row += 1
        cells_list.append(cells_row)
    cells = np.array(cells_list, dtype=np.int8)
    return cells


In [None]:
def draw(grid_map: Map, 
         start: Node = None, 
         goal: Node = None, 
         path: Iterable[Node] = None, 
         nodes_discovered: Iterable[Node] = None, 
         nodes_expanded: Iterable[Node] = None, 
         nodes_reexpanded: Iterable[Node] = None):
    """
    Auxiliary function that visualizes the environment, the path and 
    the expanded/not yet expanded/re-expanded nodes.
    Re-expansions are not happening in Dijkstra, but still...
    
    The function assumes that nodes_discovered/nodes_expanded/nodes_reexpanded
    are iterable collestions of search nodes.
    
    Parameters
    ----------
    grid_map : Map
        Environment representation in for of grid.
    start, goal : Node, Node
        Nodes for start and goal positions of agent
    path : Iterable[Node] 
        Sequence of nodes, which represents the path between start ang goal cells.
    nodes_discovered : Iterable[Node] 
        Nodes, which were discovered during search process.
    nodes_expanded : Iterable[Node]
        Nodes, which were expanded during search process.
    nodes_reexpanded : Iterable[Node]
        Nodes, which were reexpanded during search process.
    """
    scale = 5
    height, width = grid_map.get_size()
    h_im = height * scale
    w_im = width * scale
    im = Image.new('RGB', (w_im, h_im), color = 'white')
    draw = ImageDraw.Draw(im)
    
    for i in range(height):
        for j in range(width):
            if(not grid_map.traversable(i, j)):
                draw.rectangle((j * scale, i * scale, (j + 1) * scale - 1, (i + 1) * scale - 1), fill=( 70, 80, 80 ))

    if nodes_discovered is not None:
        for node in nodes_discovered:
            draw.rectangle((node.j * scale, node.i * scale, (node.j + 1) * scale - 1, (node.i + 1) * scale - 1), fill=(213, 219, 219), width=0)
    
    if nodes_expanded is not None:
        for node in nodes_expanded:
            draw.rectangle((node.j * scale, node.i * scale, (node.j + 1) * scale - 1, (node.i + 1) * scale - 1), fill=(131, 145, 146), width=0)
            
    
    if nodes_reexpanded is not None:
        for node in nodes_reexpanded:
                draw.rectangle((node.j * scale, node.i * scale, (node.j + 1) * scale - 1, (node.i + 1) * scale - 1), fill=(255, 145, 146), width=0)

    
    if path is not None:
        for step in path:
            if (step is not None):
                if (grid_map.traversable(step.i, step.j)):
                    draw.rectangle((step.j * scale, step.i * scale, (step.j + 1) * scale - 1, (step.i + 1) * scale - 1), fill=(52, 152, 219), width=0)
                else:
                    draw.rectangle((step.j * scale, step.i * scale, (step.j + 1) * scale - 1, (step.i + 1) * scale - 1), fill=(230, 126, 34), width=0)

    if (start is not None) and (grid_map.traversable(start.i, start.j)):
        draw.rectangle((start.j * scale, start.i * scale, (start.j + 1) * scale - 1, (start.i + 1) * scale - 1), fill=(40, 180, 99), width=0)
    
    if (goal is not None) and (grid_map.traversable(goal.i, goal.j)):
        draw.rectangle((goal.j * scale, goal.i * scale, (goal.j + 1) * scale - 1, (goal.i + 1) * scale - 1), fill=(231, 76, 60), width=0)
        
    _, ax = plt.subplots(dpi=150)
    ax.axes.xaxis.set_visible(False)
    ax.axes.yaxis.set_visible(False)
    plt.imshow(np.asarray(im))
    plt.show()

In [None]:
height = 15
width = 30
map_str = '''
. . . . . . . . # # . . . . . . . . . . . # # . . . . . . .  
. . . # # . . . # # . . . . . . . . . . . # # . . . . . . . 
. . . # # . . . # # . . . . . . . . . . . # # . . . . . . . 
. . . # # . . . # # . . . # # . . . . . . # # . . . . . . . 
. . . # # . . . # # . . . # # . . . . . . # # . . . . . . . 
. . . # # . . . # # . . . # # . . . . . . # # # # # . . . . 
. . . # # . . . # # . . . # # . . . . . . # # # # # . . . . 
. . . # # . . . # # . . . # # . . . . . . . . . . . . . . . 
. . . # # . . . # # . . . # # . . . . . . . . . . . . . . . 
. . . # # . . . # # . . . # # . . . . . . . . . . . . . . . 
. . . # # . . . # # . . . # # . . . . . . . . . . . . . . . 
. . . # # . . . # # . . . # # . . . . . . . . . . . . . . . 
. . . # # . . . # # . . . # # . . . . . . . . . . . . . . . 
. . . # # . . . # # . . . # # . . . . . . . . . . . . . . .
. . . # # . . . . . . . . # # . . . . . . . . . . . . . . .
'''

# map_str = '''
# . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .  
# . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 
# . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 
# . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 
# . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 
# . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 
# . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 
# . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 
# . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 
# . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 
# . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 
# . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 
# . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 
# . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .
# . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .
# '''

cells = convert_string_to_cells(map_str)
test_map = Map(cells)
start = Node(1, 1)
goal = Node(7, 15)
draw(test_map, start, goal)

In [None]:
from collections import deque
from matplotlib import pyplot as plt

def fill_heuristic_values(finish_node: Node, task_map: Map, heuristic_func: Callable):
    height, width = task_map.get_size()
    x_arr = np.tile(np.arange(height), (width, 1)).T
    y_arr = np.tile(np.expand_dims(np.arange(width), axis=1), (1, height)).T
    # h_arr = np.sqrt((x_arr - goal.i) ** 2 + (y_arr - goal.j) ** 2)

    h_arr = heuristic_func(x_arr, y_arr, finish_node.i, finish_node.j)
    h_arr[task_map._cells != 0] = 0
    return h_arr


def fill_true_dists(finish_node: Node, task_map: Map):
    layer = deque()
    layer.append(finish_node)

    # node_levels = np.zeros(task_map.get_size())
    node_levels = np.full(task_map.get_size(), np.inf)
    node_levels[finish_node.i, finish_node.j] = 0

    while len(layer) > 0:
        cur_layer_node = layer.popleft()

        for i, j in task_map.get_neighbors(cur_layer_node.i, cur_layer_node.j):
            step_length = compute_cost(cur_layer_node.i, cur_layer_node.j, i, j)
            child_node_level = cur_layer_node.level + step_length

            child_node = Node(i=i, j=j, level=child_node_level)

            if child_node != finish_node:
                if node_levels[i][j] == np.inf or child_node_level < node_levels[i][j]:
                    layer.append(child_node)
                    node_levels[i][j] = child_node.level

    return node_levels

def fill_cf_values(finish_node: Node, task_map: Map, heuristic_func: Callable):
    true_dists = fill_true_dists(finish_node, task_map)
    # print("############")
    # print(true_dists)
    # plt.imshow(true_dists)
    # plt.show()
    true_dists[true_dists == 0.0] = 1 # zero division

    # plt.imshow(true_dists)
    # plt.show()
    
    h_values = fill_heuristic_values(finish_node, task_map, heuristic_func)
    # print("############")
    # plt.imshow(h_values)
    # plt.show()

    cf_values = h_values / true_dists

    # print(task_map._cells)
    # cf = np.nan_to_num(cf)
    cf_values[task_map._cells == 1] = 0
    cf_values[finish_node.i][finish_node.j] = 1
    # print("############")
    # plt.imshow(cf_values, vmin=0, vmax=1)
    # plt.show()
    return cf_values

euclidean_distance = lambda x_arr, y_arr, goal_x, goal_y: np.sqrt((x_arr - goal_x) ** 2 + (y_arr - goal_y) ** 2)
manhattan_distance = lambda x_arr, y_arr, goal_x, goal_y: np.abs(x_arr - goal_x) + np.abs(y_arr - goal_y)

def diagonal_distance(x_arr, y_arr, goal_x, goal_y): 
    dx = np.abs(x_arr - goal_x)
    dy = np.abs(y_arr - goal_y)
    return np.abs(dx - dy) + np.sqrt(2) * np.minimum(dx, dy)

# dist_map = fill_true_dists(goal, test_map)
# plt.imshow(dist_map)
# plt.show()

# dist_map = fill_heuristic_values(goal, test_map, manhattan_distance)
# plt.imshow(dist_map)
# plt.show()

cf_map = fill_cf_values(goal, test_map, diagonal_distance)
plt.imshow(cf_map)
plt.show()

In [None]:
dist_map = fill_true_dists(goal, test_map)
dist_map[-5:,-5:]

# Loading maps

In [None]:
from pathlib import Path

load_dir = Path("/archive/savkin/raw_datasets/MIPT/TransPath_data")

load_path = load_dir / "val"

abs = np.load(load_path / "abs.npy")
cf = np.load(load_path / "cf.npy")
focal = np.load(load_path / "focal.npy")
goals = np.load(load_path / "goals.npy")
maps = np.load(load_path / "maps.npy")
starts = np.load(load_path / "starts.npy")

np.savez(load_path / "val.npz", dem=maps)  #abs=abs, cf=cf, focal=focal, goals=goals, maps=maps, starts=starts)

In [None]:
tmp = np.load("/archive/savkin/raw_datasets/MIPT/TransPath_data/dems/val.npz")
tmp = np.load("/archive/savkin/raw_datasets/MIPT/TransPath_data/val/cf_pred.npy")

In [None]:
# tmp["dem"][0][0]

In [None]:
from tqdm import tqdm

def extract_node_pos(map):
    pos = np.where(map == 1)
    return (pos[0][0], pos[1][0])

cf_pred_arr = []
for i in tqdm(range(10)):
    map = Map(maps[i][0])
    cf_true = cf[i][0]

    # start = Node(*extract_node_pos(starts[i][0]))
    goal_node = Node(*extract_node_pos(goals[i][0]))
    cf_pred = fill_cf_values(goal_node, map, diagonal_distance)
    cf_pred_arr.append([cf_pred])

    plt.imshow(cf_true, vmin=0, vmax=1)
    plt.show()

    plt.imshow(cf_pred, vmin=0, vmax=1)
    plt.show()
    print("##################################")

cf_pred_arr = np.array(cf_pred_arr)
cf_pred_arr.shape