In [172]:
from random import randint
import torch

Carbin_Count = 9  # 


#定义环境
class Environment:
    def __init__(self):
        self.state = None
        self.time_spend = 0  # 使用的总共时间 要惩罚时间的
        self.reset()

    def reset(self):
        # 每个船舱的重量
        self.time_spend = 0  # 使用的总共时间 要惩罚时间的
        self.state = torch.randn(1, Carbin_Count) + torch.full((1, Carbin_Count), 1).resize(1,
                                                                                            Carbin_Count)
        return self.state

    def is_over(self):
        # 判断state每一项是否小于0
        return torch.all(self.state < 0)

    def get_state_reword(self, action: torch.Tensor):
        # 当前状态的分数,比如为负数了还有人来卸载就应该扣分
        reword = 0
        for i in action:
            index = int(i)  # 映射到货轮
            if index == 0:
                continue
            # 动起来的奖励
            reword += 0.5
            if self.state[0][index - 1] < 0:
                reword -= 1
        if self.is_over():
            # 惩罚时间 
            reword += 1000
        return reword

    def step(self, action):
        for i in action:
            index = int(i)  # 映射到货轮
            if index == 0:
                continue
            self.state[0][index-1] -= 0.1
        self.time_spend += 1
        reword = self.get_state_reword(action)
        return self.state, reword, self.is_over(), None


class MyWrapper:
    N = 3  # 港机数量

    def __init__(self):
        self.env = Environment()
        self.step_n = 0

    def reset(self):
        state = self.env.reset()
        self.step_n = 0
        return state

    def step(self, action):
        # print(action)
        state, reward, terminated, info = self.env.step(action)
        over = terminated
        #限制最大步数
        self.step_n += 1
        if self.step_n > 200:
            over = True
        return state, reward, over

    def to_show_state(self):
        return '%s  %s  %s' % (str(self.env.state), str(self.env.time_spend), str(self.env.is_over()))


env = MyWrapper()

env.reset()
env.step(torch.full((1, 1), 0.5, dtype=torch.float32))
env.to_show_state()

'tensor([[-0.8474,  0.7139,  0.8216,  2.1499,  0.1014, -0.7658,  1.8822,  2.3720,\n          0.0216]])  1  tensor(False)'

In [173]:
import torch


class A2C:

    def __init__(self, model_actor, model_critic, model_critic_delay,
                 optimizer_actor, optimizer_critic):
        self.model_actor = model_actor
        self.model_critic = model_critic
        self.model_critic_delay = model_critic_delay
        self.optimizer_actor = optimizer_actor
        self.optimizer_critic = optimizer_critic

        self.model_critic_delay.load_state_dict(self.model_critic.state_dict())
        self.requires_grad(self.model_critic_delay, False)

    def soft_update(self, _from, _to):
        for _from, _to in zip(_from.parameters(), _to.parameters()):
            value = _to.data * 0.99 + _from.data * 0.01
            _to.data.copy_(value)

    def requires_grad(self, model, value):
        for param in model.parameters():
            param.requires_grad_(value)

    def train_critic(self, state, reward, next_state, over):
        self.requires_grad(self.model_critic, True)
        self.requires_grad(self.model_actor, False)

        #计算values和targets
        value = self.model_critic(state)

        with torch.no_grad():
            target = self.model_critic_delay(next_state)
        target = target * 0.99 * (1 - over) + reward
        # print('xxxx', value.size(), target.size(), reward.size())
        #时序差分误差,也就是tdloss
        loss = torch.nn.functional.mse_loss(value, target)

        loss.backward()
        self.optimizer_critic.step()
        self.optimizer_critic.zero_grad()
        self.soft_update(self.model_critic, self.model_critic_delay)

        #减去value相当于去基线
        return (target - value).detach()

    def train_actor(self, state, action, value):
        self.requires_grad(self.model_critic, False)
        self.requires_grad(self.model_actor, True)

        #重新计算动作的概率
        prob = self.model_actor(state)
        prob = prob.gather(dim=1, index=action)

        #根据策略梯度算法的导函数实现
        #函数中的Q(state,action),这里使用critic模型估算
        prob = (prob + 1e-8).log() * value
        loss = -prob.mean()

        loss.backward()
        self.optimizer_actor.step()
        self.optimizer_actor.zero_grad()

        return loss.item()


model_actor = [
    torch.nn.Sequential(
        torch.nn.Linear(9, 6 * 64),
        torch.nn.ReLU(),
        torch.nn.Linear(6 * 64, 6 * 64),
        torch.nn.ReLU(),
        torch.nn.Linear(6 * 64, Carbin_Count + 1),
        torch.nn.Softmax(dim=1),
    ) for _ in range(env.N)
]

model_critic, model_critic_delay = [
    torch.nn.Sequential(
        torch.nn.Linear(9, 6 * 64),
        torch.nn.ReLU(),
        torch.nn.Linear(6 * 64, 6 * 64),
        torch.nn.ReLU(),
        torch.nn.Linear(6 * 64, 1),
    ) for _ in range(2)
]

