In [1]:
import gym
import ptan
import numpy as np
import argparse
from tensorboardX import SummaryWriter

import torch
import torch.nn as nn
import torch.nn.utils as nn_utils
import torch.nn.functional as F
import torch.optim as optim

gamma = 0.99
batch_size = 11
num_envs = 6
reward_steps = 4

  for external in metadata.entry_points().get(self.group, []):


In [2]:
class Model(nn.Module):
    def __init__(self, input_shape, n_actions):
        super(Model, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(input_shape[0]*input_shape[1], 64),
            nn.ReLU(),
            nn.Linear(64, 128) 
        )
        
        self.actor = nn.Sequential(
            nn.Linear(128, 512),
            nn.ReLU(),
            nn.Linear(512, n_actions)
        )

        self.critic = nn.Sequential(
            nn.Linear(128, 512),
            nn.ReLU(),
            nn.Linear(512, 1)
        )


    def forward(self, x):
        # flatten the observation space Box to linear tensor
        x_flat = torch.flatten(x, 1,2).to(torch.float32)
        #print('x_flat', x_flat.size(), x_flat)
        init_out = self.net(x_flat)
        return self.actor(init_out), self.critic(init_out)

In [3]:
def unpack_batch(batch, model, device='cpu'):

    states = []
    actions = []
    rewards = []
    not_done_idx = []
    last_states = []
    #create lists of the states, actions and rewards
    for idx, exp in enumerate(batch):
        states.append(np.array(exp.state, copy=False))
        actions.append(int(exp.action))
        rewards.append(exp.reward)
        #separate out the last states to be able to calculate the rewards
        if exp.last_state is not None:
            not_done_idx.append(idx)
            last_states.append(np.array(exp.last_state, copy=False))

    #convert to tensors for calculations
    states = torch.FloatTensor(
        np.array(states, copy=False)).to(device)
    actions = torch.LongTensor(actions).to(device)

    # handle rewards
    rewards_np = np.array(rewards, dtype=np.float32)
    if not_done_idx:
        last_states = torch.FloatTensor(np.array(last_states, copy=False)).to(device)
        last_vals = model(last_states)[1]
        last_vals_np = last_vals.data.cpu().numpy()[:, 0]
        last_vals_np *= gamma ** reward_steps
        rewards_np[not_done_idx] += last_vals_np

    rewards = torch.FloatTensor(rewards_np).to(device)

    return states, actions, rewards

