In [21]:
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim

In [22]:
# Create an environment
env = gym.make('CartPole-v1')

In [23]:
state = env.reset()
print("Initial state:", state)

Initial state: (array([ 0.02242215, -0.01580391,  0.04297751,  0.00947539], dtype=float32), {})


In [24]:
class Action(nn.Module):
    def __init__(self,state_dim,action_dim):
        super(Action, self).__init__()
        self.l1 = nn.Linear(state_dim, 128)
        self.l2 = nn.Linear(128, 128)
        self.l3 = nn.Linear(128, 128)
        self.l4 = nn.Linear(128, action_dim)

    def forward(self, x_self):
        x_self = F.relu(self.l1(x_self))
        x_self = F.relu(self.l2(x_self))
        x_self = F.relu(self.l3(x_self))
        x_self = self.l4(x_self)
        
        return F.softmax(x_self, dim=-1)
        #This applies the Softmax function to the output of the last layer, transforming the raw scores into probabilities.


In [25]:
class Critic(nn.Module):
    #The critic evaluates the action taken by the actor (another component of the actor-critic framework) by estimating the value of the current state.
    #It helps provide the actor with feedback about how good or bad the actions taken are, facilitating the learning of a more effective policy.
    def __init__(self, state_dim):
        super(Critic, self).__init__()
        self.l1 = nn.Linear(state_dim, 128)
        self.l2 = nn.Linear(128,128)
        self.l3 = nn.Linear(128,128)
        self.l4 = nn.Linear(128,1)
        # means that this layer takes an input of size 128 and outputs a size of 1

    def forward(self, x_self):
        x_self = F.relu(self.l1(x_self))
        x_self = F.relu(self.l2(x_self))
        x_self = F.relu(self.l3(x_self))
        x_self = self.l4(x_self)
        return x_self

In [339]:
def compute_policy_gradient(policy_net, states, actions, advantages):
    #Calculates the loss to update the policy network based on the actions taken and their advantages.

    states_tensor = torch.FloatTensor(np.array(states)) # Make sure states is a tensor
    action = torch.FloatTensor(np.array(actions))
    advanatages_tens = torch.FloatTensor(np.array(advantages.detach().numpy()))
    
    dist = torch.distributions.Categorical(policy_net(states_tensor)) #This distribution is used to model the probabilities of selecting each action
    log_probs = dist.log_prob(action) #calculates the log probability of the actions that were taken, given the policy defined by the network
    policy_loss = -(log_probs * advantages).mean()
    # This computes the policy loss. The objective is to maximize the expected return, which is done by minimizing the negative log probability weighted by the advantages. Higher advantages will lead to a stronger push to increase the probability of those actions
    
    return policy_loss

In [341]:
def compute_kl_divergence(new_policy_net, old_policy_net, states):
    #This function calculates the KL divergence between the new policy and the old policy. KL divergence measures how one probability distribution diverges from a second, expected probability distribution.
    # Measures how much the new policy diverges from the old policy, which can be useful for ensuring that policy updates do not change the policy too drastically (often employed in trust region methods).

    states_tensor = torch.FloatTensor(np.array(states))
    
    old_dist = torch.distributions.Categorical(old_policy_net(states_tensor))
    new_dist = torch.distributions.Categorical(new_policy_net(states_tensor))
    kl_div = torch.distributions.kl.kl_divergence(new_dist, old_dist).mean()
    # This computes the KL divergence between the new and old policies. The mean is taken to summarize the divergence over the batch of states

    return kl_div

