In [3]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import gym
import matplotlib.pyplot as plt
from torch.distributions import Categorical
import torch.nn.functional as F

class MLP(nn.Module):
    def __init__(self, input_size, output_size):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, output_size)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def select_action(state, policy_net):
    state = torch.from_numpy(state).float().unsqueeze(0)
    probs = policy_net(state)
    probs = F.softmax(probs, dim=-1)  # 添加这一行来归一化概率值
    m = Categorical(probs)
    action = m.sample()
    return action.item(), m.log_prob(action)


def train(env, policy_net, optimizer, num_episodes=1000):
    rewards = []
    for episode in range(num_episodes):
        state = env.reset()
        total_reward = 0
        done = False
        while not done:
            action, log_prob = select_action(state, policy_net)
            next_state, reward, done, _ = env.step(action)
            optimizer.zero_grad()
            reward += 10 * (-log_prob)
            reward.backward()
            optimizer.step()
            total_reward += reward
            state = next_state
        rewards.append(total_reward)
        if episode % 100 == 0:
            print(f"Episode {episode}, Total Reward: {total_reward}")
    return rewards

def plot_rewards(rewards):
    plt.plot(rewards)
    plt.xlabel("Episode")
    plt.ylabel("Total Reward")
    plt.title("Rewards over Time")
    plt.show()

if __name__ == "__main__":
    env = gym.make("Ant-v2")
    policy_net = MLP(env.observation_space.shape[0], env.action_space.shape[0])
    optimizer = optim.Adam(policy_net.parameters(), lr=0.001)
    rewards = train(env, policy_net, optimizer)



Episode 0, Total Reward: tensor([11845.5605], grad_fn=<AddBackward0>)


KeyboardInterrupt: 