optimizer_actor = [
    torch.optim.Adam(model_actor[i].parameters(), lr=1e-3)
    for i in range(env.N)
]
optimizer_critic = torch.optim.Adam(model_critic.parameters(), lr=5e-3)

a2c = [
    A2C(model_actor[i], model_critic, model_critic_delay, optimizer_actor[i],
        optimizer_critic) for i in range(env.N)
]
# x = torch.FloatTensor([1,2,3,4,6,7,8,8,4]).resize(1,Carbin_Count)
# print(model_actor[0](x))
model_actor = None
model_critic = None
model_critic_delay = None
optimizer_actor = None
optimizer_critic = None

a2c


[<__main__.A2C at 0x1690fa50c10>,
 <__main__.A2C at 0x1690fa7a550>,
 <__main__.A2C at 0x1690fa7afd0>]

In [174]:

import random


#玩一局游戏并记录数据
def play(show=False):
    state = []
    action = []
    reward = []
    next_state = []
    over = []

    s = env.reset()
    o = False
    while not o:
        a = []
        for i in range(env.N):
            #计算动作
            prob = a2c[i].model_actor(torch.FloatTensor(s).reshape(
                1, -1))[0].tolist()
            # print(s, prob)
            a.append(random.choices(range(Carbin_Count + 1), weights=prob, k=1)[0])

        #执行动作
        ns, r, o = env.step(a)

        state.append(s)
        action.append(a)
        reward.append(r)
        next_state.append(ns)
        over.append(o)

        s = ns

        if show:
            print(env.to_show_state())
    # print(state[0])
    # print(type(state), len(state))
    state = torch.tensor([item.numpy() for item in state])
    action = torch.LongTensor(action).unsqueeze(-1)
    reward = torch.FloatTensor(reward).unsqueeze(-1).unsqueeze(-1)
    next_state = torch.tensor([item.numpy() for item in next_state])
    over = torch.LongTensor(over).reshape(-1, 1)

    return state, action, reward, next_state, over, reward.sum().item()


state, action, reward, next_state, over, reward_sum = play()

reward_sum, state.size(), reward.size(), action.size()

(960.5,
 torch.Size([112, 1, 9]),
 torch.Size([112, 1, 1]),
 torch.Size([112, 3, 1]))

In [175]:
def train():
    #训练N局
    for epoch in range(3_000):
        state, action, reward, next_state, over, _ = play()

        #合并部分字段
        state_c = state.flatten(start_dim=1)
        reward_c = reward.sum(dim=1)
        next_state_c = next_state.flatten(start_dim=1)

        for i in range(env.N):
            value = a2c[i].train_critic(state_c, reward_c, next_state_c, over)
            loss = a2c[i].train_actor(state_c, action[:, i], value)

        if epoch % 250 == 0:
            test_result = sum([play()[-1] for _ in range(20)]) / 20
            print(epoch, loss, test_result)


train()

0 -1.8459137678146362 964.325
250 4.482891082763672 62.275
500 -8.162630081176758 168.875
750 -0.0 552.225
1000 -0.0 926.225
1250 -0.0 935.85
1500 0.05349911004304886 953.25
1750 -0.0 954.375
2000 -0.0 955.775
2250 -0.0 968.625
2500 -1.3746167421340942 971.75
2750 -0.0 972.475


In [176]:
play(True)[-1]

tensor([[-0.0775,  0.1910,  0.3738,  2.0309, -1.8991,  1.0799,  0.9674,  0.3243,
         -0.1416]])  1  tensor(False)
tensor([[-0.0775,  0.0910,  0.3738,  1.9309, -1.9991,  1.0799,  0.9674,  0.3243,
         -0.1416]])  2  tensor(False)
tensor([[-0.0775, -0.0090,  0.2738,  1.9309, -2.0991,  1.0799,  0.9674,  0.3243,
         -0.1416]])  3  tensor(False)
tensor([[-0.0775, -0.1090,  0.2738,  1.8309, -2.1991,  1.0799,  0.9674,  0.3243,
         -0.1416]])  4  tensor(False)
tensor([[-0.0775, -0.2090,  0.2738,  1.7309, -2.2991,  1.0799,  0.9674,  0.3243,
         -0.1416]])  5  tensor(False)
tensor([[-0.0775, -0.3090,  0.2738,  1.7309, -2.3991,  0.9799,  0.9674,  0.3243,
         -0.1416]])  6  tensor(False)
tensor([[-0.0775, -0.4090,  0.2738,  1.6309, -2.4991,  0.9799,  0.9674,  0.3243,
         -0.1416]])  7  tensor(False)
tensor([[-0.0775, -0.5090,  0.2738,  1.5309, -2.5991,  0.9799,  0.9674,  0.3243,
         -0.1416]])  8  tensor(False)
tensor([[-0.0775, -0.6090,  0.2738,  1.4309, -2.

968.5