In [1]:
import matplotlib.pyplot as plt
import torch as th
from stable_baselines3.common.vec_env import DummyVecEnv

import numpy as np

from imitation.algorithms.mce_irl import (
    MCEIRL,
    mce_occupancy_measures,
    mce_partition_fh,
    TabularPolicy,
)
from imitation.data import rollout
from imitation.rewards import reward_nets


In [2]:
from typing import Tuple
import gym
from gymnasium import spaces
import numpy as np
import networkx as nx


class GraphEnvironment(gym.Env):
    def __init__(self, G):
        super(GraphEnvironment, self).__init__()

        if not isinstance(G, nx.Graph):
            raise ValueError("G must be a networkx Graph")

        self.G = G

        # Create a mapping from node IDs to indices
        self.node_to_index = {node: idx for idx, node in enumerate(G.nodes())}
        self.index_to_node = {idx: node for node, idx in self.node_to_index.items()}

        self.num_nodes = G.number_of_nodes()
        self.num_edges = G.number_of_edges()

        self.node_encoding = np.eye(self.num_nodes)
        self.edge_encoding = np.eye(self.num_edges)

        self.action_space = spaces.Discrete(self.num_nodes)
        self.observation_space = spaces.Box(low=0, high=1, shape=(self.num_nodes,), dtype=np.float32)

        self.current_state = None

    def reset(self, seed=None, **kwargs):
        self.seed(seed)

        # Randomly select a starting node index
        start_node_idx = np.random.choice(self.num_nodes)
        start_node = self.index_to_node[start_node_idx]

        self.current_state = self.node_encoding[start_node_idx]
        observation = np.array(self.current_state, dtype=np.float32)

        self.update_action_space(start_node)
        return observation, {}

    def update_action_space(self, current_node):
        # Get the nodes connected to the current node
        connected_nodes = list(self.G[current_node])
        connected_node_indices = [self.node_to_index[node] for node in connected_nodes]

        self.action_mapping = {idx: node for idx, node in enumerate(connected_nodes)}
        self.action_space = spaces.Discrete(len(connected_nodes))

    def step(self, action):
        # We don't care about the step mechanics because we don't want an accurate simulation of navigation app, just imitation learning to learn the reward function

        # Map the action to the actual node
        chosen_node = self.action_mapping[action]
        chosen_node_idx = self.node_to_index[chosen_node]

        # Update current state and action space
        self.current_state = self.node_encoding[chosen_node_idx]
        self.update_action_space(chosen_node)

        # Define your reward, done, and info
        reward = ...  # Define your reward logic
        done = ...    # Define your done logic
        info = {}

        return self.current_state, reward, done, info

    def seed(self, seed=None):
        # You can use the seed method to set the seed for the environment's random number generator
        # This is optional and depends on whether your environment uses randomization
        self.np_random, seed = gym.utils.seeding.np_random(seed)
        return [seed]

    def render(self, mode='human', close=False):
        # Implement the render function (optional)
        # ...
        pass

    def close(self):
        # Implement the close method (optional)
        # ...
        pass
    # Implement other methods (step, render, etc.) as needed


In [3]:
import osmnx as ox
centre_point = (51.5074, 0.1278)
G = ox.graph_from_point(centre_point, dist=150, network_type='drive')

In [4]:
env = GraphEnvironment(G=G)

In [5]:

# def train_mce_irl(demos, hidden_sizes, lr=0.01, **kwargs):
#     reward_net = reward_nets.BasicRewardNet(
#         env.observation_space,
#         env.action_space,
#         hid_sizes=hidden_sizes,
#         use_action=False,
#         use_done=False,
#         use_next_state=False,
#     )

#     mce_irl = MCEIRL(
#         demos,
#         env,
#         reward_net,
#         log_interval=250,
#         optimizer_kwargs=dict(lr=lr),
#         rng=rng,
#     )
#     occ_measure = mce_irl.train(**kwargs)

