In [35]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# 初始化

In [62]:
# 定义优化器
import torch
import random
from RobotEnv import RobotEnv
from utils import Controller, Pool
env = RobotEnv(screen_width=400, screen_height=400)
# size of obs
observation_length = env.observation_space.shape[0]
action_length = env.action_space.n

# 定义模型,评估状态下每个动作的价值
model = torch.nn.Sequential(
    torch.nn.Linear(observation_length, 64),
    torch.nn.ReLU(),
    torch.nn.Linear(64, 64),
    torch.nn.ReLU(),
    torch.nn.Linear(64, action_length),
)

# 延迟更新的模型,用于计算target
model_delay = torch.nn.Sequential(
    torch.nn.Linear(observation_length, 64),
    torch.nn.ReLU(),
    torch.nn.Linear(64, 64),
    torch.nn.ReLU(),
    torch.nn.Linear(64, action_length),
)
# 复制参数
model_delay.load_state_dict(model.state_dict())
controller = Controller(model, env)
pool = Pool(controller)
pool.update()
# pool.sample()

10042

# 训练

In [63]:
from torch.utils.tensorboard import SummaryWriter
# 训练
def train():
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
    loss_fn = torch.nn.MSELoss()
    # 共训练n_step次
    n_step = 0
    log_interval = 100_000
    last_log_step = 0
    writer = SummaryWriter("./logs")
    while n_step < 50_000_000:
        n_step += pool.update()
        # 每次更新数据后,训练N次
        for i in range(200):

            # 采样N条数据
            state, action, reward, next_state, terminated = pool.sample()

            # 计算value
            value = model(state).gather(dim=1, index=action)

            # 计算target
            with torch.no_grad():
                # 使用原模型计算动作,使用延迟模型计算target,进一步缓解自举
                next_action = model(next_state).argmax(dim=1, keepdim=True)
                target = model_delay(next_state).gather(dim=1,
                                                        index=next_action)
            target = target * 0.99 * (1 - terminated) + reward

            loss = loss_fn(value, target)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        # 复制参数
        if (n_step - last_log_step) >=log_interval:
            model_delay.load_state_dict(model.state_dict())
            test_result = sum([pool.controller.play(mode="test")[-1] for _ in range(20)]) / 20
            print(f"step:{n_step},test_result:{test_result}")
            last_log_step = n_step
            # 将步数，测试结果和损失写入TensorBoard
            writer.add_scalar('Step', n_step, global_step=n_step)
            writer.add_scalar('Test Result', test_result, global_step=n_step)
            writer.add_scalar('Loss', loss.item(), global_step=n_step)      
    writer.close()  # 训练结束后关闭writer
train()

step:101262,test_result:-55.482500000000016


KeyboardInterrupt: 

In [None]:
# 保存模型参数
torch.save(model.state_dict(), 'DQN.pth')

In [6]:
# load model
model.load_state_dict(torch.load('DQN.pth'))

<All keys matched successfully>

# 测试

In [56]:
env = RobotEnv(screen_width=400, screen_height=400)
controller = Controller(model, env)
controller.play(mode="test",show=True)

error: display Surface quit