## Basic package

In [None]:
# try:
#     %load_ext autoreload
#     %autoreload 2
# except:
#     pass
# other package imports
import matplotlib.pyplot as plt  # keep this import for CI to work
from zanj import ZANJ  # saving/loading data
from muutils.mlutils import pprint_summary  # pretty printing as json
import argparse
import os
import sys
import tqdm
import numpy as np

# maze_dataset imports
from maze_dataset import LatticeMaze, SolvedMaze, MazeDataset, MazeDatasetConfig
from maze_dataset.generation import LatticeMazeGenerators, GENERATORS_MAP
from maze_dataset.generation.default_generators import DEFAULT_GENERATORS
from maze_dataset.dataset.configs import MAZE_DATASET_CONFIGS
from maze_dataset.plotting import plot_dataset_mazes, print_dataset_mazes

# check the configs
print(MAZE_DATASET_CONFIGS.keys())
# for saving/loading things
LOCAL_DATA_PATH: str = "../data/maze_dataset/"
zanj: ZANJ = ZANJ(external_list_threshold=256)

def str2bool(x):
    if isinstance(x, bool):
        return x
    x = x.lower()
    if x[0] in ['0', 'n', 'f']:
        return False
    elif x[0] in ['1', 'y', 't']:
        return True
    raise ValueError('Invalid value: {}'.format(x))


## basic configs of dataset

In [2]:
sys.argv = [sys.argv[0]]
parser = argparse.ArgumentParser(description="Generate a dataset of mazes")
parser.add_argument("--dataset_name", type=str, default="Maze_test", help="Name of the dataset")
parser.add_argument("--grid_n", type=int, default=10, help="Number of rows/columns in the lattice")
parser.add_argument("--n_mazes", type=int, default=1000, help="Number of mazes to generate")
parser.add_argument("--maze_ctor", type=str, default="gen_dfs", help="Algorithm to generate the maze")
parser.add_argument("--do_download", type=str2bool, default=False, help="Download the dataset")
parser.add_argument("--load_local", type=str2bool, default=False, help="Load the dataset locally")
parser.add_argument("--do_generate", type=str2bool, default=True, help="Generate the dataset")
parser.add_argument("--save_local", type=str2bool, default=True, help="Save the dataset locally")
parser.add_argument("--local_base_path", type=str, default=None, help="Base path for local storage")
parser.add_argument("--verbose", type=str2bool, default=True, help="Print information about the dataset")
parser.add_argument("--gen_parallel", type=str2bool, default=False, help="Generate the mazes in parallel")
parser.add_argument("--min_length", type=int, default=0, help="Minimum length of the maze")
parser.add_argument("--max_length", type=int, default=100, help="Maximum length of the maze")

args = parser.parse_args()

try:
    get_ipython().run_line_magic('matplotlib', 'inline')
    args = parser.parse_args([])
    # args.dataset_name = "Maze-train"
    # args.dataset_name = "Maze-val"
    args.dataset_name = "Maze-test"
    # args.dataset_name = "TEST"
    args.grid_n = 30
    args.n_mazes = 1000
    args.maze_ctor = "gen_dfs"
    args.do_download = False
    args.load_local = False
    args.do_generate = True
    args.save_local = False
    args.local_base_path = "./x/dataset/maze" 
    args.verbose = True
    args.gen_parallel = True
    args.min_length = 5
    args.max_length = 20

    is_jupyter = True
except:
    args = parser.parse_args()
    is_jupyter = False

args.local_base_path = args.local_base_path + "/" + args.dataset_name+f'grid_n-{args.grid_n}_n_mazes-{args.n_mazes}_min_length-{args.min_length}_max_length-{args.max_length}'
# if not os.path.exists(args.local_base_path):
#     os.makedirs(args.local_base_path)
cfg: MazeDatasetConfig = MazeDatasetConfig(
    name=args.dataset_name,  # name of the dataset
    grid_n=args.grid_n,  # number of rows/columns in the lattice
    n_mazes=args.n_mazes,  # number of mazes to generate
    maze_ctor=LatticeMazeGenerators.gen_dfs,  # algorithm to generate the maze
    # there are a few more arguments here, to be discussed later
)

# Generate dataset from configs

In [None]:
dataset: MazeDataset = MazeDataset.from_config(
    cfg,
    # and all this below is completely optional
    do_download=args.do_download,
    load_local=args.load_local,
    do_generate=args.do_generate,
    save_local=args.save_local,
    local_base_path=args.local_base_path,
    verbose=args.verbose,
    zanj=zanj,
    gen_parallel=args.gen_parallel,
)

# plot some datapoints

In [None]:
plot_dataset_mazes(
    dataset, count=6
)  # for large datasets, set the count to some int to just plot the first few

# Filter out some datapoints

In [None]:
# dataset_filtered: MazeDataset = dataset.filter_by.path_length(min_length=args.min_length)
dataset_filtered: MazeDataset = dataset.filter_by.path_length(min_length=args.min_length)
plot_dataset_mazes(
    dataset_filtered, count=1
)  # for large datasets, set the count to some int to just plot the first few