#     imitation_trajs = rollout.generate_trajectories(
#         policy=mce_irl.policy,
#         venv=state_venv,
#         sample_until=rollout.make_min_timesteps(5000),
#         rng=rng,
#     )
#     print("Imitation stats: ", rollout.rollout_stats(imitation_trajs))

#     plt.figure(figsize=(10, 5))
#     plt.subplot(1, 2, 1)
#     env.draw_value_vec(occ_measure)
#     plt.title("Occupancy for learned reward")
#     plt.xlabel("Gridworld x-coordinate")
#     plt.ylabel("Gridworld y-coordinate")
#     plt.subplot(1, 2, 2)
#     _, true_occ_measure = mce_occupancy_measures(env)
#     env.draw_value_vec(true_occ_measure)
#     plt.title("Occupancy for true reward")
#     plt.xlabel("Gridworld x-coordinate")
#     plt.ylabel("Gridworld y-coordinate")
#     plt.show()

#     plt.figure(figsize=(10, 5))
#     plt.subplot(1, 2, 1)
#     env.draw_value_vec(
#         reward_net(th.as_tensor(env.observation_matrix), None, None, None)
#         .detach()
#         .numpy()
#     )
#     plt.title("Learned reward")
#     plt.xlabel("Gridworld x-coordinate")
#     plt.ylabel("Gridworld y-coordinate")
#     plt.subplot(1, 2, 2)
#     env.draw_value_vec(env.reward_matrix)
#     plt.title("True reward")
#     plt.xlabel("Gridworld x-coordinate")
#     plt.ylabel("Gridworld y-coordinate")
#     plt.show()

#     return mce_irl

In [6]:
def train_mce_irl(env, demos, hidden_sizes, lr=0.01, **kwargs):
    reward_net = reward_nets.BasicRewardNet(
        observation_space=env.observation_space,
        action_space=env.action_space,
        hid_sizes=hidden_sizes,
        use_action=False,
        use_done=False,
        use_next_state=False,
    )

    # Assuming rng is defined; otherwise, you need to define it
    rng = np.random.default_rng()

    mce_irl = MCEIRL(
        demos,
        env,
        reward_net,
        log_interval=250,
        optimizer_kwargs=dict(lr=lr),
        rng=rng,
    )
    occ_measure = mce_irl.train(**kwargs)

    # ... (rest of your training function)


In [15]:
import numpy as np

demos = []
num_demos = 20
demo_length = 5  # Assuming each demo has 5 steps

for _ in range(num_demos):
    # Sample random nodes from G
    random_nodes = np.random.choice(list(G.nodes()), size=demo_length, replace=False)
    
    for i in range(demo_length - 1):
        # Current observation (state)
        obs = env.node_encoding[env.node_to_index[random_nodes[i]]] 
        if not isinstance(obs, np.ndarray):
            obs = np.array(obs)

        # Next observation (state)
        next_obs = env.node_encoding[env.node_to_index[random_nodes[i + 1]]]
        if not isinstance(next_obs, np.ndarray):
            next_obs = np.array(next_obs)

        # Assuming the action is the index of the next node
        # This depends on how your environment defines actions
        action = env.node_to_index[random_nodes[i + 1]]

        # Convert action to a numpy array if it's not already
        if not isinstance(action, np.ndarray):
            action = np.array([action])  # Make it an array

        # Check if it's the last step in the demo
        done = i == (demo_length - 2)

        demo_step = {
            "obs": obs,
            "actions": action,
            "next_obs": next_obs,
            # "dones": np.array([done])  # Make sure done is an array
        }
        demos.append(demo_step)


In [16]:
demos[-1]

{'obs': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]),
 'actions': array([10]),
 'next_obs': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])}

In [17]:
len(G.nodes())

32

In [18]:
len(demos[0]['obs'])

32

In [19]:
mce_irl_model = train_mce_irl(env, demos, hidden_sizes=[64, 64])

AttributeError: 'GraphEnvironment' object has no attribute 'state_dim'

In [None]:
env.observation_space

Box(0.0, 1.0, (32,), float32)

In [None]:
spaces.utils.flatdim(env.observation_space)


32

array(Box(0.0, 1.0, (32,), float32), dtype=object)