In [4]:
class SchedulerEnv(gym.Env):

    def __init__(self):
        
        #starting parameters
        num_gps = 100
        num_slots = 32
        num_pre_booked = 15
        to_book = [2,1,2,2,1,1,1]
        num_to_book = len(to_book)
        agent_pos = [0,0]
        reward_decay = 0.95
        
        #set parameters for the day
        self.num_gps = num_gps
        self.num_slots = num_slots
        self.num_pre_booked = num_pre_booked
        self.to_book = to_book
        self.num_to_book = num_to_book
        self.diary_slots = num_gps*num_slots
        self.agent_pos = agent_pos
        self.reward_decay = reward_decay

        #set action space to move around the grid
        self.action_space = gym.spaces.Discrete(4) #up, down, left, right
        
        #set observation space 
        self.observation_space = gym.spaces.Box(low=0, high=1, shape=(self.num_slots, self.num_gps), dtype=np.int32)
   
    #creates daily diary for each gp, randomly populates prebooked appointments and resets parameters
    def reset(self):

        #creates zero filled dataframe with row per time slot and column per gp
        self.state = np.zeros((self.num_slots, self.num_gps),dtype=float)

        #randomly enters a 1 for each pre booked appointments
        pre_booked = self.num_pre_booked
        while pre_booked>0:
            pre_booked -= 1
            self.state[np.random.randint(self.num_slots), np.random.randint(self.num_gps)] = 1
            
        #randomly sets the agent start space
        self.agent_pos = [np.random.randint(self.num_slots), np.random.randint(self.num_gps)]

        #resets parameters for new episode
        self.done = False
        self.reward = 0
        self.appt_idx = 0
        self.decay_steps = 1
        
        #print('starting state', self.state.sum(), self.state)

        return self.state
    
    #calculates new position of the agent based on the action
    def move_agent(self, action):

        #set boundaries for the grid
        max_row = self.num_slots - 1
        max_col = self.num_gps - 1

        #setting new co-ordinates for the agent
        new_row = self.agent_pos[0]
        new_col = self.agent_pos[1]

        #calculate what the new position may be based on the action without going out the grid
        if action == 0:
            #print('up')
            new_row = max(self.agent_pos[0] - 1, 0)
        if action == 1:
            #print('down')
            new_row = min(self.agent_pos[0] + 1, max_row)
        if action == 2:
            #print('left')
            new_col = max(self.agent_pos[1] - 1, 0)
        if action == 3:
            #print('right')
            new_col = min(self.agent_pos[1] + 1, max_col)

        new_pos = [new_row, new_col]
        #print('new pos', new_pos)

        return new_pos

    #checks if we can look to book appointment starting here
    def check_bookable(self):
        return self.state[self.agent_pos[0], self.agent_pos[1]] == 0.0
    
    #action if we can't book the appointment
    def invalid_booking(self):
        #print('cant book')
        self.decay_steps += 1
        self.reward = -1
        
    #action if we can book the appointment
    def valid_booking(self):
        #print('go ahead and book')
        self.appt_idx += 1
        self.decay_steps = 1
        self.reward = 1
    
    #checks if the appointment fits
    def check_and_book(self):
        
        max_row = self.num_slots - 1
        cells_to_check = self.to_book[self.appt_idx]
        
        if cells_to_check==1:
            #print('good to check for single')
            if self.state[self.agent_pos[0], self.agent_pos[1]] == 0:
                self.state[self.agent_pos[0], self.agent_pos[1]] = 1
                self.valid_booking()
            else:
                #print('single taken')
                self.invalid_booking()

        if cells_to_check==2:
            #check we're not at the bottom of the grid
            if self.agent_pos[0]<max_row:
                #check the next cells is also 0.0
                #print('good to check for double')
                if self.state[self.agent_pos[0], self.agent_pos[1]] == 0 and \
                self.state[(self.agent_pos[0]+1), self.agent_pos[1]] == 0:
                    self.state[self.agent_pos[0], self.agent_pos[1]] = 1
                    self.state[(self.agent_pos[0]+1), self.agent_pos[1]] = 1
                    self.valid_booking()
                    self.agent_pos = [(self.agent_pos[0]+1), self.agent_pos[1]]
                    #print('after booking', self.agent_pos)
                else:
                    #print('double taken')
                    self.invalid_booking()
            else:
                #print('not for double')
                self.invalid_booking()
                
        if cells_to_check==3:
            #check we're not at the bottom of the grid
            if self.agent_pos[0]+1<max_row:
                #print('good to check for treble')
                if self.state[self.agent_pos[0], self.agent_pos[1]] == 0 and \
                self.state[(self.agent_pos[0]+1), self.agent_pos[1]] == 0 \
                 and self.state[(self.agent_pos[0]+2), self.agent_pos[1]] == 0:
                    self.state[self.agent_pos[0], self.agent_pos[1]] = 1
                    self.state[(self.agent_pos[0]+1), self.agent_pos[1]] = 1
                    self.state[(self.agent_pos[0]+2), self.agent_pos[1]] = 1
                    self.valid_booking()
                    self.agent_pos = [(self.agent_pos[0]+2), self.agent_pos[1]]
                else:
                    #print('treble taken')
                    self.invalid_booking()
            else:
                #print('not for treble')
                self.invalid_booking()
                
        if cells_to_check==4:
            #check we're not at the bottom of the grid
            if self.agent_pos[0]+2<max_row:
                #check the next cells is also 0.0
                #print('good for quad')
                if self.state[self.agent_pos[0], self.agent_pos[1]] == 0 and \
                self.state[(self.agent_pos[0]+1), self.agent_pos[1]] == 0 \
                 and self.state[(self.agent_pos[0]+2), self.agent_pos[1]] == 0 and \
                self.state[(self.agent_pos[0]+3), self.agent_pos[1]] == 0:
                    self.state[self.agent_pos[0], self.agent_pos[1]] = 1
                    self.state[(self.agent_pos[0]+1), self.agent_pos[1]] = 1
                    self.state[(self.agent_pos[0]+2), self.agent_pos[1]] = 1
                    self.state[(self.agent_pos[0]+3), self.agent_pos[1]] = 1
                    self.valid_booking()
                    self.agent_pos = [(self.agent_pos[0]+3), self.agent_pos[1]]
                else:
                    #print('quad taken')
                    self.invalid_booking()
            else:
                #print('not for quad')
                self.invalid_booking()

        next_state = self.state

        return next_state

    def step(self, action):

        #print('start step' , self.decay_steps)
        #get new position of agent based on action
        new_agent_pos = self.move_agent(action)
        #print('new and old pos', new_agent_pos, self.agent_pos)
        
        #if the agent is stuck on an edge then move to a new position
        if new_agent_pos == self.agent_pos:
            self.agent_pos = [np.random.randint(self.num_slots), np.random.randint(self.num_gps)]
            #print('here1', self.agent_pos)
        else:
            self.agent_pos = new_agent_pos
            #print('here2', self.agent_pos)
        
        #print('trying to book', self.to_book, self.appt_idx)
        
        #check if it's possible to book then book
        if self.check_bookable():
            #print('checked here')
            self.state = self.check_and_book()
        else:
            #print('not bookable')
            self.invalid_booking()
        
        #work out if episode complete
        if self.appt_idx == len(self.to_book):
            #print('all booked')
            self.done = True
  
        #work out rewards
        #self.reward = (1 - (self.reward_decay**self.decay_steps))
        
        #print('step', self.decay_steps, self.reward)
        #print('end step')

        info = {}
        return self.state, self.reward, self.done, info