# process the origin dataset to image dataset
Each image is a $H*W*3$ matrix, where H and W equal to $N_{grid} *2 +1$, number of channels is 3, including map,goal, and path. For map, -1 means free space, 1 means obstacle. For goal, 0 means start point, 1 means goal point. For path, -1 means no path, 1 means path.

In [None]:
WALL = 1
REE = -1
START = 0
END = 1
PATH_POINT = 1
args.local_base_path = args.local_base_path +'N-'+str(len(dataset_filtered))
if not os.path.exists(args.local_base_path):
    os.makedirs(args.local_base_path)
plot_dataset_mazes(
    dataset_filtered, count=1
)  # for large datasets, set the count to some int to just plot the first few

# for i in tqdm.tgrange(len(dataset_filtered)):
for i in range(len(dataset_filtered)):
    data_i = dataset_filtered[i]
    pixel_grid_bw = data_i._as_pixels_bw()
    pixel_grid = np.full(
        (*pixel_grid_bw.shape, 3), -1, dtype=np.int8
    ) # set all to -1 [H,W,3]
    # set map
    pixel_grid[pixel_grid_bw == True,0] = WALL

    # Set goal
    pixel_grid[data_i.start_pos[0] * 2 + 1, data_i.start_pos[1] * 2 + 1,1] = START
    pixel_grid[data_i.end_pos[0] * 2 + 1, data_i.end_pos[1] * 2 + 1,1] = END

    # Set path
    for coord in data_i.solution:
        pixel_grid[coord[0] * 2 + 1, coord[1] * 2 + 1,2] = PATH_POINT
    ## set pixels between coords
    for index, coord in enumerate(data_i.solution[:-1]):
        next_coord = data_i.solution[index + 1]
        # check they are adjacent using norm
        assert (
            np.linalg.norm(np.array(coord) - np.array(next_coord)) == 1
        ), f"Coords {coord} and {next_coord} are not adjacent"
        # set pixel between them
        pixel_grid[
            coord[0] * 2 + 1 + next_coord[0] - coord[0],
            coord[1] * 2 + 1 + next_coord[1] - coord[1],2
        ] = PATH_POINT
    np.save(f"{args.local_base_path}/maze_solved-{i}.npy", pixel_grid)

# plt pixel_grid
img_show = (pixel_grid+1.0)*122
plt.imshow(img_show)


print(f"Done! {len(dataset_filtered)} datapoints saved to {args.local_base_path}")

In [None]:
def plot_maze_from_pixel_grid(pixel_grid, save_path=None):
    """
    Plot the maze based on the given pixel grid and save the image.
    - WALL = 1 -> black
    - START = 0 -> green
    - END = 1 -> purple
    - PATH_POINT = 1 -> blue
    - Other regions -> white
    
    Args:
        pixel_grid (numpy.ndarray): The pixel grid representing the maze (shape: [H, W, 3]).
        save_path (str, optional): The path to save the plotted image. If None, the plot will be displayed instead.
    """
    # Create an empty RGB grid to store the color information (H, W, 3)
    H, W, _ = pixel_grid.shape

    # Create an empty RGB grid to store the color information (H, W, 3)
    maze_rgb = np.ones((H, W, 3), dtype=np.float32)  # Initialize with white color

    # WALL = 1 -> black
    maze_rgb[pixel_grid[:, :, 0] == 1] = [0, 0, 0]  # Set wall to black

    # START = 0 -> green (path[0,1] = green)
    maze_rgb[pixel_grid[:, :, 1] == 0] = [0, 1, 0]  # Set start point to green

    # END = 1 -> purple (path[0,2] = purple)
    maze_rgb[pixel_grid[:, :, 1] == 1] = [1, 0, 1]  # Set end point to purple

    # PATH_POINT = 1 -> blue (path[2] = blue)
    maze_rgb[pixel_grid[:, :, 2] == 1] = [0, 0, 1]  # Set path to blue

    # Plotting the maze
    plt.figure(figsize=(H / 10, W / 10))
    plt.imshow(maze_rgb)
    plt.axis('off')  # Hide axes

    # Saving or showing the plot
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Maze saved to {save_path}")
    else:
        plt.show()  # Show the maze if no save path is provided

    plt.close()

plot_dataset_mazes(pixel_grid)

In [None]:
img_show = (pixel_grid+1.0)*122
pixel_grid_wall = pixel_grid[...,0]==1
pixel_grid_start = pixel_grid[...,1]==0
pixel_grid_end = pixel_grid[...,1]==1
pixel_grid_path = pixel_grid[...,2]==1
img_show = np.ones_like(pixel_grid)*255
img_show[pixel_grid_wall] = [0,0,0]
# img_show[pixel_grid_start] = [0,255,0]
# img_show[pixel_grid_end] = [0,0,255]
# img_show[pixel_grid_path & ~pixel_grid_start & ~pixel_grid_end] = [0,0,122]
plt.imshow(img_show)