In [None]:
import gym
import numpy as np
import random
from random import choices
from collections import Counter, deque
from tensorboardX import SummaryWriter

import torch
import torch.nn as nn
import torch.nn.utils as nn_utils
import torch.nn.functional as F
import torch.optim as optim

In [None]:
# Hyperparameters
batch_size = 32

gamma = 0.99

epsilon = 0.1

num_rounds = 50
#num_episodes = 500
learning_limit = 100
replay_limit = 1000  # Number of steps until starting replay
#weight_update = 1000 # Number of steps until updating the target weights


In [None]:
#create model
class Model(nn.Module):
    def __init__(self, input_shape, n_actions):
        super(Model, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(input_shape[0]*input_shape[1], 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, n_actions)
            
        )

    def forward(self, x):
        # flatten the observation space Box to linear tensor
        tensor_array = torch.from_numpy(state)
        x_flat = torch.flatten(tensor_array).to(torch.float32)
        return self.net(x_flat)

In [None]:
class SchedulerEnv(gym.Env):

    def __init__(self):
        
        #starting parameters
        num_gps = 10
        num_slots = 10
        
        num_pre_booked = 25
        #to_book = [2,1,2,2,1,1,1,3,3,1,2,1,3,2,1,1,2,1,3,2,3,2]
        to_book = [2,1,1,1,1]
        num_to_book = len(to_book)
        
        
