In [None]:
#@title Package downloads
!pip install gymnasium --quiet

In [None]:
#@title Package imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import numpy as np
import random
import typing
import gymnasium as gym
import matplotlib.pyplot as plt
from IPython.display import display, clear_output

In [None]:
#@title Helper functions
class DiscretePolicyNetwork(nn.Module):
  def __init__(self, num_cells, action_space_size):
    super(DiscretePolicyNetwork, self).__init__()
    self.action_space_size = action_space_size
    self.dense1 = nn.LazyLinear(num_cells)
    self.dense2 = nn.LazyLinear(num_cells)
    self.dense3 = nn.LazyLinear(num_cells)
    self.output = nn.LazyLinear(action_space_size)

  def forward(self, x):
    x = F.leaky_relu(self.dense1(x), negative_slope=0.01)
    x = F.leaky_relu(self.dense2(x), negative_slope=0.01)
    x = F.leaky_relu(self.dense3(x), negative_slope=0.01)
    x = F.softmax(self.output(x), dim=-1)
    return x

  def generate_actions(self, x):
    phi_ts = self(x)
    m = torch.distributions.Categorical(phi_ts)
    actions = m.sample()
    action_probs = torch.gather(phi_ts, 1, actions.unsqueeze(-1)).squeeze(-1)
    return actions, action_probs, phi_ts

  def generate_action(self, x):
    phi_t = self(x).squeeze()
    action = torch.distributions.Categorical(phi_t).sample()
    action_prob = phi_t[action]
    return action, action_prob, phi_t


class ValueNetwork(nn.Module):
  def __init__(self, num_cells):
    super(ValueNetwork, self).__init__()
    self.dense1 = nn.LazyLinear(num_cells)
    self.dense2 = nn.LazyLinear(num_cells)
    self.dense3 = nn.LazyLinear(num_cells)
    self.output = nn.LazyLinear(1)

  def forward(self, x):
    x = F.leaky_relu(self.dense1(x), negative_slope=0.01)
    x = F.leaky_relu(self.dense2(x), negative_slope=0.01)
    x = F.leaky_relu(self.dense3(x), negative_slope=0.01)
    x = self.output(x)
    return x

# helper function for creating minibatches
def create_minibatches(data, batch_size):
  random.shuffle(data)
  minibatches = []

  for i in range(0, len(data), batch_size):
    mini_batch = data[i:i + batch_size]
    minibatches.append(mini_batch)

  return minibatches

def collect_trajectory(T, state, policy_net, angle_penalty):
  trajectory = []
  done = False

  for j in range(T):
    # Check if the trajectory has ended and break the loop if it has
      if done:
          break

      # Generate action and action probability from the policy network
      action, action_prob, _ = policy_net.generate_action(state)

      # Take a step in the environment using the chosen action
      next_state, reward, done, _, _ = env.step(action.item())
      angle = next_state[2]

      shaped_reward = reward - abs(angle) * angle_penalty
      # Append the state, action, reward, and action probability to the trajectory
      trajectory.append([state, action, shaped_reward, action_prob])

      state = torch.tensor(next_state)

  return trajectory

def extend_trajectory(T, trajectory, value_net, gamma, lambda_):
    states = torch.stack([observation[0] for observation in trajectory])
    state_values = value_net(states).squeeze().detach().numpy()

    rewards = [observation[2] for observation in trajectory]
    dones = [1.0 if t == len(trajectory) - 1 else 0.0 for t in range(len(trajectory))]

    advantages = []
    gae = 0
    for t in reversed(range(len(trajectory))):
        if t == len(trajectory) - 1:
            delta = rewards[t] - state_values[t]
        else:
            delta = rewards[t] + gamma * state_values[t + 1] * (1 - dones[t]) - state_values[t]
        gae = delta + gamma * lambda_ * (1 - dones[t]) * gae
        advantages.append(gae)

    advantages.reverse()
    advantages = torch.tensor(advantages, dtype=torch.float32)

    returns = []
    for t in range(len(trajectory)):
        if t == len(trajectory) - 1:
            returns.append(advantages[t])
        else:
            returns.append(advantages[t] + state_values[t])

    for t in range(len(trajectory)):
        trajectory[t].extend([advantages[t], returns[t]])

    return trajectory
