In [6]:
import gym
import numpy as np
import random
from collections import Counter, deque
from tensorboardX import SummaryWriter

import torch
import torch.nn as nn
import torch.nn.utils as nn_utils
import torch.nn.functional as F
import torch.optim as optim

In [7]:
# Hyperparameters
batch_size = 32

gamma = 0.99

eps_start=1.0
eps_decay = 0.995
eps_min = 0.1      # Minimal exploration rate (epsilon-greedy)

num_rounds = 50000
num_episodes = 50
learning_limit = 100
replay_limit = 1000  # Number of steps until starting replay
weight_update = 1000 # Number of steps until updating the target weights


In [8]:
#create model
class Model(nn.Module):
    def __init__(self, input_shape, n_actions):
        super(Model, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(input_shape[0]*input_shape[1], 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, n_actions)
            
        )

    def forward(self, x):
        # flatten the observation space Box to linear tensor
        tensor_array = torch.from_numpy(state)
        x_flat = torch.flatten(tensor_array).to(torch.float32)
        return self.net(x_flat)

In [9]:
class SchedulerEnv(gym.Env):

    def __init__(self):
        
        #starting parameters
        num_gps = 100
        num_slots = 32
        num_pre_booked = 750
        to_book = [2,1,2,2,1,1,1,3,3,1,2,1,3,2,1,1,2,1,3,2,3,2]
        num_to_book = len(to_book)
        agent_pos = [0,0]
        
        #set parameters for the day
        self.num_gps = num_gps
        self.num_slots = num_slots
        self.num_pre_booked = num_pre_booked
        self.to_book = to_book
        self.num_to_book = num_to_book
        self.diary_slots = num_gps*num_slots
        self.agent_pos = agent_pos

        #set action space to move around the grid
        self.action_space = gym.spaces.Discrete(4) #up, down, left, right
        
        #set observation space 
        self.observation_space = gym.spaces.Box(low=0, high=1, shape=(self.num_slots, self.num_gps), dtype=np.int32)
   
    #creates daily diary for each gp, randomly populates prebooked appointments and resets parameters
    def reset(self):

        #creates zero filled dataframe with row per time slot and column per gp
        self.state = np.zeros((self.num_slots, self.num_gps),dtype=float)

        #randomly enters a 1 for each pre booked appointments
        pre_booked = self.num_pre_booked
        while pre_booked>0:
            pre_booked -= 1
            self.state[np.random.randint(self.num_slots), np.random.randint(self.num_gps)] = 1
            
        #randomly sets the agent start space
        self.agent_pos = [np.random.randint(self.num_slots), np.random.randint(self.num_gps)]

        #resets parameters for new episode
        self.done = False
        self.reward = 0
        self.appt_idx = 0
        
        #print('starting state', self.state.sum(), self.state)

        return self.state
    
    #calculates new position of the agent based on the action
    def move_agent(self, action):

        #set boundaries for the grid
        max_row = self.num_slots - 1
        max_col = self.num_gps - 1

        #setting new co-ordinates for the agent
        new_row = self.agent_pos[0]
        new_col = self.agent_pos[1]

        #calculate what the new position may be based on the action without going out the grid
        if action == 0:
            #print('up')
            new_row = max(self.agent_pos[0] - 1, 0)
        if action == 1:
            #print('down')
            new_row = min(self.agent_pos[0] + 1, max_row)
        if action == 2:
            #print('left')
            new_col = max(self.agent_pos[1] - 1, 0)
        if action == 3:
            #print('right')
            new_col = min(self.agent_pos[1] + 1, max_col)

        new_pos = [new_row, new_col]
        #print('new pos', new_pos)

        return new_pos

    #checks if we can look to book appointment starting here
    def check_bookable(self):
        return self.state[self.agent_pos[0], self.agent_pos[1]] == 0.0
    
    #action if we can't book the appointment
    def invalid_booking(self):
        #print('cant book')
        self.reward = -1
        
    #action if we can book the appointment
    def valid_booking(self):
        #print('go ahead and book')
        self.appt_idx += 1
        self.reward = 1
    
    #checks if the appointment fits
    def check_and_book(self):
        
        max_row = self.num_slots - 1
        cells_to_check = self.to_book[self.appt_idx]
        
        if cells_to_check==1:
            #print('good to check for single')
            if self.state[self.agent_pos[0], self.agent_pos[1]] == 0:
                self.state[self.agent_pos[0], self.agent_pos[1]] = 1
                self.valid_booking()
            else:
                #print('single taken')
                self.invalid_booking()

        if cells_to_check==2:
            #check we're not at the bottom of the grid
            if self.agent_pos[0]<max_row:
                #check the next cells is also 0.0
                #print('good to check for double')
                if self.state[self.agent_pos[0], self.agent_pos[1]] == 0 and \
                self.state[(self.agent_pos[0]+1), self.agent_pos[1]] == 0:
                    self.state[self.agent_pos[0], self.agent_pos[1]] = 1
                    self.state[(self.agent_pos[0]+1), self.agent_pos[1]] = 1
                    self.valid_booking()
                    self.agent_pos = [(self.agent_pos[0]+1), self.agent_pos[1]]
                    #print('after booking', self.agent_pos)
                else:
                    #print('double taken')
                    self.invalid_booking()
            else:
                #print('not for double')
                self.invalid_booking()
                
        if cells_to_check==3:
            #check we're not at the bottom of the grid
            if self.agent_pos[0]+1<max_row:
                #print('good to check for treble')
                if self.state[self.agent_pos[0], self.agent_pos[1]] == 0 and \
                self.state[(self.agent_pos[0]+1), self.agent_pos[1]] == 0 \
                 and self.state[(self.agent_pos[0]+2), self.agent_pos[1]] == 0:
                    self.state[self.agent_pos[0], self.agent_pos[1]] = 1
                    self.state[(self.agent_pos[0]+1), self.agent_pos[1]] = 1
                    self.state[(self.agent_pos[0]+2), self.agent_pos[1]] = 1
                    self.valid_booking()
                    self.agent_pos = [(self.agent_pos[0]+2), self.agent_pos[1]]
                else:
                    #print('treble taken')
                    self.invalid_booking()
            else:
                #print('not for treble')
                self.invalid_booking()
                
        if cells_to_check==4:
            #check we're not at the bottom of the grid
            if self.agent_pos[0]+2<max_row:
                #check the next cells is also 0.0
                #print('good for quad')
                if self.state[self.agent_pos[0], self.agent_pos[1]] == 0 and \
                self.state[(self.agent_pos[0]+1), self.agent_pos[1]] == 0 \
                 and self.state[(self.agent_pos[0]+2), self.agent_pos[1]] == 0 and \
                self.state[(self.agent_pos[0]+3), self.agent_pos[1]] == 0:
                    self.state[self.agent_pos[0], self.agent_pos[1]] = 1
                    self.state[(self.agent_pos[0]+1), self.agent_pos[1]] = 1
                    self.state[(self.agent_pos[0]+2), self.agent_pos[1]] = 1
                    self.state[(self.agent_pos[0]+3), self.agent_pos[1]] = 1
                    self.valid_booking()
                    self.agent_pos = [(self.agent_pos[0]+3), self.agent_pos[1]]
                else:
                    #print('quad taken')
                    self.invalid_booking()
            else:
                #print('not for quad')
                self.invalid_booking()

        next_state = self.state

        return next_state

    def step(self, action):

        #get new position of agent based on action
        new_agent_pos = self.move_agent(action)
        #print('new and old pos', new_agent_pos, self.agent_pos)
        
        #if the agent is stuck on an edge then move to a new position
        if new_agent_pos == self.agent_pos:
            self.agent_pos = [np.random.randint(self.num_slots), np.random.randint(self.num_gps)]
        else:
            self.agent_pos = new_agent_pos
            #print('here2', self.agent_pos)
        
        #check if it's possible to book then book
        if self.check_bookable():
            #print('checked here')
            self.state = self.check_and_book()
        else:
            #print('not bookable')
            self.invalid_booking()
        
        #work out if episode complete
        if self.appt_idx == len(self.to_book):
            self.done = True

        info = {}
        return self.state, self.reward, self.done, info