#         num_pre_booked = random.randint(6*num_gps, 14*num_gps)
#         num_to_book = random.randint(6*num_gps, 12*num_gps)
#         to_book = []
#         for j in range(num_to_book):
#             to_book.append(*choices([1,2,3],[.7, .25, .05]))
            
            
        agent_pos = [0,0]
        
        #set parameters for the day
        self.num_gps = num_gps
        self.num_slots = num_slots
        self.num_pre_booked = num_pre_booked
        self.to_book = to_book
        self.num_to_book = num_to_book
        self.diary_slots = num_gps*num_slots
        self.agent_pos = agent_pos

        #set action space to move around the grid
        self.action_space = gym.spaces.Discrete(4) #up, down, left, right
        
        #set observation space 
        self.observation_space = gym.spaces.Box(low=0, high=1, shape=(self.num_slots, self.num_gps), dtype=np.int32)
   
    #creates daily diary for each gp, randomly populates prebooked appointments and resets parameters
    def reset(self):

        #creates zero filled dataframe with row per time slot and column per gp
        self.state = np.zeros((self.num_slots, self.num_gps),dtype=float)

        #randomly enters a 1 for each pre booked appointments
        pre_booked = self.num_pre_booked
        while pre_booked>0:
            pre_booked -= 1
            self.state[np.random.randint(self.num_slots), np.random.randint(self.num_gps)] = 1
            
        #randomly sets the agent start space
        self.agent_pos = [np.random.randint(self.num_slots), np.random.randint(self.num_gps)]

        #resets parameters for new episode
        self.done = False
        self.reward = 0
        self.appt_idx = 0
        
        #print('starting state', self.state.sum(), self.state)

        return self.state
    
    #calculates new position of the agent based on the action
    def move_agent(self, action):

        #set boundaries for the grid
        max_row = self.num_slots - 1
        max_col = self.num_gps - 1

        #setting new co-ordinates for the agent
        new_row = self.agent_pos[0]
        new_col = self.agent_pos[1]

        #calculate what the new position may be based on the action without going out the grid
        if action == 0:
            #print('up')
            new_row = max(self.agent_pos[0] - 1, 0)
        if action == 1:
            #print('down')
            new_row = min(self.agent_pos[0] + 1, max_row)
        if action == 2:
            #print('left')
            new_col = max(self.agent_pos[1] - 1, 0)
        if action == 3:
            #print('right')
            new_col = min(self.agent_pos[1] + 1, max_col)

        new_pos = [new_row, new_col]
        #print('new pos', new_pos)

        return new_pos

    #checks if we can look to book appointment starting here
    def check_bookable(self):
        return self.state[self.agent_pos[0], self.agent_pos[1]] == 0.0
    
    #action if we can't book the appointment
    def invalid_booking(self):
        #print('cant book')
        self.reward = -1
        
    #action if we can book the appointment
    def valid_booking(self):
        #print('go ahead and book')
        self.appt_idx += 1
        self.reward = 1
    
    #checks if the appointment fits
    def check_and_book(self):
        
        max_row = self.num_slots - 1
        cells_to_check = self.to_book[self.appt_idx]
        
        if cells_to_check==1:
            #print('good to check for single')
            if self.state[self.agent_pos[0], self.agent_pos[1]] == 0:
                self.state[self.agent_pos[0], self.agent_pos[1]] = 1
                self.valid_booking()
            else:
                #print('single taken')
                self.invalid_booking()

        if cells_to_check==2:
            #check we're not at the bottom of the grid
            if self.agent_pos[0]<max_row:
                #check the next cells is also 0.0
                #print('good to check for double')
                if self.state[self.agent_pos[0], self.agent_pos[1]] == 0 and \
                self.state[(self.agent_pos[0]+1), self.agent_pos[1]] == 0:
                    self.state[self.agent_pos[0], self.agent_pos[1]] = 1
                    self.state[(self.agent_pos[0]+1), self.agent_pos[1]] = 1
                    self.valid_booking()
                    self.agent_pos = [(self.agent_pos[0]+1), self.agent_pos[1]]
                    #print('after booking', self.agent_pos)
                else:
                    #print('double taken')
                    self.invalid_booking()
            else:
                #print('not for double')
                self.invalid_booking()
                
        if cells_to_check==3:
            #check we're not at the bottom of the grid
            if self.agent_pos[0]+1<max_row:
                #print('good to check for treble')
                if self.state[self.agent_pos[0], self.agent_pos[1]] == 0 and \
                self.state[(self.agent_pos[0]+1), self.agent_pos[1]] == 0 \
                 and self.state[(self.agent_pos[0]+2), self.agent_pos[1]] == 0:
                    self.state[self.agent_pos[0], self.agent_pos[1]] = 1
                    self.state[(self.agent_pos[0]+1), self.agent_pos[1]] = 1
                    self.state[(self.agent_pos[0]+2), self.agent_pos[1]] = 1
                    self.valid_booking()
                    self.agent_pos = [(self.agent_pos[0]+2), self.agent_pos[1]]
                else:
                    #print('treble taken')
                    self.invalid_booking()
            else:
                #print('not for treble')
                self.invalid_booking()
                
        if cells_to_check==4:
            #check we're not at the bottom of the grid
            if self.agent_pos[0]+2<max_row:
                #check the next cells is also 0.0
                #print('good for quad')
                if self.state[self.agent_pos[0], self.agent_pos[1]] == 0 and \
                self.state[(self.agent_pos[0]+1), self.agent_pos[1]] == 0 \
                 and self.state[(self.agent_pos[0]+2), self.agent_pos[1]] == 0 and \
                self.state[(self.agent_pos[0]+3), self.agent_pos[1]] == 0:
                    self.state[self.agent_pos[0], self.agent_pos[1]] = 1
                    self.state[(self.agent_pos[0]+1), self.agent_pos[1]] = 1
                    self.state[(self.agent_pos[0]+2), self.agent_pos[1]] = 1
                    self.state[(self.agent_pos[0]+3), self.agent_pos[1]] = 1
                    self.valid_booking()
                    self.agent_pos = [(self.agent_pos[0]+3), self.agent_pos[1]]
                else:
                    #print('quad taken')
                    self.invalid_booking()
            else:
                #print('not for quad')
                self.invalid_booking()

        next_state = self.state

        return next_state

    def step(self, action):

        #get new position of agent based on action
        new_agent_pos = self.move_agent(action)
        #print('new and old pos', new_agent_pos, self.agent_pos)
        
        #if the agent is stuck on an edge then move to a new position
        if new_agent_pos == self.agent_pos:
            self.agent_pos = [np.random.randint(self.num_slots), np.random.randint(self.num_gps)]
            #print('here1', self.agent_pos)
        else:
            self.agent_pos = new_agent_pos
            #print('here2', self.agent_pos)
            
        #print('trying to book', self.to_book, self.appt_idx)
        
        #check if it's possible to book then book
        if self.check_bookable():
            #print('checked here')
            self.state = self.check_and_book()
        else:
            #print('not bookable')
            self.invalid_booking()
        
        #work out if episode complete
        if self.appt_idx == len(self.to_book):
            self.done = True
            
        #print(self.state, self.agent_pos)
        agent_state = self.state.copy()
        agent_state[self.agent_pos[0], self.agent_pos[1]] = 5
        #print('agent', agent_state)

        info = {}
        return agent_state, self.reward, self.done, info

In [None]:
#device = "cuda"
device = "cpu"

env = SchedulerEnv()

#start writing to tensorboard
writer = SummaryWriter(comment="Scheduler MC")

#create the current network and target network
policy_model = Model(env.observation_space.shape, env.action_space.n).to(device)

optimizer = optim.Adam(policy_model.parameters(), lr=0.001, eps=1e-3)
stp_idx = 0

