In [27]:
%matplotlib notebook
%reset



Once deleted, variables cannot be recovered. Proceed (y/[n])?  y


In [53]:
import gymnasium as gym
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.notebook import tqdm
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import datetime
import random

In [91]:
for i in range(10):
    print(random.random())

0.11940412722115057
0.2859398124709106
0.8640670813421467
0.6589195246675073
0.5375914640111615
0.717015007997773
0.5116349047869231
0.4261274685124038
0.49295167576509613
0.8832233085790496


In [54]:
# from tensorboard import notebook

# # 初始化 SummaryWriter，日志将保存在'runs'目录下
# writer = SummaryWriter(f"./logs/{datetime.datetime.now()}")

# # 启动 TensorBoard 并指定日志目录
# notebook.start("--logdir runs")

# # 可选：创建一个链接直接跳转到 TensorBoard 界面
# # notebook.display(height=400)

In [55]:
# %matplotlib qt5

In [273]:
class Actor(nn.Module):
    def __init__(self, n_state, n_action, hidden_size = 64):
        super(Actor, self).__init__()
        
        self.fc1 = torch.nn.Linear(n_state, hidden_size)
        self.fc2 = torch.nn.Linear(hidden_size, hidden_size)
        self.fc_mean = torch.nn.Linear(hidden_size, n_action)
        self.fc_std = torch.nn.Linear(hidden_size, n_action)

    def forward(self, state):
        x = self.fc1(state)
        x = self.fc2(F.relu(x))
        mu = torch.tanh(self.fc_mean(F.relu(x))) * 1.0
        # std = F.softplus(self.fc_std(F.relu(x))) + 0e-3
        std = torch.ones_like(mu) * 1.0  # 固定标准差，增强探索

        return mu, std

        
class Critic(nn.Module):
    def __init__(self, n_state, hidden_size=64):
        super(Critic, self).__init__()

        self.fc1 = torch.nn.Linear(n_state, hidden_size)
        self.fc2 = torch.nn.Linear(hidden_size, hidden_size)
        self.fc3 = torch.nn.Linear(hidden_size, 1)

    def forward(self, state):
        x = self.fc1(state)
        x = self.fc2(F.relu(x))
        x = self.fc3(F.relu(x))
        return x


def compute_advantage(gamma, lmbda, td_delta):
    td_delta = td_delta.detach().numpy()
    advantage_list = []
    advantage = 0.0
    for delta in td_delta[::-1]:
        advantage = gamma * lmbda * advantage + delta
        advantage_list.append(advantage)
    advantage_list.reverse()
    return torch.tensor(advantage_list, dtype=torch.float)

def shaped_reward(state, original_reward):
    position, velocity = state
    # 增加与位置相关的奖励（鼓励向右侧山顶移动）
    position_reward = 10 * (position + 0.5)  # 谷底位置约为 -0.5，向右侧移动时奖励增加
    # 增加与速度方向相关的奖励（鼓励沿目标方向加速）
    velocity_reward = 5 * velocity if position > -0.5 else 0  # 右侧加速时奖励更高
    return original_reward + position_reward + velocity_reward

class PPOContinuous(nn.Module):

    def __init__(self, n_state, n_action, n_hidden = 64, actor_lr=1e-4, critic_lr=1e-4, lmbda=0.1, epochs=10, eps=0.01, gamma=0.99, device="cpu"):
        super(PPOContinuous, self).__init__()
        print(f"{n_state=}, {n_action=}, {n_hidden=}")

        self.actor = Actor(n_state, n_action, hidden_size=n_hidden)
        self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr = actor_lr)

        self.critic = Critic(n_state, hidden_size=n_hidden)
        self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr = critic_lr)

        self.lmbda = lmbda
        self.gamma = gamma
        self.eps = eps
        self.epochs = epochs
        self.device = device

    def take_action(self, state, eval = False):
        state = torch.tensor([state], dtype=torch.float).to(self.device)
        mu, std = self.actor(state)

        action_dist = torch.distributions.Normal(mu, std)
        action = action_dist.sample()

        if not eval:
            r = random.uniform(-1, 1) * 0.0
        else:
            r = 0.

        # return [np.clip(action.item() + r, -1, 1)]
        return [action.item()]
        

    def update(self, transition_dict):
        states = torch.tensor(transition_dict['states'], dtype=torch.float).to(self.device)
        actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device)
        rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).view(-1, 1).to(self.device)
        next_states = torch.tensor(transition_dict['next_states'], dtype=torch.float).to(self.device)
        dones = torch.tensor(transition_dict['dones'], dtype=torch.float).view(-1, 1).to(self.device) 

        td_target = rewards + self.gamma * self.critic(next_states) * (1 - dones)
        td_delta = td_target - self.critic(states)
        
        advantage = compute_advantage(self.gamma, self.lmbda, td_delta)

        mu, std = self.actor(states)
        action_dist = torch.distributions.Normal(mu.detach(), std.detach())
        old_log_probs = action_dist.log_prob(actions)

        for _ in range(self.epochs):
            mu, std = self.actor(states)
            # print(f"====={mu.max()=}, {std.max()=}, {std.min()=}")

            action_dists = torch.distributions.Normal(mu, std)
            log_probs = action_dists.log_prob(actions)

            ratio = torch.exp(log_probs - old_log_probs)
            # ratio = torch.clamp(ratio, min=-1e6, max=1e6)
            l1 = ratio * advantage
            l2 = torch.clamp(ratio, 1 - self.eps, 1 + self.eps) * advantage
            l3 = - torch.min(l1, l2)

            actor_loss = torch.mean(l3)
            critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))
            # print(f"====={actor_loss=}, {critic_loss=}, {l1.min()=}, {ratio.min()=}, {ratio.max()=} \n")


            self.actor_opt.zero_grad()
            self.critic_opt.zero_grad()
            actor_loss.backward()
            critic_loss.backward()
            self.actor_opt.step()
            self.critic_opt.step()
        

