In [None]:
import os
from typing import Tuple

import torch
from torch import Tensor

from matplotlib import pyplot as plt
from Agents.A2CAgent import A2CAgent
import yaml
from TrainUtils import *
from torch.utils.tensorboard import SummaryWriter

from Models.ModelUtils import *
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import cv2

# Dataset

In [None]:
class MemoryDataset(Dataset):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    """
    dataset directory structure:
    YYYYMMDD-HHMMSS
        - 1
            - 00000.pt
            - 00001.pt
            - ...
        - 2
        - ...

    """
    def __init__(self, memory_dirs: List[str]):
        self.memory_dirs = memory_dirs

        self.file_paths = []
        for memory_dir in memory_dirs:
            for checkpoint_dir in os.listdir(memory_dir):
                checkpoint_path = os.path.join(memory_dir, checkpoint_dir)
                for memory_file in os.listdir(checkpoint_path):
                    memory_path = os.path.join(checkpoint_path, memory_file)
                    self.file_paths.append(memory_path)

    def __len__(self):
        return len(self.file_paths)

    def getSaftyMap(self, lidar_map):
        blur_map = lidar_map[:, 2, ...].unsqueeze(1)  # (B, 1, 127, 127)
        blur_map = blur_map * (blur_map > 0.2)
        blur_map = torch.nn.functional.max_pool2d(blur_map, kernel_size=5, stride=1, padding=2)
        blur_map = torch.nn.functional.max_pool2d(blur_map, kernel_size=5, stride=1, padding=2)
        safety_map = torch.ones(1, 5, 127, 127, dtype=torch.float32, device=blur_map.device)
        safety_map[:, 1, ...] = safety_map[:, 1, ...] * (blur_map < 0.1)  # green means absolutely safe
        safety_map[:, 0, ...] = safety_map[:, 0, ...] * (blur_map > 0.1) * (blur_map < 0.4)  # blue means not suggested
        safety_map[:, 2, ...] = safety_map[:, 2, ...] * (blur_map > 0.4)  # red means dangerous
        safety_map[:, 3, ...] = lidar_map[:, 0, ...].unsqueeze(1)  # channel 3 means the target
        safety_map[:, 4, ...] = lidar_map[:, 1, ...].unsqueeze(1)  # channel 4 means the start

        return safety_map


    def getGridWorld(self, safty_map):
        # grid_world, blue: wall, green: target, red: start
        grid_world = func.max_pool2d(safty_map[:, 2], kernel_size=4, stride=4, padding=2)

        target_loc = torch.argmax(safty_map[:, 3].flatten(1), dim=1, keepdim=True)
        target_rows = target_loc // 127 // 4
        target_cols = target_loc % 127 // 4
        end = torch.zeros((grid_world.shape[0], 32, 32), dtype=torch.float32, device=grid_world.device)
        # compute the distance from every pixel to the target
        h_dist_map = torch.abs(torch.arange(32, dtype=torch.float32, device=grid_world.device).unsqueeze(0).repeat(32, 1) - target_cols)
        v_dist_map = torch.abs(torch.arange(32, dtype=torch.float32, device=grid_world.device).unsqueeze(1).repeat(1, 32) - target_rows)
        dist_map = torch.sqrt(h_dist_map ** 2 + v_dist_map ** 2).unsqueeze(0).repeat(grid_world.shape[0], 1, 1)
        end = 1 - dist_map / torch.max(dist_map.flatten(0), dim=0).values.view(-1, 1, 1)
        # cv2.imshow("dist_map", end[0].cpu().numpy())
        # cv2.waitKey(0)

        start = torch.zeros((grid_world.shape[0], 32, 32), dtype=torch.float32, device=grid_world.device)
        start[:, 16, 16] = 1
        return torch.stack([end, start, grid_world], dim=1)


    def __getitem__(self, idx) -> torch.Tensor:
        state: VehicleState = torch.load(self.file_paths[idx])[0]

        lidar_map, _ = state.getTensor()    # lidar_map: (1, 3, 127, 127)
        safety_map = self.getSaftyMap(lidar_map)  # (1, 5, 127, 127)
        grid_world = self.getGridWorld(safety_map)

        return lidar_map, safety_map, grid_world


def collectFunc(batch: List[torch.Tensor]) -> List[Tensor]:
    return batch[0]  # (batch_size, 3, 127, 127)

# A* Algorithm