"""
def extend_trajectory(T, trajectory, value_net):
  states = torch.stack([observation[0] for observation in trajectory])
  state_values = value_net(states).squeeze()
  next_state_value = state_values[-1] if (len(trajectory) == T) else 0

  for t in range(len(trajectory) - 1, -1, -1):

      reward = trajectory[t][2]

      if t == len(trajectory) - 1:
          # For the last step, calculate V_target using the next state value
          V_target = reward + discount_rate * next_state_value
      else:
          # For other steps, calculate V_target using the V_target of the next step
          V_target = reward + discount_rate * trajectory[t + 1][-1]

      # Get the value network's estimate for the current state and calculate the advantage
      A = V_target - state_values[t]

      # Extend the step data with the calculated advantage and V_target
      trajectory[t].extend([A, V_target])

  return trajectory
"""
def calc_lclipvh(minibatch, policy_network, value_network, v, h, clip_epsilon):
    policies_new = []
    policies_old = []
    advantages = []
    states = []
    V_targets = []
    phi_ts = []

    # Collect data for each step in the trajectory
    for trajectory in minibatch:
        for step in trajectory:
            state, action, reward, policy_old, advantage, v_target = step
            policies_old.append(policy_old)
            advantages.append(advantage)
            states.append(torch.tensor(state, dtype=torch.float32))  # Convert to PyTorch tensor
            V_targets.append(v_target)

    # Convert lists to PyTorch tensors
    policies_old = torch.tensor(policies_old, dtype=torch.float32)
    advantages = torch.tensor(advantages, dtype=torch.float32)
    V_targets = torch.tensor(V_targets, dtype=torch.float32)

    L_clip, phi_ts = calc_lclip(policy_network, states, advantages, policies_old, clip_epsilon)
    print(f"L_clip before negation: {L_clip}")
    print(f"Expected sign: positive before negation, as it's a mean of positive values")
    L_v = calc_lv(states, V_targets, value_network)
    H = calc_h(phi_ts)
    print(f"Final loss: {-L_clip + v * L_v - h * H}")
    print(f"Expected sign: could be positive or negative depending on the balance of the components")

    return L_v, -L_clip + v * L_v - h * H

def calc_lv(states, V_targets, value_network):
    state_values = value_network(torch.stack(states)).flatten()  # Ensure states is a tensor
    squared_differences = (state_values - V_targets)**2
    return torch.mean(squared_differences)

def calc_h(phi_ts):
    entropy_terms = phi_ts * torch.log(phi_ts)
    print(f"Phi_ts: {phi_ts}")
    print(f"Entropy terms: {entropy_terms}")
    print(f"Expected: phi_ts should be a probability distribution; entropy terms should be negative")

    return -torch.mean(entropy_terms)

def calc_lclip(policy_network, states, advantages, policies_old, clip_epsilon):
    states_tensor = torch.stack(states)  # Ensure states is a tensor
    _, policies_new, phi_ts = policy_network.generate_actions(states_tensor)
    policy_ratios = policies_new / policies_old
    clipped_ratios = torch.clamp(policy_ratios, 1 - clip_epsilon, 1 + clip_epsilon)
    clipped_objectives = torch.min(policy_ratios * advantages, clipped_ratios * advantages)
    print(f"Policy ratios: {policy_ratios}")
    print(f"Expected range: close to 1, within [{1 - clip_epsilon}, {1 + clip_epsilon}] for non-clipped values")
    return torch.mean(clipped_objectives), phi_ts

def init_weights(m):
  if type(m) == nn.LazyLinear:
    init.orthogonal_(m.weight)
    m.bias.data.fill_(0.01)

In [5]:
## HYPERPARAMETERS
num_trajectories = 100 # How many trajectories do we want to collect per training
T = 500  # Maximum number of steps per trajectory
gamma = 0.90  # Discount factor for future rewards
lambda_ = 0.95
v = 0.5 # weighting factor for value loss
h = 0.5 # weighting factor for entropy loss
lr = 0.01 # learning rate
minibatch_size = 128 # minibatch size
clip_epsilon = 0.2
num_epochs = 30
epoch_mean = 0
angle_penalty = 0.5
num_cells = 256 # number of cells for each layer


## GYM INITIALIZATION
# Creation of gym
# video mode
#env = gym.make("CartPole-v1", render_mode="rgb_array")
env = gym.make("CartPole-v1")

