In [38]:
%matplotlib notebook
%reset

Nothing done.


In [2]:
import gymnasium as gym
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.notebook import tqdm
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import datetime

In [3]:
# from tensorboard import notebook

# # 初始化 SummaryWriter，日志将保存在'runs'目录下
# writer = SummaryWriter(f"./logs/{datetime.datetime.now()}")

# # 启动 TensorBoard 并指定日志目录
# notebook.start("--logdir runs")

# # 可选：创建一个链接直接跳转到 TensorBoard 界面
# # notebook.display(height=400)

In [3]:
%matplotlib qt5

In [4]:
class Actor(nn.Module):
    def __init__(self, n_state, n_action, hidden_size = 64):
        super(Actor, self).__init__()
        
        self.fc1 = torch.nn.Linear(n_state, hidden_size)
        self.fc2 = torch.nn.Linear(hidden_size, hidden_size)
        self.fc_mean = torch.nn.Linear(hidden_size, n_action)
        self.fc_std = torch.nn.Linear(hidden_size, n_action)

    def forward(self, state):
        x = self.fc1(state)
        x = self.fc2(F.relu(x))
        mu = torch.tanh(self.fc_mean(F.relu(x))) * 1.0
        std = F.softplus(self.fc_std(F.relu(x)))
        return mu, std

        
class Critic(nn.Module):
    def __init__(self, n_state, hidden_size=64):
        super(Critic, self).__init__()

        self.fc1 = torch.nn.Linear(n_state, hidden_size)
        self.fc2 = torch.nn.Linear(hidden_size, hidden_size)
        self.fc3 = torch.nn.Linear(hidden_size, 1)

    def forward(self, state):
        x = self.fc1(state)
        x = self.fc2(F.relu(x))
        x = self.fc3(F.relu(x))
        return x


def compute_advantage(gamma, lmbda, td_delta):
    td_delta = td_delta.detach().numpy()
    advantage_list = []
    advantage = 0.0
    for delta in td_delta[::-1]:
        advantage = gamma * lmbda * advantage + delta
        advantage_list.append(advantage)
    advantage_list.reverse()
    return torch.tensor(advantage_list, dtype=torch.float)

class PPOContinuous(nn.Module):

    def __init__(self, n_state, n_action, n_hidden = 64, actor_lr=1e-4, critic_lr=1e-4, lmbda=0.1, epochs=10, eps=0.01, gamma=0.99, device="cpu"):
        super(PPOContinuous, self).__init__()
        print(f"{n_state=}, {n_action=}, {n_hidden=}")

        self.actor = Actor(n_state, n_action, hidden_size=n_hidden)
        self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr = actor_lr)

        self.critic = Critic(n_state, hidden_size=n_hidden)
        self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr = critic_lr)

        self.lmbda = lmbda
        self.gamma = gamma
        self.eps = eps
        self.epochs = epochs
        self.device = device

    def take_action(self, state):
        state = torch.tensor([state], dtype=torch.float).to(self.device)
        mu, std = self.actor(state)

        action_dist = torch.distributions.Normal(mu, std)
        action = action_dist.sample()
        # print(action)
        return [action.item()]
        

    def update(self, transition_dict):
        states = torch.tensor(transition_dict['states'], dtype=torch.float).to(self.device)
        actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device)
        rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).view(-1, 1).to(self.device)
        next_states = torch.tensor(transition_dict['next_states'], dtype=torch.float).to(self.device)
        dones = torch.tensor(transition_dict['dones'], dtype=torch.float).view(-1, 1).to(self.device) 

        td_target = rewards + self.gamma * self.critic(next_states) * (1 - dones)
        td_delta = td_target - self.critic(states)
        
        advantage = compute_advantage(self.gamma, self.lmbda, td_delta)

        mu, std = self.actor(states)
        action_dist = torch.distributions.Normal(mu.detach(), std.detach())
        old_log_probs = action_dist.log_prob(actions)

        for _ in range(self.epochs):
            mu, std = self.actor(states)
            action_dists = torch.distributions.Normal(mu, std)
            log_probs = action_dists.log_prob(actions)

            ratio = torch.exp(log_probs - old_log_probs)

            l1 = ratio * advantage
            l2 = torch.clamp(ratio, 1 - self.eps, 1 + self.eps) * advantage
            l3 = - torch.min(l1, l2)

            actor_loss = torch.mean(l3)
            critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))

            self.actor_opt.zero_grad()
            self.critic_opt.zero_grad()
            actor_loss.backward()
            critic_loss.backward()
            self.actor_opt.step()
            self.critic_opt.step()
        