In [10]:
#device = "cuda"
device = "cpu"

env = SchedulerEnv()

#start writing to tensorboard
writer = SummaryWriter(comment="Scheduler DQN")

#create the current network and target network
policy_model = Model(env.observation_space.shape, env.action_space.n).to(device)

target_model = Model(env.observation_space.shape, env.action_space.n).to(device)
target_model.load_state_dict(policy_model.state_dict())

optimizer = optim.Adam(policy_model.parameters(), lr=0.001, eps=1e-3)

          # Exploration rate    
replay_buffer = deque(maxlen=1000)

step_idx = 0
epsilon = eps_start

for i in range(num_rounds):
    #change this for while not true once it works
    state = env.reset()
    episode_reward = 0

    for j in range(num_episodes):
#    while not done:
        
        step_idx += 1
        #print(i,j,step_idx)


        #epsilon for epsilon greedy strategy  
        if epsilon > eps_min:
            epsilon *= eps_decay
            
        #print('epsilon', epsilon)   
        #check = policy_model(state)
        #print(check)

        # Select and perform an action
        if step_idx > learning_limit:
            if np.random.rand() > epsilon:
                action = torch.argmax(policy_model(state))
        else:
            action = np.random.randint(env.action_space.n)

        next_state, reward, done, _ = env.step(action)
        reward = torch.tensor([reward], device=device)
        episode_reward += reward
        #print('here rewards', episode_reward, reward, step_idx)

        # Store other info in replay buffer
        replay_buffer.append((state, action, reward, next_state, done))
        #print('len buffer', len(replay_buffer))

        # Move to the next state
        state = next_state

        if done:
            break
            
    print('stopped episode', j, episode_reward)

    writer.add_scalar('episode_reward', episode_reward, step_idx)
        
    #print('step', step_idx)

    #once we're ready to learn then start learning with mini batches
    if len(replay_buffer) == replay_limit:
        #print('replay buffer')
        optimizer.zero_grad()
        
        minibatch = random.sample(replay_buffer, batch_size)

        for state, action, reward, next_state, done in minibatch:    
            #pass state to policy to get qval from policy
            pred_qval = max(policy_model(state)) 

        #pass next state to target policy to get next set of qvals (future gains)
        if not done:
            next_qval = (reward + (gamma * max(target_model(next_state).detach())))
        else:
            next_qval = reward   
        
        pred_qval = pred_qval.to(torch.float32).unsqueeze(-1)
        #print('pred_qval', pred_qval, pred_qval.size())
        next_qval = next_qval.to(torch.float32)
        #print('next_qval', next_qval, next_qval.size())

        loss = F.mse_loss(pred_qval, next_qval)
        #print('loss', loss)
        writer.add_scalar('loss', loss, step_idx)
        loss.backward()

        optimizer.step()
        #print('step', step_idx, 'i', i, 'j', j)

        # Update the target network, copying all weights and biases in DQN
        # Periodically update the target network by Q network to target Q network
        if i % 100 == 0:
            print('update weights', step_idx)
            # Update weights of target
            target_model.load_state_dict(policy_model.state_dict())

writer.close()

stopped episode 17 tensor([-44])
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
s

stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episo

stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episo

stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episo

stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episo

stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episo

stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episode 17 0
stopped episo

KeyboardInterrupt: 

In [None]:
pred_qval.unsqueeze(-1).size()