In [2]:
import osmnx as ox
centre_point = (51.5074, 0.1278)
G = ox.graph_from_point(centre_point, dist=150, network_type='drive')

In [29]:
from typing import Tuple
import gym
from gym import spaces
import numpy as np
import networkx as nx
class GraphEnvironment(gym.Env):
    def __init__(self, G):
        super(GraphEnvironment, self).__init__()

        if not isinstance(G, nx.Graph):
            raise ValueError("G must be a networkx Graph")

        self.G = G

        # Create a mapping from node IDs to indices
        self.node_to_index = {node: idx for idx, node in enumerate(G.nodes())}
        self.index_to_node = {idx: node for node, idx in self.node_to_index.items()}

        self.num_nodes = G.number_of_nodes()
        self.num_edges = G.number_of_edges()

        self.node_encoding = np.eye(self.num_nodes)
        self.edge_encoding = np.eye(self.num_edges)

        self.action_space = spaces.Discrete(self.num_nodes)
        self.observation_space = spaces.Box(low=0, high=1, shape=(self.num_nodes,), dtype=np.float32)

        self.current_state = None

    def reset(self, seed=None, **kwargs):
        self.seed(seed)

        # Randomly select a starting node index
        start_node_idx = np.random.choice(self.num_nodes)
        start_node = self.index_to_node[start_node_idx]

        self.current_state = self.node_encoding[start_node_idx]
        observation = np.array(self.current_state, dtype=np.float32)

        self.update_action_space(start_node)
        return observation, {}

    def update_action_space(self, current_node):
        # Get the nodes connected to the current node
        connected_nodes = list(self.G[current_node])
        connected_node_indices = [self.node_to_index[node] for node in connected_nodes]

        self.action_mapping = {idx: node for idx, node in enumerate(connected_nodes)}
        self.action_space = spaces.Discrete(len(connected_nodes))

    def step(self, action):
        # Map the action to the actual node
        chosen_node = self.action_mapping[action]
        chosen_node_idx = self.node_to_index[chosen_node]

        # Update current state and action space
        self.current_state = self.node_encoding[chosen_node_idx]
        self.update_action_space(chosen_node)

        # Define your reward, done, and info
        reward = ...  # Define your reward logic
        done = ...    # Define your done logic
        info = {}

        return self.current_state, reward, done, info




    def seed(self, seed=None):
        # You can use the seed method to set the seed for the environment's random number generator
        # This is optional and depends on whether your environment uses randomization
        self.np_random, seed = gym.utils.seeding.np_random(seed)
        return [seed]

    def render(self, mode='human', close=False):
        # Implement the render function (optional)
        # ...
        pass

    def close(self):
        # Implement the close method (optional)
        # ...
        pass
    # Implement other methods (step, render, etc.) as needed


In [30]:
env = GraphEnvironment(G=G)


In [32]:
env.reset()

(array([0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       dtype=float32),
 {})

In [24]:
len(G[35486909])

2

In [33]:
from gym.utils.env_checker import check_env
check_env(env, warn=True)

  logger.warn("`check_env(warn=...)` parameter is now ignored.")
  logger.warn(
  logger.deprecation(
  if not isinstance(done, (bool, np.bool8)):
  logger.warn(
  logger.warn(
  logger.warn(f"{pre} is not within the observation space.")
  logger.warn(


__main__.GraphEnvironment

True