In [5]:
#device = "cuda"
device = "cpu"

#create multiple environments for multiprocessing
make_env = lambda: SchedulerEnv()
envs = [make_env() for _ in range(num_envs)]

#start writing to tensorboard
writer = SummaryWriter(comment="Scheduler")

#initialise model, agent and run through episodes to get experience
model = Model(envs[0].observation_space.shape, envs[0].action_space.n).to(device)
agent = ptan.agent.PolicyAgent(lambda x: model(x)[0], apply_softmax=True, device=device)
exp_source = ptan.experience.ExperienceSourceFirstLast(envs, agent, gamma=gamma, steps_count=reward_steps)

optimizer = optim.Adam(model.parameters(), lr=0.001, eps=1e-3)

#create list to capture batches
batch = []

#create lists to be used to record values for tracking averages
reward_stack = []
loss_stack = []

#work through each experience source to capture state, actions etc
for step_idx, exp in enumerate(exp_source):
    batch.append(exp)

    if len(batch) < batch_size:
        continue

    states, actions, rewards = unpack_batch(batch, model, device=device)
    batch.clear()

    optimizer.zero_grad()

    # using the network to give actions and state_value
    actor_val, critic_val = model(states)
    # [CRITIC] calculate the loss between value_state (just predicted now) and reward from the batch
    critic_loss = F.mse_loss(critic_val.squeeze(-1), rewards)

    # Runs the log_softmax against actor output (just predicted now)
    log_prob = F.log_softmax(actor_val, dim=1)
    # Advantage equals reward from the batch (size:[batch_size]) minus the value_state (just predicted now)
    advantage = rewards - critic_val.detach()

    # multiples the advantage at each step by the log probability of the chosen action for that step
    log_prob_actions = advantage * log_prob[range(batch_size), actions]
    # calculate the policy gradient adjustment to make (negated to move toward policy improvement)
    actor_loss = -log_prob_actions.mean()

    # perform softmax on action estimates (from ACTOR) (just predicted now)
    prob_val = F.softmax(actor_val, dim=1)
    # calculating the action entropy 
    entropy_loss = 0.01 * (prob_val * log_prob).sum(dim=1).mean()

    # calculate policy gradients only

    # [ACTOR] backpropogate
    actor_loss.backward(retain_graph=True)

    # apply entropy and value gradients
    # [CRITIC] backpropagate and apply entropy
    loss = entropy_loss + critic_loss
    loss.backward()

    optimizer.step()

    #send average loss and rewards to tensorboard
    if len(reward_stack) > 0 and step_idx % 10 == 0:
        #print(step_idx)
        avg_rewards = np.mean(reward_stack)
        avg_loss = np.mean(loss_stack)
        writer.add_scalar('ave_batch_reward', avg_rewards, step_idx)
        writer.add_scalar('ave_batch_loss', avg_loss, step_idx)
        print('ave_batch_reward', avg_rewards, 'step', step_idx)
        print('ave_batch_loss', avg_loss, 'step', step_idx)
        reward_stack.clear()
        loss_stack.clear()
    else:
        reward_stack.append(torch.mean(rewards).item())
        loss_stack.append(torch.mean(critic_loss).item())
        
    if step_idx > 100000:
        break

