In [18]:
from PIL import Image, ImageDraw
import numpy as np
import matplotlib.pyplot as plt
import math
%matplotlib inline


EPS = 1e-6

In [19]:
class GridState:
    def __init__(self, i, j):
        self.i = i
        self.j = j
        
    def __str__(self):
        return f'{self.i}_{self.j}'
    
    def __repr__(self):
        return str(self)

In [20]:
class Map:

    # Default constructor
    def __init__(self, imagePath=None):
        if imagePath is not None:
            self.cells = np.abs(1 - cv2.imread(imagePath, cv2.IMREAD_GRAYSCALE) // 255)
        else:
            self.cells = []
    
    def SetGridCells(self, gridCells):
        self.cells = gridCells

    # Checks cell is on grid.
    def inBounds(self, i, j):
        height, width = self.cells.shape
        return (0 <= j < width) and (0 <= i < height)
    
    # Checks cell is not obstacle.
    def Traversable(self, i, j):
        return not self.cells[i][j]

    # Creates a list of neighbour cells as (i,j) tuples.
    def GetNeighbors(self, state):
        i = state.i
        j = state.j
        # TODO Change the function so that the list includes the diagonal neighbors of the cell.
        # Cutting corners must be prohibited
        neighbors = []
        delta = [[0, 1], [1, 0], [0, -1], [-1, 0], [-1, -1], [-1, 1], [1, -1], [1, 1]]
        for d in delta:
            next_i = i + d[0]
            next_j = j + d[1]
            if self.inBounds(next_i, next_j) and self.Traversable(next_i, next_j) and self.Traversable(i+d[0], j) and self.Traversable(i, j+d[1]):
                neighbors.append(GridState(next_i, next_j))

        return neighbors

In [21]:
class Node:
    def __init__(self, state, g = math.inf, h = 0, f = None, parent = None, k = 0, depth=0):
        self.state = state
        self.g = g
        self.h = h
        self.k = k
        self.depth = depth
        if f is None:
            self.f = self.g + self.h
        else:
            self.f = f        
        self.parent = parent
    
    def __eq__(self, other):
        return (self.state.i == other.state.i) and (self.state.j == other.state.j)
    
    def __str__(self):
        return f'Node(i={self.state.i}, j={self.state.j}, g={self.g}, h={self.h})'
    
    def __repr__(self):
        return str(self)

    def __lt__(self, other):
        return self.f < other.f or (self.f == other.f and self.h < other.h)\
        or ((self.f == other.f) and (self.h == other.h) and (self.k < other.k))

In [22]:
from heapq import heapify, heappop, heappush

class Open():
    def __init__(self):
        self.pr_queue = []
        self.dict = {}
        
    def __iter__(self):
        return iter(self.dict.values())

    def __len__(self):
        return len(self.dict)

    def isEmpty(self):
        return len(self.dict) == 0
    
    def AddNode(self, item):
        node = self.dict.get(str(item.state), None)
        if node is None or item.g < node.g:
            self.dict[str(item.state)] = item
            heappush(self.pr_queue, item)
            
    def GetBestNode(self, pop=True):
        bestNode = self.pr_queue[0]
        
        while str(bestNode.state) not in self.dict:
            heappop(self.pr_queue)
            bestNode = self.pr_queue[0]
        
        if pop:
            self.dict.pop(str(bestNode.state))
            heappop(self.pr_queue)
        
        return bestNode


class Closed ():
    def __init__(self):
        self.elements = {}
        
    def __iter__(self):
        return iter(self.elements.values())
    
    def __len__(self):
        return len(self.elements)
        
    def AddNode(self, item, *args):
        self.elements[str(item.state)] = item
        
    def WasExpanded(self, item):
        return str(item.state) in self.elements
    
    def GetNode(self, item):
        return self.elements.get(str(item.state))
    
    def RemoveNode(self, item):
        self.elements.pop(str(item.state), None)

In [23]:
def ManhattanDistance(state1, state2):
    dx = abs(state1.j - state2.j)
    dy = abs(state1.i - state2.i)
    return abs(dx - dy)


def EuclideanDistance(state1,state2):
    return math.sqrt((state1.i - state2.i) ** 2 + (state1.j - state2.j) ** 2)


def DiagonalDistance(state1, state2):
    c_hv = 1
    c_d = math.sqrt(2)
    
    dy = abs(state1.i - state2.i)
    dx = abs(state1.j - state2.j)
    
    return c_hv * abs(dx-dy) + c_d * min(dx,dy)


def radCost(state1, state2):
    dy = abs(state1.i - state2.i)
    dx = abs(state1.j - state2.j)
    return max(dx, dy)

In [24]:
def MakePath(goal):
    length = goal.g
    current = goal
    path = []
    while current.parent:
        path.append(current)
        current = current.parent
    path.append(current)
    return path[::-1], length

In [25]:
def AStar(gridMap, startState, goalState, calcCost = EuclideanDistance,
          heuristicFunction = DiagonalDistance, reExpansion=False, weight=1):
    OPEN = Open()
    CLOSED = Closed()
    
    start = Node(startState, 0, 0)
    goal = Node(goalState) if goalState is not None else None
    pathFound = False
    
    OPEN.AddNode(start)
    
    while len(OPEN):
        best_node = OPEN.GetBestNode()
        CLOSED.AddNode(best_node)
        
        if goal is not None and best_node == goal:
            pathFound = True
            goal = best_node
            break
             
        for state in gridMap.GetNeighbors(best_node.state):
            neighbor = Node(state, parent=best_node, k=len(CLOSED))
            neighbor.g = best_node.g + calcCost(best_node.state, neighbor.state)
            if abs(weight) > EPS and goal is not None:
                neighbor.h = heuristicFunction(neighbor.state, goal.state)
                
            neighbor.f = neighbor.g + neighbor.h * weight
            
            node = CLOSED.GetNode(neighbor)
            
            if node is None:
                OPEN.AddNode(neighbor)
    
    return pathFound, goal, CLOSED, OPEN

In [122]:
class Oracle:
    def __init__(self, map_):
        self.map = map_
        self.q = None
        
    def calculateQ(self, goal):
        startState = goal
        goalState = None
        self.q = AStar(self.map, startState, goalState, weight=0)[2].elements
        
        h, w = self.map.cells.shape
        
        for i in range(h):
            for j in range(w):
                state = GridState(i, j) 
                if str(state) not in self.q:
                    self.q[str(state)] = Node(state, g = EuclideanDistance(GridState(0,0), GridState(h,w))+1)
#         self.q = sorted(self.q.values(), key=lambda item: item.g)

        
    def getQ(self, state):
        if str(state) in self.q:
            return self.q[str(state)].g
        return None
    
    def policy(self, open_):
        isOpen = sorted([node for node in self.q.values() if str(node.state) in open_], 
                        key=lambda node: node.g)
        
        return isOpen[0]
        
#         best_idx = 0
#         while best_idx < len(self.q) and str(self.q[best_idx].state) not in open_:
#             print(str(self.q[best_idx].state))
#             best_idx += 1
            
#         return self.q[best_idx]

In [158]:
import torch.nn as nn
import torch
import torch.nn.functional as F

class QNet(nn.Module):
    def __init__(self, input_dim=17, hid_dims=(100, 50), output_dim=1):
        super().__init__()
        self.q = nn.Sequential(
            nn.Linear(input_dim, hid_dims[0]),
            nn.ReLU(),
            nn.Linear(hid_dims[0], hid_dims[1]),
            nn.ReLU(),
            nn.Linear(hid_dims[1], output_dim)
        )          
        
    def forward(self, x):
        return self.q(x)

In [168]:
import os
import cv2

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

class Trainer:
    def __init__(self):
        self.oracle = None
        self.start = None
        self.goal = None
        self.map = None
        self.agent = QNet().to(device)
        self.open = {}
        self.closed = {}
        self.invalid = {}
        
        self.optimizer = torch.optim.Adam(self.agent.parameters(), lr=1e-3)
        
    def sampleMap(self):
        datasetPath = './motion_planning_datasets/'
        idx = np.random.randint(len(os.listdir(datasetPath)))
        el = os.listdir(datasetPath)[idx]
        imagePath = os.path.join(datasetPath, el, 'train', f'{np.random.randint(800)}.png')     
        
        self.map = Map(imagePath)
        self.oracle = Oracle(self.map)
        
    def sampleTargetPoints(self):
        startState = None
        goalState = None

        while startState is None or goalState is None:
            s_i = np.random.randint(self.map.cells.shape[0])
            s_j = np.random.randint(self.map.cells.shape[1])  

            if self.map.Traversable(s_i, s_j):
                startState = GridState(s_i, s_j)

            g_i = np.random.randint(self.map.cells.shape[0])                
            g_j = np.random.randint(self.map.cells.shape[1])     

            if self.map.Traversable(g_i, g_j):
                goalState = GridState(g_i, g_j)

        self.start = startState
        self.goal = goalState
        
        self.oracle.calculateQ(self.goal)
        
        
    def extractFeatures(self, node):
        features = []
        
        features.extend([node.state.i, node.state.j])
        features.extend([self.goal.i, self.goal.j])
        features.append(node.g)
        features.append(EuclideanDistance(node.state, self.goal))
        features.append(ManhattanDistance(node.state, self.goal))
        features.append(node.depth)
        
        minDist = (EuclideanDistance(GridState(0,0), GridState(self.map.cells.shape[0], self.map.cells.shape[1]))+1, 
                   GridState(self.map.cells.shape[0]+1, self.map.cells.shape[1]+1))
        
        minDistX = (EuclideanDistance(GridState(0,0), GridState(self.map.cells.shape[0], self.map.cells.shape[1]))+1, 
                   GridState(self.map.cells.shape[0]+1, self.map.cells.shape[1]+1))
        
        minDistY = (EuclideanDistance(GridState(0,0), GridState(self.map.cells.shape[0], self.map.cells.shape[1]))+1, 
                   GridState(self.map.cells.shape[0]+1, self.map.cells.shape[1]+1))
        
        for obs in self.invalid.values():
            dist = EuclideanDistance(node.state, obs.state)
            distX = abs(node.state.j - obs.state.j)
            distY = abs(node.state.i - obs.state.i)
            
            if dist < minDist[0]:
                minDist[0] = dist
                minDist[1] = obs.state
                
            if distX < minDistX[0]:
                minDistX[0] = distX
                minDistX[1] = obs.state
                
            if distY < minDistY[0]:
                minDistY[0] = distY
                minDistY[1] = obs.state
                
        features.extend([minDist[0], minDist[1].i, minDist[1].j])
        features.extend([minDistX[0], minDistX[1].i, minDistX[1].j])                    
        features.extend([minDistY[0], minDistY[1].i, minDistY[1].j])
                    
        return np.array(features)
           
    def mixPolicy(self, beta):
        if np.random.random(1) < beta:
            res_node = self.oracle.policy(self.open)
        else:
            res = []
            for node in self.open.values():
                x = self.extractFeatures(node)
                with torch.no_grad():
                    res.append((node, self.agent(x).item()))
                
            res.sort(key=lambda item: item[1])
            res_node = res[0]
        
        self.open.pop(str(res_node.state), None)
        return res_node
    
    def expand(self, node):
        for child in self.map.GetNeighbors(node.state):
            if not self.map.Traversable(child.i, child.j):
                self.invalid.add(str(child))
                
            if str(child) not in self.closed:
                childNode = Node(child, parent=node, depth=node.depth+1, 
                                 g=node.g+EuclideanDistance(node.state, child))
                self.open[str(child)] = childNode
    
    def update(self, D):
        inputs, truths = D
        
        inputs = torch.FloatTensor(np.array(inputs)).to(device)
        truths = torch.FloatTensor(np.array(truths)).to(device)
        
        preds = self.agent(inputs).squeeze()
        
        loss = F.mse_loss(preds, truths)
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss.item()
        
    
    def train(self, N, m, k, T):
        D = [[], []]
        for i in range(N):
            for j in range(m):
                self.sampleMap()
                self.sampleTargetPoints()
                
                timesteps = set(np.random.choice(range(T), size=k, replace=False))
                self.open[str(self.start)] = Node(self.start, g=0, depth=1)
                t = 0
                while t < T and len(self.open) > 0:
                    node = self.mixPolicy(1)
                    self.expand(node)
                    if t in timesteps:
                        D[0].append(self.extractFeatures(node))
                        D[1].append(self.oracle.getQ(node.state))
                    
                    t += 1
            loss = self.update(D)
            print(loss)

In [172]:
trainer = Trainer()

In [173]:
trainer.train(10, 5, 20, 100)

665.579345703125
261.5365295410156
289.1980285644531
285.23358154296875
217.87440490722656
141.48098754882812
93.59748840332031
89.37812805175781
110.33057403564453
121.92019653320312
