# TD3

In [None]:
import time
import gym
import torch
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from collections import deque
import random
import copy
import os
import shutil

In [None]:
class CriticNet(torch.nn.Module):
    def __init__(self, env):
        super(CriticNet, self).__init__()
        # critic1
        self.fc1 = torch.nn.Linear(env.observation_space.shape[0] + env.action_space.shape[0], 128)
        self.relu1 = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(128, 128)
        self.relu2 = torch.nn.ReLU()
        self.fc3 = torch.nn.Linear(128, env.action_space.shape[0])

        # critic2
        self.fc4 = torch.nn.Linear(env.observation_space.shape[0] + env.action_space.shape[0], 128)
        self.relu4 = torch.nn.ReLU()
        self.fc5 = torch.nn.Linear(128, 128)
        self.relu5 = torch.nn.ReLU()
        self.fc6 = torch.nn.Linear(128, env.action_space.shape[0])

    def forward(self, observation, action):
        cat_x = torch.cat([observation, action], dim=1)

        # critic1
        x1 = self.relu1(self.fc1(cat_x))
        x1 = self.relu2(self.fc2(x1))
        x1 = self.fc3(x1)

        # critic2
        x2 = self.relu4(self.fc4(cat_x))
        x2 = self.relu5(self.fc5(x2))
        x2 = self.fc6(x2)

        return x1, x2

    def Q1(self, observation, action):
        cat_x = torch.cat([observation, action], dim=1)

        # critic1
        x1 = self.relu1(self.fc1(cat_x))
        x1 = self.relu2(self.fc2(x1))
        x1 = self.fc3(x1)

        return x1

In [None]:
class ActorNet(torch.nn.Module):
    def __init__(self, env, max_action):
        super(ActorNet, self).__init__()
        self.fc1 = torch.nn.Linear(env.observation_space.shape[0], 128)
        self.relu1 = torch.nn.ReLU()

        self.fc2 = torch.nn.Linear(128, 128)
        self.relu2 = torch.nn.ReLU()

        self.fc3 = torch.nn.Linear(128, env.action_space.shape[0])
        self.tanh = torch.nn.Tanh()

        self.max_action = max_action

    def forward(self, x):
        x = self.relu1(self.fc1(x))
        x = self.relu2(self.fc2(x))
        x = self.tanh(self.fc3(x))
        x = self.max_action * x
        return x

In [None]:
class TD3:
    def __init__(self, env, batch_size=64):
        self.critic = CriticNet(env)
        self.critic_target = copy.deepcopy(self.critic)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=1e-4)

        self.actor = ActorNet(env, max_action=env.action_space.high[0])
        self.actor_target = copy.deepcopy(self.actor)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=1e-4)

        self.max_action = env.action_space.high[0]
        self.mse = torch.nn.MSELoss()
        self.buffer = deque(maxlen=(10 ** 6))
        self.batch_size = batch_size
        self.gamma = 0.99
        self.tau = 0.005
        self.noise_clip = 0.5 * self.max_action
        self.policy_noise = 0.2 * self.max_action
        self.iter = 0
        self.policy_freq = 2
        self.env = env

    # def choose_action(self, state, explore=True):
    #     state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
    #     with torch.no_grad():
    #         action = self.actor(state).squeeze(0).numpy()
    #     if not explore:
    #         return action
    #     else:
    #         action = np.clip(np.random.normal(action, 1), -2, 2)
    #         return action

    def choose_action(self, state):
        state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
        with torch.no_grad():
            action = self.actor(state).squeeze(0).numpy()
        return action

    def update_target(self):

        for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

        for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

    def learn(self):
        self.iter += 1
        batch_samples = random.sample(self.buffer, self.batch_size)
        state_lst, action_lst, reward_lst, new_state_lst, done_lst = zip(*batch_samples)
        state_lst = torch.FloatTensor(state_lst)
        action_lst = torch.FloatTensor(action_lst)
        reward_lst = torch.FloatTensor(reward_lst)
        new_state_lst = torch.FloatTensor(new_state_lst)
        done_lst = torch.FloatTensor(done_lst)
        # print(state_lst.size())
        # print(action_lst.size())
        # print(reward_lst.size())
        # print(new_state_lst.size())

        # 更新critic网络
        with torch.no_grad():
            noise = torch.clip(torch.randn_like(action_lst) * self.policy_noise, -self.noise_clip, self.noise_clip)
            new_action = torch.clip(self.actor_target(new_state_lst) + noise, -self.max_action, self.max_action)

        q_target1, q_target2 = self.critic_target(new_state_lst, new_action)
        q_target = reward_lst + self.gamma * (torch.min(q_target1, q_target2)) * (1 - done_lst)
        q_value1, q_value2 = self.critic(state_lst, action_lst)
        td_error = self.mse(q_target, q_value1) + self.mse(q_target, q_value2)
        self.critic_optimizer.zero_grad()
        td_error.backward()
        self.critic_optimizer.step()

        # 更新actor网络
        if self.iter % self.policy_freq == 0:
            action = self.actor(state_lst)
            q_value1 = self.critic.Q1(state_lst, action)
            loss_actor = -torch.mean(q_value1)  # 寻找最小的loos_actor, 就是寻找最大的torch.mean(q_value), 就是使其q值最大
            self.actor_optimizer.zero_grad()
            loss_actor.backward()
            self.actor_optimizer.step()
            self.update_target()

    def model_save(self, path):
        torch.save({
            'actor_model_state_dict': self.actor.state_dict(),
            'actor_target_model_state_dict': self.actor_target.state_dict(),
            'critic_model_state_dict': self.critic.state_dict(),
            'critic_target_model_state_dict': self.critic_target.state_dict(),
            'actor_optimizer_state_dict': self.actor_optimizer.state_dict(),
            'critic_optimizer_state_dict': self.critic_optimizer.state_dict(),

        }, path)

    def model_load(self, path):
        checkpoint = torch.load(path)
        self.actor.load_state_dict(checkpoint['actor_model_state_dict'])
        self.actor_target.load_state_dict(checkpoint['actor_target_model_state_dict'])
        self.critic.load_state_dict(checkpoint['critic_model_state_dict'])
        self.critic_target.load_state_dict(checkpoint['ctitic_target_model_state_dict'])
        self.actor_optimizer.load_state_dict(checkpoint['actor_optimizer_state_dict'])
        self.critic_optimizer.load_state_dict(checkpoint['critic_optimizer_state_dict'])

