In [1]:
from model import BotDemineur
from utils import Transition, ReplayMemory
from env import DemineurInterface
import torch
import torch.optim as optim
import torch.nn as nn
import random
import math

import pyautogui


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

env = DemineurInterface()

In [3]:
BATCH_SIZE = 128
GAMMA = 0.999
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 200
TARGET_UPDATE = 10

In [4]:
# Get number of actions from gym action space

n_actions = env.action_space_nb

policy_net = BotDemineur(env.grid.rows, env.grid.cols, n_actions).to(device)
target_net = BotDemineur(env.grid.rows, env.grid.cols, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = optim.RMSprop(policy_net.parameters())
memory = ReplayMemory(10000)

env.reset()
steps_done = 0

In [5]:
def select_action(state):
    """ 
        Select action to use according to the state

        state: np.array, grid of the game
    """
    global steps_done
    
    sample = random.random()
    
    eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    
    if sample > eps_threshold:
        with torch.no_grad():
            return policy_net(state).max(1)[1].view(1, 1)
    
    else:
        return torch.tensor([[random.randrange(n_actions)]], device=device, dtype=torch.long)

    

In [6]:
state = env.get_state()

In [7]:
def optimize_model():
    
    if len(memory) < BATCH_SIZE:
        return

    transitions = memory.sample(BATCH_SIZE)
    # Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
    # detailed explanation). This converts batch-array of Transitions
    # to Transition of batch-arrays.
    batch = Transition(*zip(*transitions))

    # Compute a mask of non-final states and concatenate the batch elements
    # (a final state would've been the one after which simulation ended)
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                          batch.next_state)), device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state
                                                if s is not None])
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
    # columns of actions taken. These are the actions which would've been taken
    # for each batch state according to policy_net
    state_action_values = policy_net(state_batch).gather(1, action_batch)

    # Compute V(s_{t+1}) for all next states.
    # Expected values of actions for non_final_next_states are computed based
    # on the "older" target_net; selecting their best reward with max(1)[0].
    # This is merged based on the mask, such that we'll have either the expected
    # state value or 0 in case the state was final.
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
    # Compute the expected Q values
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    # Compute Huber lossqaaq
    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    for param in policy_net.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()
    

In [10]:
import keyboard

num_episodes = 200

pyautogui.PAUSE = 0
print("start")

for i_episode in range(num_episodes):
    
    env.reset()
    state = env.get_state()
    
    done = False
    
    while not done:
        
        action = select_action(state)
        reward, done = env.step(action.item())
        
        #print("Reward: ", reward, 'Done: ', done)
        
        reward = torch.tensor([reward], device=device)
        
        if not done:
            next_state = env.get_state()
        else:
            next_state = None
        
        memory.push(state, action, next_state, reward)
        
        state = next_state
        
        optimize_model()
        
        if keyboard.is_pressed('q'):
            done = True

    if i_episode % TARGET_UPDATE == 0:
        target_net.load_state_dict(policy_net.state_dict())


start


In [11]:
len(memory)

3110