In [None]:
from ray.rllib.utils.spaces.repeated import Repeated
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from gym.spaces import Discrete
from gym.utils import seeding
from gym.envs.registration import register
import torch
import numpy as np

# choses random numbers from a list
# e.g: random.choice([1, 2, 3, 4, 5]) -> 2
# returns a random int from an interval:
# e.g: random.randint(0, 10) -> 8
import random

import unittest

In [5]:
from ray.rllib.policy.torch_policy_template import build_policy_class
import torch

def policy_gradient_loss(policy, model, dist_class, train_batch):
    logits, _ = model.from_batch(train_batch)
    action_dist = dist_class(logits, model)
    return -torch.mean(
        action_dist.logp(train_batch["actions"]) * train_batch["rewards"])

# <class 'ray.rllib.policy.tf_policy_template.MyTFPolicy'>
MyTorchPolicy = build_policy_class(
                        framework='torch',
                        name='MyTorchPolicy',
                        loss_fn=policy_gradient_loss)

In [None]:
# For the distance function:
# https://networkx.org/documentation/stable/reference/algorithms/shortest_paths.html

class RoadNet(MultiAgentEnv):
    def __init__(self, num_as):
        self.POS_MIN = 1
        self.POS_MAX = 32
        self.MAX_STEPS = 15

        self.REWARD_AWAY = -2
        self.REWARD_STEP = -1
        self.REWARD_GOAL = MAX_STEPS
        
        self.PLATOON = 0
        
        self.agents = [MockEnv(self.MAX_STEPS)]*num_as
        
        self.dones = set()
        self.obs = {}
        self.rew = {}
        self.done = {}
        self.info = {}
        self.num_as = num_as
        
        self.goal = 25
        self.init_positions = list(range(self.POS_MIN, self.POS_MAX)).remove(self.goal)
        self.seed()
        
        self.action_space = Repeated(child_space=Discrete(POS_MAX), 
                                     max_len=5)
        self.observation_space = Discrete(self.POS_MAX + 1)
        
    def reset(self):
        # TODO: 
        #   - Do not allow agents to start from the same pos.
        #   - Make this as dicts per agent.
        #   - 
        
        self.dones = set()
        self.obs = {}
        self.rew = {}
        self.done = {}
        self.info = {}
        
        self.position = self.np_random.choice(self.init_positions)
        self.count = 0
        
        self.state = self.position
        
        
        return {i: a.reset() for i, a in enumerate(self.agents)}
    
    def step(self, action_dict):
        obs, rew, done, info = {}, {}, {}, {}
        
        for i, action in action_dict.items():
            obs[i], rew[i], done[i], info[i] = self.agents[i].step(action)
            if done[i]:
                self.dones.add(i)
        done["__all__"] = len(self.dones) == len(self.agents)
        
        if self.done:
            print("EPISODE DONE!!!")
        elif self.count == self.MAX_STEPS:
            self.done = True
        else:
            assert self.action_space.contains(action)
            self.count += 1
            
        # ---------------------------------------------
        # here define how an action will transition the position
        # of the agent from the current state to the next state
        # in observation_space.
        # also the reward should be included.
        # also include the platoon: should compare the states (pos)
        # of agents and see if they colide, when true platoon increases
        # a reward should be given
        
        # moving left
        if action == self.MOVE_LF:
            # if we are at the far left then we lost
            # give agent -2
            # We can change this by checking the distance between
            # current pose and goal if it is more than a diameter
            # then -2 should be given else normal move
            if self.position == self.LF_MIN:
                # invalid
                self.reward = self.REWARD_AWAY
            else:
                # else normal left move
                self.position -= 1    
                
            # if we reached the goal
            if self.position == self.goal:
                # on goal now
                self.reward = self.REWARD_GOAL
                self.done = True
            elif distance(self.position, self.goal) > self.MAX_STEPS:
                # moving away from goal
                self.reward = self.REWARD_AWAY
            else:
                # moving toward goal
                self.reward = self.REWARD_STEP
        # ---------------------------------------------
            
        try:
            assert self.observation_space.contains(self.state)
        except AssertionError:
            print("INVALID STATE", self.state)
            
        # define the other state and info objects
        self.state = self.position
        self.info["dist"] =  distance(self.position, self.goal)
        
        return obs, rew, done, info
    
    def render(self, mode="human"):
        # simple print later should render the steps taken by the agents
        print(f'position: {self.state} reward: {self.reward:.2f} info: {self.info}')
        
    def seed(self, seed=None):
        self.np_random, seed = seeding.np_random(seed)
        return [seed]
    
    def close(self):
        pass
    
register(id="roadnet-v0",
         entry_point="env:RoadNet",
        )