In [1]:
import gymnasium as gym
import torch
from tqdm import tqdm

In [2]:
import numpy as np

# 读取数据文件
with open('expert_data.txt', 'r') as file:
    header = file.readline()
    data_lines = file.readlines()

# 解析形状信息
data_shape = tuple(map(int, header.strip().replace('(','').replace(')','').split(',')))
data_shape[1]
loaded_data = np.zeros(data_shape)
for i, line in enumerate(data_lines):
    loaded_data[i] = np.fromstring(line, sep=',')

# 将加载的数据分隔为 expert_s 和 expert_a 数组
expert_s = loaded_data[:, :-1]
expert_a = loaded_data[:, -1:]

In [3]:
# class PolicyNet(torch.nn.Module):
#     def __init__(self,state_dim,hidden_dim,action_dim):
#         super(PolicyNet,self).__init__()
#         self.fc1=torch.nn.Linear(state_dim,hidden_dim)
#         self.relu=torch.nn.ReLU()
#         self.fc2=torch.nn.Linear(hidden_dim,action_dim)
#         self.softmax=torch.nn.Softmax(dim=1)
#     def forward(self,x):
#         x=self.fc1(x)
#         x=self.relu(x)
#         x=self.fc2(x)
#         x=self.softmax(x)
#         return x
class Policy_net(torch.nn.Module):
    def __init__(self,state_dim,hidden_dim,action_dim):
        super(Policy_net,self).__init__()
        self.fc1=torch.nn.Linear(state_dim,hidden_dim)
        self.fc2=torch.nn.Linear(hidden_dim,action_dim)
        self.softmax=torch.nn.Softmax(dim=1)
    def forward(self,x):
        return self.softmax(self.fc2(self.fc1(x)))

In [4]:
class REINFORCE:
    def __init__(self,state_dim,hidden_dim,action_dim,gamma,lr,device):
        self.policy_net=Policy_net(state_dim,hidden_dim,action_dim).to(device)
        self.policy_optimizer=torch.optim.Adam(self.policy_net.parameters(),lr)
        self.gamma=gamma
        self.device=device
    def take_action(self,state):
        state=torch.tensor([state],dtype=torch.float).to(device)
        probs=self.policy_net(state)
        action_dist=torch.distributions.Categorical(probs)
        action=action_dist.sample()
        return action.item()
    def update(self,transition_dict):
        state_list=transition_dict["states"]
        reward_list=transition_dict["rewards"]
        action_list=transition_dict["actions"]
        G=0
        self.policy_optimizer.zero_grad()
        for i in reversed(range(len(state_list))):
            reward=torch.tensor(reward_list[i]).view(-1,1).to(self.device)
            action=torch.tensor([action_list[i]]).view(-1,1).to(self.device)
            state=torch.tensor([state_list[i]],dtype=torch.float).to(self.device)
            log_prob=torch.log(self.policy_net(state).gather(1,action))
            G=self.gamma*G+reward
            loss=-1*log_prob*G
            loss.backward()
        self.policy_optimizer.step()
        

In [5]:
class Discriminator(torch.nn.Module):
    def __init__(self,state_dim,hidden_dim,action_dim):
        super(Discriminator,self).__init__()
        self.fc1=torch.nn.Linear(state_dim+action_dim,hidden_dim)
        self.relu=torch.nn.ReLU()
        self.fc2=torch.nn.Linear(hidden_dim,1)
        self.sigmoid=torch.nn.Sigmoid()
    def forward(self,state,action):
        x=torch.cat([state,action],dim=1)
        x=self.fc1(x)
        x=self.relu(x)
        x=self.fc2(x)
        x=self.sigmoid(x)
        return x

In [6]:
class GAIL:
    def __init__(self,agent,state_dim,hidden_dim,action_dim,lr,device):
        self.agent=agent
        self.discriminator=Discriminator(state_dim,hidden_dim,action_dim).to(device)
        self.discriminator_optimizer=torch.optim.Adam(self.discriminator.parameters(),lr)
    def learn(self,expert_s, expert_a, agent_s, agent_a, next_s, dones):
        expert_states = torch.tensor(expert_s, dtype=torch.float).to(device)
        expert_actions = torch.tensor(expert_a,dtype=torch.int64).view(-1).to(device)
        agent_states = torch.tensor(agent_s, dtype=torch.float).to(device)
        agent_actions = torch.tensor(agent_a).to(device)
        expert_actions=torch.nn.functional.one_hot(expert_actions,num_classes=2).float()
        agent_actions=torch.nn.functional.one_hot(agent_actions,num_classes=2).float()
        expert_prob = self.discriminator(expert_states, expert_actions)
        agent_prob = self.discriminator(agent_states, agent_actions)
        loss=torch.nn.BCELoss()
        discriminator_loss=loss(expert_prob,torch.ones_like(expert_prob))+loss(agent_prob,torch.zeros_like(agent_prob))
        self.discriminator_optimizer.zero_grad()
        discriminator_loss.backward()
        self.discriminator_optimizer.step()
        rewards=agent_prob.detach().cpu().numpy()
        transition_dict = {
            'states': agent_s,
            'actions': agent_a,
            'rewards': rewards,
            'next_states': next_s,
            'dones': dones
        }
        self.agent.update(transition_dict)

In [9]:
env_name="CartPole-v1"
env=gym.make(env_name)
torch.manual_seed(0)
lr_d = 1e-3
lr=1e-3
state_dim=env.observation_space.shape[0]
action_dim=env.action_space.n
hidden_dim=128
gamma=0.98
device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
agent=REINFORCE(state_dim, hidden_dim, action_dim, gamma, lr,device=device)
gail=GAIL(agent, state_dim, hidden_dim, action_dim, lr_d,device)
n_episode = 2000
return_list = []

In [10]:
with tqdm(total=n_episode, desc="进度条") as pbar:
    for i in range(n_episode):
        episode_return = 0
        state = env.reset()[0]
        done = False
        state_list = []
        action_list = []
        next_state_list = []
        done_list = []
        while not done:
            action = agent.take_action(state)
            next_state,reward,terminated,truncated,_=env.step(action)
            done=terminated or truncated
            state_list.append(state)
            action_list.append(action)
            next_state_list.append(next_state)
            done_list.append(done)
            state = next_state
            episode_return += reward
        return_list.append(episode_return)
        gail.learn(expert_s, expert_a, state_list, action_list,
                   next_state_list, done_list)
        if (i + 1) % 10 == 0:
            pbar.set_postfix({'return': '%.3f' % np.mean(return_list[-10:])})
        pbar.update(1)

进度条: 100%|██████████████████████████████████████████| 2000/2000 [08:28<00:00,  3.93it/s, return=500.000]