In [343]:
def conjugate_gradient(Avp, b, nsteps, epsilon=1e-10):
    
    #Avp: A function that computes the product of a matrix 𝐴 with a vector 𝑝. This is useful for scenarios where you cannot directly compute 𝐴 or it is too large to store explicitly
    #b: The right-hand side vector in the equation Ax = b
    #nsteps : the number of iterations
    #epsilon: A small number to avoid division by zero. This is used to ensure numerical stability, although it's not actively utilized in the provided code.

    # Ensure b is 1D
    if b.dim() == 0:
        b = b.view(1)  # Convert scalar to 1D tensor
    
    solution_vec = torch.zeros_like(b) #This vector will store the solution, initialized to zeros.
    residual_vec = b.clone() #The residual vector, initialized to 𝑏. It represents the difference between the left-hand side and the right-hand side of the equation and indicates how far the current estimate is from the true solution.
    search_direction = b.clone() #The search direction, initialized to 𝑏. This direction is updated during each iteration of the algorithm.
    rdotr = torch.dot(residual_vec, residual_vec) #The dot product of the residual with itself, which is a measure of how far the current solution is from satisfying Ax = b
    for _ in range(nsteps):
        Avp_ = Avp(search_direction) #Computes A*P using provided function
        alpha = rdotr / torch.dot(search_direction, Avp_) #Determines how far to move along the direction 𝑝 to reduce the residual. This is based on the steepest descent method
        solution_vec += alpha * search_direction #Updates the solution vector by moving in the direction of p
        solution_vec -= alpha * Avp_ #Updates the residual to reflect how much closer the current estimate is to the solution
        new_rdotr = torch.dot(residual_vec, residual_vec) #Calculates new dot product of the residual to see how much it has changed
        beta = new_rdotr / rdotr #computes parameter for updating the search direction
        search_direction = residual_vec + beta * seacrh_direction #Adjusts the search direction for the next iteration
        rdotr = new_rdotr
    return solution_vec

#The conjugate_gradient function efficiently solves the linear system Ax = b by iteratively refining an initial guess for 𝑥. It uses the properties of the Conjugate Gradient method to ensure that each search direction is optimal with respect to the residual, making it particularly useful in large-scale optimization problems found in machine learning and numerical analysis.

In [345]:
def TRO(policy_net, old_policy_net, states, actions, advantages, step_direction, kl_div, max_iterations=10):
    for _ in range(max_iterations):
        proposed_update = step_direction

        if kl_div <= delta:
            return proposed_update
        else:
            proposed_update *= 0.5


In [347]:
def collect_trajectory(env, policy_net, max_steps=1000):
    """
    Collects a trajectory of states, actions, rewards, next states, and done flags 
    by interacting with the specified environment using the provided policy network.

    Parameters:
    - env: The environment to interact with (e.g., a Gym environment).
    - policy_net: The neural network model used to select actions based on state input.
    - max_steps: The maximum number of steps to take in the environment during this trajectory.

    Returns:
    - states: A list of states observed during the trajectory.
    - actions: A list of actions taken by the agent.
    - rewards: A list of rewards received for each action taken.
    - next_states: A list of next states observed after taking actions.
    - dones: A list indicating whether each state was terminal (done).
    """

    # Initialize lists to store trajectory data
    states, actions, rewards, next_states, dones = [], [], [], [], []
    
    # Reset the environment to get the initial state
    state = env.reset()  # Obtain the initial observation from the environment
    print("Initial state:", state)  # Print the initial state for debugging

    state_array = state[0]  # Extract the first element of the state (the observation)
    state_tens = torch.FloatTensor(state_array).unsqueeze(0)  # Convert to tensor and reshape to [1, state_dim]

    # Loop for a maximum number of steps
    for _ in range(max_steps):
        # Get action probabilities from the policy network based on the current state
        action_probs = policy_net(state_tens)
        
        # Sample an action based on the action probabilities
        action = torch.multinomial(action_probs, 1).item()  # Select an action stochastically

        # Take a step in the environment using the selected action
        step_result = env.step(action)  # Interact with the environment
        print("Step result:", step_result)  # Print the result of the step for debugging
        
        # Check the structure of the result from the environment
        if isinstance(step_result, tuple):
            print("Step result is a tuple with length:", len(step_result))

        # Unpack the result based on its length (the expected output of env.step)
        if len(step_result) == 4:
            next_state, reward, done, info = step_result  # Unpack the 4 expected values
        elif len(step_result) == 5:  # Handle cases with additional return values
            next_state, reward, done, info, additional_info = step_result
        else:
            print("Unexpected step result format:", step_result)  # Print an error message
            break  # Exit the loop if the result format is not as expected
        
        # Print the next state for debugging
        print("Next state:", next_state)  # Output the next state for inspection

        # Extract the next state
        next_state_array = next_state[0] if isinstance(next_state, tuple) else next_state  # Handle tuple formats
        if isinstance(next_state_array, np.ndarray):
            # Convert next state to tensor and reshape it for consistency
            next_state_tens = torch.FloatTensor(next_state_array).unsqueeze(0)  # Reshape for next state
        else:
            print("Unexpected next_state format:", next_state_array)  # Print an error message
            continue  # Skip further processing if the format is unexpected

        # Store the current state, action, reward, next state, and done flag for later analysis
        states.append(state_tens.squeeze(0).numpy())  # Convert state tensor to NumPy array
        actions.append(action)  # Store the action taken
        rewards.append(reward)  # Store the received reward
        next_states.append(next_state_tens.squeeze(0).numpy())  # Convert next state tensor to NumPy
        dones.append(done)  # Store whether the episode is done

        # Update the current state tensor to the next state tensor for the next iteration
        state_tens = next_state_tens  # Move to the next state

        # Check if the episode is done; if so, exit the loop
        if done:
            break

    # Return the collected trajectory data
    return states, actions, rewards, next_states, dones


# Had to make sure that the environment was returning the needed states, then I realized it was returning it in a tuple format so we had to unpack it and make it a tensor.
# Had to resize the data so it was in a format that our models could input
# Our models intake tensors so always check the data types

In [349]:
def compute_advantages_and_returns(rewards, states, value_net, gamma=0.95, lam=0.95):
    """
    Computes advantages and returns using Generalized Advantage Estimation (GAE).
    
    Parameters:
    - rewards: List of rewards from the trajectory.
    - states: List of states from the trajectory.
    - value_net: A neural network model that estimates the value of each state.
    - gamma: Discount factor for future rewards.
    - lam: Lambda for GAE, controlling the bias-variance tradeoff.

    Returns:
    - advantages: Computed advantages for each state.
    - returns: Computed returns for each state.
    """
    
    # Convert lists to tensors
    rewards = torch.tensor(rewards, dtype=torch.float32)
    states = torch.tensor(states, dtype=torch.float32)

    # Get state values from the value network
    state_values = value_net(states).squeeze()  # Assuming value_net returns a tensor of shape [N, 1]

    # Calculate the returns and advantages
    returns = torch.zeros_like(rewards)
    advantages = torch.zeros_like(rewards)

    # Compute returns and advantages in reverse order
    # By iterating backward, the method can use the value of future states to compute the current state’s advantage effectively
    # The delta represents the temporal difference error, which is the difference between the actual reward plus the discounted value of the next state and the estimated value of the current state.
    
    for t in reversed(range(len(rewards))):
        if t == len(rewards) - 1:
            delta = rewards[t] + gamma * state_values[t] - state_values[t]  # Last state value (next state)
        else:
            delta = rewards[t] + gamma * state_values[t + 1] - state_values[t]

        advantages[t] = delta + gamma * lam * advantages[t + 1] if t + 1 < len(rewards) else delta
        returns[t] = advantages[t] + state_values[t]  # R_t = A_t + V(s_t)

    return torch.FloatTensor(advantages).unsqueeze(0), torch.FloatTensor(returns).unsqueeze(0)


In [400]:
def Avp(vector):
    # Compute the product of the Fisher information matrix and the vector
    # This often involves using the policy network to compute gradients
    # and then take the dot product with the vector.
    policy_net.zero_grad()  # Clear previous gradients
    
    # Calculate the gradient of the policy loss
    loss = compute_policy_gradient(policy_net, states, actions, advantages)
    loss.backward(retain_graph=True)  # Backpropagate to compute gradients

    # Get the gradient and flatten it
    grad = torch.cat([param.grad.view(-1) for param in policy_net.parameters()])

    # Initialize the Fisher information product
    fisher_info_product = torch.zeros_like(grad)

    # Loop through parameters to compute Fisher information product
    offset = 0
    for param in policy_net.parameters():
        param_grad = param.grad.view(-1)  # Flatten the gradient
        param_size = param_grad.size(0)  # Total number of elements in the flattened gradient

        if param_grad is not None:
            # Check the size of the vector for compatibility
            vector_slice = vector[offset:offset + param_size]

            # Ensure the sizes are compatible
            if vector_slice.size(0) != param_size:
                raise ValueError(f"Size mismatch: vector_slice size {vector_slice.size(0)} does not match param_grad size {param_size}")

            # Compute the Fisher information product
            fisher_info_product[offset:offset + param_size] += (param_grad * vector_slice).sum() * param_grad
        
        offset += param_size  # Update the offset

    return fisher_info_product

In [407]:
policy_net = Action(state_dim=env.observation_space.shape[0], action_dim=env.action_space.n)
old_policy_net = Action(state_dim=env.observation_space.shape[0], action_dim=env.action_space.n)
value_net = Critic(state_dim=env.observation_space.shape[0])

# Optimizers
policy_optimizer = optim.Adam(policy_net.parameters(), lr=1e-3)
value_optimizer = optim.Adam(value_net.parameters(), lr=1e-3)

num_epochs = 10

for epoch in range(num_epochs):
    #print(collect_trajectory(env, policy_net))
    states, actions, rewards, next_states, dones = collect_trajectory(env, policy_net)
    
    # Compute advantages and returns
    advantages, returns = compute_advantages_and_returns(rewards, states, value_net)
    
    # Compute policy gradient and KL divergence
    policy_loss = compute_policy_gradient(policy_net, states, actions, advantages)
    kl_div = compute_kl_divergence(policy_net, old_policy_net, states)
    
    # Compute the step direction using conjugate gradient
    # print(policy_loss)
    # print(policy_loss.shape)  # Should be a scalar (0-dimensional)
    policy_loss.retain_grad()
    policy_loss.backward(retain_graph=True)
    policy_optimizer.zero_grad()
    
    # print(policy_loss.grad)
    step_direction = conjugate_gradient(Avp, policy_loss.grad, nsteps=10)
    
    # Perform line search to find the optimal step size
    step_size = TRO(policy_net, old_policy_net, states, actions, advantages, step_direction, kl_div)
    
    # Update policy network
    for param, step in zip(policy_net.parameters(), step_direction):
        param.data += step_size * step
    
    # Update value network
    value_loss = compute_value_loss(value_net, states, returns)
    value_loss.backward()
    value_optimizer.step()
    value_optimizer.zero_grad()

    # Update old policy network
    old_policy_net.load_state_dict(policy_net.state_dict())

    #render the environment
    env.render()


Initial state: (array([-0.03367027,  0.03490659,  0.02899447,  0.03661744], dtype=float32), {})
Step result: (array([-0.03297214, -0.16061889,  0.02972682,  0.33830556], dtype=float32), 1.0, False, False, {})
Step result is a tuple with length: 5
Next state: [-0.03297214 -0.16061889  0.02972682  0.33830556]
Step result: (array([-0.03618452,  0.03406772,  0.03649293,  0.05514307], dtype=float32), 1.0, False, False, {})
Step result is a tuple with length: 5
Next state: [-0.03618452  0.03406772  0.03649293  0.05514307]
Step result: (array([-0.03550316, -0.16155796,  0.03759579,  0.35911277], dtype=float32), 1.0, False, False, {})
Step result is a tuple with length: 5
Next state: [-0.03550316 -0.16155796  0.03759579  0.35911277]
Step result: (array([-0.03873432,  0.03300994,  0.04477805,  0.07851771], dtype=float32), 1.0, False, False, {})
Step result is a tuple with length: 5
Next state: [-0.03873432  0.03300994  0.04477805  0.07851771]
Step result: (array([-0.03807412,  0.22746232,  0.04

ValueError: Size mismatch: vector_slice size 1 does not match param_grad size 512