In [1]:
import warnings
warnings.filterwarnings("ignore")

import torch as th
from torch import multiprocessing

In [2]:
is_fork = multiprocessing.get_start_method() == "fork"

device = (
    th.device(0)
    if th.cuda.is_available() and not is_fork
    else th.device("cpu")
)

## Environment Preparation

#### Transform environment from `mlagents` to `gymnasium`

In [None]:
import gymnasium as gym
print(gym.__version__)

In [4]:
from mlagents_envs.side_channel.engine_configuration_channel import EngineConfigurationChannel
from mlagents_envs.environment import UnityEnvironment

from env_camera_raycasts_gymnasium_wrapper import UnityCameraRaycastsGymWrapper

env_path = "D:/_Thesis/warehouse-bot-training/environment_builds/warehouse_stage2_find/Warehouse_Bot.exe"
def make_env():

  channel = EngineConfigurationChannel()

  unity_env = UnityEnvironment(
    file_name=env_path,
    side_channels=[channel],
  )
  
  channel.set_configuration_parameters(time_scale=1)
  
  gymnasium_env = UnityCameraRaycastsGymWrapper(unity_env)
  
  print(gymnasium_env.observation_space)
  
  return gymnasium_env

#### Policy Config

In [5]:
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.distributions import Categorical

