In [2]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import gymnasium as gym
import matplotlib.pyplot as plt
from collections import deque

In [3]:

# Neural Network for the REINFORCE Agent
class ReinforceNetwork(nn.Module):
    def __init__(self, n_inputs, n_outputs):
        super(ReinforceNetwork, self).__init__()
        self.fc1 = nn.Linear(n_inputs, 16)
        self.fc2 = nn.Linear(16, 32)
        self.fc3 = nn.Linear(32, n_outputs)

    def forward(self, x):
        x = torch.tanh(self.fc1(x))
        x = torch.tanh(self.fc2(x))
        return torch.softmax(self.fc3(x), dim=-1)

# REINFORCE Agent
class ReinforceAgent:
    def __init__(self, model, learning_rate=0.005, gamma=0.99):
        self.model = model
        self.optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        self.gamma = gamma
        self.saved_log_probs = []
        self.rewards = []

    def select_action(self, state):
        if isinstance(state, tuple):
            state = np.array(state)  # Convert tuple to np.ndarray
        state = torch.from_numpy(state).float().unsqueeze(0)
        probs = self.model(state)
        m = torch.distributions.Categorical(probs)
        action = m.sample()
        self.saved_log_probs.append(m.log_prob(action))
        return action.item()


    def finish_episode(self):
        R = 0
        policy_loss = []
        returns = []
        for r in self.rewards[::-1]:
            R = r + self.gamma * R
            returns.insert(0, R)
        
        returns = torch.tensor(returns)
        returns = (returns - returns.mean()) / (returns.std() + 1e-6)  # Normalize returns
        
        for log_prob, R in zip(self.saved_log_probs, returns):
            policy_loss.append(-log_prob * R)

        self.optimizer.zero_grad()
        policy_loss = torch.cat(policy_loss).sum()
        policy_loss.backward()
        self.optimizer.step()
        
        # Reset rewards and saved actions
        del self.rewards[:]
        del self.saved_log_probs[:]

# Hyperparameters
learning_rate = 0.005
gamma = 0.99
max_episodes = 5000
#max_episodes = 50

# Create the Lunar Lander environment
env = gym.make ("LunarLander-v2", continuous = False )

# Initialize the agent
n_inputs = env.observation_space.shape[0]
n_outputs = env.action_space.n
reinforce_net = ReinforceNetwork(n_inputs, n_outputs)
agent = ReinforceAgent(reinforce_net, learning_rate=learning_rate, gamma=gamma)

# Lists to store rewards and loss for plotting
episode_rewards = []
mean_rewards = []

# Training loop
for episode in range(max_episodes):
    state = env.reset()
    ep_reward = 0
    done = False

    while not done:
        # Asumiendo que 'state' es una tupla (estado, info)
        actual_state = state[0] if isinstance(state, tuple) else state
        action = agent.select_action(actual_state)
        
        #next_state, reward, done, _ = env.step(action)
        next_state = env.step(action)
        reward = next_state[1]
        done = next_state[2]
        agent.rewards.append(reward)
        ep_reward += reward
        state = next_state

        if action == "0":
            print("action: do nothing")
        elif action == "1":
            print("action: left")
        elif action == "1":
            print("action: main")
        else:
            print("action: right")

        print("reward: ", reward)
        print("done: ", done)
        print("episode reward: ", ep_reward)

    # Update model
    agent.finish_episode()
    episode_rewards.append(ep_reward)

    # Calculate mean reward for the last 100 episodes
    if episode >= 100:
        mean_reward = np.mean(episode_rewards[-100:])
        mean_rewards.append(mean_reward)
    else:
        mean_rewards.append(np.mean(episode_rewards))

    # Print progress
    if episode % 100 == 0:
        print(f"Episode {episode}, Average Reward: {mean_rewards[-1]}")

# Close the environment
env.close()


action:  3
action:  1
action:  1
action:  3
action:  1
action:  1
action:  1
action:  0
action:  1
action:  0
action:  2
action:  3
action:  0
action:  3
action:  3
action:  0
action:  2
action:  2
action:  3
action:  3
action:  3
action:  2
action:  3
action:  1
action:  3
action:  3
action:  0
action:  3
action:  3
action:  3
action:  2
action:  1
action:  3
action:  0
action:  2
action:  2
action:  1
action:  3
action:  3
action:  1
action:  0
action:  2
action:  0
action:  1
action:  2
action:  1
action:  2
action:  0
action:  1
action:  3
action:  1
action:  3
action:  3
action:  3
action:  3
action:  0
action:  3
action:  1
action:  1
action:  1
action:  0
action:  1
action:  3
action:  1
action:  1
action:  2
action:  1
action:  3
action:  3
action:  2
action:  1
action:  0
action:  2
action:  2
action:  2
action:  3
action:  0
action:  1
action:  1
action:  0
action:  2
action:  2
action:  2
action:  2
action:  1
action:  3
action:  0
action:  1
action:  1
action:  2
action:  2