In [1]:
import gymnasium as gym
import numpy as np
import soulsgym
import torch
from torchrl.envs.utils import check_env_specs
from torchrl.envs import GymWrapper, TransformedEnv
from torchrl.data import (
    Composite,
    OneHot,
    Bounded,
    TensorSpec
)
from tensordict import TensorDict
from tensordict.nn import TensorDictModule
import multiprocessing



In [2]:
is_fork = multiprocessing.get_start_method() == "fork"
device = torch.device("cuda" if torch.cuda.is_available() and not is_fork else "cpu")
device

device(type='cuda')

In [3]:
raw_env = gym.make("SoulsGymIudex-v0")

print("Raw Environment Specs:")
print("Observation Space:", raw_env.observation_space)

for entry in raw_env.observation_space:
    print(entry, raw_env.observation_space[entry], '\n')

print("Action Space:", raw_env.action_space)

Raw Environment Specs:
Observation Space: Dict('boss_animation': Discrete(33, start=-1), 'boss_animation_duration': Box(0.0, 10.0, (1,), float32), 'boss_hp': Box(0.0, 1037.0, (1,), float32), 'boss_max_hp': Discrete(1, start=1037), 'boss_pose': Box([110.     540.     -73.      -3.1416], [190.     640.     -55.       3.1416], (4,), float32), 'camera_pose': Box([110. 540. -73.  -1.  -1.  -1.], [190. 640. -55.   1.   1.   1.], (6,), float32), 'lock_on': Discrete(2), 'phase': Discrete(2, start=1), 'player_animation': Discrete(51, start=-1), 'player_animation_duration': Box(0.0, 10.0, (1,), float32), 'player_hp': Box(0.0, 454.0, (1,), float32), 'player_max_hp': Discrete(1, start=454), 'player_max_sp': Discrete(1, start=95), 'player_pose': Box([110.     540.     -73.      -3.1416], [190.     640.     -55.       3.1416], (4,), float32), 'player_sp': Box(0.0, 95.0, (1,), float32))
boss_animation Discrete(33, start=-1) 

boss_animation_duration Box(0.0, 10.0, (1,), float32) 

boss_hp Box(0.0, 10

In [4]:
from gymnasium import ObservationWrapper
from gymnasium.spaces import Box, Discrete

class FlattenObservationWrapper(ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)

        self.ob_keys = list(self.env.observation_space.spaces.keys())
        self.space_info = []

        # Build flattened observation space
        total_dim = 0
        for key in self.ob_keys:
            space = self.env.observation_space[key]
            if isinstance(space, Box):
                dim = int(np.prod(space.shape))
            elif isinstance(space, Discrete):
                # Skip Discrete(1) constants (like boss_max_hp)
                if space.n == 1:
                    continue
                dim = 1  # could use one-hot if needed
            else:
                raise NotImplementedError(f"Unsupported space type: {type(space)}")
            self.space_info.append((key, space, dim))
            total_dim += dim

        # Define new observation space for Gym compatibility (still needed)
        self.observation_space = Box(
            low=-float("inf"),
            high=float("inf"),
            shape=(total_dim,),
            dtype=np.float32
        )

        # Debug print
        """ print("Flattened Observation Space Breakdown:")
        for key, space, dim in self.space_info:
            print(f" - {key}: {space} -> flattened dim {dim}") """

    def observation(self, observation_dict):
        obs = []
        for key, space, dim in self.space_info:
            value = observation_dict[key]
            if isinstance(space, Box):
                obs.append(torch.as_tensor(value, dtype=torch.float32).flatten())
            elif isinstance(space, Discrete):
                obs.append(torch.tensor([float(value)], dtype=torch.float32))
        return torch.cat(obs, dim=0)  # single 1D torch tensor

    def reset(self, **kwargs):
        obs = self.env.reset(**kwargs)
        return self.observation(obs)

    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        return self.observation(obs), reward, done, info

flattened_env = FlattenObservationWrapper(raw_env)

print("Flattened Observation Shape:", flattened_env.observation_space.shape)

print("\nFlattened Observation Keys and Dimensions:")
for key, space, dim in flattened_env.space_info:
    print(f" - {key}: shape={space.shape if isinstance(space, Box) else space.n}, flat_dim={dim}")


Flattened Observation Shape: (23,)

Flattened Observation Keys and Dimensions:
 - boss_animation: shape=33, flat_dim=1
 - boss_animation_duration: shape=(1,), flat_dim=1
 - boss_hp: shape=(1,), flat_dim=1
 - boss_pose: shape=(4,), flat_dim=4
 - camera_pose: shape=(6,), flat_dim=6
 - lock_on: shape=2, flat_dim=1
 - phase: shape=2, flat_dim=1
 - player_animation: shape=51, flat_dim=1
 - player_animation_duration: shape=(1,), flat_dim=1
 - player_hp: shape=(1,), flat_dim=1
 - player_pose: shape=(4,), flat_dim=4
 - player_sp: shape=(1,), flat_dim=1


In [5]:
# Wrap in GymWrapper
wrapped_env = GymWrapper(flattened_env)

# Patch to avoid StopIteration: Use GymWrapper's reset() with return_info=True manually
def safe_reset(env):
    result = env.reset()
    if isinstance(result, tuple) and len(result) == 2:
        return result
    else:
        return result, {}


In [6]:
env = TransformedEnv(wrapped_env).to(device)

In [7]:
print("Observation spec:", env.observation_spec)
print("Action spec:", env.action_spec)
check_env_specs(env)

Observation spec: Composite(
    observation: UnboundedContinuous(
        shape=torch.Size([23]),
        space=ContinuousBox(
            low=Tensor(shape=torch.Size([23]), device=cuda:0, dtype=torch.float32, contiguous=True),
            high=Tensor(shape=torch.Size([23]), device=cuda:0, dtype=torch.float32, contiguous=True)),
        device=cuda:0,
        dtype=torch.float32,
        domain=continuous),
    device=cuda:0,
    shape=torch.Size([]))
Action spec: OneHot(
    shape=torch.Size([20]),
    space=CategoricalBox(n=20),
    device=cuda:0,
    dtype=torch.int64,
    domain=discrete)


error: (0, 'SetForegroundWindow', 'No error message is available')