# Actor-Critic Network
class ActorCritic(nn.Module):
    def __init__(self, obs_dim, act_dim):
        super().__init__()
        self.shared = nn.Sequential(
            nn.Linear(obs_dim, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh()
        )
        self.policy_head = nn.Linear(64, act_dim)
        self.value_head = nn.Linear(64, 1)

    def forward(self, x):
        x = self.shared(x)
        return self.policy_head(x), self.value_head(x)

    def get_action(self, obs):
        logits, value = self.forward(obs)
        dist = Categorical(logits=logits)
        action = dist.sample()
        return action, dist.log_prob(action), dist.entropy(), value.squeeze()

    def evaluate_actions(self, obs, actions):
        logits, values = self.forward(obs)
        dist = Categorical(logits=logits)
        log_probs = dist.log_prob(actions)
        entropy = dist.entropy()
        return log_probs, entropy, values.squeeze()

class Swish(nn.Module):
    def forward(self, x):
        return x * th.sigmoid(x)

class ActorCriticMultimodal(nn.Module):
    def __init__(self, act_dim, visual_size=[3, 36, 64], vector_obs_size=128):
        super().__init__()
        bands = visual_size[0]

        # Shapes of image and vector inputs: [<batch size>, <bands, height, width>], [<batch size>, <length>]

        visual_out_size = 64
        vector_out_size = 64

        # Visual Encoder
        self.visual_encoder_cnn = nn.Sequential(
            nn.Conv2d(bands, 16, kernel_size=5, stride=4, padding=0),
            nn.LeakyReLU(0.01),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.01),
            nn.Flatten(),
        )
        # Compute flattened visual output size from dummy input
        dummy_input = th.zeros(1, bands, visual_size[1], visual_size[2])
        with th.no_grad():
            visual_encoder_cnn_out_size = self.visual_encoder_cnn(dummy_input).shape[1]

        self.visual_encoder_mlp = nn.Sequential(
            nn.Linear(visual_encoder_cnn_out_size, 64),
            Swish(),
            nn.Linear(64, visual_out_size),
            Swish()
        )
        
        
        # Vector Encoder
        self.vector_encoder = nn.Sequential(
            nn.Linear(vector_obs_size, 64),
            Swish(),
            nn.Linear(64, vector_out_size),                             
            Swish()
        )

        # Concatenation Network
        self.shared = nn.Sequential(
            nn.Linear(visual_out_size + vector_out_size, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh()
        )
        self.policy_head = nn.Linear(64, act_dim)
        self.value_head = nn.Linear(64, 1)

    
    def forward(self, observations):
        image = observations["image"].float()
        vector = observations["vector"]

        image_features = self.visual_encoder_cnn(image)
        image_features = self.visual_encoder_mlp(image_features)
        vector_features = self.vector_encoder(vector)

        combined = th.cat([image_features, vector_features], dim=1)
        x = self.shared(combined)
        return self.policy_head(x), self.value_head(x)

    def get_action(self, obs):
        logits, value = self.forward(obs)
        dist = Categorical(logits=logits)
        action = dist.sample()
        return action, dist.log_prob(action), dist.entropy(), value.squeeze()

    def evaluate_actions(self, obs, actions):
        logits, values = self.forward(obs)
        dist = Categorical(logits=logits)
        log_probs = dist.log_prob(actions)
        entropy = dist.entropy()
        return log_probs, entropy, values.squeeze()
    

#### Training Algorithm

In [6]:
# Hyperparameters
env_id = "CartPole-v1"
gamma = 0.99
lam = 0.95
clip_eps = 0.2
ppo_epochs = 40
batch_size = 128
update_timesteps = 1024
lr = 3e-4

# Loss
val_loss_coef = 0.5
ent_loss_coef = 0.002

device = th.device("cuda" if th.cuda.is_available() else "cpu")

# Rollout Buffer
class RolloutBuffer:
    def __init__(self):
        self.buffer = {
            'obs': [],
            'acts': [],
            'logps': [],
            'rews': [],
            'vals': [],
            'dones': []
        }

    def add(self, obs, act, logp, rew, val, done):
        self.buffer['obs'].append(obs)
        self.buffer['acts'].append(act)
        self.buffer['logps'].append(logp)
        self.buffer['rews'].append(rew)
        self.buffer['vals'].append(val)
        self.buffer['dones'].append(done)

    def compute_gae(self, gamma=0.99, lam=0.95, last_val=0.0):
        vals = self.buffer['vals'] + [last_val]
        vals = th.tensor(vals, dtype=th.float32, device=device)

        advantages = []
        gae = 0.0
        for t in reversed(range(len(self.buffer['rews']))):
            delta = self.buffer['rews'][t] + gamma * vals[t + 1] * (1 - self.buffer['dones'][t]) - vals[t]
            gae = delta + gamma * lam * (1 - self.buffer['dones'][t]) * gae
            advantages.insert(0, gae)

        advantages = th.tensor(advantages, dtype=th.float32, device=device)
        returns = advantages + vals[:-1]

        # Handle dictionary observations
        if isinstance(self.buffer['obs'][0], dict):
            # Stack dictionary observations
            obs = {}
            for key in self.buffer['obs'][0].keys():
                obs[key] = th.stack([th.tensor(obs_item[key], dtype=th.float32, device=device) 
                                   for obs_item in self.buffer['obs']])
        else:
            # Handle simple tensor observations
            obs = th.tensor(self.buffer['obs'], dtype=th.float32, device=device)
            
        acts = th.tensor(self.buffer['acts'], dtype=th.int64, device=device)
        logps = th.tensor(self.buffer['logps'], dtype=th.float32, device=device)

        # Clear buffer
        for key in self.buffer:
            self.buffer[key].clear()

        return obs, acts, logps, returns, advantages

# PPO Agent
class PPOAgent:
    def __init__(self, model_net):

        self.model = model_net.to(device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)

    def calculate_loss(self, mb_obs, mb_acts, mb_old_logps, mb_returns, mb_advantages):
        logps, entropy, values = self.model.evaluate_actions(mb_obs, mb_acts)
        ratios = th.exp(logps - mb_old_logps)

        surr1 = ratios * mb_advantages
        surr2 = th.clamp(ratios, 1 - clip_eps, 1 + clip_eps) * mb_advantages
        policy_loss = -th.min(surr1, surr2).mean()

        value_loss = ((values - mb_returns)**2).mean()
        entropy_bonus = entropy.mean()

        loss = policy_loss + val_loss_coef * value_loss - ent_loss_coef * entropy_bonus
        return loss
    
    def update(self, obs, acts, old_logps, returns, advantages):
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        for _ in range(ppo_epochs):
            # Get batch size based on observation type
            batch_len = len(obs) if not isinstance(obs, dict) else len(list(obs.values())[0])
            idxs = np.random.permutation(batch_len)
            
            for start in range(0, batch_len, batch_size):
                end = start + batch_size
                mb_idx = idxs[start:end]

                # Handle dictionary or tensor observations
                if isinstance(obs, dict):
                    mb_obs = {key: obs[key][mb_idx] for key in obs.keys()}
                else:
                    mb_obs = obs[mb_idx]
                    
                mb_acts = acts[mb_idx]
                mb_old_logps = old_logps[mb_idx]
                mb_returns = returns[mb_idx]
                mb_advantages = advantages[mb_idx]

                loss = self.calculate_loss(mb_obs, mb_acts, mb_old_logps, mb_returns, mb_advantages)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                
    def train(self, env, iterations):
        for iteration in range(iterations):
            obs, _ = env.reset()
            buffer = RolloutBuffer()
            ep_return = 0
            ep_returns = []
            ep_steps = []

            t = 0
            ep_t = 0
            while True:
                # Handle dictionary or tensor observations
                if isinstance(obs, dict):
                    # Convert each observation to tensor and add batch dimension
                    obs_tensor = {key: th.tensor(obs[key], dtype=th.float32, device=device) for key in obs.keys()}
                else:
                    obs_tensor = th.tensor(obs, dtype=th.float32, device=device)
                    
                action, logp, _, value = self.model.get_action(obs_tensor)
                
                next_obs, reward, terminated, truncated, _ = env.step(action.item())
                done = terminated or truncated

                # Remove batch dimension from observation tensors for buffer storage
                if isinstance(obs, dict):
                    # Convert each observation to tensor and add batch dimension
                    obs_tensor = {key: th.tensor(obs_tensor[key].squeeze(0), dtype=th.float32, device=device) for key in obs.keys()}
                else:
                    obs_tensor = th.tensor(obs_tensor.squeeze(0), dtype=th.float32, device=device) # Unsure of squeeze

                # Store observations in buffer
                buffer.add(obs_tensor, action.item(), logp.item(), reward, value.item(), done)
                ep_return += reward
                obs = next_obs

                if done:
                    ep_returns.append(ep_return)
                    ep_return = 0
                    ep_steps.append(ep_t)
                    ep_t = 0
                    obs, _ = env.reset()
                    if t >= update_timesteps:
                        break
                t += 1
                ep_t += 1

            # Training step
            obs, acts, logps, returns, advantages = buffer.compute_gae(gamma, lam)
            self.update(obs, acts, logps, returns, advantages)

            # Stats per real episode
            ep_steps_np = np.array(ep_steps)
            mean_steps = ep_steps_np.mean() if len(ep_steps_np) > 0 else 0.0
            std_steps = ep_steps_np.std(ddof=0) if len(ep_steps_np) > 0 else 0.0
            
            ep_returns_np = np.array(ep_returns)
            mean_return = ep_returns_np.mean() if len(ep_returns_np) > 0 else 0.0
            std_return = ep_returns_np.std(ddof=0) if len(ep_returns_np) > 0 else 0.0

            print(f"Iteration {iteration} | Episodes: {len(ep_returns)} | "
                  f"Mean Return: {mean_return:.2f} | Std Return: {std_return:.2f} "
                  f"Mean steps: {mean_steps:.2f} | Std steps: {std_steps:.2f}")

In [None]:
# Training Loop
# env = gym.make(env_id)
env = make_env()

In [None]:
act_dim = env.action_space.n

model_net = ActorCriticMultimodal(act_dim, visual_size=[3, 36, 64], vector_obs_size=80)
agent = PPOAgent(model_net)

agent.train(env, 30)