In [2]:
import gymnasium as gym
import numpy as np
from stable_baselines3 import DDPG
from stable_baselines3.common.env_util import DummyVecEnv
import torch
import csv

# Check if a GPU is available and set the device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


  from .autonotebook import tqdm as notebook_tqdm


# 1. Training Pendulum model with DDPG

In [3]:
# Create the Pendulum-v1 environment
env = gym.make('Pendulum-v1')

# Wrap the environment to support vectorized environments (optional but recommended for stability)
env = DummyVecEnv([lambda: env])

# Initialize the DDPG agent with GPU support
model = DDPG('MlpPolicy', env, device=device)

In [None]:

# Train the agent for 10000 episodes
model.learn(total_timesteps=10000)

model.save("ddpg_pendulum")

# 2. Collect trajectory data
[state, action, next_state]

In [4]:

def save_data_to_csv(data, filename):
    with open(filename, 'w', newline='') as csvfile:
        csvwriter = csv.writer(csvfile)

        # Write header
        csvwriter.writerow(['State', 'Action', 'NextState'])
        for lines in data:
            csvwriter.writerow(lines)


In [7]:
model = DDPG.load("ddpg_pendulum")
# List to store state-action-next_state transitions
state_action_data = []

for episode in range(50):
    obs = env.reset()
    done = False
    count = 10
    while not done:
        action, _ = model.predict(obs)
        env.render()
        next_obs, _, done, _ = env.step(action)
        state_action_data.append((obs[0], action[0], next_obs[0]))
        obs = next_obs
        if 1 - abs(obs[0][0]) < 0.00001 and abs(action[0]) < 0.6:
            count -=1
            if count <0:
                done = True


# Save the state-action-next_state data to a CSV file
save_data_to_csv(state_action_data, 'state_action_data2.csv')

# Close the environment
env.close()




[[0.8885336]] <class 'numpy.ndarray'>
[[-0.21499872]] <class 'numpy.ndarray'>
[[-1.6799346]] <class 'numpy.ndarray'>
[[-1.997861]] <class 'numpy.ndarray'>
[[-1.996724]] <class 'numpy.ndarray'>
[[-1.9874488]] <class 'numpy.ndarray'>
[[-1.9654365]] <class 'numpy.ndarray'>
[[-1.9450122]] <class 'numpy.ndarray'>
[[-1.9563553]] <class 'numpy.ndarray'>
[[-1.951995]] <class 'numpy.ndarray'>
[[-1.96415]] <class 'numpy.ndarray'>
[[-1.9666692]] <class 'numpy.ndarray'>
[[-1.9797423]] <class 'numpy.ndarray'>
[[-1.985506]] <class 'numpy.ndarray'>
[[-1.995594]] <class 'numpy.ndarray'>
[[-1.9978536]] <class 'numpy.ndarray'>
[[-1.9979451]] <class 'numpy.ndarray'>
[[-1.9975319]] <class 'numpy.ndarray'>
[[-1.9968513]] <class 'numpy.ndarray'>
[[-1.9958446]] <class 'numpy.ndarray'>
[[-1.9941527]] <class 'numpy.ndarray'>
[[-1.9920483]] <class 'numpy.ndarray'>
[[-1.9901091]] <class 'numpy.ndarray'>
[[-1.9930316]] <class 'numpy.ndarray'>
[[-1.9941797]] <class 'numpy.ndarray'>
[[-1.9937817]] <class 'numpy.nda