In [1]:
import gymnasium as gym


class MyWrapper(gym.Wrapper):
    def __init__(self):
        env = gym.make('CartPole-v1', render_mode='rgb_array')
        super().__init__(env)
        self.env = env
        self.step_n = 0

    def reset(self):
        state, _ = self.env.reset()
        self.step_n = 0
        return state

    def step(self, action):
        state, reward, terminated, truncated, info = self.env.step(action)
        done = terminated or truncated
        self.step_n += 1
        if self.step_n >= 200:
            done = True
        return state, reward, done, info


env = MyWrapper()
env.reset()

array([-0.02780182, -0.01095542, -0.03379848, -0.00699443], dtype=float32)

In [2]:
import  torch
actor_model = torch.nn.Sequential(
    torch.nn.Linear(4, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 2),
    torch.nn.Softmax(dim=1),
)
critic_model = torch.nn.Sequential(
    torch.nn.Linear(4, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 1),
)
actor_model(torch.randn((2,4))),critic_model(torch.randn((2,4)))

(tensor([[0.4731, 0.5269],
         [0.5060, 0.4940]], grad_fn=<SoftmaxBackward0>),
 tensor([[-0.1387],
         [-0.1858]], grad_fn=<AddmmBackward0>))

In [3]:
import random
def get_action(state):
    state=torch.FloatTensor(state).reshape(1,4)
    probs = actor_model(state)
    action=random.choices(range(2), weights=probs[0].tolist(),k=1)[0]
    return action
get_action([1,2,3,4])

1

In [4]:
def get_oneGame_data():
    states=[]
    next_states=[]
    actions=[]
    rewards=[]
    overs=[]
    state=env.reset()
    over=False
    while not over:
        action=get_action(state)
        next_state,reward,over,_=env.step(action)
        states.append(state)
        rewards.append(reward)
        actions.append(action)
        overs.append(over)
        next_states.append(next_state)
        
        state=next_state
        
    states = torch.tensor(states, dtype=torch.float)
    actions = torch.tensor(actions, dtype=torch.long).reshape(-1,1)
    rewards = torch.tensor(rewards, dtype=torch.float).reshape(-1,1)
    next_states = torch.tensor(next_states,dtype=torch.float)
    overs = torch.tensor(overs,dtype=torch.long).reshape(-1,1)

    return states, rewards, actions, next_states, overs

get_oneGame_data()

  states = torch.tensor(states, dtype=torch.float)


(tensor([[ 2.6506e-02, -1.9506e-02, -4.2783e-02, -1.3262e-02],
         [ 2.6116e-02,  1.7620e-01, -4.3048e-02, -3.1913e-01],
         [ 2.9640e-02,  3.7191e-01, -4.9430e-02, -6.2507e-01],
         [ 3.7078e-02,  1.7751e-01, -6.1932e-02, -3.4836e-01],
         [ 4.0629e-02, -1.6677e-02, -6.8899e-02, -7.5829e-02],
         [ 4.0295e-02, -2.1075e-01, -7.0416e-02,  1.9434e-01],
         [ 3.6080e-02, -1.4692e-02, -6.6529e-02, -1.1969e-01],
         [ 3.5786e-02, -2.0880e-01, -6.8923e-02,  1.5128e-01],
         [ 3.1610e-02, -1.2763e-02, -6.5897e-02, -1.6233e-01],
         [ 3.1355e-02, -2.0688e-01, -6.9143e-02,  1.0886e-01],
         [ 2.7217e-02, -4.0095e-01, -6.6966e-02,  3.7895e-01],
         [ 1.9198e-02, -2.0494e-01, -5.9387e-02,  6.5932e-02],
         [ 1.5100e-02, -3.9917e-01, -5.8068e-02,  3.3930e-01],
         [ 7.1162e-03, -2.0327e-01, -5.1282e-02,  2.8888e-02],
         [ 3.0508e-03, -7.4497e-03, -5.0705e-02, -2.7952e-01],
         [ 2.9018e-03, -2.0181e-01, -5.6295e-02, -3.254

In [6]:
def test():
    state=env.reset()
    reward_sum=0
    over=False
    while not over:
        action=get_action(state)
        next_state,reward,over,_=env.step(action)
        state=next_state
        reward_sum+=reward
    return reward_sum
test()

31.0

In [9]:
def GAE(deltas):
    advantages = []
    s=0.0
    for delta in deltas[::-1]:
        s=0.98*0.95*s+delta
        advantages.append(s)
    advantages.reverse()
    return advantages
GAE(deltas=[1,2,3,4])

[8.690100963999999, 8.260044, 6.724, 4.0]

In [15]:
def train():
    optimizer = torch.optim.Adam(actor_model.parameters(), lr=1e-3)
    optimizer_c=torch.optim.Adam(critic_model.parameters(), lr=1e-2)
    loss_fc=torch.nn.MSELoss()
    
    for epoch in range(1000):
        states, rewards, actions, next_states, overs=get_oneGame_data()
        
        #cirtic
        values=critic_model(states)
        target=rewards+(1-overs)*0.98*critic_model(next_states).detach()
        
        #计算优势
        deltas=(target-values).squeeze(dim=1).tolist()
        advantages=GAE(deltas)
        advantages=torch.tensor(advantages, dtype=torch.float).reshape(-1,1)
        
        old_probs = actor_model(states)
        old_probs = old_probs.gather(dim=1, index=actions)
        old_probs = old_probs.detach()
        
        for _ in range(10):
            new_probs=actor_model(states)
            new_probs=new_probs.gather(dim=1, index=actions)
            
            ratios=new_probs/old_probs
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 0.8, 1.2) * advantages

            loss = -torch.min(surr1, surr2)
            loss = loss.mean()
            
            values=critic_model(states)
            loss_c=loss_fc(values,target)
            
            optimizer_c.zero_grad()
            loss_c.backward()
            optimizer_c.step()
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        if (epoch+1) % 100 == 0:
            test_result=sum([test() for _ in range(10) ])/10
            print('epoch:',epoch,'test_result:',test_result)
        
train() 

[1.0160866975784302, 0.9789121747016907, 0.9819928407669067, 1.0150607824325562, 1.0157073736190796, 1.0107085704803467, 1.001929759979248, 0.9951879978179932, 0.9818902015686035, 1.0108790397644043, 0.9836628437042236, 1.0116976499557495, 1.002530813217163, 0.9956793785095215, 1.0108733177185059, 0.9944018721580505, 1.0121004581451416, 1.0079628229141235, 0.9837712049484253]
