In [1]:
#@title 默认标题文本

import torch
import torch.nn as nn
import time
import retro

In [2]:
env = retro.make(game='Airstriker-Genesis')

In [3]:
class Actor_NN(nn.Module):
    def __init__(self):
        super(Actor_NN, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.BatchNorm2d(8, affine=True),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),
            nn.Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.BatchNorm2d(16, affine=True),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        )

        # self.cnn = models.vgg19(pretrained=True).features.to(DEVICE).eval().requires_grad_(requires_grad=False)
        self.pool = nn.Sequential(
            nn.AdaptiveAvgPool2d(output_size=(7, 7))
        )
        self.fc = nn.Sequential(
            nn.Linear(784, 128),
            nn.BatchNorm1d(128, affine=True),
            nn.ReLU(),
            nn.Linear(128, 32),
            nn.BatchNorm1d(32, affine=True),
            nn.ReLU(),
            nn.Linear(32, 6),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        x2 = x.permute(0, 3, 2, 1)
        y = self.cnn(x2)
        y1 = self.pool(y)
        y1 = torch.flatten(y1, start_dim=1, end_dim=-1)
        z = self.fc(y1)
        return z


class Critic_NN(nn.Module):
    def __init__(self):
        super(Critic_NN, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.BatchNorm2d(8, affine=True),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),
            nn.Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.BatchNorm2d(16, affine=True),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        )

        # self.cnn = models.vgg19(pretrained=True).features.to(DEVICE).eval().requires_grad_(requires_grad=False)
        self.pool = nn.Sequential(
            nn.AdaptiveAvgPool2d(output_size=(7, 7))
        )
        self.fc = nn.Sequential(
            nn.Linear(784, 64),
            nn.BatchNorm1d(64, affine=True),
            nn.ReLU(),
            nn.Linear(64, 16),
            nn.BatchNorm1d(16, affine=True),
            nn.ReLU(),
            nn.Linear(16, 1),
        )

    def forward(self, x):
        x2 = x.permute(0, 3, 2, 1)
        y = self.cnn(x2)
        y1 = self.pool(y)
        y1 = torch.flatten(y1, start_dim=1, end_dim=-1)
        z = self.fc(y1)
        return z


In [4]:
class PPO(nn.Module):
    def __init__(self, timestep, discount, batch, device, lr, lamb, actor_update, critic_update):
        super(PPO, self).__init__()
        self.actor_update = actor_update
        self.critic_update = critic_update
        self.lamb = lamb  
        self.lr = lr
        self.device = device
        self.batch_size = batch
        self.timestep = timestep
        self.gamma = discount
        self.ji = self.lamb * self.gamma
        self.coeff = torch.tensor([self.ji**i for i in range(timestep)]).to(self.device)  # 算advantage
        self.target_coeff = torch.tensor([self.gamma**i for i in range(timestep)]).to(device) # 算target，折扣因子
        self.epsilon = 0.2
        self.actor = Actor_NN().to(self.device)
        self.old_actor = Actor_NN().to(self.device)
        self.critic = Critic_NN().to(self.device)
        self.actor_optim = torch.optim.Adam(self.actor.parameters(), lr=self.lr)
        self.critic_optim = torch.optim.Adam(self.critic.parameters(), lr=self.lr)
        self.discrete_action = []
        self.decode_action()

    def decode_action(self):
        buttons = env.unwrapped.buttons
        combos = [[],['B'],['LEFT'],['RIGHT'],['B','LEFT'],['B','RIGHT']]
        for combo in combos:
            # arr = np.array([0] * env.action_space.n)
            arr = [0,0,0,0,0,0,0,0,0,0,0,0]
            for button in combo:
                arr[buttons.index(button)] = 1
            self.discrete_action.append(arr)
        # self.discrete_action = np.array(self.discrete_action)

    def action_select(self, s):
        s = (torch.FloatTensor(s)).to(self.device)
        prob = self.actor(s)
        dist = torch.distributions.Categorical(prob.squeeze(0))
        action = self.discrete_action[dist.sample()]
        return action


    def actor_learn(self, s, a, advantage):

        prob = self.actor(s[:self.batch_size])
        pi = torch.distributions.Categorical(prob)

        old_prob = self.old_actor(s[:self.batch_size])
        old_pi = torch.distributions.Categorical(old_prob)
        # print(self.discrete_action)
        # print(a)
        a = torch.tensor([self.discrete_action.index(list(i.cpu())) for i in a]).to(self.device)
        # print(a)
        # print(pi.log_prob(a))
        # print(old_pi.log_prob(a))
        ratio = torch.exp(pi.log_prob(a) - old_pi.log_prob(a))

        advantage = torch.tensor(advantage).to(self.device)
        surr = (ratio * advantage).reshape(-1,1)  # torch.Size([batch, 1])

        # mean取均值，就是期望
        # 这个loss就是画面价值，越高越好
        loss = -torch.mean(
            torch.min(surr, torch.clamp(ratio, 1 - self.epsilon, 1 + self.epsilon) * advantage.reshape(-1, 1)))

        self.actor_optim.zero_grad()
        loss.backward()
        self.actor_optim.step()

    def critic_learn(self, targets, s): # 传入s重新计算V而不用self.V一个是因为actor反向传播后会清空梯度，
                      # 在这里就无法反向传播了，再一个就是多次反向传播每次都会清空梯度
        
        self.V = self.critic(s).reshape(self.batch_size+self.timestep)
        targets_ = torch.FloatTensor(targets).to(self.device)

        loss_func = nn.MSELoss()
        # print(targets_.shape)
        # print(self.V[:self.batch_size].shape)
        loss = loss_func(self.V[:self.batch_size], targets_)

        self.critic_optim.zero_grad()
        loss.backward()
        self.critic_optim.step()

    def target_cal(self, r):  # 用于更新critic
        targets = []
        for i in range(self.batch_size):
          # print(r[i].device)
          # print(self.V.device)
          # print(self.target_coeff.device)
          target = r[i] + torch.sum(self.V[i+1:i+self.timestep+1,0]*self.target_coeff)
          targets.append(int(target))
        
        return targets

    def delta_cal(self, r, s):  

        self.V = self.critic(s)
        delta = []
        for i in range(len(self.V)-1):
            delta0 = r[i] + self.gamma * self.V[i+1,0] - self.V[i,0]
            delta.append(int(delta0))

        return delta  # 用来计算advantage

    def adv_cal(self, delta):

        advantage = []
        delta_ = torch.tensor(delta).to(self.device)
        for i in range(self.batch_size):
          adv = torch.sum(self.coeff * delta_[i:i+self.timestep])
          advantage.append(int(adv))

        return advantage

    def update(self, s, a, delta, targets):
        self.old_actor.load_state_dict(self.actor.state_dict())
        advantage = self.adv_cal(delta)

        for i in range(self.actor_update):
            self.actor_learn(s, a, advantage)

        for i in range(self.critic_update):
            self.critic_learn(targets, s)

