In [None]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal, TransformedDistribution
from torch.distributions.transforms import TanhTransform
import time
import matplotlib.pyplot as plt
import torch.nn.init as init

# Adjusted Hyperparameters
ENV_NAME = 'BipedalWalkerHardcore-v3'
HIDDEN_SIZE = 512
LEARNING_RATE = 1e-4
GAMMA = 0.99
LAMBDA = 0.95
CLIP_EPSILON = 0.2
ENTROPY_COEF = 0.001
VALUE_LOSS_COEF = 0.5
MAX_GRAD_NORM = 0.5
PPO_EPOCHS = 15
MINI_BATCH_SIZE = 256
TOTAL_EPISODES = 10000
ROLLOUT_LENGTH = 4096
EVAL_INTERVAL = 100       # Evaluate every 100 episodes
EVAL_EPISODES = 5

# Early stopping parameters
patience = 10
min_delta = 1e-3

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Set random seeds for reproducibility
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)

# Initialize the environment
env = gym.make(ENV_NAME)
eval_env = gym.make(ENV_NAME, render_mode='human')

env.action_space.seed(seed)
eval_env.action_space.seed(seed + 1)

obs_size = env.observation_space.shape[0]
action_size = env.action_space.shape[0]

# Running Mean and Std for observations
class RunningMeanStd:
    def __init__(self, epsilon=1e-4, shape=()):
        self.mean = torch.zeros(shape, dtype=torch.float64).to(device)
        self.var = torch.ones(shape, dtype=torch.float64).to(device)
        self.count = epsilon

    def update(self, x):
        x = x.to(torch.float64)
        batch_mean = torch.mean(x, dim=0)
        batch_var = torch.var(x, dim=0)
        batch_count = x.size(0)
        self.update_from_moments(batch_mean, batch_var, batch_count)

    def update_from_moments(self, batch_mean, batch_var, batch_count):
        delta = batch_mean - self.mean
        total_count = self.count + batch_count

        new_mean = self.mean + delta * batch_count / total_count
        m_a = self.var * self.count
        m_b = batch_var * batch_count
        M2 = m_a + m_b + delta ** 2 * self.count * batch_count / total_count
        new_var = M2 / total_count

        self.mean = new_mean
        self.var = new_var
        self.count = total_count

    def normalize(self, x):
        x = x.to(torch.float64)
        return (x - self.mean) / (torch.sqrt(self.var) + 1e-8)

obs_rms = RunningMeanStd(shape=obs_size)

# Define the Actor-Critic Network
class TanhNormal(TransformedDistribution):
    def __init__(self, loc, scale):
        self.normal = Normal(loc, scale)
        transforms = [TanhTransform(cache_size=1)]
        super(TanhNormal, self).__init__(self.normal, transforms)
        self.loc = loc
        self.scale = scale

    @property
    def mean(self):
        mu = self.normal.mean
        for transform in self.transforms:
            mu = transform(mu)
        return mu

    def entropy(self):
        return self.base_dist.entropy()

class ActorCritic(nn.Module):
    def __init__(self, obs_size, action_size):
        super(ActorCritic, self).__init__()
        # Common network
        self.shared = nn.Sequential(
            nn.Linear(obs_size, HIDDEN_SIZE),
            nn.ReLU(),
            nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE),
            nn.ReLU(),
        )
        # Actor network
        self.actor_mean = nn.Linear(HIDDEN_SIZE, action_size)
        # Initialize weights orthogonally
        for layer in self.modules():
            if isinstance(layer, nn.Linear):
                init.orthogonal_(layer.weight, gain=init.calculate_gain('relu'))
                init.zeros_(layer.bias)
        # Actor log_std (learned)
        self.actor_log_std = nn.Parameter(torch.zeros(action_size))
        # Critic network
        self.critic = nn.Sequential(
            nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE),
            nn.ReLU(),
            nn.Linear(HIDDEN_SIZE, 1)
        )

    def forward(self, x):
        shared_out = self.shared(x)
        # Actor
        mean = self.actor_mean(shared_out)
        std = self.actor_log_std.exp().expand_as(mean)
        dist = TanhNormal(mean, std)
        # Critic
        value = self.critic(shared_out)
        return dist, value

