In [1]:
import gymnasium as gym


class MyWrapper(gym.Wrapper):
    def __init__(self):
        env = gym.make('CartPole-v1', render_mode='rgb_array')
        super().__init__(env)
        self.env = env
        self.step_n = 0

    def reset(self):
        state, _ = self.env.reset()
        self.step_n = 0
        return state

    def step(self, action):
        state, reward, terminated, truncated, info = self.env.step(action)
        done = terminated or truncated
        self.step_n += 1
        if self.step_n >= 200:
            done = True
        return state, reward, done, info


env = MyWrapper()
env.reset()

array([-0.03236017,  0.00699723, -0.02335546, -0.00839442], dtype=float32)

In [2]:
import  torch
actor_model = torch.nn.Sequential(
    torch.nn.Linear(4, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 2),
    torch.nn.Softmax(dim=1),
)
critic_model = torch.nn.Sequential(
    torch.nn.Linear(4, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 1),
)
actor_model(torch.randn((2,4))),critic_model(torch.randn((2,4)))

(tensor([[0.4809, 0.5191],
         [0.4871, 0.5129]], grad_fn=<SoftmaxBackward0>),
 tensor([[-0.0339],
         [ 0.0812]], grad_fn=<AddmmBackward0>))

In [3]:
import random
def get_action(state):
    state=torch.FloatTensor(state).reshape(1,4)
    probs = actor_model(state)
    action=random.choices(range(2), weights=probs[0].tolist(),k=1)[0]
    return action
get_action([1,2,3,4])

1

In [4]:
def get_oneGame_data():
    states=[]
    next_states=[]
    actions=[]
    rewards=[]
    overs=[]
    state=env.reset()
    over=False
    while not over:
        action=get_action(state)
        next_state,reward,over,_=env.step(action)
        states.append(state)
        rewards.append(reward)
        actions.append(action)
        overs.append(over)
        next_states.append(next_state)
        
        state=next_state
        
    states = torch.tensor(states, dtype=torch.float)
    actions = torch.tensor(actions, dtype=torch.long).reshape(-1,1)
    rewards = torch.tensor(rewards, dtype=torch.float).reshape(-1,1)
    next_states = torch.tensor(next_states,dtype=torch.float)
    overs = torch.tensor(overs,dtype=torch.long).reshape(-1,1)

    return states, rewards, actions, next_states, overs

get_oneGame_data()

        

  states = torch.tensor(states, dtype=torch.float)


(tensor([[-0.0387, -0.0324,  0.0395,  0.0121],
         [-0.0393,  0.1621,  0.0397, -0.2678],
         [-0.0361, -0.0335,  0.0344,  0.0371],
         [-0.0367,  0.1611,  0.0351, -0.2445],
         [-0.0335,  0.3557,  0.0302, -0.5259],
         [-0.0264,  0.5504,  0.0197, -0.8089],
         [-0.0154,  0.3550,  0.0035, -0.5101],
         [-0.0083,  0.1598, -0.0067, -0.2163],
         [-0.0051,  0.3550, -0.0110, -0.5111],
         [ 0.0020,  0.5503, -0.0212, -0.8072],
         [ 0.0130,  0.7457, -0.0374, -1.1065],
         [ 0.0279,  0.5511, -0.0595, -0.8258],
         [ 0.0389,  0.3569, -0.0760, -0.5524],
         [ 0.0461,  0.5530, -0.0871, -0.8680],
         [ 0.0571,  0.7491, -0.1044, -1.1868],
         [ 0.0721,  0.5555, -0.1282, -0.9285],
         [ 0.0832,  0.7521, -0.1467, -1.2586],
         [ 0.0983,  0.9488, -0.1719, -1.5934],
         [ 0.1172,  1.1455, -0.2038, -1.9344]]),
 tensor([[1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],


In [5]:
def test():
    state=env.reset()
    reward_sum=0
    over=False
    while not over:
        action=get_action(state)
        next_state,reward,over,_=env.step(action)
        state=next_state
        reward_sum+=reward
    return reward_sum
test()

19.0

In [7]:
def train():
    optimizer = torch.optim.Adam(actor_model.parameters(), lr=1e-3)
    optimizer_c=torch.optim.Adam(critic_model.parameters(), lr=1e-2)
    loss_fc=torch.nn.MSELoss()
    
    for epoch in range(1000):
        states, rewards, actions, next_states, overs=get_oneGame_data()
        
        #cirtic
        values=critic_model(states)
        target=rewards+(1-overs)*0.98*critic_model(next_states).detach()
        loss_c=loss_fc(values,target)
        
        optimizer_c.zero_grad()
        loss_c.backward()
        optimizer_c.step()
        
        #actor
        probs=actor_model(states)
        probs=probs.gather(1,actions)
        delta=(target-values).detach()
        loss=(-torch.log(probs)*delta).mean()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (epoch+1) % 100 == 0:
            test_result=sum([test() for _ in range(10) ])/10
            print('epoch:',epoch,'test_result:',test_result)
        
train()

epoch: 99 test_result: 200.0
epoch: 199 test_result: 200.0
epoch: 299 test_result: 200.0
epoch: 399 test_result: 200.0
epoch: 499 test_result: 200.0
epoch: 599 test_result: 195.1
epoch: 699 test_result: 200.0
epoch: 799 test_result: 200.0
epoch: 899 test_result: 200.0
epoch: 999 test_result: 200.0
