In [1]:
import torch
import numpy as np
import gymnasium as gym
import os
import pickle
from stable_baselines3 import PPO

In [2]:
env = gym.make("CartPole-v1", render_mode="human")

In [3]:
print("Action Space:", env.action_space.shape)
print("Observation Space:", env.observation_space.shape[0])

Action Space: ()
Observation Space: 4


In [3]:
import torch.nn as nn
from torch.nn import functional as F

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
STATE_SIZE = env.observation_space.shape[0]  # Number of features in the state
ACTION_SIZE = 2

#
# Define model
#
class PolicyNet(nn.Module):
    def __init__(self, hidden_dim=64):
        super().__init__()
        self.hidden = nn.Linear(STATE_SIZE, hidden_dim)
        self.classify = nn.Linear(hidden_dim, ACTION_SIZE)

    def forward(self, s):
        outs = self.hidden(s)
        outs = F.relu(outs)
        logits = self.classify(outs)
        return logits

#
# Generate model
#
policy_func = PolicyNet().to(device)

In [6]:
def pick_sample_and_logits(policy, s):
    """
    Stochastically pick up action and logits with policy model.

    Parameters
    ----------
    policy : torch.nn.Module
        Policy network to use
    s : torch.tensor((..., STATE_SIZE), dtype=int)
        The feature (one-hot) of state.
        The above "..." can have arbitrary shape with 0 or 1 dimension.

    Returns
    ----------
    action : torch.tensor((...), dtype=int)
        The picked-up actions.
        If input shape is (*, STATE_SIZE), this shape becomes (*).
    logits : torch.tensor((..., ACTION_SIZE), dtype=float)
        Logits of categorical distribution (used to optimize model).
        If input shape is (*, STATE_SIZE), this shape becomes (*, ACTION_SIZE).
    """
    # get logits from state
    # --> size : (*, ACTION_SIZE)
    logits = policy(s.float())
    # from logits to probabilities
    # --> size : (*, ACTION_SIZE)
    probs = F.softmax(logits, dim=-1)
    # pick up action's sample
    # --> size : (*, 1)
    a = torch.multinomial(probs, num_samples=1)
    # --> size : (*)
    a = a.squeeze()

    # Return
    return a, logits

In [7]:
def evaluate(policy, batch_size):
    total_reward = torch.tensor(0.0).to(device)
    s = env.reset()
    if isinstance(s, tuple):  # Gymnasium returns (obs, info)
        s = s[0]
    while True:
        s_tensor = torch.tensor(s, dtype=torch.float32).to(device)
        with torch.no_grad():
            a, _ = pick_sample_and_logits(policy, s_tensor)
        s, r, term, trunc, _ = env.step(a.item())
        total_reward += r
        done = term or trunc
        if done:
            break
    return total_reward.item() / batch_size

avg_reward = evaluate(policy_func, 300)
print(f"Estimated rewards (before training): {avg_reward}")

Estimated rewards (before training): 0.03666666666666667


# Save Dataset

In [None]:

os.makedirs("./expert_data", exist_ok=True)

In [14]:
env = gym.make("CartPole-v1")
expert_model = PPO.load("models/ppo_cartpole_expert")  # adjust path if needed

num_episodes = 1000
all_states = []
all_actions = []
timestep_lens = []

for ep in range(num_episodes):
    obs, _ = env.reset()
    done = False
    states = []
    actions = []
    while not done:
        states.append(obs)
        action, _ = expert_model.predict(obs, deterministic=True)
        actions.append(action)
        obs, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
    all_states.extend(states)
    all_actions.extend(actions)
    timestep_lens.append(len(states))

# Save expert data
expert_data = {
    "states": all_states,
    "actions": all_actions,
    "timestep_lens": timestep_lens
}

with open("./expert_data/ckpt0.pkl", "wb") as f:
    pickle.dump(expert_data, f)

print("Expert dataset saved to ./expert_data/ckpt0.pkl")

Expert dataset saved to ./expert_data/ckpt0.pkl


