Skip to content

Commit

Permalink
feat: removing physics pausing, tracking number of times quadcopter f…
Browse files Browse the repository at this point in the history
…ailed to take off so that we can kill and restart the whole training sequence if so
  • Loading branch information
simojo committed Mar 20, 2024
1 parent e232140 commit 44e3e26
Showing 1 changed file with 20 additions and 11 deletions.
31 changes: 20 additions & 11 deletions src/clover_train/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ def episode_init_and_grab_state(gazebo_world_filepath: str) -> State:
return local_state


def wait_for_all_state_topics_to_publish_once():
def wait_for_all_state_topics_to_publish_once() -> None:
"""Wait until all topics used to create the state have published once to ensure up-to-date data."""
# this is the most suckless piece of software simojo has ever created
def wfm(t: str, c: Any):
Expand Down Expand Up @@ -785,9 +785,16 @@ def calibrate_accelerometers() -> None:
# To store reward history of each episode
reward_history = []

for ep in range(num_episodes_per_world):
# keep track of how many episodes we've done in this world
episode_count_for_this_world = 0
while episode_count_for_this_world < num_episodes_per_world:

# increment episode count for this world
episode_count_for_this_world += 1

# purge ROS log files
util.rosclean_purge()

prev_state: State = episode_init_and_grab_state(gazebo_world_filepath)
print(f"[TRACE] begin episode {episode_count}")
episodic_reward = 0
Expand All @@ -796,8 +803,8 @@ def calibrate_accelerometers() -> None:
num_actions_taken = 0
num_actions_per_ep = 4

service_proxies.unpause_physics()
startup_takeoff_failed = True
n_failures = 0
# take off to get drone off of ground for first step.
# attempt to take off until we do it successfully.
while not rospy.is_shutdown() and startup_takeoff_failed:
Expand All @@ -811,7 +818,16 @@ def calibrate_accelerometers() -> None:
auto_arm=True,
timeout=10
)
service_proxies.pause_physics()
if startup_takeoff_failed:
n_failures += 1
if n_failures >= 3:
break
# if we've encountered three failed attempts to take off in a row,
# reset the environment to see if that will help.
if n_failures >= 3:
# reset episode count for this world
episode_count_for_this_world -= 1
continue

while not rospy.is_shutdown() and num_actions_taken < num_actions_per_ep:
tf_prev_state = tf.expand_dims(
Expand All @@ -821,18 +837,11 @@ def calibrate_accelerometers() -> None:
action = policy(tf_prev_state, ou_noise)
print("[TRACE] policy calculated")

# calculations completed;
# we can unpause physics engine
service_proxies.unpause_physics()

# Recieve state and reward from environment.
local_state, reward, done = episode_take_action(action)
print("[TRACE] action taken")
num_actions_taken += 1

# pause physics engine while learning is taking place
service_proxies.pause_physics()

buffer.record((prev_state, action, reward, local_state))
episodic_reward += reward
print("[TRACE] recorded")
Expand Down

0 comments on commit 44e3e26

Please sign in to comment.