In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from RobotEnv import RobotEnv
import random
env = RobotEnv(screen_width=400, screen_height=400)
# size of obs
observation_length = env.observation_space.shape[0]
action_length = env.action_space.n

# 定义模型

In [None]:
import torch


class SelectOutput(torch.nn.Module):
    def forward(self, inputs):
        outputs, (hidden, cell) = inputs
        return outputs


model = torch.nn.Sequential(
    torch.nn.Linear(observation_length, 64),
    torch.nn.ReLU(),
    torch.nn.LSTM(64, 64, batch_first=True),
    SelectOutput(),  # 添加自定义层
    torch.nn.ReLU(),
    torch.nn.Linear(64, action_length),
)
model_delay = torch.nn.Sequential(
    torch.nn.Linear(observation_length, 64),
    torch.nn.ReLU(),
    torch.nn.LSTM(64, 64, batch_first=True),
    SelectOutput(),  # 添加自定义层
    torch.nn.ReLU(),
    torch.nn.Linear(64, action_length),
)
# 复制参数
model_delay.load_state_dict(model.state_dict())

# 单局游戏

In [None]:
# 定义优化器
from utils import Controller,Pool
controller = Controller(model, env)
pool = Pool(controller)
pool.update()
pool.sample()

# 训练

In [7]:
from torch.utils.tensorboard import SummaryWriter
# 训练
def train():
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
    loss_fn = torch.nn.MSELoss()
    # 共训练n_step次
    n_step = 0
    log_interval = 100_000
    last_log_step = 0
    writer = SummaryWriter("./logs/DRQN_logs")
    while n_step < 50_000_000:
        n_step += pool.update()
        # print(f"n_step:{n_step}")
        # 每次更新数据后,训练N次
        for i in range(200):
            # print(f"i:{i}")
            # 采样N条数据
            state, action, reward, next_state, terminated = pool.sample()

            # 计算value
            value = model(state).gather(dim=1, index=action)

            # 计算target
            with torch.no_grad():
                # 使用原模型计算动作,使用延迟模型计算target,进一步缓解自举
                next_action = model(next_state).argmax(dim=1, keepdim=True)
                target = model_delay(next_state).gather(dim=1,
                                                        index=next_action)
            target = target * 0.99 * (1 - terminated) + reward

            loss = loss_fn(value, target)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        # 复制参数
        if (n_step - last_log_step) >=log_interval:
            model_delay.load_state_dict(model.state_dict())
            test_result = sum([pool.controller.play(mode="test")[-1] for _ in range(20)]) / 20
            print(f"step:{n_step},test_result:{test_result}")
            last_log_step = n_step
            # 将步数，测试结果和损失写入TensorBoard
            writer.add_scalar('Step', n_step, global_step=n_step)
            writer.add_scalar('Test Result', test_result, global_step=n_step)
            writer.add_scalar('Loss', loss.item(), global_step=n_step)      
    writer.close()  # 训练结束后关闭writer
train()

step:101286,test_result:-96.73750000000003
step:202098,test_result:-68.7875
step:303756,test_result:-49.02250000000001
step:405020,test_result:-31.8825
step:506121,test_result:-55.57250000000001
step:607524,test_result:-44.410000000000004
step:708584,test_result:-40.81250000000001
step:809928,test_result:-82.68750000000006
step:911629,test_result:-47.40999999999998
step:1013063,test_result:-59.6475
step:1114264,test_result:-64.21750000000004
step:1215964,test_result:-63.76500000000001
step:1317366,test_result:-146.25000000000003
step:1418985,test_result:-136.82999999999998
step:1520087,test_result:-83.77000000000001
step:1621289,test_result:-44.41500000000001
step:1722760,test_result:-28.615
step:1824132,test_result:-133.2425
step:1925113,test_result:-136.45000000000002
step:2025915,test_result:-53.902499999999996
step:2127532,test_result:-66.61250000000001
step:2228982,test_result:-36.13000000000001
step:2330683,test_result:-57.09500000000001
step:2431784,test_result:-36.8575000000000

In [None]:
# 保存模型参数
torch.save(model.state_dict(), 'DRQN.pth')