# Initialize the network and optimizer
model = ActorCritic(obs_size, action_size).to(device)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# Learning rate scheduler
from torch.optim.lr_scheduler import LambdaLR
scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: 1 - epoch / TOTAL_EPISODES)

# Storage for rollouts
class RolloutBuffer:
    def __init__(self):
        self.obs = []
        self.actions = []
        self.log_probs = []
        self.rewards = []
        self.dones = []
        self.values = []

    def clear(self):
        self.obs = []
        self.actions = []
        self.log_probs = []
        self.rewards = []
        self.dones = []
        self.values = []

buffer = RolloutBuffer()

# Function to compute Generalized Advantage Estimation (GAE)
def compute_gae(next_value, rewards, dones, values):
    values = values + [next_value]
    gae = 0
    returns = []
    for step in reversed(range(len(rewards))):
        delta = rewards[step] + GAMMA * values[step + 1] * (1 - dones[step]) - values[step]
        gae = delta + GAMMA * LAMBDA * (1 - dones[step]) * gae
        returns.insert(0, gae + values[step])
    return returns

# Function to evaluate the agent
def evaluate_policy(model, eval_env, episodes=5):
    model.eval()
    total_rewards = []
    for episode in range(episodes):
        state, info = eval_env.reset(seed=seed + episode)
        state = torch.FloatTensor(state).to(device)
        state = obs_rms.normalize(state).to(torch.float32)
        terminated = truncated = False
        episode_reward = 0
        while not (terminated or truncated):
            with torch.no_grad():
                dist, _ = model(state)
                action = dist.mean
            # Step the environment
            next_state, reward, terminated, truncated, _ = eval_env.step(action.detach().cpu().numpy())
            next_state = torch.FloatTensor(next_state).to(device)
            state = obs_rms.normalize(next_state).to(torch.float32)
            state = next_state
            episode_reward += reward
            time.sleep(0.01)
        total_rewards.append(episode_reward)
    model.train()
    avg_reward = np.mean(total_rewards)
    print(f"Evaluation over {episodes} episodes: Average Reward = {avg_reward}")
    return avg_reward

# Initialize variables for early stopping and tracking
best_avg_reward = -np.inf
no_improvement_counter = 0
all_episode_rewards = []
all_episode_lengths = []
all_losses = []
all_actor_losses = []
all_critic_losses = []
all_entropies = []
all_avg_rewards = []
episode_rewards = []
episode_lengths = []
total_timesteps = 0
next_eval = EVAL_INTERVAL
episode_count = 0

