In [2]:
import gymnasium as gym
import numpy as np
from stable_baselines3 import DDPG
from stable_baselines3.common.env_util import DummyVecEnv
import torch
import csv

# Check if a GPU is available and set the device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


  from .autonotebook import tqdm as notebook_tqdm


# 1. Training Pendulum model with DDPG

In [3]:
# Create the Pendulum-v1 environment
env = gym.make('Pendulum-v1')

# Wrap the environment to support vectorized environments (optional but recommended for stability)
env = DummyVecEnv([lambda: env])

# Initialize the DDPG agent with GPU support
model = DDPG('MlpPolicy', env, device=device)

In [19]:

# Train the agent for 10000 episodes
model.learn(total_timesteps=10000)

model.save("ddpg_pendulum")

# 2. Collect trajectory data
[state, action, next_state]

In [4]:

def save_data_to_csv(data, filename):
    with open(filename, 'w', newline='') as csvfile:
        csvwriter = csv.writer(csvfile)

        # Write header
        csvwriter.writerow(['State', 'Action', 'NextState'])
        for lines in data:
            csvwriter.writerow(lines)


In [22]:
model = DDPG.load("ddpg_pendulum_10k")
# List to store state-action-next_state transitions
state_action_data = []

count = 0
for _ in range(5000):
    # random sample 100 steps
    obs = env.reset()
    count += 1
    action = env.action_space.sample()  # agent policy that uses the observation and info
    next_obs, reward, terminated, truncated = env.step([action])
    state_action_data.append((obs[0], action, next_obs[0]))
    obs = next_obs

    if count%10 == 0:
        obs = env.reset()

for episode in range(50):    
    # record 1 trajectory for around 90 steps
    obs = env.reset()
    done = False
    count = 10
    while not done:
        action, _ = model.predict(obs)
        next_obs, reward, terminated, truncated = env.step(action)
        state_action_data.append((obs[0], action[0], next_obs[0]))
        obs = next_obs
        if 1 - abs(obs[0][0]) < 0.001 and abs(action[0]) < 0.6:
            count -=1
            if count <0:
                done = True


# Save the state-action-next_state data to a CSV file
save_data_to_csv(state_action_data, 's_a_random.csv')

# Close the environment
env.close()
