In [1]:
import sys

sys.path.append("/Users/matthewnicastro/Desktop/reinforcement-learning/")

import gymnasium as gym
import torch
from models.MLP import MLP
from utils.algorithms.ppo import proximal_policy_optimization
from wrappers.training import ModelTrainingWrapper
import matplotlib.pyplot as plt

torch.device("mps")
torch.manual_seed(42)


<torch._C.Generator at 0x12e260170>

In [None]:
env = gym.make("CartPole-v1")
policy_model = MLP(
    architecture=[
        (4, "", {}),
        (64, "ReLU", {}),
        (2, "LogSoftmax", {"dim": -1}),
    ]
)


def policy_output_parser(outputs, index=None):
    if index is None:
        actions = torch.multinomial(torch.exp(outputs), num_samples=1)
    else:
        actions = index.view(index.shape[-1], 1)
    return actions, torch.gather(outputs, -1, index=actions).squeeze()


policy_wrapper = ModelTrainingWrapper(
    network=policy_model,
    optimizer_name="Adam",
    optimizer_params={"lr": 3e-3},
    output_parser=policy_output_parser,
)
value_model = MLP(
    architecture=[
        (4, "", {}),
        (64, "ELU", {}),
        (1, "", {}),
    ]
)
value_wrapper = ModelTrainingWrapper(
    network=value_model,
    optimizer_name="Adam",
    optimizer_params={"lr": 1e-3},
    output_parser=lambda output: output,
)
value_loss = torch.nn.MSELoss()

logger = proximal_policy_optimization(
    env=env,
    state_parser=lambda state: torch.tensor(state, dtype=torch.float32),
    policy_wrapper=policy_wrapper,
    value_wrapper=value_wrapper,
    value_loss=value_loss,
    reward_func=lambda rewards, reward: 1 if reward > 0 else -2 * sum(rewards),
    epochs=1000,
    num_trajectories=10,
    num_steps=500,
    updates_per_epoch=10,
    discount_factor=0.99,
    gae_lambda=1.0,
    clipping_parameter=0.2,
)

fig, axs = plt.subplots(1, 3, figsize=(12, 4))

# Plot the data on the subplots
axs[0].plot(logger["policy_losses"])
axs[0].set_title("Policy Loss")
axs[1].plot(logger["value_losses"])
axs[1].set_title("Value Losses")
axs[2].plot(logger["num_steps"])
axs[2].set_title("Num steps over time")

# Show the plot
plt.show()

env = gym.make("CartPole-v1", render_mode="human")
state = env.reset()[0]
states = [state]
probs = []
done = False
while not done:
    state_t = torch.tensor(state)
    action_prob = policy_wrapper.network(state_t)
    probs.append(action_prob)
    action, prob = policy_wrapper.output_parser(action_prob)
    action = action.item()
    new_state, reward, done, _, _ = env.step(action)
    env.render()
    states += [new_state]
    state = new_state
env.close()


In [2]:
# import sys

# sys.path.append("/Users/matthewnicastro/Desktop/reinforcement-learning/")

# import gymnasium as gym
# import torch
# from utils.io.model import load_model
# from models.MLP import MLP

# policy_wrapper = load_model(
#     MLP, "../weights/cart-pole-v1.pt", "../config/cart-pole-v1.pkl", eval_mode=True
# )