writer.close()

ave_batch_reward 0.7927932790480554 step 120
ave_batch_loss 3.3803622245788576 step 120
ave_batch_reward 1.3872011817163892 step 230
ave_batch_loss 4.314593394597371 step 230
ave_batch_reward 1.8789456619156732 step 340
ave_batch_loss 5.333387633164723 step 340
ave_batch_reward 2.8976784812079535 step 450
ave_batch_loss 4.413786225848728 step 450
ave_batch_reward 4.293225871192084 step 560
ave_batch_loss 7.188605255550808 step 560
ave_batch_reward 4.762633138232761 step 670
ave_batch_loss 10.464578628540039 step 670
ave_batch_reward 4.35230376985338 step 780
ave_batch_loss 7.136023786332872 step 780
ave_batch_reward 4.16373085975647 step 890
ave_batch_loss 7.5537782775031195 step 890
ave_batch_reward 4.729552931255764 step 1000
ave_batch_loss 9.058580981360542 step 1000
ave_batch_reward 4.9304408232371015 step 1110
ave_batch_loss 10.592374695671928 step 1110
ave_batch_reward 4.828351126776801 step 1220
ave_batch_loss 9.110197226206461 step 1220
ave_batch_reward 4.797639634874132 step 1

ave_batch_reward 5.031325552198622 step 10350
ave_batch_loss 10.770365715026855 step 10350
ave_batch_reward 5.227910306718615 step 10460
ave_batch_loss 10.902318742540148 step 10460
ave_batch_reward 5.092733383178711 step 10570
ave_batch_loss 10.113860236273872 step 10570
ave_batch_reward 5.130991988711887 step 10680
ave_batch_loss 10.076834360758463 step 10680
ave_batch_reward 5.240369743771023 step 10790
ave_batch_loss 9.878457652197945 step 10790
ave_batch_reward 5.169296317630344 step 10900
ave_batch_loss 10.745310889350044 step 10900
ave_batch_reward 5.276642534467909 step 11010
ave_batch_loss 10.382305992974175 step 11010
ave_batch_reward 4.976211123996311 step 11120
ave_batch_loss 9.901999261644152 step 11120
ave_batch_reward 5.2561508019765215 step 11230
ave_batch_loss 9.319781409369575 step 11230
ave_batch_reward 5.11783848868476 step 11340
ave_batch_loss 9.054891374376085 step 11340
ave_batch_reward 5.137901306152344 step 11450
ave_batch_loss 9.625856717427572 step 11450
ave_

ave_batch_reward 5.312419043646918 step 20580
ave_batch_loss 10.758958286709255 step 20580
ave_batch_reward 5.252944469451904 step 20690
ave_batch_loss 9.460780620574951 step 20690
ave_batch_reward 5.11913447909885 step 20800
ave_batch_loss 10.78576374053955 step 20800
ave_batch_reward 5.182332356770833 step 20910
ave_batch_loss 9.54203626844618 step 20910
ave_batch_reward 5.087407535976833 step 21020
ave_batch_loss 9.84144073062473 step 21020
ave_batch_reward 5.104547712537977 step 21130
ave_batch_loss 9.27966441048516 step 21130
ave_batch_reward 4.99278810289171 step 21240
ave_batch_loss 10.818244722154406 step 21240
ave_batch_reward 5.164246188269721 step 21350
ave_batch_loss 10.306227684020996 step 21350
ave_batch_reward 5.269475248124865 step 21460
ave_batch_loss 10.070697943369547 step 21460
ave_batch_reward 5.311410744984944 step 21570
ave_batch_loss 10.658455106947157 step 21570
ave_batch_reward 5.2812844382392035 step 21680
ave_batch_loss 10.474566565619575 step 21680
ave_batc