while episode_count < TOTAL_EPISODES:
    # Reset the environment and get the initial state
    state, info = env.reset(seed=seed + episode_count)
    state = torch.FloatTensor(state).to(device)
    state = obs_rms.normalize(state).to(torch.float32)
    episode_reward = 0
    episode_length = 0
    terminated = truncated = False

    while not (terminated or truncated):
        dist, value = model(state)
        action = dist.rsample()
        log_prob = dist.log_prob(action).sum(dim=-1)
        # Step the environment
        next_state, reward, terminated, truncated, _ = env.step(action.detach().cpu().numpy())
        next_state = torch.FloatTensor(next_state).to(device)
        next_state = obs_rms.normalize(next_state).to(torch.float32)
        done = terminated or truncated

        # Store experience in buffer (detach tensors)
        buffer.obs.append(state)
        buffer.actions.append(action.detach())
        buffer.log_probs.append(log_prob.detach())
        buffer.rewards.append(reward)
        buffer.dones.append(done)
        buffer.values.append(value.detach().squeeze())
        state = next_state
        episode_reward += reward
        episode_length += 1
        total_timesteps += 1

        # Check if it's time to update the policy
        if len(buffer.rewards) >= ROLLOUT_LENGTH or done:
            # Compute next value
            with torch.no_grad():
                _, next_value = model(state)
            next_value = next_value.detach().squeeze()

            # Compute returns and advantages
            returns = compute_gae(next_value, buffer.rewards, buffer.dones, buffer.values)
            advantages = [ret - val for ret, val in zip(returns, buffer.values)]
            advantages = torch.tensor(advantages, dtype=torch.float32).to(device)
            returns = torch.tensor(returns, dtype=torch.float32).to(device)

            # Update observation normalization
            obs_batch = torch.stack(buffer.obs)
            obs_rms.update(obs_batch)

            # Normalize observations in the buffer
            buffer.obs = [obs_rms.normalize(obs).to(torch.float32) for obs in buffer.obs]

            # Flatten the buffers
            obs_tensor = torch.stack(buffer.obs)
            actions_tensor = torch.stack(buffer.actions)
            log_probs_tensor = torch.stack(buffer.log_probs)
            values_tensor = torch.stack(buffer.values).to(device)
            # Clear buffer
            buffer.clear()

            # Normalize advantages
            advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

            # PPO Optimization step
            total_loss = 0
            total_actor_loss = 0
            total_critic_loss = 0
            total_entropy = 0
            num_updates = 0
            for _ in range(PPO_EPOCHS):
                # Create mini-batches
                indices = np.arange(len(obs_tensor))
                np.random.shuffle(indices)
                for start in range(0, len(obs_tensor), MINI_BATCH_SIZE):
                    end = start + MINI_BATCH_SIZE
                    mini_batch_indices = indices[start:end]
                    mb_obs = obs_tensor[mini_batch_indices]
                    mb_actions = actions_tensor[mini_batch_indices]
                    mb_log_probs = log_probs_tensor[mini_batch_indices]
                    mb_returns = returns[mini_batch_indices]
                    mb_advantages = advantages[mini_batch_indices]
                    # Forward pass
                    dist, value = model(mb_obs)
                    # Compute entropy using base distribution
                    entropy = dist.base_dist.entropy().sum(dim=-1).mean()
                    new_log_probs = dist.log_prob(mb_actions).sum(dim=-1)
                    # Ratio for clipping
                    ratio = (new_log_probs - mb_log_probs).exp()
                    surr1 = ratio * mb_advantages
                    surr2 = torch.clamp(ratio, 1.0 - CLIP_EPSILON, 1.0 + CLIP_EPSILON) * mb_advantages
                    actor_loss = -torch.min(surr1, surr2).mean()
                    critic_loss = VALUE_LOSS_COEF * (mb_returns - value.squeeze()).pow(2).mean()
                    loss = actor_loss + critic_loss - ENTROPY_COEF * entropy
                    # Backpropagation
                    optimizer.zero_grad()
                    loss.backward()
                    nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
                    optimizer.step()
                    scheduler.step()
                    # Accumulate losses
                    total_loss += loss.item()
                    total_actor_loss += actor_loss.item()
                    total_critic_loss += critic_loss.item()
                    total_entropy += entropy.item()
                    num_updates += 1

            # Compute average losses
            avg_loss = total_loss / num_updates
            avg_actor_loss = total_actor_loss / num_updates
            avg_critic_loss = total_critic_loss / num_updates
            avg_entropy = total_entropy / num_updates

            # Store losses
            all_losses.append(avg_loss)
            all_actor_losses.append(avg_actor_loss)
            all_critic_losses.append(avg_critic_loss)
            all_entropies.append(avg_entropy)

            # Verbose logging
            print(f"Episode {episode_count} | Timesteps {total_timesteps} | Avg Loss: {avg_loss:.4f} | "
                  f"Actor Loss: {avg_actor_loss:.4f} | Critic Loss: {avg_critic_loss:.4f} | "
                  f"Entropy: {avg_entropy:.4f}")
            break  # Exit loop after policy update

    episode_rewards.append(episode_reward)
    episode_lengths.append(episode_length)
    episode_count += 1

    # Print average reward every 10 episodes
    if episode_count % 10 == 0:
        avg_reward = np.mean(episode_rewards[-10:])
        avg_length = np.mean(episode_lengths[-10:])
        print(f"Episode {episode_count} | Average Reward (last 10 episodes): {avg_reward:.2f} | "
              f"Average Length: {avg_length:.2f}")

    # Evaluate the agent periodically
    if episode_count % EVAL_INTERVAL == 0:
        print(f"\nEvaluating at episode {episode_count}...")
        avg_reward = evaluate_policy(model, eval_env, episodes=EVAL_EPISODES)
        all_avg_rewards.append(avg_reward)

        # Early stopping and model saving
        if avg_reward > best_avg_reward + min_delta:
            best_avg_reward = avg_reward
            no_improvement_counter = 0
            # Save the model
            torch.save(model.state_dict(), f'best_model_episode_{episode_count}.pth')
            print(f"Best model saved with average reward {best_avg_reward} at episode {episode_count}")
        else:
            no_improvement_counter += 1
            print(f"No improvement for {no_improvement_counter} evaluation(s)")

        if no_improvement_counter >= patience:
            print(f"Early stopping at episode {episode_count} due to no improvement in average reward")
            break
        print()

