In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from utils import *
from RobotEnv import RobotEnv
import numpy as np

t_model = torch.nn.Sequential(
    torch.nn.Linear(8, 64),
    torch.nn.ReLU(),
    torch.nn.Linear(64, 64),
    torch.nn.ReLU(),
    torch.nn.Linear(64, 4),
)
env = RobotEnv(400, 400)
# 玩一局游戏获取轨迹
controller = Controller(t_model, env)
t_list = []
s, _ = env.reset()


def get_trajectory(model, env):
    t = []
    terminated = False
    truncated = False
    while not (terminated or truncated):
        a = model(torch.FloatTensor(s).reshape(1, 8)).argmax().item()
        if random.random() < 0.1:
            a = env.action_space.sample()
        ns, r, terminated, truncated, _ = env.step(a)
        t.append(ns)
    env.reset()
    return np.array(t)

In [None]:
# 获取多条轨迹，存储至t_list
for i in range(200):
    t = get_trajectory(t_model, env)
    t_list.append(t)
len(t_list)

In [None]:
# t_list = t_list[:10]

In [None]:
import torch
from torch import nn
from torch.optim import Adam
from torch.nn import MSELoss
from torch.utils.data import DataLoader

# 假设每个状态的大小为8
state_size = 8
hidden_size = 128

# 创建模型、优化器和损失函数
model = GRUModel(state_size, hidden_size, state_size)
optimizer = Adam(model.parameters(), lr=0.001)
loss_fn = MSELoss()

# 测试GRU输出
# x = torch.randn(1, 1, state_size)
# model(x).shape

In [None]:
from torch.utils.tensorboard import SummaryWriter

# 训练模型
writer = SummaryWriter("./logs/state_prediction")


def test():
    loss_sum = 0
    for trajectory in t_list:
        trajectory = torch.tensor(trajectory, dtype=torch.float32)
        m = len(trajectory)
        t_loss = 0
        for j, state in enumerate(trajectory):
            if j == m - 1:
                break
            s = torch.FloatTensor(state).reshape(1, 1, 8)
            next_s = torch.FloatTensor(trajectory[j + 1]).reshape(1, 8)
            with torch.no_grad():
                predict_state = model(s)
            t_loss += loss_fn(predict_state, next_s)
        avg_loss_t = t_loss.item() / m
        loss_sum += avg_loss_t
    return loss_sum / len(t_list)


def train():
    for epoch in range(30):  # 进行100个训练周期
        for i, trajectory in enumerate(t_list):
            # 将轨迹转换为Tensor
            trajectory = torch.tensor(trajectory, dtype=torch.float32)
            # 获取输入和目标
            m = len(trajectory)
            t_loss = 0
            repeat = 10
            for n in range(repeat):
                for j, state in enumerate(trajectory):
                    if j == m - 1:
                        break
                    s = torch.FloatTensor(state).reshape(1, 1, 8)
                    next_s = torch.FloatTensor(trajectory[j + 1]).reshape(1, 8)
                    predict_state = model(s)
                    # 计算损失
                    loss = loss_fn(predict_state, next_s)
                    # 反向传播和优化
                    with torch.no_grad():
                        t_loss += loss.item()
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                # reset_model(model)
            avg_loss = t_loss / (m * repeat)
            print(f"Epoch {epoch+1}, Trajectory {i+1}, Loss: {avg_loss}")
        # 每个epoch测试一次，记录所有轨迹预测的平均loss
        epoch_loss = test()
        writer.add_scalar("Loss", epoch_loss, global_step=epoch + 1)
        print(f"Test Epoch {epoch+1}, Loss: {epoch_loss}")


train()

In [None]:
# save
torch.save(model.state_dict(), "model.pth")