In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

from environment_static import Env

# Set random seeds for reproducibility
np.random.seed(0)
torch.manual_seed(0)

# Create environment
env = Env()
state_dim = len(env.reset())
action_dim = len(env.action_space)

# Actor-critic network architecture
class ActorCritic(nn.Module):
    def __init__(self):
        super(ActorCritic, self).__init__()
        self.fc1 = nn.Linear(state_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.actor = nn.Linear(128, action_dim)
        self.critic = nn.Linear(128, 1)

        # Initialize the weights of the network
        self.apply(self.init_weights)

    def init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            nn.init.constant_(m.bias, 0.0)

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        logits = self.actor(x)
        value = self.critic(x)
        return logits, value

# Proximal Policy Optimization (PPO) algorithm
class PPO:
    def __init__(self):
        self.policy = ActorCritic()
        self.learning_rate = 0.0005
        self.optimizer = optim.Adam(self.policy.parameters(), lr=self.learning_rate)
        self.gamma = 0.99
        self.eps_clip = 0.2
        self.update_timestep = 20

    def choose_action(self, state):
        state = torch.FloatTensor(state).unsqueeze(0)
        logits, _ = self.policy(state)
        action_probs = torch.softmax(logits, dim=-1)
        action = torch.multinomial(action_probs, 1)
        return action.item()

    def calculate_advantages(self, rewards, dones, values, next_value):
        advantages = []
        discounted_sum = 0
        for i in range(len(rewards) - 1, -1, -1):
            discounted_sum = rewards[i] + self.gamma * discounted_sum * (1 - dones[i])
            advantage = discounted_sum - values[i].item()
            advantages.insert(0, advantage)
        return advantages

    def update_policy(self, states, actions, advantages, returns):
        states = torch.FloatTensor(states)
        actions = torch.LongTensor(actions)
        advantages = torch.FloatTensor(advantages).unsqueeze(1)
        returns = torch.FloatTensor(returns)

        unique_actions, unique_indices = torch.unique(actions, return_inverse=True)

        logits, values = self.policy(states)
        values = values.squeeze()

        action_probs = torch.softmax(logits, dim=-1)
        action_masks = torch.zeros_like(action_probs).scatter_(1, unique_actions.unsqueeze(1), 1)
        old_action_probs = torch.sum(action_probs * action_masks[unique_indices.unsqueeze(1)], dim=1)
        ratios = torch.exp(torch.log(old_action_probs + 1e-10) - torch.log(action_probs + 1e-10))

        surr1 = ratios * advantages
        surr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * advantages
        actor_loss = -torch.min(surr1, surr2).mean()

        critic_loss = nn.MSELoss()(returns, values)

        loss = actor_loss + 0.5 * critic_loss

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def train(self, env, num_episodes):
        total_timesteps = 0
        episode_rewards = []
        success_count = 0
        success_steps = []

        for episode in range(num_episodes):
            total_rewards = 0
            episode_steps = 0
            max_steps = 1000

            states, actions, rewards, dones, next_states = [], [], [], [], []
            episode_reward = 0
            state = env.reset()

            while episode_steps < max_steps:
                action = self.choose_action(state)
                next_state, reward, done = env.step(action)

                states.append(state)
                actions.append(action)
                rewards.append(reward)
                dones.append(done)
                next_states.append(next_state)

                episode_reward += reward
                state = next_state
                episode_steps += 1

                total_timesteps += 1

                if done:
                    total_rewards += episode_reward
                    if episode_reward >= 0:  # Considered successful if the total reward is 200 or more
                        success_count += 1
                        success_steps.append(episode_steps)
                    else:
                        success_steps.append(0)
                    break

                if total_timesteps % self.update_timestep == 0:
                    _, next_value = self.policy(torch.FloatTensor(next_states))
                    returns = []
                    discounted_sum = 0
                    for i in range(len(rewards) - 1, -1, -1):
                        discounted_sum = rewards[i] + self.gamma * discounted_sum * (1 - dones[i])
                        returns.insert(0, discounted_sum)
                    
                    # Convert states to tensor before passing to policy network
                    states_tensor = torch.FloatTensor(states)
                    
                    advantages = self.calculate_advantages(rewards, dones, torch.cat((self.policy(states_tensor)[1], next_value.detach()), 0), next_value)
                    self.update_policy(states, actions, advantages, returns)

            episode_rewards.append(total_rewards)
            print(f"Episode: {episode + 1}, Total Timesteps: {total_timesteps}, Episode Reward: {episode_reward}")

        return episode_rewards, success_count, success_steps

class Analyzer:
    @staticmethod
    def plot_total_rewards(episode_rewards):
        plt.plot(episode_rewards)
        plt.xlabel('Episode')
        plt.ylabel('Total Rewards')
        plt.title('Total Rewards Obtained per Episode')
        plt.show()

    @staticmethod
    def calculate_success_rate(success_count, n_episodes):
        return success_count / n_episodes

    @staticmethod
    def plot_steps_to_success(success_steps):
        plt.plot(success_steps)
        plt.xlabel('Episode')
        plt.ylabel('Steps to Success')
        plt.title('Total Steps Needed to Reach Successful Episodes')
        plt.show()

if __name__ == "__main__":
    # Initialize PPO agent
    ppo_agent = PPO()
    num_episodes = 1000

    # Train PPO agent
    episode_rewards, success_count, success_steps = ppo_agent.train(env, num_episodes)

    # Plot total rewards obtained per episode
    Analyzer.plot_total_rewards(episode_rewards)

    # Calculate and print success rate of episodes
    success_rate = Analyzer.calculate_success_rate(success_count, num_episodes)
    print("Success rate:", success_rate)

    # Plot total steps needed to reach successful episodes
    if success_count > 0:
        Analyzer.plot_steps_to_success(success_steps)
    else:
        print("No successful episodes.")
