# Définition de l'environnement

In [1]:
import numpy as np
from tqdm import tqdm

class GridEnvironment:
    def __init__(self):
        self.grid_size = (10, 10)
        self.obstacle_positions = [(2,2), (2,3), (2,4), (5,5), (6,5), (7,5)]
        self.goal_position = (9,9)
        
    def reset(self):
        self.agent_position = (0,0)
        
    def step(self, action):
        x, y = self.agent_position
        if action == 0: # up
            x = max(x - 1, 0)
        elif action == 1: # down
            x = min(x + 1, self.grid_size[0] - 1)
        elif action == 2: # left
            y = max(y - 1, 0)
        elif action == 3: # right
            y = min(y + 1, self.grid_size[1] - 1)
        
        if (x, y) in self.obstacle_positions:
            reward = -10
        elif (x, y) == self.goal_position:
            reward = 10
        else:
            reward = -1
        
        self.agent_position = (x, y)
        done = (x, y) == self.goal_position
        return self.agent_position, reward, done

# Définition de la Q-Table

In [2]:
class QTable:
    def __init__(self, state_size, action_size):
        self.q_table = np.random.rand(*state_size, action_size)
        
    def get_action(self, state, epsilon):
        if np.random.uniform() < epsilon:
            return np.random.choice(len(self.q_table[state]))
        else:
            return np.argmax(self.q_table[state])
        
    def update(self, state, action, reward, next_state, alpha, gamma):
        q_value = reward + gamma * np.max(self.q_table[next_state])
        self.q_table[state][action] = (1 - alpha) * self.q_table[state][action] + alpha * q_value

# Entraînement de la Q-Table

In [3]:
print("Entraînement de la Q-Table")

Entraînement de la Q-Table


In [4]:
env = GridEnvironment()
q_table = QTable(state_size=env.grid_size, action_size=4)

num_episodes = 1000
max_steps = 1000

epsilon = 0.1 # exploration rate
alpha = 0.5 # learning rate
gamma = 0.9 # discount factor

for episode in tqdm(range(num_episodes)):
    env.reset()
    state = env.agent_position
    done = False
    steps = 0
    
    while not done and steps < max_steps:
        action = q_table.get_action(state, epsilon)
        next_state, reward, done = env.step(action)
        q_table.update(state, action, reward, next_state, alpha, gamma)
        state = next_state
        steps += 1

100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 1300.76it/s]


# Définition de la fonction de visualisation

In [5]:
import pygame
import time

def visualize(q_table, env, num_episodes, max_steps):
    pygame.init()

    grid_size = env.grid_size
    screen_width = 400
    screen_height = 400
    screen = pygame.display.set_mode((screen_width, screen_height))
    font = pygame.font.Font(None, 30)

    cell_width = screen_width // grid_size[1]
    cell_height = screen_height // grid_size[0]

    for episode in range(num_episodes):
        env.reset()
        state = env.agent_position
        done = False
        steps = 0

        while not done and steps < max_steps:
            action = q_table.get_action(state, epsilon=0)
            next_state, reward, done = env.step(action)
            state = next_state
            steps += 1

            screen.fill((255, 255, 255))
            for row in range(grid_size[0]):
                for col in range(grid_size[1]):
                    if (row, col) in env.obstacle_positions:
                        color = (0, 0, 0)
                    elif (row, col) == env.goal_position:
                        color = (0, 255, 0)
                    else:
                        color = (255, 255, 255)
                    pygame.draw.rect(screen, color, (col * cell_width, row * cell_height, cell_width, cell_height))

                    pygame.draw.line(screen, (0, 0, 0), (0, row * cell_height), (screen_width, row * cell_height), 1)
                    pygame.draw.line(screen, (0, 0, 0), (col * cell_width, 0), (col * cell_width, screen_height), 1)
            pygame.draw.circle(screen, (255, 0, 0), (env.agent_position[1] * cell_width + cell_width // 2, env.agent_position[0] * cell_height + cell_height // 2), min(cell_width, cell_height) // 2)
            text = font.render(f"Step: {steps}", True, (0, 0, 0))
            screen.blit(text, (10, 10))
            time.sleep(0.1)

            pygame.display.update()
            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    pygame.quit()
                    return

    pygame.quit()

pygame 2.1.3 (SDL 2.0.22, Python 3.9.2)
Hello from the pygame community. https://www.pygame.org/contribute.html


# Visualisation

In [6]:
visualize(q_table, env, num_episodes=100, max_steps=1000)