env.close()
eval_env.close()

# Plotting the results
episodes = range(len(episode_rewards))

# Plot rewards
plt.figure()
plt.plot(episodes, episode_rewards)
plt.xlabel('Episode')
plt.ylabel('Reward')
plt.title('Episode Reward Over Time')
plt.savefig('rewards.png')
plt.show()

# Plot losses
plt.figure()
plt.plot(range(len(all_losses)), all_losses, label='Total Loss')
plt.plot(range(len(all_actor_losses)), all_actor_losses, label='Actor Loss')
plt.plot(range(len(all_critic_losses)), all_critic_losses, label='Critic Loss')
plt.xlabel('Policy Update')
plt.ylabel('Loss')
plt.title('Losses Over Time')
plt.legend()
plt.savefig('losses.png')
plt.show()

# Plot average rewards during evaluation
eval_episodes = [EVAL_INTERVAL * i for i in range(1, len(all_avg_rewards)+1)]
plt.figure()
plt.plot(eval_episodes, all_avg_rewards)
plt.xlabel('Episode')
plt.ylabel('Average Reward')
plt.title('Average Reward During Evaluation')
plt.savefig('avg_rewards.png')
plt.show()


Episode 0 | Timesteps 47 | Avg Loss: 943.8116 | Actor Loss: -0.0132 | Critic Loss: 943.8305 | Entropy: 5.6750
Episode 1 | Timesteps 1000 | Avg Loss: 34.3401 | Actor Loss: 0.0156 | Critic Loss: 34.3301 | Entropy: 5.6756
Episode 2 | Timesteps 3000 | Avg Loss: 1.6720 | Actor Loss: 0.0143 | Critic Loss: 1.6634 | Entropy: 5.6723
Episode 3 | Timesteps 5000 | Avg Loss: 0.1187 | Actor Loss: -0.0267 | Critic Loss: 0.1510 | Entropy: 5.6485
Episode 4 | Timesteps 6944 | Avg Loss: 15.0784 | Actor Loss: -0.0018 | Critic Loss: 15.0859 | Entropy: 5.6334
Episode 5 | Timesteps 7003 | Avg Loss: 411.8330 | Actor Loss: 0.1043 | Critic Loss: 411.7344 | Entropy: 5.6330
Episode 6 | Timesteps 9003 | Avg Loss: 1.6675 | Actor Loss: -0.0159 | Critic Loss: 1.6890 | Entropy: 5.6250
Episode 7 | Timesteps 9070 | Avg Loss: 366.1867 | Actor Loss: 0.1499 | Critic Loss: 366.0424 | Entropy: 5.6036
Episode 8 | Timesteps 11070 | Avg Loss: 0.1154 | Actor Loss: -0.0261 | Critic Loss: 0.1471 | Entropy: 5.5912
Episode 9 | Times