#epsilon = eps_start
for a in range(7):
    print(a)
    episode_list =[]

    for i in range(num_rounds):
        #change this for while not true once it works
        state = env.reset()
        episode_reward = 0
        done = False
        #print('reset here')

        for j in range(50):
    #    while not done:

            # Select and perform an action
            if np.random.rand() > epsilon:
                action = torch.argmax(policy_model(state))
            else:
                action = np.random.randint(env.action_space.n)

            next_state, reward, done, _ = env.step(action)
            episode_reward += reward

            episode_list.append([state, action, reward, done])

            #print('here rewards', episode_reward, reward, step_idx)

            # Move to the next state
            state = next_state
            
            stp_idx +=1
            print(stp_idx)
            
            if done:
                break

        #print('stopped episode', j, episode_reward)

        writer.add_scalar('episode_reward', episode_reward, i)

        #print('step', step_idx, 'i', i, 'j', j, episode_reward)

    #create list of state, action pairs with action values
    tot_reward = []
    reward_to_add = 0
    for item in reversed(episode_list):
        if item[3]:
            tot_reward.append([item[0], int(item[1]), item[2]])
            #print([item[0], int(item[1]), item[2]])
            reward_to_add = 1
        else:
            reward_to_add += item[2]
            tot_reward.append([item[0], int(item[1]), reward_to_add])

    # create dictionary to be able to average action values to use for training
    test_dict = {}
    for item in tot_reward:
        flat_state = tuple(item[0].flatten())
        int_act = int(item[1])
        k = (flat_state, int_act)
        if k in test_dict.keys():
            test_dict[k].append(item[2])
        else:
            test_dict[k] = [item[2]] 

    optimizer.zero_grad()

    state_list = []
    action_list = []
    state_value = []
    for key in test_dict.keys():
        #print(len(key))
        state_list.append(key[0])
        action_list.append(key[1])
        v = test_dict[key]
        if sum(v) == 0:
            state_value.append(sum(v))
        else:
            state_value.append(sum(v)/ float(len(v))) 

    predicted = []
    target = []
    for i in range(len(state_list)):
        predicted.append(policy_model(state_list[i])[action_list[i]])
        target.append(state_value[i])

    loss = F.mse_loss(torch.Tensor(predicted), torch.Tensor(target))
    loss.requires_grad = True
    loss.backward()
    writer.add_scalar('loss', loss, stp_idx)

    optimizer.step()      


writer.close()

In [None]:
test_example = episode_list.copy()

In [None]:
len(test_example)

In [None]:
test_example[0]

In [None]:
tot_reward = []
reward_to_add = 0
for item in reversed(test_example):
    if item[3]:
        tot_reward.append([item[0], int(item[1]), item[2]])
        #print([item[0], int(item[1]), item[2]])
        reward_to_add = 1
    else:
        reward_to_add += item[2]
        tot_reward.append([item[0], int(item[1]), reward_to_add])

In [None]:
for item in tot_reward:
    flat_state = tuple(item[0].flatten())
    int_act = int(item[1])
    k = (flat_state, int_act)
    #print(k)
    if k in test_dict.keys():
        test_dict[k].append(item[2])
    else:
        test_dict[k] = [item[2]]

In [None]:
k[1]

In [None]:
ans = policy_model(k[0])
ans

In [None]:
ans[k[1]]

In [None]:
policy_model(k[0])[k[1]]

In [None]:
# #Train
optimizer.zero_grad()

state_list = []
action_list = []
state_value = []
for key in test_dict.keys():
    #print(len(key))
    state_list.append(key[0])
    action_list.append(key[1])
    v = test_dict[key]
    if sum(v) == 0:
        state_value.append(sum(v))
    else:
        state_value.append(sum(v)/ float(len(v))) 
        
predicted = []
target = []
for i in range(len(state_list)):
    predicted.append(policy_model(state_list[i])[action_list[i]])
    target.append(state_value[i])
    
#pred = torch.Tensor(predicted)
#pred = (predicted, requires_grad=True)
#tar = torch.Tensor(target)

loss = F.mse_loss(torch.Tensor(predicted), torch.Tensor(target))
loss.requires_grad = True
loss.backward()

optimizer.step()
# #print('step', step_idx, 'i', i, 'j', j)

In [None]:
test_dict1 = {('abd',45):[1,2,3], ('acd',54):[1,3,3], ('a3d',52):[0,0,0]}

state_list = []
action_list = []
state_value = []
for key in test_dict1.keys():
    print(len(key))
    state_list.append(key[0])
    action_list.append(key[1])
    v = test_dict1[key]
    if sum(v) == 0:
        state_value.append(sum(v))
    else:
        state_value.append(sum(v)/ float(len(v)))

In [None]:
state_list

In [None]:
action_list

In [None]:
state_value