### Trust Region Policy Optimization

**Principle:** Trust Region Policy Updates

**Definition:** Constrains policy updates within a Kullback-Leibler Divergence trust region for monotonic improvement (TRPO).

**Algorithm Description:** Trust Region Policy Optimization constrains policy updates to a trust region defined by KL divergence between old and new policies. It uses conjugate gradient methods to find the largest policy improvement step within this trust region, guaranteeing monotonic performance improvement.

**Typical Use Cases:**
- Avoids catastrophic performance drops
- Making stable and monotonic policy updates
- More complex than ppo.
- When guarantees on update step size are needed
- Works with continuous and discrete action spaces

**Assumptions:**
- Computationally intensive
- Continuous/discrete actions
- On-policy learning
- Stable updates



### 1. Import Libraries

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical

sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

### 2. TRPO Algorithm

TRPO (Trust Region Policy Optimization) constrains policy updates:

**Maximize:** E[π_θ(a|s)/π_θ_old(a|s) * A(s,a)]

**Subject to:** KL(π_θ_old || π_θ) ≤ δ

Where:
- KL = Kullback-Leibler divergence
- δ = trust region constraint (e.g., 0.01)
- A = advantage function

**Simplified implementation using KL penalty instead of hard constraint**

In [None]:
class PolicyNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, action_dim),
            nn.Softmax(dim=-1)
        )
    
    def forward(self, x):
        return self.fc(x)

class ValueNetwork(nn.Module):
    def __init__(self, state_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
    
    def forward(self, x):
        return self.fc(x)

In [None]:
class TRPOAgent:
    """Simplified TRPO with KL penalty"""
    
    def __init__(self, state_dim, action_dim, lr=3e-4, gamma=0.99, kl_target=0.01):
        self.gamma = gamma
        self.kl_target = kl_target
        self.beta = 1.0  # KL penalty coefficient
        
        self.policy = PolicyNetwork(state_dim, action_dim)
        self.value = ValueNetwork(state_dim)
        self.policy_optimizer = optim.Adam(self.policy.parameters(), lr=lr)
        self.value_optimizer = optim.Adam(self.value.parameters(), lr=lr)
        
        self.buffer = {'states': [], 'actions': [], 'rewards': [], 'dones': []}
    
    def select_action(self, state):
        with torch.no_grad():
            state_t = torch.FloatTensor(state)
            probs = self.policy(state_t)
            dist = Categorical(probs)
            action = dist.sample()
        return action.item()
    
    def compute_kl(self, old_probs, new_probs):
        """Compute KL divergence"""
        return (old_probs * (torch.log(old_probs + 1e-10) - torch.log(new_probs + 1e-10))).sum(-1).mean()
    
    def update(self):
        states = torch.FloatTensor(self.buffer['states'])
        actions = torch.LongTensor(self.buffer['actions'])
        rewards = self.buffer['rewards']
        dones = self.buffer['dones']
        
        # Compute returns
        returns = []
        G = 0
        for r, done in zip(reversed(rewards), reversed(dones)):
            if done:
                G = 0
            G = r + self.gamma * G
            returns.insert(0, G)
        returns = torch.FloatTensor(returns)
        returns = (returns - returns.mean()) / (returns.std() + 1e-7)
        
        # Update value function
        values = self.value(states).squeeze()
        value_loss = F.mse_loss(values, returns)
        self.value_optimizer.zero_grad()
        value_loss.backward()
        self.value_optimizer.step()
        
        # Compute advantages
        with torch.no_grad():
            advantages = returns - self.value(states).squeeze()
        
        # Old policy distribution
        with torch.no_grad():
            old_probs = self.policy(states)
        
        # Update policy with KL constraint (penalty method)
        new_probs = self.policy(states)
        dist = Categorical(new_probs)
        log_probs = dist.log_prob(actions)
        
        policy_loss = -(log_probs * advantages).mean()
        kl_div = self.compute_kl(old_probs, new_probs)
        total_loss = policy_loss + self.beta * kl_div
        
        self.policy_optimizer.zero_grad()
        total_loss.backward()
        self.policy_optimizer.step()
        
        # Adjust KL penalty
        if kl_div > 1.5 * self.kl_target:
            self.beta *= 2
        elif kl_div < self.kl_target / 1.5:
            self.beta *= 0.5
        
        # Clear buffer
        self.buffer = {'states': [], 'actions': [], 'rewards': [], 'dones': []}

### 3. Train TRPO Agent

In [None]:
env = gym.make('CartPole-v1')
agent = TRPOAgent(env.observation_space.shape[0], env.action_space.n)

episode_rewards = []
update_interval = 20

print('Training TRPO...')
for episode in range(500):
    state, _ = env.reset()
    total_reward = 0
    
    for t in range(500):
        action = agent.select_action(state)
        next_state, reward, terminated, truncated, _ = env.step(action)
        
        agent.buffer['states'].append(state)
        agent.buffer['actions'].append(action)
        agent.buffer['rewards'].append(reward)
        agent.buffer['dones'].append(terminated or truncated)
        
        total_reward += reward
        state = next_state
        
        if terminated or truncated:
            break
    
    episode_rewards.append(total_reward)
    
    if (episode + 1) % update_interval == 0:
        agent.update()
    
    if (episode + 1) % 50 == 0:
        print(f'Episode {episode+1}, Avg: {np.mean(episode_rewards[-50:]):.2f}, Beta: {agent.beta:.3f}')

### 4. Visualize Results

In [None]:
plt.figure(figsize=(12, 5))
plt.plot(episode_rewards, alpha=0.3, label='Raw')
if len(episode_rewards) > 20:
    ma = np.convolve(episode_rewards, np.ones(20)/20, mode='valid')
    plt.plot(ma, linewidth=2, label='MA(20)')
plt.axhline(195, color='r', linestyle='--', label='Solved')
plt.xlabel('Episode')
plt.ylabel('Reward')
plt.title('TRPO Training Performance')
plt.legend()
plt.grid(alpha=0.3)
plt.show()

print(f'\nFinal Avg Reward: {np.mean(episode_rewards[-100:]):.2f}')
env.close()