In [None]:
log_dir = './runs'
if os.path.exists(log_dir):
    try:
        shutil.rmtree(log_dir)
        print(f'文件夹 {log_dir} 已成功删除。')
    except OSError as error:
        print(f'删除文件夹 {log_dir} 失败: {error}')
else:
    os.makedirs(log_dir)
    print(f'文件夹 {log_dir} 不存在，已创建文件夹 {log_dir}。')

In [ ]:
model_save_dir = "./model_save"
if os.path.exists(model_save_dir):
    try:
        shutil.rmtree(model_save_dir)
        print(f'文件夹 {model_save_dir} 已成功删除。')
    except OSError as error:
        print(f'删除文件夹 {model_save_dir} 失败: {error}')
else:
    os.makedirs(model_save_dir)
    print(f'文件夹 {model_save_dir} 不存在，已创建文件夹 {model_save_dir}。')

In [None]:
logwriter = SummaryWriter(log_dir=log_dir)
env = gym.make("BipedalWalker-v3")
batch_size = 256
td3 = TD3(env, batch_size)
episode = 2000
steps = 3000
all_reward = []
now_epoch = 0
expl_noise = 0.25
for epoch in range(episode):
    start_time = time.time()
    state, _ = env.reset()
    step = 0
    episode_rewards = 0
    done = False
    expl_noise *= 0.999
    while not done:
        action = (td3.choose_action(state) + np.random.normal(0, td3.max_action * expl_noise,
                                                              size=env.action_space.shape[0])).clip(-td3.max_action,
                                                                                                    td3.max_action)

        new_state, reward, done, _, _ = env.step(action)
        if reward <= -100:
            reward = -1
            td3.buffer.append([state, action, [reward], new_state, [True]])
        else:
            td3.buffer.append([state, action, [reward], new_state, [False]])
        state = new_state
        episode_rewards += reward
        # print("epoch: {}, step: {}, episode reward: {}".format(epoch, step, episode_rewards))
        step += 1
        # if done or step == each_episode:
        #     for _ in range(100):
        #         td3.learn()
        #     break
        if step > steps:
            break

        if len(td3.buffer) > 2000:
            td3.learn()
    if epoch % 10 == 0:
        if not os.path.exists("./model_save/"):
            os.makedirs("./model_save/")
        td3.model_save('./model_save/{}.pth'.format(epoch))
    now_epoch = epoch
    all_reward.append(episode_rewards)
    logwriter.add_scalar('episode_rewards', episode_rewards, epoch)
    print("Epoch/Episode: {}/{},reward: {}".format(epoch + 1, episode, episode_rewards))

In [None]:
env = gym.make("BipedalWalker-v3", render_mode='human')
episode_rewards = 0
for _ in range(50):
    start_time = time.time()
    state, _ = env.reset()
    step = 0
    while True:
        a = td3.choose_action(torch.tensor(state))
        new_state, reward, done, _, _ = env.step(a)
        step += 1
        state = new_state
        if done:
            end_time = time.time()
            print(end_time - start_time)
            break

In [None]:
env.close()

In [None]:
td3.model_save('./model_save/{}.pth'.format(now_epoch))