In [1]:
import gymnasium as gym
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import rl_utils
import random

In [2]:
class Policy_net(torch.nn.Module):
    def __init__(self,state_dim,hidden_dim,action_dim):
        super(Policy_net,self).__init__()
        self.fc1=torch.nn.Linear(state_dim,hidden_dim)
        self.fc2=torch.nn.Linear(hidden_dim,action_dim)
        self.softmax=torch.nn.Softmax(dim=1)
    def forward(self,x):
        return self.softmax(self.fc2(self.fc1(x)))

In [3]:
class REINFORCE:
    def __init__(self,state_dim,hidden_dim,action_dim,learning_rate,gamma,device):
        self.state_dim=state_dim
        self.hidden_dim=hidden_dim
        self.action_dim=action_dim
        self.learning_rate=learning_rate
        self.gamma=gamma
        self.device=device
        self.policy_net=Policy_net(state_dim,hidden_dim,action_dim).to(self.device)
        self.optimizer=torch.optim.Adam(self.policy_net.parameters(),lr=self.learning_rate)
    def take_action(self,state):
        state=torch.tensor([state],dtype=torch.float).to(self.device)
        probs=self.policy_net(state)
        action_dist=torch.distributions.Categorical(probs)
        action=action_dist.sample()
        return action.item()
    def update(self,transition_dict):
        state_list=transition_dict["states"]
        reward_list=transition_dict["rewards"]
        action_list=transition_dict["actions"]
        G=0
        self.optimizer.zero_grad()
        for i in reversed(range(len(state_list))):
            reward=reward_list[i]
            action=torch.tensor([action_list[i]]).view(-1,1).to(self.device)
            state=torch.tensor([state_list[i]],dtype=torch.float).to(self.device)
            log_prob=torch.log(self.policy_net(state).gather(1,action))
            G=self.gamma*G+reward
            loss=-1*log_prob*G
            loss.backward()
        self.optimizer.step()
        

In [5]:
learning_rate=1e-3
num_episodes=2000
hidden_dim=128
gamma=0.98
env_name="CartPole-v1"
device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
env=gym.make(env_name)
state_dim=env.observation_space.shape[0]
action_dim=env.action_space.n
torch.manual_seed(0)
agent=REINFORCE(state_dim, hidden_dim, action_dim, learning_rate, gamma, device)
return_list=[]
for i in range(10):
    with tqdm(total=int(num_episodes/10),desc="第%d轮"%i) as pbar:
        for i_epoisode in range(int(num_episodes/10)):
            state=env.reset()[0]
            done=False
            transition_dict={
                "states":[],
                "rewards":[],
                "actions":[]
            }
            episode_return = 0
            while not done:
                action=agent.take_action(state)
                next_state,reward,terminated,truncated,_=env.step(action)
                
                transition_dict["states"].append(state)
                transition_dict["rewards"].append(reward)
                transition_dict["actions"].append(action)
                state=next_state
                episode_return+=reward
                
                done=terminated or truncated
            return_list.append(episode_return)
            agent.update(transition_dict)
            if((i_epoisode+1)%10==0):
                pbar.set_postfix({
                    'episode':'%d'%(num_episodes/10*i+i_epoisode+1),
                    'return':'%.3f'%np.mean(return_list[-10:])
                })
            pbar.update(1)
            
                                

第0轮: 100%|█████████████████████████████████| 200/200 [00:06<00:00, 30.83it/s, episode=200, return=77.300]
第1轮: 100%|████████████████████████████████| 200/200 [00:42<00:00,  4.74it/s, episode=400, return=298.800]
第2轮: 100%|████████████████████████████████| 200/200 [01:10<00:00,  2.86it/s, episode=600, return=340.100]
第3轮: 100%|████████████████████████████████| 200/200 [00:37<00:00,  5.40it/s, episode=800, return=163.600]
第4轮: 100%|███████████████████████████████| 200/200 [00:31<00:00,  6.31it/s, episode=1000, return=131.600]
第5轮: 100%|███████████████████████████████| 200/200 [00:44<00:00,  4.47it/s, episode=1200, return=277.700]
第6轮: 100%|███████████████████████████████| 200/200 [01:06<00:00,  2.99it/s, episode=1400, return=165.000]
第7轮: 100%|███████████████████████████████| 200/200 [00:44<00:00,  4.53it/s, episode=1600, return=135.900]
第8轮: 100%|███████████████████████████████| 200/200 [01:10<00:00,  2.83it/s, episode=1800, return=495.200]
第9轮: 100%|███████████████████████████████| 200

In [8]:
def sample_expert_data(n_episode):
    states = []
    actions = []
    for episode in range(n_episode):
        state = env.reset()[0]
        done = False
        while not done:
            action = agent.take_action(state)
            states.append(state)
            actions.append(action)
            next_state, reward, done, _,_ = env.step(action)
            state = next_state
    return np.array(states), np.array(actions)



torch.manual_seed(0)
random.seed(0)
n_episode = 1
expert_s, expert_a = sample_expert_data(n_episode)
data = np.column_stack((expert_s, expert_a))
# 获取合并后的数据的形状
data_shape = data.shape

# 保存合并后的数据及其形状为文本文件
np.savetxt('expert_data.txt', data, delimiter=',', header=str(data_shape), comments='')

In [14]:
def test_agent(agent,state,n_episodes):
    return_list=[]
    for episode in range(n_episodes):
        state=env.reset(seed=0)[0]
        done=False
        episode_return=0
        while not done:
            action=agent.take_action(state)
            next_state,reward,terminated,truncated,_=env.step(action)
            done=terminated or truncated
            state=next_state
            episode_return+=reward
        return_list.append(episode_return)
    return return_list
test_returns=[]
with tqdm(total=100,desc="进度条") as pbar:
    for i in range(1000):
        current_return = test_agent(agent, env, 10)
        test_returns.append(current_return)
        if (i + 1) % 10 == 0:
            pbar.set_postfix({'return': '%.3f' % np.mean(test_returns[-10:])})
        pbar.update(1)

进度条: 1000it [36:18,  2.18s/it, return=484.000]                                                          
