In [3]:
# Import thư viện
import numpy as np
import torch
from collections import deque
from Env.environment import make_env
from Dqn.dqn_agent import DQNAgent
import matplotlib.pyplot as plt


In [None]:
# Khởi tạo môi trường và agent
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
env = make_env("BipedalWalker-v3", seed=42, render_mode=None)

state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

agent = DQNAgent(state_dim, action_dim, device=device)

replay_buffer = deque(maxlen=100000)


In [4]:
# Tham số training
num_episodes = 200
max_steps = 2000
batch_size = 64
gamma = 0.99
eps_start = 1.0
eps_end = 0.05
eps_decay = 0.995
target_update = 10


In [None]:
# Vòng lặp training
epsilon = eps_start
rewards_all = []

for episode in range(1, num_episodes + 1):
    state, _ = env.reset()
    total_reward = 0

    for t in range(max_steps):
        action = agent.act(state, epsilon)
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated

        replay_buffer.append((state, action, reward, next_state, done))

        state = next_state
        total_reward += reward

        if len(replay_buffer) >= batch_size:
            batch = np.random.choice(len(replay_buffer), batch_size, replace=False)
            states, actions, rewards, next_states, dones = zip(*[replay_buffer[i] for i in batch])
            agent.learn(states, actions, rewards, next_states, dones)

        if done:
            break

    epsilon = max(eps_end, epsilon * eps_decay)
    rewards_all.append(total_reward)

    if episode % target_update == 0:
        agent.update_target()

    print(f"Episode {episode}, Reward: {total_reward:.2f}, Epsilon: {epsilon:.3f}")


In [None]:
# Vẽ biểu đồ reward theo episode
plt.plot(rewards_all)
plt.xlabel("Episode")
plt.ylabel("Total Reward")
plt.title("Training Rewards of DQN Agent")
plt.show()


In [None]:
# Lưu model
algorithm = "dqn"
save_path = rf"D:\code_etc\Python\_File_chay_code\DRL\Bidepal_Gym\Model\actor_dqn.pth"
agent.save(save_path)
print(f"Model saved to {save_path}")