In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import gym
from IPython import display

In [2]:
from ppo import PolicyNetwork, PPOAgent

In [3]:

# Environment setup
env = gym.make('CartPole-v1')  # Replace with your 2D maze environment


In [4]:

# PPO parameters
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
lr = 0.001
gamma = 0.99
epsilon_clip = 0.1

ppo_agent = PPOAgent(state_dim, action_dim, lr, gamma, epsilon_clip)


In [7]:

# Training loop
num_episodes = 1000

for episode in range(num_episodes):
    state = env.reset() # a tuple; the first element is the state
    state = state[0]
    done = False
    total_reward = 0
    states, actions, rewards, old_probs, values, dones = [], [], [], [], [], []

    while not done:
        # turn the state into a tensor with batch 1
        state = np.array(state)
        state = torch.FloatTensor(state).unsqueeze(0)
        # Collect data
        action_probs = ppo_agent.policy(state)
        action = torch.multinomial(action_probs, 1).item()
        value = ppo_agent.policy(state).detach().numpy()[0]
        next_state, reward, done, truncated, info = env.step(action)

        states.append(state)
        actions.append(action)
        rewards.append(reward)
        dones.append(done)

        old_probs.append(action_probs[0, action].item())
        values.append(value)

        state = next_state
        total_reward += reward

    print(values)

    # Compute returns and advantages
    returns = []
    advantages = ppo_agent.compute_advantage(rewards, values, dones)

    for t in range(len(rewards)):
        Gt = np.sum([r * (ppo_agent.gamma ** i) for i, r in enumerate(rewards[t:])])
        returns.append(Gt)

    # Normalize advantages
    advantages = (advantages - np.mean(advantages)) / (np.std(advantages) + 1e-8)

    # Update policy
    ppo_agent.update_policy(states, actions, old_probs, advantages, returns)

    # Print episode info
    print(f"Episode: {episode + 1}, Total Reward: {total_reward}")


[array([0.50861603, 0.49138394], dtype=float32), array([0.51267064, 0.48732942], dtype=float32), array([0.5086045, 0.4913955], dtype=float32), array([0.51282233, 0.48717767], dtype=float32), array([0.5085772 , 0.49142277], dtype=float32), array([0.51361966, 0.48638028], dtype=float32), array([0.5081432 , 0.49185678], dtype=float32), array([0.5126833 , 0.48731676], dtype=float32), array([0.5140834, 0.4859166], dtype=float32), array([0.512106  , 0.48789394], dtype=float32), array([0.51408076, 0.48591924], dtype=float32), array([0.5144085, 0.4855915], dtype=float32), array([0.5089071 , 0.49109292], dtype=float32), array([0.51109123, 0.4889087 ], dtype=float32), array([0.50822294, 0.49177712], dtype=float32), array([0.5090783, 0.4909217], dtype=float32), array([0.5147395, 0.4852605], dtype=float32), array([0.5165912 , 0.48340878], dtype=float32), array([0.5137822, 0.4862178], dtype=float32), array([0.5049338, 0.4950663], dtype=float32), array([0.50820714, 0.49179295], dtype=float32), array

ValueError: setting an array element with a sequence.