In [1]:
import numpy as np
import pygame as pg

pygame 2.0.1 (SDL 2.0.14, Python 3.8.10)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
W = 10
H = 7
SCALE = 50
FPS = 20

goal = (3, 7)
start = (3, 0)

empty = 0
wall = 1
st = 2
hazard = 3
agent_ = 4
button = 5

bg_color = (30, 30, 30)
line_color = (30, 30, 30)
wall_color = (118, 118, 118)
st_color = (0, 255, 0)
hazard_color = (255, 20, 20)
agent_color = (250, 250, 250)
button_color = (0, 0, 255)

color_code = {
    empty: bg_color,
    wall: wall_color,
    st: st_color,
    hazard: hazard_color,
    agent_: agent_color,
    button: button_color
}


In [3]:
def grid_world():
    grid = np.zeros((H, W), dtype=int)
    grid[goal] = st
    return grid
    
def get_wind():
    wind = np.zeros((H, W), dtype=int)
    wind[:, [4, 5, 6, 9]] = -1
    wind[:, [7, 8]] = -2
    return wind


In [4]:
class TestEnv:
    def __init__(self, n_dims, n_states, actions):
        self.n_dims = n_dims
        self.n_states = n_states
        self.actions = actions
        self.n_actions = actions.shape[0]
        self.action_space = range(self.n_actions)
        
        self.update_s0()
        self.grid = grid_world()
        self.wind = get_wind()
        
        self.s_agent = [0, 1]  # index of player position in state vector

        self.screen = None

    
    def update_s0(self):
        """return a random initial state"""
        # pick and update button state
        self.s = np.array(start, dtype=int)

    def step(self, a):
        r = -1
        terminal = False
        target_pos = self.s + self.actions[a]
        
        target_pos[0] = max(min(target_pos[0], H-1), 0)
        target_pos[1] = max(min(target_pos[1], W-1), 0)
        
        target_pos_wind = target_pos + self.wind[target_pos[0], target_pos[1]]
        target_pos_wind[0] = max(min(target_pos_wind[0], H-1), 0)
        target_pos_wind[1] = max(min(target_pos_wind[1], W-1), 0) 
        self.s = target_pos_wind
        if np.array_equal(self.s, goal):
            r = 1
            terminal = True
        return np.copy(self.s), r, terminal

    def draw_grid(self):
        for i in range(self.grid.shape[0]):
            for j in range(self.grid.shape[1]):
                color = color_code[self.grid[i, j]]
                pg.draw.rect(self.screen, color, (j * SCALE, i * SCALE, SCALE, SCALE))
    
    def init_pg(self):
        pg.init()
        self.clock = pg.time.Clock()
        screen = pg.display.set_mode((W * SCALE, H * SCALE))
        screen.fill(bg_color)
        pg.display.set_caption("Mohamed Martini")
        return screen

    def render(self):
        if self.screen is None:
            self.screen = self.init_pg()
        elif self.screen is False:
            return
        
        # look for quit command
        for event in pg.event.get():
            if event.type == pg.QUIT:
                pg.quit()
                self.screen = False
                return False
        self.clock.tick(FPS)
        # color screen
        self.draw_grid()
        agent_pos = self.s
        pg.draw.rect(self.screen, agent_color, (agent_pos[1] * SCALE, agent_pos[0] * SCALE, SCALE, SCALE))
        pg.display.flip()
        return True
        
    def reset(self):
        self.grid = grid_world()
        self.update_s0()
        return self.s
        