action_space_size = env.action_space.n
step = torch.tensor(env.reset()[0])

## INITIALIZATION OF NETWORKS

# define policy network
policy_net = DiscretePolicyNetwork(num_cells, action_space_size)
dummy_output = policy_net(step)
policy_net.apply(init_weights)

# define value network
value_net = ValueNetwork(num_cells)
dummy_output = value_net(step)
value_net.apply(init_weights)


optimizer_policy = torch.optim.Adam(policy_net.parameters(), lr = lr)
optimizer_value = torch.optim.Adam(value_net.parameters(), lr = lr)
# Define the learning rate scheduler
scheduler_policy = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_policy, 'max', patience=15, factor=0.8, verbose=True)
scheduler_value = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_value, 'max', patience=15, factor=0.8, verbose=True)


epoch_losses = []  # List to store losses for each epoch
average_rewards = []  # List to store the average rewards
max_episode_lengths = []
median_rewards = []
outer_loop = 0
plt.ion()

while True:
    training_data = []
    total_rewards = []
    episode_lengths = []

    for i in range(num_trajectories):
        state = torch.tensor(env.reset()[0])
        total_reward = 0
        steps = 0

        trajectory = collect_trajectory(T, state, policy_net, angle_penalty)
        trajectory = extend_trajectory(T, trajectory, value_net, gamma, lambda_)

        for step in trajectory:
            total_reward += step[2]  # Assuming the reward is the second element in the step
            steps += 1

        training_data.append(trajectory)
        total_rewards.append(total_reward)
        episode_lengths.append(steps)

    # After collecting all trajectories, calculate the average reward and the maximum episode length
    average_reward = sum(total_rewards) / len(total_rewards)
    max_episode_length = max(episode_lengths)
    average_rewards.append(average_reward)
    max_episode_lengths.append(max_episode_length)
    median_reward = np.median(total_rewards)
    median_rewards.append(median_reward)


    # Training loop for each epoch
    for epoch in range(num_epochs):
        minibatches = create_minibatches(training_data, minibatch_size)
        epoch_loss = []  # List to store losses for each minibatch in the current epoch

        for minibatch in minibatches:
        
          # Zero the gradients (PyTorch accumulates gradients by default)
          optimizer_policy.zero_grad()
          optimizer_value.zero_grad()

          # Turn on gradient tracking
          policy_net.train()
          value_net.train()

          L_v, L_clipvh = calc_lclipvh(minibatch, policy_net, value_net, v, h, clip_epsilon)
        
          L_clipvh.backward()
          L_v
          print(f"Gradients of policy network parameters: {[param.grad for param in policy_net.parameters()]}")
          print(f"Gradients should not be None and should have variability")

          optimizer_policy.step()
          optimizer_value.step()
          epoch_loss.append(L_clipvh.item())

        #scheduler_policy.step(average_reward)
        #scheduler_value.step(average_reward)

        epoch_mean = sum(epoch_loss) / len(epoch_loss)
        epoch_losses.append(epoch_mean)


    # Update the plot
    clear_output(wait=True)
    plt.figure(figsize=(10, 14))
    plt.subplot(4,2,1)
    plt.plot(average_rewards, label='Average Reward', color = "blue")
    plt.plot(median_rewards, label="Median Reward", color = "green")
    plt.xlabel('Loop')
    plt.ylabel('Reward')
    plt.title('Average and Median Reward Over Time')
    plt.legend()

    plt.subplot(4,2,2)
    plt.plot(max_episode_lengths, label='Maximum Length', color ='red')
    plt.xlabel('Loop')
    plt.ylabel('Maximum Length')
    plt.title('Maximum Length Over Time')
    plt.legend()

    plt.subplot(4,2,3)
    plt.hist(total_rewards, bins = 30, label='Reward Distribution')
    plt.xlabel('Reward')
    plt.ylabel('Frequency')
    plt.title('Reward Distribution Over Time')
    plt.legend()

    plt.subplot(4,2,4)
    plt.plot(epoch_losses, label='Epoch Loss', color = "purple")
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Epoch Loss Over Time')
    plt.legend()

    plt.tight_layout()
    plt.show()

    outer_loop += 1

# Turn off interactive plotting for final display
plt.ioff()

KeyboardInterrupt: 