In [26]:
from collections import deque
from enum import Enum
import torch
import random
import tqdm
from torch import nn
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


## Robot class

Given a maze matrix, the robot will start from the start state and go to the target state.

.train() method will train the robot.
.q_matrix to get the q matrix.

In [27]:

class Robot:
    """
    Given a maze matrix, the robot will start from the start state and go to the target state.
    """
    def __init__(self, start_state: tuple[int,int], target_state: tuple[int,int], maze_matrix: torch.Tensor, 
                 actions: torch.Tensor = torch.tensor([[0, -1], [0, 1], [-1, 0], [1, 0]], dtype=torch.int8),
                 reward_matrix: torch.Tensor = None, restrict_matrix: torch.Tensor = None, q_matrix: torch.Tensor = None, 
                 max_step = 100, greedy_rate = 0.7, discount = 0.9, lr = 0.1, 
                 wall_reward = -100, target_reward = 100):
        
        self.start_state = torch.tensor(start_state)
        self.target_state = torch.tensor(target_state)
        self.maze_matrix = maze_matrix
        self.actions = actions
        self.max_step = max_step
        self.greedy_rate = greedy_rate
        self.discount = discount
        self.lr = lr
        
        self.n_action = actions.size(0)
        
        if reward_matrix is None:
            reward_matrix = -torch.ones_like(maze_matrix, dtype=torch.int8)
            reward_matrix[maze_matrix] = wall_reward
            reward_matrix[target_state[0], target_state[1]] = target_reward
        self.reward_matrix = reward_matrix
            
        if restrict_matrix is None:
            restrict_matrix = torch.zeros_like(reward_matrix, dtype=torch.bool)
        self.restrict_matrix = restrict_matrix            
        
        if q_matrix is None:
            q_matrix = torch.zeros_like(reward_matrix, dtype=torch.float32).repeat(self.n_action,1,1)
        self.q_matrix = q_matrix
    
    def choose_action_id(self, state: torch.Tensor, greedy_rate: float = None) -> int:
        """
        Choose an action based on the given state and greedy rate.
        """
        
        if greedy_rate is None:
            greedy_rate = self.greedy_rate
            
        # ramdom choose action
        if random.random() > greedy_rate:
            return random.randint(0, self.n_action - 1)
        # greedy choose action
        else:
            q_values = self.q_matrix[:, state[0], state[1]] # get q values of all actions
            max_value = torch.max(q_values) # get max q value of all actions
            max_indices = torch.where(q_values == max_value)[0].tolist() # get indices of max q value
            return random.choice(max_indices) # choose one of the max q value
    
    def update(self, state:torch.Tensor, action_id: int, next_state: torch.Tensor, discount = None, lr = None):
        """
        Update the q matrix based on the given state, action, next state, discount and learning rate.
        """
        if discount is None:
            discount = self.discount
        if lr is None:
            lr = self.lr
        
        reward = self.reward_matrix[next_state[0], next_state[1]]
        q_old = self.q_matrix[action_id, state[0], state[1]]
        
        if self.maze_matrix[next_state[0], next_state[1]]:
            q_new = reward
        else:
            q_next_max = torch.max(self.q_matrix[:, next_state[0], next_state[1]])
            q_new = reward + discount * q_next_max
        
        self.q_matrix[action_id, state[0], state[1]] += lr * (q_new - q_old)
    
    def go(self, start_state: torch.Tensor, valid: bool = False, debug: bool = False) -> list:
        """
        Start from the start state and go to the target state.
        """
        states = [start_state]
        state = states[-1]
        # while not stop_matrix[state[0], state[1]] and len(states) <= MAX_STEP:
        while not torch.allclose(state, self.target_state) and len(states) <= self.max_step:
            if valid: # if in valid mode, choose the best action
                action_id = self.choose_action_id(state, 1)
            else:   # if not in valid mode, choose action based on greedy rate
                action_id = self.choose_action_id(state)
            # Get next state based on the action
            next_state = state + self.actions[action_id]
            if not valid: # if not in valid mode, update the q matrix
                self.update(state, action_id, next_state)

            if not self.maze_matrix[next_state[0], next_state[1]] + self.restrict_matrix[next_state[0], next_state[1]]: # if not hit the wall or restricted
                state = next_state
            else: # if hit the wall
                if debug:
                    print(f'{next_state}, hit the wall or restricted')
            states.append(next_state) # add the next state to the states list
        return states
    
    def train(self, n_epoch: int = 100, valid: bool = False, debug: bool = False) -> list:
        """
        Train the robot for n_epoch times.
        """
        paths = []
        for epoch in tqdm.trange(n_epoch):
            path = self.go(self.start_state, valid, debug)
            paths.append(path)
        return paths

    def eval(self, debug: bool = False) -> list:
        """
        Evaluate the robot.
        """
        return self.go(self.start_state, True, debug)