ave_batch_reward 5.355544567108154 step 30590
ave_batch_loss 10.236200650533041 step 30590
ave_batch_reward 5.264405303531223 step 30700
ave_batch_loss 10.496074358622232 step 30700
ave_batch_reward 5.083683013916016 step 30810
ave_batch_loss 9.901006274753147 step 30810
ave_batch_reward 5.073610226313273 step 30920
ave_batch_loss 10.464320288764107 step 30920
ave_batch_reward 5.463833332061768 step 31030
ave_batch_loss 11.491325166490343 step 31030
ave_batch_reward 5.124676465988159 step 31140
ave_batch_loss 10.957761340671116 step 31140
ave_batch_reward 5.23734548356798 step 31250
ave_batch_loss 9.907944467332628 step 31250
ave_batch_reward 5.129859103096856 step 31360
ave_batch_loss 9.8459259668986 step 31360
ave_batch_reward 5.216495831807454 step 31470
ave_batch_loss 10.501654518975151 step 31470
ave_batch_reward 5.092924859788683 step 31580
ave_batch_loss 10.322150389353434 step 31580
ave_batch_reward 5.288170178731282 step 31690
ave_batch_loss 10.827070660061306 step 31690
ave_b

ave_batch_reward 5.143732865651448 step 40820
ave_batch_loss 10.990800751580132 step 40820
ave_batch_reward 5.308857176038954 step 40930
ave_batch_loss 10.934568405151367 step 40930
ave_batch_reward 5.185686376359728 step 41040
ave_batch_loss 10.428281784057617 step 41040
ave_batch_reward 5.404014958275689 step 41150
ave_batch_loss 9.994269476996529 step 41150
ave_batch_reward 5.116740279727512 step 41260
ave_batch_loss 9.893460591634115 step 41260
ave_batch_reward 5.239932060241699 step 41370
ave_batch_loss 9.63934294382731 step 41370
ave_batch_reward 5.137573136223687 step 41480
ave_batch_loss 10.48480913374159 step 41480
ave_batch_reward 5.274077415466309 step 41590
ave_batch_loss 10.920336405436197 step 41590
ave_batch_reward 5.345853964487712 step 41700
ave_batch_loss 11.065010070800781 step 41700
ave_batch_reward 5.28741807407803 step 41810
ave_batch_loss 10.980045000712076 step 41810
ave_batch_reward 5.127713044484456 step 41920
ave_batch_loss 10.38546699947781 step 41920
ave_ba

ave_batch_reward 5.208127074771458 step 51050
ave_batch_loss 11.141449610392252 step 51050
ave_batch_reward 5.28033791648017 step 51160
ave_batch_loss 11.062172465854221 step 51160
ave_batch_reward 5.164061228434245 step 51270
ave_batch_loss 10.087337017059326 step 51270
ave_batch_reward 5.259736167060004 step 51380
ave_batch_loss 10.50805950164795 step 51380
ave_batch_reward 5.1154369248284235 step 51490
ave_batch_loss 9.998206297556559 step 51490
ave_batch_reward 5.290635744730632 step 51600
ave_batch_loss 9.926471869150797 step 51600
ave_batch_reward 5.19767411549886 step 51710
ave_batch_loss 10.967096010843912 step 51710
ave_batch_reward 5.496160560184055 step 51820
ave_batch_loss 11.084408654106987 step 51820
ave_batch_reward 5.2479141553243 step 51930
ave_batch_loss 11.043375968933105 step 51930
ave_batch_reward 5.0048148896959095 step 52040
ave_batch_loss 10.500757641262478 step 52040
ave_batch_reward 5.122239854600695 step 52150
ave_batch_loss 10.034801165262857 step 52150
ave_

ave_batch_reward 5.0342640611860485 step 61170
ave_batch_loss 8.839059935675728 step 61170
ave_batch_reward 5.078638262218899 step 61280
ave_batch_loss 9.882431772020128 step 61280
ave_batch_reward 4.927578952577379 step 61390
ave_batch_loss 9.820196257697212 step 61390
ave_batch_reward 5.189442740546332 step 61500
ave_batch_loss 10.426003138224283 step 61500
ave_batch_reward 5.105572541554769 step 61610
ave_batch_loss 10.06177446577284 step 61610
ave_batch_reward 5.302096260918511 step 61720
ave_batch_loss 10.54061190287272 step 61720
ave_batch_reward 5.040703747007582 step 61830
ave_batch_loss 9.21821599536472 step 61830
ave_batch_reward 4.937919722663032 step 61940
ave_batch_loss 10.193031205071343 step 61940
ave_batch_reward 4.843268182542589 step 62050
ave_batch_loss 9.816924201117622 step 62050
ave_batch_reward 4.92431214120653 step 62160
ave_batch_loss 9.347163518269857 step 62160
ave_batch_reward 5.1809073024325905 step 62270
ave_batch_loss 10.328520350986057 step 62270
ave_bat

