In [2]:
from ddpg import ReplayBuffer, DDPGAgent
from fsae.envs import *
import time

In [29]:
env = RandomTrackEnv(render_mode="tp_camera", seed=0)

In [33]:
# Initialize the agent, replay buffer, and environment
state_dim = 12 # Dimension of the state space
action_dim = 2 # Dimension of the action space
hidden_dim = 256
max_action = 0.6 # Maximum value of the action
num_episodes = 10000
max_steps = 1000
batch_size = 5000

replay_buffer = ReplayBuffer(buffer_size=1000000, state_dim=state_dim, action_dim=action_dim)
replay_buffer.load_from_csv("replayBuffer_teleop_test.csv")
replay_buffer.load_from_csv("replayBuffer_train_test.csv") #add multiple files together
agent = DDPGAgent(state_dim, action_dim, hidden_dim, replay_buffer, max_action)

if replay_buffer.size > batch_size:
    print("Replay buffer size after load: ", replay_buffer.size, " vs Batch Size: ", batch_size)
    agent.train(replay_buffer.size)

# Training loop
for episode in range(num_episodes):
    state = env.reset(seed=episode)
    print(state)
    episode_reward = 0

    for step in range(max_steps):
        action = agent.get_action(state)
        next_state, reward, done, _ = env.step(action)
        replay_buffer.add(state, action, reward, next_state, done)
        state = next_state
        episode_reward += reward

        if replay_buffer.size > batch_size:
            agent.train(batch_size)

        if done:
            break
    
    print(f"total: {episode_reward}, step: {reward}")
    print("Replay buffer size after load: ", replay_buffer.size, " vs Batch Size: ", batch_size)


Replay buffer size after load:  12880  vs Batch Size:  2000
[ 4.20173299 -1.27204461  1.          4.33277174  1.74832553  2.
  8.10899952 -0.85079498  1.          8.50119707  2.24202562  2.        ]
total: -1.1202814561196535, step: 0.07537343443269129
Replay buffer size after load:  12914  vs Batch Size:  2000
[ 4.01519101 -1.42767105  1.          4.22467488  1.588432    2.
  7.51890428 -1.10474868  1.          8.0536501   2.00839409  2.        ]


KeyboardInterrupt: 

In [32]:
replay_buffer.save_as_csv("replayBuffer_train_test.csv")

In [27]:
env.close()