In [1]:
import torch
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
from collections import deque
from tqdm import tqdm

from walker import PPO, Normalize

In [25]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize environment
env = gym.make('Walker2d-v4', render_mode='rgb_array')

# Number of state and action
N_S = env.observation_space.shape[0]
N_A = env.action_space.shape[0]

# Initialize PPO model
ppo = PPO(N_S, N_A)
normalize = Normalize(N_S)

# Load the saved model
log_dir = "../runs/20240709_14-44-31/1000/ppo/"
ppo.load(log_dir)
normalize.load_params(log_dir + "../../normalize_params.npy")

# Test the model
test_total_reward = 0
test_episodes = 10  # Number of episodes to test
for episode_id in range(test_episodes):
    now_state, _ = env.reset(seed=500)
    now_state = normalize(now_state)
    score = 0
    for _ in range(1000):
#         env.render()

        with torch.no_grad():
            ppo.actor_net.eval()
            a = ppo.actor_net.choose_action(torch.from_numpy(np.array(now_state).astype(np.float32)).unsqueeze(0))[0]
        now_state, r, done, _, _ = env.step(a)
        now_state = normalize(now_state)
        score += r

        if done:
            break
    print("episode: ", episode_id, "\tscore: ", score)
    
# for _ in range(test_episodes):
#     state, _ = env.reset()
#     state = normalize(state)
#     done = False
#     episode_reward = 0
#     while not done:
#         action = ppo.actor_net.choose_action(torch.from_numpy(np.array(state).astype(np.float32)).unsqueeze(0))[0]
#         next_state, reward, truncated, terminated, info = env.step(action)
#         episode_reward += reward
#         state = normalize(next_state)
#         done = truncated or terminated
#     test_total_reward += episode_reward
# average_test_reward = test_total_reward / test_episodes
# print('Average test reward: {:.2f}'.format(average_test_reward))

episode:  0 	score:  2.208516010088635
episode:  1 	score:  7.215898441911826
episode:  2 	score:  -15.911611703431799
episode:  3 	score:  -11.40454163017532
episode:  4 	score:  -22.853406463803356
episode:  5 	score:  -22.70895037888944
episode:  6 	score:  -19.829054643318152
episode:  7 	score:  -42.53286963474444
episode:  8 	score:  -16.60295789560731
episode:  9 	score:  -23.146209907657855
