In [1]:

import warnings
warnings.filterwarnings("ignore")

import torch as th
from torch import multiprocessing
import numpy as np

# Environment imports
from src.environments.env_utils import make_env

In [None]:
print("Starting PPO Training for Warehouse Stage1 Complex Pos Neg 3...")

# Setup device
is_fork = multiprocessing.get_start_method() == "fork"
device = (
    th.device(0)
    if th.cuda.is_available() and not is_fork
    else th.device("cpu")
)
print(f"Using device: {device}")

# Create environment
print("\nCreating environment...")
env = make_env(time_scale=1, no_graphics=False, verbose=True, env_type="multimodal", env_path='environment_builds/stage2/S2_Find_Items_64x36camera120deg_rew0_100/Warehouse_Bot.exe')

print(env.observation_space)

In [None]:
from src.models.actor_critic import ActorCriticMultimodal

# Get observation and action space info
image_shape = env.observation_space['visual'].shape
vector_shape = env.observation_space['vector'].shape[0]
n_actions = env.action_space.n

print(image_shape)
print(vector_shape)
print(n_actions)

# Instantiate the ActorCriticMultimodal model
model = ActorCriticMultimodal(act_dim=n_actions, visual_obs_size=image_shape, vector_obs_size=vector_shape, device=device)

for episode in range(10):
    obs, _ = env.reset(seed=episode)
    # print('obs', obs)
    done = False
    truncated = False

    while not (done or truncated):

        # Get action from model
        with th.no_grad():
            action, _, _, _ = model.get_action(obs)
        # Take action in environment
        obs, reward, done, truncated, _ = env.step(action.item())