In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import gym

class PolicyNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return F.softmax(self.fc3(x), dim=-1)

class ValueNetwork(nn.Module):
    def __init__(self, state_dim):
        super(ValueNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

class PPO:
    def __init__(self, state_dim, action_dim, lr=3e-4, gamma=0.99, clip_epsilon=0.2, c1=0.5, c2=0.01):
        self.policy = PolicyNetwork(state_dim, action_dim)
        self.policy_old = PolicyNetwork(state_dim, action_dim)
        self.policy_old.load_state_dict(self.policy.state_dict())
        self.optimizer = optim.Adam(self.policy.parameters(), lr=lr)
        self.value_network = ValueNetwork(state_dim)
        self.value_optimizer = optim.Adam(self.value_network.parameters(), lr=lr)

        self.gamma = gamma
        self.clip_epsilon = clip_epsilon
        self.c1 = c1
        self.c2 = c2

    def select_action(self, state):
        # Verifica que state tenga la longitud esperada
        if not isinstance(state, np.ndarray) or state.shape != (4,):
            raise ValueError(f"Expected state to be a numpy array of shape (4,), but got {state.shape}")
        
        state = torch.FloatTensor(state).unsqueeze(0)
        probs = self.policy_old(state).detach().numpy()
        action = np.random.choice(len(probs[0]), p=probs[0])
        return action

    def compute_advantages(self, rewards, values, next_values, dones):
        advantages = []
        gae = 0
        for i in reversed(range(len(rewards))):
            delta = rewards[i] + self.gamma * next_values[i] * (1 - dones[i]) - values[i]
            gae = delta + self.gamma * gae * (1 - dones[i])
            advantages.insert(0, gae)
        return advantages

    def update(self, states, actions, rewards, next_states, dones):
        states = torch.FloatTensor(states)
        next_states = torch.FloatTensor(next_states)
        actions = torch.LongTensor(actions)
        rewards = torch.FloatTensor(rewards)
        dones = torch.FloatTensor(dones)

        # Compute values and advantages
        values = self.value_network(states).squeeze()
        next_values = self.value_network(next_states).squeeze()
        advantages = self.compute_advantages(rewards, values, next_values, dones)
        advantages = torch.FloatTensor(advantages)

        # Compute policy loss
        log_probs = torch.log(self.policy(states).gather(1, actions.unsqueeze(1)).squeeze())
        old_log_probs = torch.log(self.policy_old(states).gather(1, actions.unsqueeze(1)).squeeze()).detach()
        ratio = torch.exp(log_probs - old_log_probs)
        surr1 = ratio * advantages
        surr2 = torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * advantages
        policy_loss = -torch.min(surr1, surr2).mean()

        # Compute value loss
        returns = rewards + self.gamma * next_values * (1 - dones)
        value_loss = F.mse_loss(values, returns)

        # Compute entropy (to encourage exploration)
        entropy = -(self.policy(states) * torch.log(self.policy(states))).sum(dim=1).mean()

        # Total loss
        loss = policy_loss + self.c1 * value_loss - self.c2 * entropy

        # Update networks
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        self.value_optimizer.zero_grad()
        value_loss.backward()
        self.value_optimizer.step()

        # Update old policy
        self.policy_old.load_state_dict(self.policy.state_dict())

if __name__ == "__main__":
    env = gym.make("CartPole-v1")
    ppo = PPO(state_dim=env.observation_space.shape[0], action_dim=env.action_space.n)

    num_episodes = 1000
    max_timesteps = 300
    update_timestep = 2000
    timestep = 0

    for episode in range(num_episodes):
        state = env.reset()
        episode_rewards = 0
        done = False
        states, actions, rewards, next_states, dones = [], [], [], [], []

        while not done:
            action = ppo.select_action(state)
            next_state, reward, done, _ = env.step(action)
            states.append(state)
            actions.append(action)
            rewards.append(reward)
            next_states.append(next_state)
            dones.append(done)

            state = next_state
            episode_rewards += reward
            timestep += 1

            if timestep % update_timestep == 0:
                ppo.update(states, actions, rewards, next_states, dones)
                states, actions, rewards, next_states, dones = [], [], [], [], []

        print(f"Episode {episode + 1}: Total Reward = {episode_rewards}")

    env.close()


AttributeError: 'tuple' object has no attribute 'shape'