In [None]:
import warnings
warnings.filterwarnings("ignore")

import torch
from torch import multiprocessing
from torch import nn
from tensordict.nn import TensorDictModule

from collections import defaultdict

In [None]:
import warnings
warnings.filterwarnings("ignore")

import torch
from torch import multiprocessing

In [None]:
is_fork = multiprocessing.get_start_method() == "fork"

device = (
    torch.device(0)
    if torch.cuda.is_available() and not is_fork
    else torch.device("cpu")
)

## Environment Preparation

#### Load unity environment using `mlagents_envs`

In [None]:
from mlagents_envs.side_channel.engine_configuration_channel import EngineConfigurationChannel
from mlagents_envs.environment import UnityEnvironment

channel = EngineConfigurationChannel()
env_path = "C:/Users/Pawel/Documents/Unity_Project/warehouse-bot-training/environment_builds/test_env_simplified/Warehouse_Bot.exe"

from torchrl.envs import UnityMLAgentsEnv

unity_env = UnityEnvironment(
  file_name=env_path,
  side_channels=[channel],
  # additional_args=["-batchmode", "-nographics"]
)
channel.set_configuration_parameters(time_scale=3)

#### Transform environment from `mlagents` to `gymnasium`

In [None]:
import gymnasium as gym

In [None]:
print(gym.__version__)

In [None]:
import numpy as np
from gymnasium import spaces
from mlagents_envs.environment import UnityEnvironment
from mlagents_envs.base_env import ActionTuple

class UnityGymWrapper(gym.Env):
    def __init__(self, unity_env, seed=None):
        super().__init__()
        self.unity_env = unity_env
        self.unity_env.reset()
        self.behavior_name = list(self.unity_env.behavior_specs.keys())[0]
        self.spec = self.unity_env.behavior_specs[self.behavior_name]   
        
        # Define observation space (assuming visual input)
        obs_shape = self.spec.observation_specs[0].shape
        self.observation_space = spaces.Box(low=0, high=255, shape=obs_shape, dtype=np.uint8) # ???
        
        # Define action space
        # if self.spec.action_spec.is_continuous():
            # self.action_space = spaces.Box(
            #     low=self.spec.action_spec.continuous_action_spec[0],
            #     high=self.spec.action_spec.continuous_action_spec[1],
            #     shape=(self.spec.action_spec.continuous_size,),
            #     dtype=np.float32
            # )
        if self.spec.action_spec.is_discrete():
            self.action_space = spaces.Discrete(self.spec.action_spec.discrete_branches[0])

        
    def reset(self, seed=None, options=None):
        self.unity_env.reset()
        decision_steps, _ = self.unity_env.get_steps(self.behavior_name)
        obs = decision_steps.obs[0]  # Assuming single-agent scenario
        return obs, {}

    def step(self, action):
        action_tuple = ActionTuple()
        # if self.spec.action_spec.is_continuous():
        #     action_tuple.add_continuous(np.array(action).reshape(1, -1))
        # else:
        #     action_tuple.add_discrete(np.array(action).reshape(1, -1))
        
        if self.spec.action_spec.is_discrete():
            action_tuple.add_discrete(np.array(action).reshape(1, -1))
        
        # print(action_tuple, np.array(action).reshape(1, -1))
        self.unity_env.set_action_for_agent(self.behavior_name, 0, action_tuple)
        self.unity_env.step()
        
        decision_steps, terminal_steps = self.unity_env.get_steps(self.behavior_name)

        if 0 in terminal_steps:
            obs = terminal_steps.obs[0]
            reward = terminal_steps.reward[0]
            
            # terminated - Natural episode ending.
            terminated = not terminal_steps.interrupted[0]
            
            # truncated - "Whether the truncation condition outside the scope of the MDP is satisfied. Typically, this is a timelimit"
            # interrupted - "The episode ended due to max steps or external termination, not because the episode ended naturally (failed/succeeded)."
            truncated = terminal_steps.interrupted[0]
            
            # terminated and truncated are mutually exclusive
        else:
            obs = decision_steps.obs[0]
            reward = decision_steps.reward[0]
            terminated = False
            truncated = False
        
        return obs, reward, terminated, truncated, {}

    def render(self, mode='human'):
        pass  # Unity renders its own environment
    
    def close(self):
        self.unity_env.close()

In [None]:
gymnasium_env = UnityGymWrapper(unity_env)

In [None]:
# gymnasium_env.step(0)

#### Creating stable_baselines3 model

In [None]:
from stable_baselines3 import PPO

model = PPO("MlpPolicy", gymnasium_env, verbose=1)
model.learn(total_timesteps=100_000)