In [274]:
actor_lr = 1e-4
critic_lr = 1e-3
num_episodes = 1000
hidden = 128

gamma = 0.98
lmbda = 0.95
epochs = 10
eps = 0.4

device = "cpu"
env_name = "MountainCarContinuous-v0"
# env_name = "Pendulum-v1"


env = gym.make(env_name)
torch.manual_seed(0)
print(f"state space:" , env.observation_space.sample())
print(f"action space: {env.action_space}")
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]

agent = PPOContinuous(n_state=state_dim, n_action=action_dim, n_hidden=hidden,actor_lr=actor_lr, critic_lr=critic_lr)

state space: [-0.9018829   0.02360933]
action space: Box(-1.0, 1.0, (1,), float32)
n_state=2, n_action=1, n_hidden=128


In [275]:
def train_on_policy_agent(env, agent, num_episodes):
    return_list = []
    episode_return = 0

    for epoch in range(num_episodes//10):
    
        for i in tqdm(range(10), position=0, desc=f"###{epoch} : {episode_return} "):
            episode_return = 0
            transition_dict = {'states': [], 'actions': [], 'next_states': [], 'rewards': [], 'dones': []}
            state, _ = env.reset()
            done , truncated = False, False
            while not done and not truncated:
                action = agent.take_action(state)
                next_state, reward, done, truncated, info = env.step(action)
                # reward = shaped_reward(state, reward)
                
    
                transition_dict['states'].append(state)
                transition_dict['actions'].append(action)
                transition_dict['next_states'].append(next_state)
                transition_dict['rewards'].append(reward)
                transition_dict['dones'].append(done)
                state = next_state
                episode_return += reward
            return_list.append(episode_return)
            agent.update(transition_dict)


    return return_list
        
return_list = train_on_policy_agent(env, agent, num_episodes)
plt.plot(return_list)
plt.show()

###0 : 0 :   0%|          | 0/10 [00:00<?, ?it/s]

###1 : -101.16863295896077 :   0%|          | 0/10 [00:00<?, ?it/s]

###2 : -102.39011370411221 :   0%|          | 0/10 [00:00<?, ?it/s]

###3 : -105.47359589021839 :   0%|          | 0/10 [00:00<?, ?it/s]

###4 : -102.29501268629689 :   0%|          | 0/10 [00:00<?, ?it/s]

###5 : -99.340455300911 :   0%|          | 0/10 [00:00<?, ?it/s]

###6 : -100.94980686846606 :   0%|          | 0/10 [00:00<?, ?it/s]

###7 : -96.75876040022787 :   0%|          | 0/10 [00:00<?, ?it/s]

###8 : -101.85771243855943 :   0%|          | 0/10 [00:00<?, ?it/s]

###9 : -102.63427133849662 :   0%|          | 0/10 [00:00<?, ?it/s]

###10 : -90.92225238770455 :   0%|          | 0/10 [00:00<?, ?it/s]

###11 : -96.55385903801596 :   0%|          | 0/10 [00:00<?, ?it/s]

###12 : -110.89649797018883 :   0%|          | 0/10 [00:00<?, ?it/s]

###13 : -101.57235204309994 :   0%|          | 0/10 [00:00<?, ?it/s]

###14 : -104.80319872940647 :   0%|          | 0/10 [00:00<?, ?it/s]

###15 : -101.76233496493475 :   0%|          | 0/10 [00:00<?, ?it/s]

###16 : -104.36130911214508 :   0%|          | 0/10 [00:00<?, ?it/s]

###17 : -101.23465711925358 :   0%|          | 0/10 [00:00<?, ?it/s]

###18 : -102.21231397200125 :   0%|          | 0/10 [00:00<?, ?it/s]

###19 : -104.74761074095095 :   0%|          | 0/10 [00:00<?, ?it/s]

###20 : -103.43861539271262 :   0%|          | 0/10 [00:00<?, ?it/s]

###21 : -108.09601290502984 :   0%|          | 0/10 [00:00<?, ?it/s]

###22 : -105.18816802762194 :   0%|          | 0/10 [00:00<?, ?it/s]

###23 : -102.49470865160818 :   0%|          | 0/10 [00:00<?, ?it/s]

###24 : -98.1860985780383 :   0%|          | 0/10 [00:00<?, ?it/s]

###25 : -103.54040733824138 :   0%|          | 0/10 [00:00<?, ?it/s]

###26 : -106.45173897127871 :   0%|          | 0/10 [00:00<?, ?it/s]

###27 : -98.74851194150637 :   0%|          | 0/10 [00:00<?, ?it/s]

###28 : -92.71029441571764 :   0%|          | 0/10 [00:00<?, ?it/s]

###29 : -100.22000165113498 :   0%|          | 0/10 [00:00<?, ?it/s]

###30 : -97.31288856087465 :   0%|          | 0/10 [00:00<?, ?it/s]

###31 : -97.72180849656226 :   0%|          | 0/10 [00:00<?, ?it/s]

###32 : -105.7962488189959 :   0%|          | 0/10 [00:00<?, ?it/s]

###33 : -106.40851796926037 :   0%|          | 0/10 [00:00<?, ?it/s]

###34 : -106.12339400343046 :   0%|          | 0/10 [00:00<?, ?it/s]

###35 : -96.93539209945929 :   0%|          | 0/10 [00:00<?, ?it/s]

###36 : -103.9481170000257 :   0%|          | 0/10 [00:00<?, ?it/s]

###37 : -94.14638789690734 :   0%|          | 0/10 [00:00<?, ?it/s]

###38 : -100.29241159319258 :   0%|          | 0/10 [00:00<?, ?it/s]

###39 : -108.62596651452995 :   0%|          | 0/10 [00:00<?, ?it/s]

###40 : -95.50407299574728 :   0%|          | 0/10 [00:00<?, ?it/s]

###41 : -101.76460190865096 :   0%|          | 0/10 [00:00<?, ?it/s]

###42 : -104.41162065495166 :   0%|          | 0/10 [00:00<?, ?it/s]

###43 : -97.20250914290486 :   0%|          | 0/10 [00:00<?, ?it/s]

###44 : -100.33320207139622 :   0%|          | 0/10 [00:00<?, ?it/s]

###45 : -105.31425632404247 :   0%|          | 0/10 [00:00<?, ?it/s]

###46 : -98.2440285522618 :   0%|          | 0/10 [00:00<?, ?it/s]

###47 : -109.04009097509825 :   0%|          | 0/10 [00:00<?, ?it/s]

###48 : -103.70117025544502 :   0%|          | 0/10 [00:00<?, ?it/s]

###49 : -97.50791095042737 :   0%|          | 0/10 [00:00<?, ?it/s]

###50 : -112.25030827204264 :   0%|          | 0/10 [00:00<?, ?it/s]

###51 : -101.28818509881239 :   0%|          | 0/10 [00:00<?, ?it/s]

###52 : -106.03872722808617 :   0%|          | 0/10 [00:00<?, ?it/s]

###53 : -90.06553266907477 :   0%|          | 0/10 [00:00<?, ?it/s]

###54 : -97.4817635638676 :   0%|          | 0/10 [00:00<?, ?it/s]

###55 : -98.08531798688554 :   0%|          | 0/10 [00:00<?, ?it/s]

###56 : -97.93722378541725 :   0%|          | 0/10 [00:00<?, ?it/s]

###57 : 51.18833423935983 :   0%|          | 0/10 [00:00<?, ?it/s]

###58 : 2.4289820869580296 :   0%|          | 0/10 [00:00<?, ?it/s]

###59 : -111.97290565951704 :   0%|          | 0/10 [00:00<?, ?it/s]

###60 : -99.1286448011075 :   0%|          | 0/10 [00:00<?, ?it/s]

###61 : -106.29012086468788 :   0%|          | 0/10 [00:00<?, ?it/s]

###62 : -111.84767245338995 :   0%|          | 0/10 [00:00<?, ?it/s]

###63 : 31.944352814987084 :   0%|          | 0/10 [00:00<?, ?it/s]

###64 : 28.182395882817687 :   0%|          | 0/10 [00:00<?, ?it/s]

###65 : 38.55355823181373 :   0%|          | 0/10 [00:00<?, ?it/s]

###66 : 40.53372296873587 :   0%|          | 0/10 [00:00<?, ?it/s]

###67 : 20.98763121439589 :   0%|          | 0/10 [00:00<?, ?it/s]

###68 : 59.71209041717982 :   0%|          | 0/10 [00:00<?, ?it/s]

###69 : 53.820605795762916 :   0%|          | 0/10 [00:00<?, ?it/s]

###70 : 46.5801317734701 :   0%|          | 0/10 [00:00<?, ?it/s]

###71 : 63.1520281613472 :   0%|          | 0/10 [00:00<?, ?it/s]

###72 : 66.72666100562606 :   0%|          | 0/10 [00:00<?, ?it/s]

###73 : 68.42909008207337 :   0%|          | 0/10 [00:00<?, ?it/s]

###74 : 67.77730172752001 :   0%|          | 0/10 [00:00<?, ?it/s]

###75 : 74.07206324574528 :   0%|          | 0/10 [00:00<?, ?it/s]

###76 : 66.02464225722788 :   0%|          | 0/10 [00:00<?, ?it/s]

###77 : 69.81715618281305 :   0%|          | 0/10 [00:00<?, ?it/s]

###78 : 75.83231993152225 :   0%|          | 0/10 [00:00<?, ?it/s]

###79 : 73.16242661732217 :   0%|          | 0/10 [00:00<?, ?it/s]

###80 : 76.72343837525906 :   0%|          | 0/10 [00:00<?, ?it/s]

###81 : 76.80292050830697 :   0%|          | 0/10 [00:00<?, ?it/s]

###82 : 77.66847334399138 :   0%|          | 0/10 [00:00<?, ?it/s]

###83 : 76.15222326649013 :   0%|          | 0/10 [00:00<?, ?it/s]

###84 : 74.37851571048108 :   0%|          | 0/10 [00:00<?, ?it/s]

###85 : 74.07723417937837 :   0%|          | 0/10 [00:00<?, ?it/s]

###86 : 80.20637116447489 :   0%|          | 0/10 [00:00<?, ?it/s]

###87 : 75.82817837775114 :   0%|          | 0/10 [00:00<?, ?it/s]

###88 : 81.76175455758646 :   0%|          | 0/10 [00:00<?, ?it/s]

###89 : 67.54864219838828 :   0%|          | 0/10 [00:00<?, ?it/s]

###90 : 72.50945313873419 :   0%|          | 0/10 [00:00<?, ?it/s]

###91 : 65.24415151096298 :   0%|          | 0/10 [00:00<?, ?it/s]

###92 : 77.24721342283718 :   0%|          | 0/10 [00:00<?, ?it/s]

###93 : 74.3055954099092 :   0%|          | 0/10 [00:00<?, ?it/s]

###94 : 80.98704721291878 :   0%|          | 0/10 [00:00<?, ?it/s]

###95 : 74.3688575891939 :   0%|          | 0/10 [00:00<?, ?it/s]

###96 : 73.02562732962436 :   0%|          | 0/10 [00:00<?, ?it/s]

###97 : 77.33879709503299 :   0%|          | 0/10 [00:00<?, ?it/s]

###98 : 68.26796689870154 :   0%|          | 0/10 [00:00<?, ?it/s]

###99 : 64.85853396090027 :   0%|          | 0/10 [00:00<?, ?it/s]

In [276]:
def test_agent(agent, env_name):
    env = gym.make(env_name, render_mode="human")

    state, info = env.reset()

    print(f"Starting observation: {state}")

    episode_over = False
    total_reward = 0
    action_list = []

    while not episode_over:
        action = agent.take_action(state, eval = True)
        state, reward, terminated, truncated, info = env.step(action)
        total_reward += reward
        episode_over = terminated or truncated
        # print(f"{action=}")
        action_list.append(action)

    print(f"Episode finished! Total reward: {total_reward}")
    env.close()
    plt.hist(action_list)

test_agent(agent, env_name)

Starting observation: [-0.49269798  0.        ]
Episode finished! Total reward: 81.37953463048275
