In [29]:
import numpy as np
import pygame
import gymnasium as gym
from gymnasium import spaces
import math
from bond_graph import *
from bond_graph_nodes import*
from itertools import permutations
import random
import copy
from gymnasium.envs.registration import register

In [78]:
seed = None
MAX_PARAM_VAL = 10
num_node_types = 6
max_nodes = 8

add_node_space = spaces.Discrete(num_node_types-3, start=3, seed=seed) # node additions correspond to choosing what type you want, don't include the NONE type for adding
add_edge_space = spaces.MultiDiscrete([max_nodes, max_nodes, 2], seed=seed) # edge additions sample space

action_space = spaces.Dict(
    {
        'node_or_bond': spaces.Discrete(2, start=0, seed=seed),
        'node_param': spaces.Discrete(MAX_PARAM_VAL, start=1, seed=seed),
        "node_type": add_node_space,
        "bond": add_edge_space,
    }
)

print("Action Space: ", action_space)

flattened_action_space = spaces.utils.flatten_space(action_space)
print("Flattened Action Space: ", flattened_action_space)
print(flattened_action_space.shape)

obs = action_space.sample()
flat_obs = spaces.utils.flatten(action_space, obs)
print("Obs: ", obs)
print("Flat Obs: ", flat_obs)
print(flat_obs.shape)

unflattened_obs = spaces.utils.unflatten(action_space, flat_obs)
print(unflattened_obs)


Action Space:  Dict('bond': MultiDiscrete([8 8 2]), 'node_or_bond': Discrete(2), 'node_param': Discrete(10, start=1), 'node_type': Discrete(3, start=3))
Flattened Action Space:  Box(0, 1, (33,), int64)
(33,)
Obs:  OrderedDict([('bond', array([7, 2, 0], dtype=int64)), ('node_or_bond', 1), ('node_param', 1), ('node_type', 5)])
Flat Obs:  [0 0 0 0 0 0 0 1 0 0 1 0 0 0 0 0 1 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 1]
(33,)
OrderedDict([('bond', array([7, 2, 0], dtype=int64)), ('node_or_bond', 1), ('node_param', 1), ('node_type', 5)])


In [28]:
# from gymnasium.spaces import Dict, Discrete, Box
space = spaces.Dict(
    {
        "position": spaces.Discrete(2), 
        "velocity": spaces.Box(0, 1, shape=(2, 2))
    }
) 
obs = space.sample()
print(obs)

flat_space = spaces.utils.flatten_space(space) 
print(flat_space.shape)
flat_obs = spaces.utils.flatten(space, obs)
print(flat_obs)


OrderedDict([('position', 0), ('velocity', array([[0.7972299 , 0.27445236],
       [0.9572579 , 0.88732606]], dtype=float32))])
(6,)
[1.         0.         0.79722989 0.27445236 0.95725793 0.88732606]
