In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import random
from collections import deque
import time
from IPython.display import display, clear_output
from tqdm import tqdm

In [53]:
class Game:
    def __init__(self, x, y, random_pos_lst, max_steps=20) -> None:
        # random_pos_lst = np.random.choice(x * y, size=2 + hole_num, replace=False)
        self.random_pos_lst = random_pos_lst
        self.x = x
        self.y = y
        self.max_steps = max_steps
        self.current_step = 0
        self.agent_pos = (random_pos_lst[0] // x, random_pos_lst[0] % x)
        self.reward_pos = (random_pos_lst[1] // x, random_pos_lst[1] % x)
        self.hole_pos_list = [(pos // x, pos % x)
                              for pos in random_pos_lst[2:]]

    def init(self):
        return self.get_board() / 9.0, 0, False
    
    def reset(self):
        self.current_step = 0
        self.agent_pos = (self.random_pos_lst[0] // self.x, self.random_pos_lst[0] % self.x)
        self.reward_pos = (self.random_pos_lst[1] // x, self.random_pos_lst[1] % x)
        self.hole_pos_list = [(pos // self.x, pos % self.x)
                              for pos in self.random_pos_lst[2:]]
    
    def step(self, action):
        previous_pos = self.agent_pos
        # Implement the game logic based on the action chosen by the agent
        if action == 0:  # Up
            self.agent_pos = (max(0, self.agent_pos[0] - 1), self.agent_pos[1])
        elif action == 1:  # Down
            self.agent_pos = (
                min(self.x - 1, self.agent_pos[0] + 1), self.agent_pos[1])
        elif action == 2:  # Left
            self.agent_pos = (self.agent_pos[0], max(0, self.agent_pos[1] - 1))
        elif action == 3:  # Right
            self.agent_pos = (self.agent_pos[0], min(
                self.y - 1, self.agent_pos[1] + 1))

        # Calculate the reward based on the agent's position and the treasure location
        if self.agent_pos == self.reward_pos:
            reward = 1.0 * (self.max_steps - self.current_step + 1)
        elif self.agent_pos in self.hole_pos_list:
            reward = -100.0
        elif self.agent_pos == previous_pos:
            reward = -50.0
        else:
            reward = 0.0

        # Update the current step count
        self.current_step += 1

        # Check if the episode is done (either the agent found the treasure or reached the maximum steps)
        done = self.agent_pos == self.reward_pos or self.current_step >= self.max_steps or self.agent_pos in self.hole_pos_list

        return self.get_board() / 9.0, reward, done

    def render(self):
#         # clear
#         for _ in range(self.x):
#             print('\033[1A', end='\x1b[2K')
        board = self.get_board()
        graph = ''
        for row in board:
            graph += '|'
            for cell in row:
                cell_item = ' '
                if cell == 1:
                    cell_item = 'Y'
                elif cell == 9:
                    cell_item = 'O'
                elif cell == 5:
                    cell_item = 'X'
                graph += cell_item
                graph += '|'
            graph += '\n'
        print(graph, end='\r')

    def get_board(self):
        board = np.zeros((self.x, self.y), dtype=np.int8)
        board[self.reward_pos] = 9
        board[self.agent_pos] = 1
        for hole_pos in self.hole_pos_list:
            board[hole_pos] = 5
        return board

In [44]:
x, y = 10, 10
learning_rate = 0.001

policy_network = keras.Sequential(
    [
        keras.Input(shape=(x*y)),
        layers.Dense(32, activation='relu'),
        layers.Dense(64, activation='relu'),
        layers.Dense(32, activation='relu'),
        layers.Dense(4, activation='linear')
    ]
)
target_network = keras.models.clone_model(policy_network)
target_network.set_weights(policy_network.get_weights())

policy_network.compile(optimizer=keras.optimizers.Adam(learning_rate=learning_rate), loss=keras.losses.mean_squared_error)

In [45]:
c = 3
# game_board = [14, 2]
max_steps = 100
num_actions = 4
action_list = np.array(range(num_actions))
gamma = 0.97
batch_size = 256
exp_size = 10000
# epsilon_up_factor = 1.1
epsilin_down_factor = 0.9
consecutive_wins_lmt = 10

# (state, action, reward, state_bar)
experience_replay = deque(maxlen=exp_size)

In [None]:
experience_replay = deque(maxlen=exp_size)
for _ in tqdm(range(100)):
    game_board = random.sample(range(x * y), 2)
    env = Game(x, y, game_board, max_steps)
    epsilon = 1.0
    consecutive_wins = 0
    win_episodes = []

    for episode in range(500):
        print(f'wins: {win_episodes}, epsilon: {epsilon}')
        steps = 0
        env.reset()
        observation, reward, done = env.init() 
        while not done:
            state_input = observation.reshape(-1, x * y)
            if np.random.rand() > epsilon:
                action = np.argmax(policy_network.predict(state_input))
            else:
                action = np.random.choice(action_list, 1).item()
            next_observation, reward, done = env.step(action)
            next_state_input = next_observation.reshape(-1, x * y)

            experience_replay.append((state_input, action, reward, next_state_input, done))

            if done and reward > 0:
                epsilon *= epsilin_down_factor
                win_episodes.append(episode)
                consecutive_wins += 1
            elif done and reward <= 0:
                consecutive_wins = 0

            if batch_size <= len(experience_replay):
                memories = random.sample(experience_replay, batch_size)

                states = np.squeeze(np.array([memory[0] for memory in memories]))
                actions = np.array([memory[1] for memory in memories])
                rewards = np.array([memory[2] for memory in memories])
                next_states = np.squeeze(np.array([memory[3] for memory in memories]))
                dones = np.array([memory[4] for memory in memories])

                q_values = policy_network.predict(states)
                next_q_values = target_network.predict(next_states)

                targets = np.copy(q_values)
                for i in range(batch_size):
                    targets[i, int(actions[i])] = rewards[i] + gamma * np.max(next_q_values[i]) * (1 - dones[i])

                policy_network.fit(states, targets, batch_size=64, epochs=1)
                steps += 1
                if steps % c == 0:
                    target_network.set_weights(policy_network.get_weights()) 
            observation = next_observation
        if consecutive_wins >= consecutive_wins_lmt:
            break

  0%|                                                                                                                                                                                                                                                                              | 0/100 [00:00<?, ?it/s]

wins: [], epsilon: 1.0
wins: [0], epsilon: 0.9
wins: [0, 1], epsilon: 0.81
wins: [0, 1, 2], epsilon: 0.7290000000000001
wins: [0, 1, 2, 3], epsilon: 0.6561000000000001
wins: [0, 1, 2, 3, 4], epsilon: 0.5904900000000002
wins: [0, 1, 2, 3, 4], epsilon: 0.5904900000000002




wins: [0, 1, 2, 3, 4, 6], epsilon: 0.5314410000000002
wins: [0, 1, 2, 3, 4, 6, 7], epsilon: 0.47829690000000014
wins: [0, 1, 2, 3, 4, 6, 7, 8], epsilon: 0.43046721000000016






wins: [0, 1, 2, 3, 4, 6, 7, 8], epsilon: 0.43046721000000016


wins: [0, 1, 2, 3, 4, 6, 7, 8, 10], epsilon: 0.38742048900000015
wins: [0, 1, 2, 3, 4, 6, 7, 8, 10, 11], epsilon: 0.34867844010000015


wins: [0, 1, 2, 3, 4, 6, 7, 8, 10, 11, 12], epsilon: 0.31381059609000017
wins: [0, 1, 2, 3, 4, 6, 7, 8, 10, 11, 12, 13], epsilon: 0.28242953648100017
wins: [0, 1, 2, 3, 4, 6, 7, 8, 10, 11, 12, 13, 14], epsilon: 0.25418658283290013
wins: [0, 1, 2, 3, 4, 6, 7, 8, 10, 11, 12, 13, 14, 15], epsilon: 0.22876792454961012
wins: [0, 1, 2, 3, 4, 6, 7, 8, 10, 11, 12, 13, 14, 15, 16], epsilon: 0.2058911320946491
wins: [0, 1, 2, 3, 4, 6, 7, 8, 10, 11, 12, 13, 14, 15, 16, 17], epsilon: 0.1853020188851842
wins: [0, 1, 2, 3, 4, 6, 7, 8, 10, 11, 12, 13, 14, 15, 16, 17, 18], epsilon: 0.16677181699666577


  1%|██▌                                                                                                                                                                                                                                                                | 1/100 [02:04<3:25:58, 124.83s/it]

wins: [], epsilon: 1.0




wins: [], epsilon: 1.0




wins: [], epsilon: 1.0




wins: [], epsilon: 1.0




wins: [], epsilon: 1.0




wins: [], epsilon: 1.0




wins: [], epsilon: 1.0




wins: [], epsilon: 1.0




wins: [], epsilon: 1.0




wins: [], epsilon: 1.0




wins: [9], epsilon: 0.9




wins: [9], epsilon: 0.9




wins: [9], epsilon: 0.9




wins: [9], epsilon: 0.9




wins: [9, 13], epsilon: 0.81




wins: [9, 13], epsilon: 0.81




wins: [9, 13], epsilon: 0.81




wins: [9, 13], epsilon: 0.81






wins: [9, 13], epsilon: 0.81


wins: [9, 13, 18], epsilon: 0.7290000000000001






wins: [9, 13, 18], epsilon: 0.7290000000000001




wins: [9, 13, 18], epsilon: 0.7290000000000001




wins: [9, 13, 18], epsilon: 0.7290000000000001




wins: [9, 13, 18], epsilon: 0.7290000000000001






wins: [9, 13, 18], epsilon: 0.7290000000000001




wins: [9, 13, 18], epsilon: 0.7290000000000001




wins: [9, 13, 18], epsilon: 0.7290000000000001




wins: [9, 13, 18], epsilon: 0.7290000000000001






wins: [9, 13, 18], epsilon: 0.7290000000000001




wins: [9, 13, 18], epsilon: 0.7290000000000001
wins: [9, 13, 18, 29], epsilon: 0.6561000000000001






wins: [9, 13, 18, 29], epsilon: 0.6561000000000001




wins: [9, 13, 18, 29], epsilon: 0.6561000000000001




wins: [9, 13, 18, 29], epsilon: 0.6561000000000001


wins: [9, 13, 18, 29, 33], epsilon: 0.5904900000000002




wins: [9, 13, 18, 29, 33, 34], epsilon: 0.5314410000000002


wins: [9, 13, 18, 29, 33, 34, 35], epsilon: 0.47829690000000014


wins: [9, 13, 18, 29, 33, 34, 35, 36], epsilon: 0.43046721000000016
wins: [9, 13, 18, 29, 33, 34, 35, 36, 37], epsilon: 0.38742048900000015


wins: [9, 13, 18, 29, 33, 34, 35, 36, 37, 38], epsilon: 0.34867844010000015
wins: [9, 13, 18, 29, 33, 34, 35, 36, 37, 38, 39], epsilon: 0.31381059609000017


wins: [9, 13, 18, 29, 33, 34, 35, 36, 37, 38, 39, 40], epsilon: 0.28242953648100017
wins: [9, 13, 18, 29, 33, 34, 35, 36, 37, 38, 39, 40, 41], epsilon: 0.25418658283290013




  2%|█████▏                                                                                                                                                                                                                                                            | 2/100 [28:11<26:29:23, 973.10s/it]

wins: [], epsilon: 1.0




wins: [], epsilon: 1.0




wins: [], epsilon: 1.0




wins: [], epsilon: 1.0


wins: [3], epsilon: 0.9




wins: [3], epsilon: 0.9




wins: [3, 5], epsilon: 0.81




wins: [3, 5], epsilon: 0.81




wins: [3, 5, 7], epsilon: 0.7290000000000001






wins: [3, 5, 7], epsilon: 0.7290000000000001




wins: [3, 5, 7], epsilon: 0.7290000000000001




wins: [3, 5, 7], epsilon: 0.7290000000000001


wins: [3, 5, 7, 11], epsilon: 0.6561000000000001




wins: [3, 5, 7, 11], epsilon: 0.6561000000000001






wins: [3, 5, 7, 11], epsilon: 0.6561000000000001




wins: [3, 5, 7, 11], epsilon: 0.6561000000000001




wins: [3, 5, 7, 11], epsilon: 0.6561000000000001




wins: [3, 5, 7, 11, 16], epsilon: 0.5904900000000002




wins: [3, 5, 7, 11, 16], epsilon: 0.5904900000000002






wins: [3, 5, 7, 11, 16], epsilon: 0.5904900000000002




wins: [3, 5, 7, 11, 16], epsilon: 0.5904900000000002


wins: [3, 5, 7, 11, 16, 20], epsilon: 0.5314410000000002




wins: [3, 5, 7, 11, 16, 20, 21], epsilon: 0.47829690000000014


wins: [3, 5, 7, 11, 16, 20, 21, 22], epsilon: 0.43046721000000016


wins: [3, 5, 7, 11, 16, 20, 21, 22, 23], epsilon: 0.38742048900000015
wins: [3, 5, 7, 11, 16, 20, 21, 22, 23, 24], epsilon: 0.34867844010000015
wins: [3, 5, 7, 11, 16, 20, 21, 22, 23, 24, 25], epsilon: 0.31381059609000017


wins: [3, 5, 7, 11, 16, 20, 21, 22, 23, 24, 25, 26], epsilon: 0.28242953648100017
wins: [3, 5, 7, 11, 16, 20, 21, 22, 23, 24, 25, 26, 27], epsilon: 0.25418658283290013


wins: [3, 5, 7, 11, 16, 20, 21, 22, 23, 24, 25, 26, 27, 28], epsilon: 0.22876792454961012


  3%|███████▋                                                                                                                                                                                                                                                         | 3/100 [48:44<29:24:35, 1091.50s/it]

wins: [], epsilon: 1.0




wins: [], epsilon: 1.0




wins: [], epsilon: 1.0


wins: [2], epsilon: 0.9


wins: [2, 3], epsilon: 0.81
wins: [2, 3, 4], epsilon: 0.7290000000000001
wins: [2, 3, 4, 5], epsilon: 0.6561000000000001






wins: [2, 3, 4, 5], epsilon: 0.6561000000000001




wins: [2, 3, 4, 5], epsilon: 0.6561000000000001




wins: [2, 3, 4, 5], epsilon: 0.6561000000000001




wins: [2, 3, 4, 5], epsilon: 0.6561000000000001






wins: [2, 3, 4, 5], epsilon: 0.6561000000000001




wins: [2, 3, 4, 5], epsilon: 0.6561000000000001
wins: [2, 3, 4, 5, 12], epsilon: 0.5904900000000002




wins: [2, 3, 4, 5, 12], epsilon: 0.5904900000000002






wins: [2, 3, 4, 5, 12], epsilon: 0.5904900000000002




wins: [2, 3, 4, 5, 12, 15], epsilon: 0.5314410000000002




wins: [2, 3, 4, 5, 12, 15], epsilon: 0.5314410000000002
wins: [2, 3, 4, 5, 12, 15, 17], epsilon: 0.47829690000000014


wins: [2, 3, 4, 5, 12, 15, 17, 18], epsilon: 0.43046721000000016
wins: [2, 3, 4, 5, 12, 15, 17, 18, 19], epsilon: 0.38742048900000015
wins: [2, 3, 4, 5, 12, 15, 17, 18, 19, 20], epsilon: 0.34867844010000015


wins: [2, 3, 4, 5, 12, 15, 17, 18, 19, 20, 21], epsilon: 0.31381059609000017
wins: [2, 3, 4, 5, 12, 15, 17, 18, 19, 20, 21, 22], epsilon: 0.28242953648100017
wins: [2, 3, 4, 5, 12, 15, 17, 18, 19, 20, 21, 22, 23], epsilon: 0.25418658283290013
wins: [2, 3, 4, 5, 12, 15, 17, 18, 19, 20, 21, 22, 23, 24], epsilon: 0.22876792454961012
wins: [2, 3, 4, 5, 12, 15, 17, 18, 19, 20, 21, 22, 23, 24, 25], epsilon: 0.2058911320946491




  4%|██████████▎                                                                                                                                                                                                                                                       | 4/100 [59:48<24:36:53, 923.05s/it]

wins: [], epsilon: 1.0




wins: [], epsilon: 1.0




wins: [], epsilon: 1.0




wins: [], epsilon: 1.0




wins: [], epsilon: 1.0




wins: [], epsilon: 1.0


wins: [5], epsilon: 0.9




wins: [5], epsilon: 0.9




wins: [5], epsilon: 0.9




wins: [5], epsilon: 0.9






wins: [5], epsilon: 0.9




wins: [5], epsilon: 0.9




wins: [5], epsilon: 0.9




wins: [5], epsilon: 0.9




wins: [5], epsilon: 0.9




wins: [5], epsilon: 0.9




wins: [5], epsilon: 0.9




wins: [5, 16], epsilon: 0.81




wins: [5, 16], epsilon: 0.81




wins: [5, 16], epsilon: 0.81




wins: [5, 16], epsilon: 0.81






wins: [5, 16], epsilon: 0.81




wins: [5, 16], epsilon: 0.81




wins: [5, 16], epsilon: 0.81




wins: [5, 16], epsilon: 0.81




wins: [5, 16], epsilon: 0.81






wins: [5, 16], epsilon: 0.81




wins: [5, 16], epsilon: 0.81






In [65]:
# game_board = [0, 9]
game_board = [55, 22]
env = Game(x, y, game_board, max_steps)
observation, reward, terminated = env.init()
env.render()
time.sleep(0.1)
while not terminated:
    state_input = observation.reshape(-1, x * y)
    action = np.argmax(target_network.predict(state_input))
#     if np.random.rand() > epsilon:
#         action = np.argmax(policy_network.predict(state_input))
#     else:
#         worst_action = np.argmin(policy_network.predict(state_input))
#         mask = action_list != worst_action
#         better_action_list = action_list[mask]
#         action = np.random.choice(better_action_list, 1).item()
    next_observation, reward, terminated = env.step(action)
    clear_output(wait=True)
    env.render()
    time.sleep(0.1)
    observation = next_observation

| | | | | | | | | | |
| | | | | | | | | | |
| | |Y| | | | | | | |
| | | | | | | | | | |
| | | | | | | | | | |
| | | | | | | | | | |
| | | | | | | | | | |
| | | | | | | | | | |
| | | | | | | | | | |
| | | | | | | | | | |


In [57]:
print(f'wins: {win_episodes}, epsilon: {epsilon}')

wins: [2, 4, 5, 7, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18], epsilon: 0.22876792454961012
