In [1]:
import torch
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
from collections import deque
from tqdm import tqdm

from walker import PPO, Normalize

RED = "\033[31m"
GREEN = "\033[32m"
YELLOW = "\033[33m"
BLUE = "\033[34m"
MAGENTA = "\033[35m"
CYAN = "\033[36m"
RESET = "\033[0m"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
# Initialize environment
env = gym.make('Walker2d-v4', render_mode='rgb_array')

log_dir = "../runs/20240712_02-38-11/"

# Number of state and action
N_S = env.observation_space.shape[0]
N_A = env.action_space.shape[0]

# Initialize PPO model
ppo = PPO(N_S, N_A, log_dir)
normalize = Normalize(N_S, log_dir)

# Load the saved model
ppo.actor_net.load_model()
ppo.actor_net.eval()
normalize.load_params()

Actor(
  (fc1): Linear(in_features=17, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=64, bias=True)
  (sigma): Linear(in_features=64, out_features=6, bias=True)
  (mu): Linear(in_features=64, out_features=6, bias=True)
)

In [3]:
# Test the model
state, _ = env.reset()
state = normalize(state)

test_total_reward = 0
test_episodes = 10  # Number of episodes to test
for episode_id in range(test_episodes):
    state, _ = env.reset()
    state = normalize(state)
    # state = np.zeros(17)
    print('initial state: ', state)
    score = 0
    for _ in range(1000):
        action = ppo.actor_net.choose_action(state)
        # print(f"{YELLOW}walker velocity: {RESET}", state[8])
        state, reward, done, _, _ = env.step(action)
        state = normalize(state)
        score += reward

        if done:
            env.reset()
            break
    
    print("episode: ", episode_id, "\tscore: ", score)
env.close()

initial state:  [-0.70710428  0.70710538 -0.70710392 -0.7070993   0.70708507 -0.70710467
  0.70710514 -0.70709734  0.70710566 -0.70710407 -0.70710551  0.7070915
 -0.70709833 -0.7071049  -0.70710268  0.70710484 -0.70710183]
episode:  0 	score:  -36.67559178923674
initial state:  [ 4.83876177e-01  9.49394189e-01 -1.87067416e-01 -1.13418201e-01
  7.32700303e-01 -2.08805995e+00 -2.53796650e-01 -1.27467210e-01
  1.57910588e+00  6.46092782e-01  9.55926646e-01 -1.74592316e-02
  1.34176342e-03  1.50983192e-01 -5.52326491e-02  9.83781270e-04
  1.32822805e-01]
episode:  1 	score:  -10.560587481711615
initial state:  [ 0.7032588   0.97823902 -1.00089282 -0.37876452 -0.06091392 -2.67230465
 -0.52421287 -0.31443171  1.49370478  0.68153022  0.95315263 -0.0066872
 -0.00659636  0.12432343 -0.03438324  0.00657912  0.06273744]
episode:  2 	score:  -12.611072934886023
initial state:  [ 7.77006200e-01  9.97541230e-01 -9.69830708e-01 -6.01462015e-01
 -2.24659944e-01 -3.34123353e+00 -7.59112263e-01 -3.79797