In [None]:
def getPathAStar(grid_world, obstacle_map, start_pos, end_pos, value_map):
    # run A* algorithm
    path = [start_pos]
    path_map = torch.zeros_like(obstacle_map)
    path_map[end_pos[0], end_pos[1]] = 0.2
    path_map[start_pos[0], start_pos[1]] = 1

    while path[-1] != end_pos:
        current_r, current_c = path[-1]
        best_neighbor = None
        best_value = -1
        for dr, dc in [(0, 1), (0, -1), (1, 0), (-1, 0)]:
            neighbor_r = current_r + dr
            neighbor_c = current_c + dc
            if 0 <= neighbor_r < 34 and 0 <= neighbor_c < 34 and obstacle_map[neighbor_r, neighbor_c] == 0:
                value = value_map[neighbor_r, neighbor_c]
                if value > best_value:
                    best_value = value
                    best_neighbor = [neighbor_r, neighbor_c]
        path.append(best_neighbor)
        path_map[current_r, current_c] = 0.5
        path_map[best_neighbor[0], best_neighbor[1]] = 1
        # update value map
        value_map[current_r, current_c] -= 1 / 16

        grid_world_np = grid_world[0].permute(1, 2, 0).detach().cpu().numpy()
        cv2.imshow("grid_world", cv2.resize(grid_world_np, (512, 512), interpolation=cv2.INTER_NEAREST))
        cv2.imshow("path_map", cv2.resize(path_map.cpu().numpy(), (512, 512), interpolation=cv2.INTER_NEAREST))
        cv2.waitKey(0)

    return path

# Smooth Path

In [None]:
def smoothPath(path, obstacle_map):
    # shorten path, if two non-adjacent path points are connected by a straight line,
    # replace all middle points by a straight line
    if len(path) < 3:
        return path

    i = 0
    while (i < len(path) - 2):
        # find the last connectable point from path[i]
        should_replace = False
        replace_id = -1
        longest_linkage = []
        for j in range(i+2, len(path)):
            ri, ci = path[i]
            rj, cj = path[j]
            dr = rj - ri
            dc = cj - ci
            l1_dist = abs(dr) + abs(dc)

            linkage = [[ri, ci]]
            linkage_exist = True
            while linkage[-1] != [rj, cj]:
                moveable = False
                if linkage[-1][0] != rj and obstacle_map[linkage[-1][0]+np.sign(dr), linkage[-1][1]] == 0:
                    linkage.append([linkage[-1][0]+np.sign(dr), linkage[-1][1]])
                    moveable = True
                if linkage[-1][1] != cj and obstacle_map[linkage[-1][0], linkage[-1][1]+np.sign(dc)] == 0:
                    linkage.append([linkage[-1][0], linkage[-1][1]+np.sign(dc)])
                    moveable = True
                if not moveable:
                    linkage_exist = False
                    break

            if j - i > l1_dist and linkage_exist:
                longest_linkage = linkage[:]
                should_replace = True
                replace_id = j

        if should_replace:
            # now, replace path[i+1] to path[replace_id-1] by a straight line connecting path[i] and path[replace_id]
            path = path[:i+1] + longest_linkage + path[replace_id:]
            i = replace_id
        else:
            i += 1

    return path

# Path finder
Combine above functions

In [None]:
def findPathTraditional(grid_world):
    # grid_world: (1, 3, 32, 32)
    # channel 0: end
    # channel 1: start
    # channel 2: obstacle

    grid_world = func.pad(grid_world, (1, 1, 1, 1), mode="constant", value=0)

    # run path finding algorithm
    obstacle_map = grid_world[0, 2]     # 32x32
    value_map = grid_world[0, 0]    # 32x32

    start_pos = [17, 17]
    end_pos = [int(each) for each in torch.where(grid_world[0, 0] == 1)]
    # end_pos.reverse()

    path = getPathAStar(grid_world, obstacle_map, start_pos, end_pos, value_map)

    path = smoothPath(path, obstacle_map)

    grid_world_np = grid_world[0].permute(1, 2, 0).detach().cpu().numpy()
    for r, c in path:
        grid_world_np[r, c, 1] = 1
    cv2.imshow("grid_world", cv2.resize(grid_world_np, (512, 512), interpolation=cv2.INTER_NEAREST))
    cv2.waitKey(0)

    return path

# Main

with open("config_with_mem.yaml", 'r') as in_file:
        configs = yaml.load(in_file, Loader=yaml.FullLoader)

    dataset = MemoryDataset(configs["data_folders"])

    lidar_map, safety_map, grid_world = dataset[600]
    cv2.imshow("lidar", lidar_map[0].cpu().permute(1, 2, 0).numpy())
    cv2.imshow("safety", safety_map[0, :3].cpu().permute(1, 2, 0).numpy())
    cv2.waitKey(0)

    findPathTraditional(grid_world)