## Generate maze

In [28]:

def randomPrimMaze(width, height) -> torch.Tensor:
    
    class WALL_DIRECTION(Enum):
        WALL_LEFT = 0,
        WALL_UP = 1,
        WALL_RIGHT = 2,
        WALL_DOWN = 3,

    class Map():
        def __init__(self, width, height):
            self.width = width
            self.height = height
            self.map = torch.ones((height, width), dtype=torch.bool)
        
        def setEmpty(self, x, y):
            self.map[y][x] = False
        
        def isVisited(self, x, y):
            return not self.map[y][x]

        def showMap(self):
            for row in self.map:
                s = ''
                for entry in row:
                    if not entry:
                        s += ' -1\t'
                    elif entry:
                        s += ' -100\t'
                    else:
                        s += ' X'
                print(s)
    
    # find unvisited adjacent entries of four possible entris
    # then add random one of them to checklist and mark it as visited
    def checkAdjacentPos(map, x, y, width, height, checklist):
        directions = []
        if x > 0:
            if not map.isVisited(2*(x-1)+1, 2*y+1):
                directions.append(WALL_DIRECTION.WALL_LEFT)
                    
        if y > 0:
            if not map.isVisited(2*x+1, 2*(y-1)+1):
                directions.append(WALL_DIRECTION.WALL_UP)

        if x < width -1:
            if not map.isVisited(2*(x+1)+1, 2*y+1):
                directions.append(WALL_DIRECTION.WALL_RIGHT)
            
        if y < height -1:
            if not map.isVisited(2*x+1, 2*(y+1)+1):
                directions.append(WALL_DIRECTION.WALL_DOWN)
            
        if len(directions):
            direction = random.choice(directions)
            #print("(%d, %d) => %s" % (x, y, str(direction)))
            if direction == WALL_DIRECTION.WALL_LEFT:
                map.setEmpty(2*(x-1)+1, 2*y+1)
                map.setEmpty(2*x, 2*y+1)
                checklist.append((x-1, y))
            elif direction == WALL_DIRECTION.WALL_UP:
                map.setEmpty(2*x+1, 2*(y-1)+1)
                map.setEmpty(2*x+1, 2*y)
                checklist.append((x, y-1))
            elif direction == WALL_DIRECTION.WALL_RIGHT:
                map.setEmpty(2*(x+1)+1, 2*y+1)
                map.setEmpty(2*x+2, 2*y+1)
                checklist.append((x+1, y))
            elif direction == WALL_DIRECTION.WALL_DOWN:
                map.setEmpty(2*x+1, 2*(y+1)+1)
                map.setEmpty(2*x+1, 2*y+2 )
                checklist.append((x, y+1))
            return True
        else:
            # if not find any unvisited adjacent entry
            return False
            
            
    # random prim algorithm
    def randomPrim(map, width, height):
        startX, startY = (random.randint(0, width-1), random.randint(0, height-1))
        print("start(%d, %d)" % (startX, startY))
        map.setEmpty(2*startX+1, 2*startY+1)
        
        checklist = []
        checklist.append((startX, startY))
        while len(checklist):
            # select a random entry from checklist
            entry = random.choice(checklist)	
            if not checkAdjacentPos(map, entry[0], entry[1], width, height, checklist):
                # the entry has no unvisited adjacent entry, so remove it from checklist
                checklist.remove(entry)
            
    map = Map(width, height)
    randomPrim(map, (map.width-1)//2, (map.height-1)//2)
    
    return map.map


## Generae distance matrix

In [29]:

def gen_d_matrix(start: tuple[int, int], maze_matrix: torch.Tensor, 
                 actions: torch.Tensor = torch.tensor([[0, -1], [0, 1], [-1, 0], [1, 0]])) -> torch.Tensor:
    start = torch.tensor(start)
    
    q = deque()
    q.append(start)
    d_matrix = torch.where(maze_matrix, float('inf'), -1)
    d_matrix[start[0],start[1]] = 0
    
    while q:
        state = q.popleft()
        for action in actions:
            new_state = state + action
            if d_matrix[new_state[0],new_state[1]] == -1:
                d_matrix[new_state[0],new_state[1]] = d_matrix[state[0],state[1]] + 1
                q.append(new_state)
    
    return d_matrix


## Get a batch of data

### Generate data

1. Generate a maze
2. Generate a distance matrix (min distance from any start_state)
3. For all walkable cells as target_states
    - For all start_state, get the q matrices
    - For all start_state, get the q matrix_combo (trained by a shared q matrix)

Return:
- maze_matrix: (height, width)
- distance_matrix: (height, width)
- q_matrices_dict: {target_states: {start_states: q_matrix}}
- q_matrix_combo_dict: {target_states: q_matrix_combo}

### Convert data to tensor

1. Convert q_matrices_dict to q_matrices_tensor: (B, C, H, W), C = n_action * n_robot
2. Convert q_matrix_combo_dict to q_matrix_combo_tensor: (B, C, H, W), C = n_action
3. Convert maze_matrix to maze_tensor: (H, W), optional
4. Convert distance_matrix to distance_tensor: (H, W), optional

In [30]:

def gen_data(width: int, height: int, start_states: list[tuple[int,int]]) -> tuple[torch.Tensor, torch.Tensor, dict, dict]:
    """
    Return:
    maze_matrix: torch.Tensor, the maze matrix
    d_matrix: torch.Tensor, the minimum d matrix among all start states
    q_matrices_dict: dict[dict[torch.Tensor]], the q matrix for each target state for each start state
    """

    # Generate maze matrix
    maze_matrix = randomPrimMaze(width, height)
    
    # Generate d matrix for each start state
    d_matrices_dict = {}
    for start_state in start_states:
        d_matrix = gen_d_matrix(start_state, maze_matrix)
        d_matrices_dict[start_state] = d_matrix
        
    # Get min d matrix among all start states
    d_matrix = torch.stack(list(d_matrices_dict.values())).min(dim=0).values
        
    # Generate q matrix for each target state
    target_states = torch.nonzero(maze_matrix == False)
    target_states = {tuple(target_state.tolist()) for target_state in target_states}
    target_states -= set(start_states)
    
    # print(maze_matrix)
    
    q_matrices_dict: dict[dict[torch.Tensor]] = {}
    q_matrix_combo_dict: dict[torch.Tensor] = {}
    
    for target_state in target_states:
        q_matrices_by_start_dict:dict[torch.Tensor] = {}
        
        # Generate q matrix for each start state
        for start_state in start_states:
            # start training
            robot = Robot(start_state, target_state, maze_matrix)
            robot.train()
            
            # save trained robot
            q_matrices_by_start_dict[start_state] = robot.q_matrix
        
        # Generate q matrix combo for each target state
        q_matrix_combo_dict[target_state] = torch.zeros_like(q_matrices_by_start_dict[start_states[0]])
        for start_state in start_states:
            # start training
            robot = Robot(start_state, target_state, maze_matrix, q_matrix=q_matrices_by_start_dict[start_state])
            robot.train()
            
            # save trained robot
            q_matrix_combo_dict[target_state] = robot.q_matrix
        
        q_matrices_dict[target_state] = q_matrices_by_start_dict
        
    return maze_matrix, d_matrix, q_matrices_dict, q_matrix_combo_dict

In [37]:
        
def data2tensor(data: tuple[torch.Tensor, torch.Tensor, dict, dict]):
    """
    Convert data to tensor.
    1. Convert q_matrices_dict to q_matrices_tensor: (B, C, H, W), C = n_action * n_robot
    2. Convert q_matrix_combo_dict to q_matrix_combo_tensor: (B, C, H, W), C = n_action
    3. Convert maze_matrix to maze_tensor: (H, W), optional
    4. Convert distance_matrix to distance_tensor: (H, W), optional
    """
    maze_matrix, d_matrix, q_matrices_dict, q_matrix_combo_dict = data
    maze_tensor = maze_matrix.to(device)
    d_tensor = d_matrix.to(device)
    
    q_matrices_tensor = torch.stack([torch.cat(list(q_matrices_dict[target_state].values())) for target_state in q_matrices_dict]).to(device)
    
    q_matrix_combo_tensor = torch.stack(list(q_matrix_combo_dict.values())).to(device)
    
    target_states = torch.tensor(list(q_matrices_dict.keys()), dtype=torch.int8).to(device)
    
    start_states = torch.tensor(list(list(q_matrices_dict.values())[0].keys()), dtype=torch.int8).to(device)
    
    return maze_tensor, d_tensor, q_matrices_tensor, q_matrix_combo_tensor, target_states, start_states

In [None]:
batches = []
n_batch = 100
for _ in range(n_batch):
    batch = data2tensor(gen_data(9,9,[(7,1),(7,7)]))
    batches.append(batch)

torch.save(batches, 'batches.pt')