In [8]:
# use the following expert dataset
dest_dir = "./expert_data"
checkpoint_files = ["ckpt0.pkl"]

# create optimizer
opt = torch.optim.AdamW(policy_func.parameters(), lr=0.001)

for ckpt in checkpoint_files:
    # load expert data from pickle
    with open(f"{dest_dir}/{ckpt}", "rb") as f:
        all_data = pickle.load(f)
    all_states = all_data["states"]
    all_actions = all_data["actions"]
    timestep_lens = all_data["timestep_lens"]
    # loop all episodes in demonstration
    current_timestep = 0
    for i, timestep_len in enumerate(timestep_lens):
        # pick up states and actions in a single episode
        states = all_states[current_timestep:current_timestep+timestep_len]
        actions = all_actions[current_timestep:current_timestep+timestep_len]
        # collect loss and optimize (train)
        opt.zero_grad()
        loss = []
        # for s, a in zip(states, actions):
        #     s_tensor = torch.tensor(s, dtype=torch.float32).to(device)  # Use float for CartPole
        #     _, logits = pick_sample_and_logits(policy_func, s_tensor)
        #     logits = logits.unsqueeze(0)
        #     print(f"Logits shape: {logits.shape}, Action: {a}")
        #     print(f"Logits: {logits}, Action: {a}")
        #     l = F.cross_entropy(logits, torch.tensor([a]).to(device), reduction="none")
        #     loss.append(l)
        for s, a in zip(states, actions):
            s_tensor = torch.tensor(s, dtype=torch.float32).to(device)
            _, logits = pick_sample_and_logits(policy_func, s_tensor)
            logits = logits.unsqueeze(0)  # shape: [1, ACTION_SIZE]
            
            a = int(a)  # ✅ convert to native Python int
            target = torch.tensor([a], dtype=torch.long, device=device)  # shape: [1]
            
            l = F.cross_entropy(logits, target, reduction="none")
            loss.append(l)
        total_loss = torch.stack(loss, dim=0)
        total_loss.sum().backward()
        opt.step()
        # log
        print("Processed {:5d} episodes in checkpoint {}...".format(i + 1, ckpt), end="\r")
        # run evaluation in each 1000 episodes
        if i % 100 == 99:
            avg = evaluate(policy_func, 200)
            print(f"\nEvaluation result (Average reward): {avg}")
        # proceed to next episode
        current_timestep += timestep_len

Processed   100 episodes in checkpoint ckpt0.pkl...
Evaluation result (Average reward): 0.09
Processed   200 episodes in checkpoint ckpt0.pkl...
Evaluation result (Average reward): 0.405
Processed   300 episodes in checkpoint ckpt0.pkl...
Evaluation result (Average reward): 0.31
Processed   400 episodes in checkpoint ckpt0.pkl...
Evaluation result (Average reward): 0.39
Processed   500 episodes in checkpoint ckpt0.pkl...
Evaluation result (Average reward): 0.42
Processed   600 episodes in checkpoint ckpt0.pkl...
Evaluation result (Average reward): 0.465
Processed   700 episodes in checkpoint ckpt0.pkl...
Evaluation result (Average reward): 0.44
Processed   800 episodes in checkpoint ckpt0.pkl...
Evaluation result (Average reward): 0.445
Processed   900 episodes in checkpoint ckpt0.pkl...
Evaluation result (Average reward): 1.555
Processed  1000 episodes in checkpoint ckpt0.pkl...
Evaluation result (Average reward): 1.005


In [9]:
policy_func = policy_func.to(device)

In [10]:
env = gym.make("CartPole-v1", render_mode="human")
obs, _ = env.reset()
done = False
total_reward = 0

while not done:
    s_tensor = torch.tensor(obs, dtype=torch.float32).to(device)
    with torch.no_grad():
        action, _ = pick_sample_and_logits(policy_func, s_tensor)
    obs, reward, terminated, truncated, _ = env.step(action.item())
    total_reward += reward
    done = terminated or truncated

print(f"Total reward: {total_reward}")
env.close()

Total reward: 175.0