In [9]:
actor_lr = 1e-3
critic_lr = 1e-2
num_episodes = 500
hidden = 128

gamma = 0.98
lmbda = 0.95
epochs = 10
eps = 0.2

device = "cpu"
env_name = "MountainCarContinuous-v0"
# env_name = "Pendulum-v1"


env = gym.make(env_name)
torch.manual_seed(0)
print(f"state space:" , env.observation_space.sample())
print(f"action space: {env.action_space}")
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]

agent = PPOContinuous(n_state=state_dim, n_action=action_dim, n_hidden=hidden,actor_lr=actor_lr, critic_lr=critic_lr)

state space: [-1.0480918  -0.05745233]
action space: Box(-1.0, 1.0, (1,), float32)
n_state=2, n_action=1, n_hidden=128


In [13]:
def train_on_policy_agent(env, agent, num_episodes):
    return_list = []
    episode_return = 0

    for epoch in range(num_episodes//10):
    
        for i in tqdm(range(10), position=0, desc=f"###{epoch} : {episode_return} "):
            episode_return = 0
            transition_dict = {'states': [], 'actions': [], 'next_states': [], 'rewards': [], 'dones': []}
            state, _ = env.reset()
            done , truncated = False, False
            while not done and not truncated:
                action = agent.take_action(state)
                next_state, reward, done, truncated, info = env.step(action)
    
                transition_dict['states'].append(state)
                transition_dict['actions'].append(action)
                transition_dict['next_states'].append(next_state)
                transition_dict['rewards'].append(reward)
                transition_dict['dones'].append(done)
                state = next_state
                episode_return += reward
            return_list.append(episode_return)
            agent.update(transition_dict)

        # writer.add_scalar("reward", episode_return, i)

    plt.plot(return_list)
        
train_on_policy_agent(env, agent, num_episodes)

###0 : 0 :   0%|          | 0/10 [00:00<?, ?it/s]

###1 : -93.84133858334292 :   0%|          | 0/10 [00:00<?, ?it/s]

###2 : -92.39987013551166 :   0%|          | 0/10 [00:00<?, ?it/s]

###3 : -88.55647541766075 :   0%|          | 0/10 [00:00<?, ?it/s]

###4 : -80.43104085687393 :   0%|          | 0/10 [00:00<?, ?it/s]

###5 : -77.18342917274362 :   0%|          | 0/10 [00:00<?, ?it/s]

###6 : -72.14911113487037 :   0%|          | 0/10 [00:00<?, ?it/s]

###7 : -71.7140634745839 :   0%|          | 0/10 [00:00<?, ?it/s]

###8 : -70.4080494939098 :   0%|          | 0/10 [00:00<?, ?it/s]

###9 : -62.22980941862138 :   0%|          | 0/10 [00:00<?, ?it/s]

###10 : -57.879411668363744 :   0%|          | 0/10 [00:00<?, ?it/s]

###11 : -59.87772323753635 :   0%|          | 0/10 [00:00<?, ?it/s]

###12 : -58.677413395246205 :   0%|          | 0/10 [00:00<?, ?it/s]

###13 : -58.95170923699332 :   0%|          | 0/10 [00:00<?, ?it/s]

###14 : -59.549150147953505 :   0%|          | 0/10 [00:00<?, ?it/s]

###15 : -60.45109098623604 :   0%|          | 0/10 [00:00<?, ?it/s]

###16 : -59.40602003343395 :   0%|          | 0/10 [00:00<?, ?it/s]

###17 : -59.18408743383421 :   0%|          | 0/10 [00:00<?, ?it/s]

###18 : -54.91135063294087 :   0%|          | 0/10 [00:00<?, ?it/s]

###19 : -54.32389178790463 :   0%|          | 0/10 [00:00<?, ?it/s]

###20 : -50.72366656587637 :   0%|          | 0/10 [00:00<?, ?it/s]

###21 : -51.09753756590697 :   0%|          | 0/10 [00:00<?, ?it/s]

###22 : -49.47372743496942 :   0%|          | 0/10 [00:00<?, ?it/s]

###23 : -50.22477689525824 :   0%|          | 0/10 [00:00<?, ?it/s]

###24 : -46.78695111113428 :   0%|          | 0/10 [00:00<?, ?it/s]

###25 : -45.29352730109139 :   0%|          | 0/10 [00:00<?, ?it/s]

###26 : -45.04152146242341 :   0%|          | 0/10 [00:00<?, ?it/s]

###27 : -44.27996499668439 :   0%|          | 0/10 [00:00<?, ?it/s]

###28 : -41.78275410837183 :   0%|          | 0/10 [00:00<?, ?it/s]

###29 : -40.63479637278872 :   0%|          | 0/10 [00:00<?, ?it/s]

###30 : -41.92397429810087 :   0%|          | 0/10 [00:00<?, ?it/s]

###31 : -41.281765249816004 :   0%|          | 0/10 [00:00<?, ?it/s]

###32 : -39.044433745615216 :   0%|          | 0/10 [00:00<?, ?it/s]

###33 : -34.8778399567056 :   0%|          | 0/10 [00:00<?, ?it/s]

###34 : -32.79638766044619 :   0%|          | 0/10 [00:00<?, ?it/s]

###35 : -29.42503018076689 :   0%|          | 0/10 [00:00<?, ?it/s]

###36 : -27.105351176212498 :   0%|          | 0/10 [00:00<?, ?it/s]

###37 : -26.32031908662521 :   0%|          | 0/10 [00:00<?, ?it/s]

###38 : -25.333014820695375 :   0%|          | 0/10 [00:00<?, ?it/s]

###39 : -24.161414780142678 :   0%|          | 0/10 [00:00<?, ?it/s]

###40 : -23.85957351366895 :   0%|          | 0/10 [00:00<?, ?it/s]

###41 : -22.658708874744548 :   0%|          | 0/10 [00:00<?, ?it/s]

###42 : -23.025017757783374 :   0%|          | 0/10 [00:00<?, ?it/s]

###43 : -22.101948573833138 :   0%|          | 0/10 [00:00<?, ?it/s]

###44 : -19.017011336701476 :   0%|          | 0/10 [00:00<?, ?it/s]

###45 : -19.011303495119265 :   0%|          | 0/10 [00:00<?, ?it/s]

###46 : -17.549007237332585 :   0%|          | 0/10 [00:00<?, ?it/s]

###47 : -17.988541203591332 :   0%|          | 0/10 [00:00<?, ?it/s]

###48 : -17.425611770320153 :   0%|          | 0/10 [00:00<?, ?it/s]

###49 : -15.570866185637152 :   0%|          | 0/10 [00:00<?, ?it/s]

In [14]:
def test_agent(agent, env_name):
    env = gym.make(env_name, render_mode="human")

    state, info = env.reset()

    print(f"Starting observation: {state}")

    episode_over = False
    total_reward = 0

    while not episode_over:
        action = agent.take_action(state)
        state, reward, terminated, truncated, info = env.step(action)
        total_reward += reward
        episode_over = terminated or truncated
        print(f"{action=}")

    print(f"Episode finished! Total reward: {total_reward}")
    env.close()

test_agent(agent, env_name)

Starting observation: [-0.42971027  0.        ]
action=[-0.3821898400783539]
action=[-0.37004613876342773]
action=[-0.3761988878250122]
action=[-0.4050670564174652]
action=[-0.36931857466697693]
action=[-0.3646675944328308]
action=[-0.3717162311077118]
action=[-0.38405758142471313]
action=[-0.3775842487812042]
action=[-0.37856075167655945]
action=[-0.37106603384017944]
action=[-0.36460646986961365]
action=[-0.3850235939025879]
action=[-0.4000374972820282]
action=[-0.3903217017650604]
action=[-0.3964332938194275]
action=[-0.3655615746974945]
action=[-0.3703799247741699]
action=[-0.4208000898361206]
action=[-0.38364875316619873]
action=[-0.3751783072948456]
action=[-0.327769935131073]
action=[-0.3548699915409088]
action=[-0.35705021023750305]
action=[-0.3747051954269409]
action=[-0.41080325841903687]
action=[-0.3790697455406189]
action=[-0.39908212423324585]
action=[-0.38081857562065125]
action=[-0.3495803475379944]
action=[-0.368195503950119]
action=[-0.42072319984436035]
action=[-0.351