In [1]:
!apt-get install ffmpeg freeglut3-dev xvfb  # For visualization
!pip install "stable-baselines3[extra]>=2.0.0a4"
!pip install tensorboard
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import gym
from torch.distributions import Categorical

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
ffmpeg is already the newest version (7:4.4.2-0ubuntu0.22.04.1).
The following additional packages will be installed:
  freeglut3 libegl-dev libfontenc1 libgl-dev libgl1-mesa-dev libgles-dev libgles1 libglu1-mesa
  libglu1-mesa-dev libglvnd-core-dev libglvnd-dev libglx-dev libice-dev libopengl-dev libsm-dev
  libxfont2 libxkbfile1 libxt-dev x11-xkb-utils xfonts-base xfonts-encodings xfonts-utils
  xserver-common
Suggested packages:
  libice-doc libsm-doc libxt-doc
The following NEW packages will be installed:
  freeglut3 freeglut3-dev libegl-dev libfontenc1 libgl-dev libgl1-mesa-dev libgles-dev libgles1
  libglu1-mesa libglu1-mesa-dev libglvnd-core-dev libglvnd-dev libglx-dev libice-dev libopengl-dev
  libsm-dev libxfont2 libxkbfile1 libxt-dev x11-xkb-utils xfonts-base xfonts-encodings xfonts-utils
  xserver-common xvfb
0 upgraded, 25 newly installed, 0 to remove and 49 not upgraded.
Need t

In [3]:
# Create the CartPole-v1 environment
env = gym.make('CartPole-v1')

# Hyperparameters
gamma = 0.99
lr = 1e-4  # Try lowering the learning rate
H = 20  # Number of inner loops
K = 500  # Reduce the outer loops for testing
eta = 0.009  # Policy update rate
alpha = 0.9
beta = 0.9
xi = 0.9
entropy_weight = 0.01  # Entropy regularization

# Policy Network
class PolicyNetwork(nn.Module):
    def __init__(self):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(4, 128)  # Input is 4-dimensional for CartPole
        self.fc2 = nn.Linear(128, 2)  # Output is 2-dimensional (left/right)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return torch.softmax(x, dim=-1)

# Accelerated Natural Policy Gradient (ANPG)
class ANPG:
    def __init__(self, env, policy_network):
        self.env = env
        self.policy_network = policy_network
        self.optimizer = optim.Adam(self.policy_network.parameters(), lr=lr)

    def select_action(self, state):
        state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)  # Shape: [1, 4]
        probs = self.policy_network(state)
        dist = Categorical(probs)
        action = dist.sample()
        return action.item(), dist.log_prob(action), probs

    def compute_returns(self, rewards):
        R = 0
        returns = []
        for r in rewards[::-1]:
            R = r + gamma * R
            returns.insert(0, R)
        return returns

    def train(self):
        for k in range(K):
            state = self.env.reset()  # Reset environment and get initial state
            x0 = [torch.zeros_like(param) for param in self.policy_network.parameters()]  # Initialize x0
            v0 = [torch.zeros_like(param) for param in self.policy_network.parameters()]  # Initialize v0

            # Outer loop - K iterations
            for h in range(H):
                log_probs = []
                rewards = []
                y_h = [alpha * x0_ + (1 - alpha) * v0_ for x0_, v0_ in zip(x0, v0)]  # Layer-wise update for y_h

                # Inner loop - Sampling trajectories
                for step in range(200):
                    action, log_prob, probs = self.select_action(state)
                    next_state, reward, done, _ = self.env.step(action)

                    # Debug prints
                    #print(f"Step: {step}, Action: {action}, Reward: {reward}, Done: {done}, Probs: {probs}")

                    log_probs.append(log_prob)
                    rewards.append(reward)
                    state = next_state

                    if done:
                        break

                # Compute returns and policy loss
                returns = self.compute_returns(rewards)
                returns = torch.tensor(returns, dtype=torch.float32)  # Ensure it's a tensor
                log_probs = torch.stack(log_probs)

                policy_loss = -log_probs * (returns - returns.mean())  # Policy gradient loss
                policy_loss = policy_loss.sum()

                # Entropy regularization to encourage exploration
                entropy = -torch.sum(probs * torch.log(probs), dim=-1)
                policy_loss -= entropy_weight * entropy.sum()

                self.optimizer.zero_grad()
                grad_logp = torch.autograd.grad(policy_loss, self.policy_network.parameters(), create_graph=True, retain_graph=True)
                policy_loss.backward()
                self.optimizer.step()

                # Accelerated Gradient Steps (layer-wise)
                with torch.no_grad():
                    x_h = [y_h_ - eta * g for y_h_, g in zip(y_h, grad_logp)]  # Update x_h layer-wise
                    z_h = [beta * y_h_ + (1 - beta) * v0_ for y_h_, v0_ in zip(y_h, v0)]
                    v_h = [z_h_ - xi * g for z_h_, g in zip(z_h, grad_logp)]

                # Update x0 for the next inner loop iteration
                x0 = x_h

            # Tail averaging (layer-wise)
            omega_k = [2 / H * sum(x_h_) for x_h_ in x0]  # Calculate omega_k for each layer

            # Update policy parameters (layer-wise)
            for param, omega in zip(self.policy_network.parameters(), omega_k):
                param.data -= eta * omega

            # Logging
            print(f'Episode {k}: Reward Sum {sum(rewards)}')

# Instantiate policy network and ANPG algorithm
policy_net = PolicyNetwork()
anpg = ANPG(env, policy_net)

# Training ANPG
anpg.train()

env.close()


  deprecation(
  deprecation(


Episode 0: Reward Sum 0.0
Episode 1: Reward Sum 0.0
Episode 2: Reward Sum 0.0
Episode 3: Reward Sum 0.0
Episode 4: Reward Sum 0.0
Episode 5: Reward Sum 0.0
Episode 6: Reward Sum 0.0
Episode 7: Reward Sum 0.0
Episode 8: Reward Sum 0.0
Episode 9: Reward Sum 0.0
Episode 10: Reward Sum 0.0
Episode 11: Reward Sum 0.0
Episode 12: Reward Sum 0.0
Episode 13: Reward Sum 0.0
Episode 14: Reward Sum 0.0
Episode 15: Reward Sum 0.0
Episode 16: Reward Sum 0.0
Episode 17: Reward Sum 0.0
Episode 18: Reward Sum 0.0
Episode 19: Reward Sum 0.0
Episode 20: Reward Sum 0.0
Episode 21: Reward Sum 0.0
Episode 22: Reward Sum 0.0
Episode 23: Reward Sum 0.0
Episode 24: Reward Sum 0.0
Episode 25: Reward Sum 0.0
Episode 26: Reward Sum 0.0
Episode 27: Reward Sum 0.0
Episode 28: Reward Sum 0.0
Episode 29: Reward Sum 0.0
Episode 30: Reward Sum 0.0
Episode 31: Reward Sum 0.0
Episode 32: Reward Sum 0.0
Episode 33: Reward Sum 0.0
Episode 34: Reward Sum 0.0
Episode 35: Reward Sum 0.0
Episode 36: Reward Sum 0.0
Episode 37: