In [1]:
import argparse
import gymnasium as gym
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class Agent(nn.Module):
    def __init__(self, input_dim, hidden_dim, out_dim):
        super().__init__()
        self.policy_net = nn.Sequential(
                         nn.Linear(input_dim, hidden_dim*2),
                         nn.ReLU(),
                         nn.Linear(hidden_dim*2, hidden_dim),
                         nn.ReLU(),
                         nn.Linear(hidden_dim, out_dim))

        self.value_net = nn.Sequential(
                         nn.Linear(input_dim, hidden_dim*2),
                         nn.ReLU(),
                         nn.Linear(hidden_dim*2, hidden_dim),
                         nn.ReLU(),
                         nn.Linear(hidden_dim, 1))
    
    def act(self, obs):
        obs = torch.tensor(obs)
        pd_params = self.policy_net(obs)
        prob_dist = torch.distributions.Categorical(logits=pd_params)
        action = prob_dist.sample()
        #calculate log of probability of taking action(a_t) by the policy(pi) given the obs(s_t)
        log_prob = prob_dist.log_prob(action)
        return action.item(), log_prob

    def compute_state_value(self, obs):
        obs = torch.tensor(obs)
        state_value = self.value_net(obs)
        return state_value

In [3]:
def sample_trajectories(env, agent, nb_episodes, nb_timesteps, gamma):
    returns_no_baseline, returns_baseline = [], []
    for episode in range(1, nb_episodes+1):
        obs, _ = env.reset()
        rewards, state_values = [], []
        for timestep in range(1, nb_timesteps+1):
            action, log_prob = agent.act(obs)
            state_value = agent.compute_state_value(obs)
            obs, reward, terminated, truncated, _ = env.step(action)
            rewards.append(reward)
            state_values.append(state_value)
            if terminated or truncated:
                break

        #"Reward-to-go policy gradient"
        #calculate return at each time step efficiently by using dynamic programming
        returns = []
        future_return = 0.0
        for t in reversed(range(len(rewards))):
            #R[t] = r[t] + gamma * R[t+1]
            future_return = rewards[t] + gamma * future_return
            returns.append(future_return)
        returns.reverse() #Now, the returns are indexed from 0 to nb_timesteps
        returns = torch.tensor(returns)
        state_values = torch.cat(state_values)

        no_baseline = torch.sum(returns).item()
        baseline = torch.sum(returns - state_values).item()
        returns_no_baseline.append(no_baseline)
        returns_baseline.append(baseline)
    
    return returns_no_baseline, returns_baseline

In [4]:
parser = argparse.ArgumentParser()
parser.add_argument("--env", default="CartPole-v1")
parser.add_argument("--nb_episodes", default=300)
parser.add_argument("--nb_timesteps", default=200)
parser.add_argument("--gamma", default=0.99)
args = parser.parse_args(args=[])

env = gym.make(args.env)

agent = Agent(input_dim=env.observation_space.shape[0],
                hidden_dim=32, out_dim=env.action_space.n)

In [5]:
ckpt_name = "checkpoints/agent_ckpt-ep_150.pt"
agent.load_state_dict(torch.load(ckpt_name))

<All keys matched successfully>

In [6]:
returns_no_baseline, returns_baseline = sample_trajectories(env, agent, nb_episodes=1_000_000, 
                                                            nb_timesteps=args.nb_timesteps, gamma=args.gamma)

In [7]:
returns_no_baseline, returns_baseline = torch.tensor(returns_no_baseline), torch.tensor(returns_baseline)
var_no_baseline, var_baseline = torch.var(returns_no_baseline), torch.var(returns_baseline)

In [8]:
print(f"Variance for no baseline: {var_no_baseline:.3f}")
print(f"Variance for baseline: {var_baseline:.3f}")

Variance for no baseline: 714.624
Variance for baseline: 341.414


---