In [5]:
DEVICE = torch.device('cuda')
BATCH = 32
EPSILON = 0.98
DISCOUNT = 0.99
TIME_STEP = 256
GAE_PARA = 0.95
CLIPING = 0.1
N_EPISODE = 1201
LR = 2.5*1e-4
Actor_Update = 3
Critic_Update = 3

In [6]:
env.seed(23)
torch.manual_seed(23)

agent = PPO(TIME_STEP, DISCOUNT, BATCH, DEVICE, LR, 
            GAE_PARA, Actor_Update, Critic_Update).to(DEVICE)

In [7]:
agent.load_state_dict(torch.load("D:\Python_Projects\D3QN__Airstriker-Genesis\models\ppo_Atari\ppo_one\ppo_600.pth"))

<All keys matched successfully>

In [None]:

all_ep_r = []
s = env.reset()
while True:
    s = torch.FloatTensor(s[24:,:,:]).unsqueeze(0)
    a = agent.eval().action_select(s)
    s, r, done, _ = env.step(a)
    env.render()
    time.sleep(0.012)
    if done is True:
        s = env.reset()


ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "D:\ProgramData\Anaconda3\lib\site-packages\IPython\core\interactiveshell.py", line 3457, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "C:\Users\JACKLE~1\AppData\Local\Temp/ipykernel_43740/4007662215.py", line 7, in <module>
    env.render()
  File "D:\ProgramData\Anaconda3\lib\site-packages\retro\retro_env.py", line 230, in render
    self.viewer.imshow(img)
  File "D:\ProgramData\Anaconda3\lib\site-packages\gym\envs\classic_control\rendering.py", line 449, in imshow
    self.window.flip()
  File "D:\ProgramData\Anaconda3\lib\site-packages\pyglet\window\win32\__init__.py", line 388, in flip
    self.context.flip()
  File "D:\ProgramData\Anaconda3\lib\site-packages\pyglet\gl\win32.py", line 252, in flip
    _gdi32.SwapBuffers(self.canvas.hdc)
KeyboardInterrupt

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "D:\ProgramData\Anaconda3\lib\site-packages\IPy

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "D:\ProgramData\Anaconda3\lib\site-packages\IPython\core\interactiveshell.py", line 3457, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "C:\Users\JACKLE~1\AppData\Local\Temp/ipykernel_43740/4007662215.py", line 7, in <module>
    env.render()
  File "D:\ProgramData\Anaconda3\lib\site-packages\retro\retro_env.py", line 230, in render
    self.viewer.imshow(img)
  File "D:\ProgramData\Anaconda3\lib\site-packages\gym\envs\classic_control\rendering.py", line 449, in imshow
    self.window.flip()
  File "D:\ProgramData\Anaconda3\lib\site-packages\pyglet\window\win32\__init__.py", line 388, in flip
    self.context.flip()
  File "D:\ProgramData\Anaconda3\lib\site-packages\pyglet\gl\win32.py", line 252, in flip
    _gdi32.SwapBuffers(self.canvas.hdc)
KeyboardInterrupt

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "D:\ProgramData\Anaconda3\lib\site-packages\IPy

tensor([[1., 2., 3.]])