Skip to content

Commit

Permalink
fix: various fixes to ensure training stability
Browse files Browse the repository at this point in the history
* most notably, waiting for receiving messages from each topic before actually moving forward with the next step. this has shown IMMEDIATE improvement. I think we might have it.
  • Loading branch information
simojo committed Mar 20, 2024
1 parent ee1b5f3 commit e232140
Showing 1 changed file with 31 additions and 6 deletions.
37 changes: 31 additions & 6 deletions src/clover_train/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from gazebo_msgs.msg import ContactState, ContactsState
from tensorflow.keras import layers
from threading import Lock, Thread
from typing import Tuple, List, Callable
from typing import Tuple, List, Callable, Any
from dataclasses import dataclass
from pymavlink import mavutil
from mavros_msgs import msg as mavros_msg
Expand Down Expand Up @@ -616,11 +616,13 @@ def episode_init_and_grab_state(gazebo_world_filepath: str) -> State:
desired_pose = two_free_poses.pop()

# start simulation
simulation_nodes.launch_clover_simulation(gazebo_world_filepath, gui=True, clover_pose=start_pose)
simulation_nodes.launch_clover_simulation(gazebo_world_filepath, gui=False, clover_pose=start_pose)

# await simulation to come online by reinitializing service proxies
service_proxies.init()

# ensure all topics essential to the state have published at least once
wait_for_all_state_topics_to_publish_once()
# append desired pose to the state
state_mutex.acquire()
state.x_desired = desired_pose.position.x
Expand All @@ -631,6 +633,28 @@ def episode_init_and_grab_state(gazebo_world_filepath: str) -> State:
return local_state


def wait_for_all_state_topics_to_publish_once():
"""Wait until all topics used to create the state have published once to ensure up-to-date data."""
# this is the most suckless piece of software simojo has ever created
def wfm(t: str, c: Any):
"""Wait For Message wrapper."""
rospy.wait_for_message(t, c)

# append primary topics
subscribed_topics_and_msg_types = {
"/mavros/local_position/velocity_local": TwistStamped,
"/mavros/local_position/pose": PoseStamped,
}
# append rangefinder topics
for i in range(10):
subscribed_topics_and_msg_types[f"/rangefinder_{i}/range"] = sensor_msg.Range
threads = [Thread(target=wfm, args=(k, v)) for k, v in subscribed_topics_and_msg_types.items()]
for t in threads:
t.start()
for t in threads:
t.join()


def episode_take_action(action: Action) -> Tuple[State, float, bool]:
"""Take the given action in the environment."""
global state
Expand Down Expand Up @@ -659,6 +683,7 @@ def episode_take_action(action: Action) -> Tuple[State, float, bool]:
auto_arm=True,
timeout=10,
)
wait_for_all_state_topics_to_publish_once()
state_mutex.acquire()
local_state = copy.deepcopy(state)
state_mutex.release()
Expand Down Expand Up @@ -775,14 +800,14 @@ def calibrate_accelerometers() -> None:
startup_takeoff_failed = True
# take off to get drone off of ground for first step.
# attempt to take off until we do it successfully.
while not rospy.is_shutdown() and not startup_takeoff_failed:
while not rospy.is_shutdown() and startup_takeoff_failed:
print("attempting to navigate to initial position of (0, 0, 0.5)")
# (effectively, startup_takeoff_failed is not timeout_passed)
startup_takeoff_failed = not navigate_wait(
# (effectively, startup_takeoff_failed is timeout_passed)
startup_takeoff_failed = navigate_wait(
x=0,
y=0,
z=0.5,
frame_id="body",
frame_id="map",
auto_arm=True,
timeout=10
)
Expand Down

0 comments on commit e232140

Please sign in to comment.