<a href="https://colab.research.google.com/github/OneFineStarstuff/TheOneEverAfter/blob/main/Proximal_Policy_Optimization_(PPO).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import gym
import numpy as np
from torch.distributions import Categorical
from torch.utils.tensorboard import SummaryWriter

# Define the actor-critic network architecture
class PPOActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(PPOActorCritic, self).__init__()
        self.fc1 = nn.Linear(state_dim, 256)
        self.actor = nn.Linear(256, action_dim)
        self.critic = nn.Linear(256, 1)

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        policy = torch.softmax(self.actor(x), dim=-1)
        value = self.critic(x)
        return policy, value

# Define the PPO algorithm
class PPO:
    def __init__(self, state_dim, action_dim, lr=3e-4, gamma=0.99, eps_clip=0.2, K_epochs=4):
        self.policy = PPOActorCritic(state_dim, action_dim).to(device)
        self.optimizer = optim.Adam(self.policy.parameters(), lr=lr)
        self.gamma = gamma
        self.eps_clip = eps_clip
        self.K_epochs = K_epochs
        self.policy_old = PPOActorCritic(state_dim, action_dim).to(device)
        self.policy_old.load_state_dict(self.policy.state_dict())
        self.MseLoss = nn.MSELoss()

    def select_action(self, state):
        state = torch.FloatTensor(state).to(device)
        with torch.no_grad():
            policy, _ = self.policy_old(state)
        action_prob = Categorical(policy)
        action = action_prob.sample()
        return action.item(), action_prob.log_prob(action)

    def evaluate(self, state, action):
        policy, value = self.policy(state)
        action_prob = Categorical(policy)
        action_logprobs = action_prob.log_prob(action)
        dist_entropy = action_prob.entropy()
        return action_logprobs, torch.squeeze(value), dist_entropy

    def update(self, memory):
        states = torch.FloatTensor(memory.states).to(device)
        actions = torch.LongTensor(memory.actions).to(device)
        logprobs = torch.FloatTensor(memory.logprobs).to(device)
        rewards = torch.FloatTensor(memory.rewards).to(device)
        dones = torch.FloatTensor(memory.dones).to(device)

        returns = []
        discounted_return = 0
        for reward, done in zip(reversed(rewards), reversed(dones)):
            if done:
                discounted_return = 0
            discounted_return = reward + (self.gamma * discounted_return)
            returns.insert(0, discounted_return)

        returns = torch.tensor(returns).to(device)
        returns = (returns - returns.mean()) / (returns.std() + 1e-5)

        for _ in range(self.K_epochs):
            logprobs_new, state_values, dist_entropy = self.evaluate(states, actions)
            state_values = torch.squeeze(state_values)

            ratios = torch.exp(logprobs_new - logprobs.detach())
            advantages = returns - state_values.detach()

            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * advantages
            loss = -torch.min(surr1, surr2) + 0.5 * self.MseLoss(state_values, returns) - 0.01 * dist_entropy

            self.optimizer.zero_grad()
            loss.mean().backward()
            self.optimizer.step()

        self.policy_old.load_state_dict(self.policy.state_dict())

# Memory class to store experiences
class Memory:
    def __init__(self):
        self.states = []
        self.actions = []
        self.logprobs = []
        self.rewards = []
        self.dones = []

    def clear_memory(self):
        self.states.clear()
        self.actions.clear()
        self.logprobs.clear()
        self.rewards.clear()
        self.dones.clear()

# Training function for PPO
def train_ppo(env_name, max_episodes=1000):
    env = gym.make(env_name)  # Create the environment
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
    memory = Memory()
    ppo = PPO(state_dim, action_dim)

    # Initialize TensorBoard writer
    writer = SummaryWriter()

    running_reward = 0

    for episode in range(max_episodes):
        state = env.reset()  # Reset the environment and get initial state

        episode_reward = 0

        for t in range(1000):
            action, logprob = ppo.select_action(state)
            next_state, reward, done, info = env.step(action)

            memory.states.append(state)
            memory.actions.append(action)
            memory.logprobs.append(logprob.item())
            memory.rewards.append(reward)
            memory.dones.append(done)

            state = next_state
            episode_reward += reward

            if done:
                break

        ppo.update(memory)
        memory.clear_memory()

        running_reward += episode_reward
        avg_reward = running_reward / (episode + 1)

        # Log average reward to TensorBoard
        writer.add_scalar('Average Reward', avg_reward, episode)

        print(f'Episode {episode}, Average Reward: {avg_reward}')

    writer.close()
    env.close()

# Define your device (GPU or CPU)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Train PPO on CartPole-v1
train_ppo("CartPole-v1")