ave_batch_reward 4.73583369784885 step 71400
ave_batch_loss 10.029064231448704 step 71400
ave_batch_reward 5.08882310655382 step 71510
ave_batch_loss 9.01286416583591 step 71510
ave_batch_reward 5.092113018035889 step 71620
ave_batch_loss 10.244057920244005 step 71620
ave_batch_reward 5.23761166466607 step 71730
ave_batch_loss 10.463149070739746 step 71730
ave_batch_reward 5.164229843351576 step 71840
ave_batch_loss 10.802684677971733 step 71840
ave_batch_reward 5.045222653283013 step 71950
ave_batch_loss 10.95126681857639 step 71950
ave_batch_reward 5.358484268188477 step 72060
ave_batch_loss 10.97456497616238 step 72060
ave_batch_reward 5.104195303387112 step 72170
ave_batch_loss 10.828447024027506 step 72170
ave_batch_reward 5.13822078704834 step 72280
ave_batch_loss 9.499467902713352 step 72280
ave_batch_reward 5.138306670718723 step 72390
ave_batch_loss 10.240410380893284 step 72390
ave_batch_reward 5.281153122584025 step 72500
ave_batch_loss 9.886225912306044 step 72500
ave_batch

ave_batch_reward 5.294468694263035 step 81630
ave_batch_loss 11.053453551398384 step 81630
ave_batch_reward 5.310504966311985 step 81740
ave_batch_loss 10.643436008029514 step 81740
ave_batch_reward 5.197653558519152 step 81850
ave_batch_loss 10.487485567728678 step 81850
ave_batch_reward 5.204704019758436 step 81960
ave_batch_loss 10.172848383585611 step 81960
ave_batch_reward 5.1916919814215765 step 82070
ave_batch_loss 10.446391847398546 step 82070
ave_batch_reward 5.258729298909505 step 82180
ave_batch_loss 10.445213317871094 step 82180
ave_batch_reward 5.175738493601481 step 82290
ave_batch_loss 10.387900670369467 step 82290
ave_batch_reward 5.215944078233507 step 82400
ave_batch_loss 10.174155341254341 step 82400
ave_batch_reward 5.108971277872722 step 82510
ave_batch_loss 10.448970158894857 step 82510
ave_batch_reward 5.010713524288601 step 82620
ave_batch_loss 10.209090974595812 step 82620
ave_batch_reward 5.145132011837429 step 82730
ave_batch_loss 9.596218374040392 step 82730

ave_batch_reward 5.301935778723823 step 91860
ave_batch_loss 10.830933888753256 step 91860
ave_batch_reward 5.292596764034695 step 91970
ave_batch_loss 11.030575964185926 step 91970
ave_batch_reward 5.216263771057129 step 92080
ave_batch_loss 10.467002126905653 step 92080
ave_batch_reward 5.263163725535075 step 92190
ave_batch_loss 10.864399062262642 step 92190
ave_batch_reward 5.228837966918945 step 92300
ave_batch_loss 10.853144751654732 step 92300
ave_batch_reward 5.268167018890381 step 92410
ave_batch_loss 11.212727228800455 step 92410
ave_batch_reward 5.021034585105048 step 92520
ave_batch_loss 10.067067305246988 step 92520
ave_batch_reward 5.206266429689196 step 92630
ave_batch_loss 10.495266278584799 step 92630
ave_batch_reward 5.138407892651028 step 92740
ave_batch_loss 10.678661028544107 step 92740
ave_batch_reward 5.193681425518459 step 92850
ave_batch_loss 11.118967692057291 step 92850
ave_batch_reward 5.1619523366292315 step 92960
ave_batch_loss 10.097411897447374 step 9296