-
Notifications
You must be signed in to change notification settings - Fork 46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Callback Support #339
Add Callback Support #339
Conversation
This pull request introduces 2 alerts when merging bc49d47 into cc0d123 - view on LGTM.com new alerts:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like a lot the idea. Just wondering if we should actually include it in the experiment config, but other than that, and some picky details, LGTM.
|
||
import torch | ||
import torch.distributed as dist # type: ignore | ||
import torch.distributions # type: ignore | ||
import torch.multiprocessing as mp # type: ignore | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
from allenact.algorithms.onpolicy_sync.misc import TrackingInfo, TrackingInfoType |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Throughout the code base we try to keep third party imports before allenact ones, so I would move this a few lines below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you're using PyCharm I think this would be handled by running Code -> Optimize Imports
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I will revert these import updates.
I am using isort, which works well with Black and is pretty popular for sorting and organizing imports. It autoformats on save for VSCode so it ended up changing the import order automatically.
@@ -1736,7 +1734,8 @@ def run_eval( | |||
lengths: List[int] | |||
if self.num_active_samplers > 0: | |||
lengths = self.vector_tasks.command( | |||
"sampler_attr", ["length"] * self.num_active_samplers, | |||
"sampler_attr", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see many formatting changes - are you also using black
to ensure consistency?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, note that we're using version 19.10b0
of black
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, using Black, but must be using some different preferences with it (such as the number of characters in a line). I will try with 19.10b0
:)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really like this but do have a few suggestions here and there that I think could improve usability. I think the main question for us is if we see people wanting to mix and match callbacks across repeated runs of the same experiment. If not then maybe its sufficient to define callbacks at the experiment config level. One important thing to remember for this discussion is that it need not be the same person running the experiment and so if we "hard code" certain callbacks (e.g. wandb logging) then this will make it a bit more annoying for people who don't have all the appropriate permissions set up.
|
||
import torch | ||
import torch.distributed as dist # type: ignore | ||
import torch.distributions # type: ignore | ||
import torch.multiprocessing as mp # type: ignore | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
from allenact.algorithms.onpolicy_sync.misc import TrackingInfo, TrackingInfoType |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you're using PyCharm I think this would be handled by running Code -> Optimize Imports
.
@@ -1736,7 +1734,8 @@ def run_eval( | |||
lengths: List[int] | |||
if self.num_active_samplers > 0: | |||
lengths = self.vector_tasks.command( | |||
"sampler_attr", ["length"] * self.num_active_samplers, | |||
"sampler_attr", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, note that we're using version 19.10b0
of black
.
@@ -926,6 +926,12 @@ def _task_sampling_loop_generator_fn( | |||
step_result = step_result.clone({"info": {}}) | |||
step_result.info[COMPLETE_TASK_METRICS_KEY] = metrics | |||
|
|||
task_callback_data = current_task.task_callback_data() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer if tasks didn't know about callbacks as this creates a bidirectional dependency between the tasks and callbacks. Could we instead pass a function (or callable object) to the VectorSampledTask
and then do something like task_callback_data = task_callback_data_fn(current_task)
. This task_callback_data_fn
could be returned by the experiment config similarly as to how make_sampler_fn
is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually even better would be if each Callback
was required to have this function defined on itself as a static method and we just passed in a list of these methods (that way the experiment config wouldn't need to know which callbacks were going to be used with it making it easier to mix and match from the command line).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That’s an interesting idea! Like an on_task_end(task: Task)
callback.
allenact/base_abstractions/task.py
Outdated
def task_callback_data(self) -> Optional[Any]: | ||
"""Returns any data that should be passed to the log callback function.""" | ||
return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If doing my above suggestion then this would be removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a minor comment and a short question about usage, bot other than that LGTM.
@@ -86,14 +89,18 @@ def __init__( | |||
disable_tensorboard: bool = False, | |||
disable_config_saving: bool = False, | |||
distributed_ip_and_port: str = "127.0.0.1:0", | |||
distributed_preemption_threshold: float = 0.7, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is nice - not sure why it was never added as an argument, but good you did. I guess it would also make sense to make it an arg in main?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, have just been periodically adding to this branch for any issues I’ve come across. I think I’m just about ready for it, other than addressing the existing comments above. I’ve been waiting a bit since it’s nice to sometimes add arguments to the callbacks, but if it’s merged into main, it’d be hard to add anything without breaking backwards compatibility.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant main.py
:)
|
||
message.append(f"tasks {num_tasks} checkpoint {checkpoint_file_name[0]}") | ||
get_logger().info(" ".join(message)) | ||
|
||
for callback in self.callbacks: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess it's also possible to make the callback start a thread and return immediately to allow logger to keep showing stats for train in "real time". Is that right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would be possible, but one of the issues is that if wandb is initialized in one process, you cannot log from a thread, and hence why logging from a thread and then destroying it wouldn’t work as one would expect. If you try to log from all threads and processes, logging becomes prohibitively expensive.
See their notes on distributed training: https://docs.wandb.ai/guides/track/advanced/distributed-training
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I didn't express correctly what I meant. Currently, if the callback function takes a long time to process, packages in the runner queue for logging wait for a long time. When they're read, we quickly flush that queue, resulting in e.g. spikes of FPS (note that there's no timestamp in the logging packages sent from the trainers). I thought that passing the data for the callback to a thread (in the same process) could do the job, but if using an open wandb session from a thread in the same process is a no go, then there's no discussion. 👍
…d error handling.
… to preprocessor.
…into callbacks-cpca-softmax
Merging main into callbacks and fixing merge conflict.
…k doesn't need to know anything about callbacks.
Fixing callbacks PR comments and other misc improvements.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having merged the callback sensor API, LGTM.
Background
Adds initial support for Callbacks, inspired by PyTorch Lightning.
The immediate use case is to enable logging during training with Weights and Biases.
Motivation
The motivation is to make it easier to log, debug, and inspect the training setup without having to manually modify
runner.py
.Down the line, I suspect callbacks will also be the best place to write tests, where the tests may be in callback functions like
on_checkpoint_load(model)
.Example
An example usage might be to define a
Callback
class under the filetraining/callbacks/wandb_logging.py
:and to use it, one would add the file to the
--callbacks
flag in theallenact
command:Note that this doesn't require modifying the experiment configs at all, and hence is fully opt in functionality.
Notes
I'm still thinking about what callbacks would be best, and what should be passed into each of them.
Right now, I think the best approach I have for logging videos, images, or other more complex information, is to save that information to disk, and then process, log, and delete it inside of
on_train_log()
, but perhaps there's a cleaner solution.