In [17]:
import gymnasium as gym


class MyWrapper(gym.Wrapper):
    def __init__(self):
        env = gym.make('CartPole-v1', render_mode='rgb_array')
        super().__init__(env)
        self.env = env
        self.step_n = 0

    def reset(self):
        state, _ = self.env.reset()
        self.step_n = 0
        return state

    def step(self, action):
        state, reward, terminated, truncated, info = self.env.step(action)
        done = terminated or truncated
        self.step_n += 1
        if self.step_n >= 200:
            done = True
        return state, reward, done, info


env = MyWrapper()
env.reset()

array([ 0.01178736,  0.04867173, -0.01777672,  0.03269818], dtype=float32)

In [18]:
import torch
model = torch.nn.Sequential(
    torch.nn.Linear(4, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 2),
    torch.nn.Softmax(dim=1),
)

In [19]:
import random
def get_action(state):
    state=torch.FloatTensor(state).reshape(1,4)
    probs = model(state)
    action=random.choices(range(2), weights=probs[0].tolist(),k=1)[0]
    return action

In [20]:
def get_oneGame_data():
    states=[]
    rewards=[]
    actions=[]
    
    state=env.reset()
    over=False
    while not over:
        action=get_action(state)
        next_state,reward,over,_=env.step(action)
        
        states.append(state)
        rewards.append(reward)
        actions.append(action)
        
        state=next_state
    return states,rewards,actions

([array([-0.04920211, -0.02577092, -0.02133037,  0.02168001], dtype=float32),
  array([-0.04971753, -0.22058058, -0.02089677,  0.30755737], dtype=float32),
  array([-0.05412914, -0.41539863, -0.01474562,  0.59357744], dtype=float32),
  array([-0.06243711, -0.22007342, -0.00287407,  0.2962864 ], dtype=float32),
  array([-0.06683858, -0.41515428,  0.00305166,  0.5880615 ], dtype=float32),
  array([-0.07514167, -0.61031884,  0.01481289,  0.88170415], dtype=float32),
  array([-0.08734804, -0.4154012 ,  0.03244697,  0.5937146 ], dtype=float32),
  array([-0.09565607, -0.6109619 ,  0.04432127,  0.8964391 ], dtype=float32),
  array([-0.1078753 , -0.41646796,  0.06225005,  0.6180108 ], dtype=float32),
  array([-0.11620466, -0.22226828,  0.07461026,  0.3455652 ], dtype=float32),
  array([-0.12065003, -0.02828257,  0.08152157,  0.07731123], dtype=float32),
  array([-0.12121568, -0.22447282,  0.08306779,  0.39456007], dtype=float32),
  array([-0.12570514, -0.4206692 ,  0.09095899,  0.7122337 ], dt

In [21]:
def test():
    state=env.reset()
    reward_sum=0
    over=False
    while not over:
        action=get_action(state)
        next_state,reward,over,_=env.step(action)
        state=next_state
        reward_sum+=reward
    return reward_sum
test()

15.0

In [23]:
def train():
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    
    for epoch in range(1000):
        states,rewards,actions=get_oneGame_data()
        optimizer.zero_grad()
        rewards_sum=0
        total_loss=0
        for i in reversed(range(len(states))):
            rewards_sum*=0.98
            rewards_sum+=rewards[i]
            state=torch.FloatTensor(states[i]).reshape(1,4)
            
            prob=model(state)
            prob=prob[0,actions[i]]
            
            loss=-torch.log(prob)*rewards_sum
            total_loss+=loss
        total_loss.backward()
        optimizer.step()
        if (epoch+1) % 100 == 0:
            test_result=sum([test() for _ in range(10) ])/10
            print('epoch:',epoch,'test_result:',test_result)
train()

epoch: 99 test_result: 43.8
epoch: 199 test_result: 52.1
epoch: 299 test_result: 157.1
epoch: 399 test_result: 187.5
epoch: 499 test_result: 166.1
epoch: 599 test_result: 200.0
epoch: 699 test_result: 200.0
epoch: 799 test_result: 153.2
epoch: 899 test_result: 108.9
epoch: 999 test_result: 196.8
