[agent][Fix] Fix SkyRLAgentPPOTrainer after switch to async#1237
[agent][Fix] Fix SkyRLAgentPPOTrainer after switch to async#1237
SkyRLAgentPPOTrainer after switch to async#1237Conversation
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
There was a problem hiding this comment.
Code Review
This pull request correctly transitions the SkyRLAgentPPOTrainer.train method to be asynchronous, aligning it with the async method in its base class. The changes replace blocking asyncio.run() calls with non-blocking await expressions, which is the correct approach for handling coroutines within an async function. The asyncio import is also correctly removed as it's no longer used with these changes. I have one point of feedback regarding blocking calls that remain in the train method.
| if self.colocate_all: | ||
| self.policy_model.offload_to_cpu(offload_optimizer=True, offload_model=False) | ||
| asyncio.run(self.inference_engine_client.wake_up(tags=["weights"])) | ||
| await self.inference_engine_client.wake_up(tags=["weights"]) |
There was a problem hiding this comment.
While this change to use await is correct, the train method still contains blocking calls like ray.get() on lines 302 and 422. These calls will block the asyncio event loop, which can negate the benefits of making this method asynchronous. To make this method fully non-blocking, these should be replaced with asynchronous equivalents. For example, you can await Ray's ObjectRefs, possibly using asyncio.gather for lists of them. This would require re-importing asyncio.
There was a problem hiding this comment.
This is correct and meant to be synchronous.
The new trainer uses the equivalent dispatch.save_weights_for_sampler() method
What does this PR do?
Fixes
SkyRLAgentPPOTrainerafter #1235 . Previously theSkyRLAgentPPOTrainer.trainwas a sync function, even though we switched to making the base class's methodRayPPOTrainer.trainasync in #868 . Training still progressed as usual but it would have errored out at the end of training when the return value would be evaluated byasyncio.run(...)This PR is a follow-up to #1235 to transition the
SkyRLAgentPPOTrainer.trainto an async function.