Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Callback Support #339

Merged
merged 41 commits into from
Aug 16, 2022
Merged

Add Callback Support #339

merged 41 commits into from
Aug 16, 2022

Conversation

mattdeitke
Copy link
Member

@mattdeitke mattdeitke commented Mar 12, 2022

Background

Adds initial support for Callbacks, inspired by PyTorch Lightning.

The immediate use case is to enable logging during training with Weights and Biases.

Motivation

The motivation is to make it easier to log, debug, and inspect the training setup without having to manually modify runner.py.

Down the line, I suspect callbacks will also be the best place to write tests, where the tests may be in callback functions like on_checkpoint_load(model).

Example

An example usage might be to define a Callback class under the file training/callbacks/wandb_logging.py:

from typing import Any, Dict, Optional

import wandb
from allenact.base_abstractions.callbacks import Callback


class WandbLogging(Callback):
    def setup(self, name: str, **kwargs) -> None:
        wandb.init(
            project="test-project",
            entity="prior-ai2",
            name=name,
            config=kwargs,
        )

    def on_train_log(self, metric_means: Dict[str, float], step: int, **kwargs) -> None:
        wandb.log({**metric_means, "step": step})

    def on_valid_log(
        self,
        metrics: Optional[Dict[str, Any]],
        metric_means: Dict[str, float],
        step: int,
        **kwargs
    ) -> None:
        wandb.log({**metric_means, "step": step})

    def on_test_log(
        self,
        checkpoint: str,
        metrics: Dict[str, Any],
        metric_means: Dict[str, float],
        step: int,
        **kwargs
    ) -> None:
        wandb.log({**metric_means, "step": step})

and to use it, one would add the file to the --callbacks flag in the allenact command:

allenact <...> --callbacks training/callbacks/wandb_logging.py

Note that this doesn't require modifying the experiment configs at all, and hence is fully opt in functionality.

Notes

I'm still thinking about what callbacks would be best, and what should be passed into each of them.

Right now, I think the best approach I have for logging videos, images, or other more complex information, is to save that information to disk, and then process, log, and delete it inside of on_train_log(), but perhaps there's a cleaner solution.

@mattdeitke mattdeitke marked this pull request as draft March 12, 2022 21:46
@lgtm-com
Copy link

lgtm-com bot commented Mar 13, 2022

This pull request introduces 2 alerts when merging bc49d47 into cc0d123 - view on LGTM.com

new alerts:

  • 2 for Variable defined multiple times

@mattdeitke mattdeitke marked this pull request as ready for review March 13, 2022 01:41
Copy link
Collaborator

@jordis-ai2 jordis-ai2 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like a lot the idea. Just wondering if we should actually include it in the experiment config, but other than that, and some picky details, LGTM.


import torch
import torch.distributed as dist # type: ignore
import torch.distributions # type: ignore
import torch.multiprocessing as mp # type: ignore
import torch.nn as nn
import torch.optim as optim
from allenact.algorithms.onpolicy_sync.misc import TrackingInfo, TrackingInfoType
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Throughout the code base we try to keep third party imports before allenact ones, so I would move this a few lines below.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're using PyCharm I think this would be handled by running Code -> Optimize Imports.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I will revert these import updates.

I am using isort, which works well with Black and is pretty popular for sorting and organizing imports. It autoformats on save for VSCode so it ended up changing the import order automatically.

@@ -1736,7 +1734,8 @@ def run_eval(
lengths: List[int]
if self.num_active_samplers > 0:
lengths = self.vector_tasks.command(
"sampler_attr", ["length"] * self.num_active_samplers,
"sampler_attr",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see many formatting changes - are you also using black to ensure consistency?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, note that we're using version 19.10b0 of black.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, using Black, but must be using some different preferences with it (such as the number of characters in a line). I will try with 19.10b0 :)

allenact/algorithms/onpolicy_sync/runner.py Outdated Show resolved Hide resolved
allenact/algorithms/onpolicy_sync/vector_sampled_tasks.py Outdated Show resolved Hide resolved
allenact/algorithms/onpolicy_sync/vector_sampled_tasks.py Outdated Show resolved Hide resolved
allenact/base_abstractions/callbacks.py Show resolved Hide resolved
allenact/main.py Show resolved Hide resolved
Copy link
Collaborator

@Lucaweihs Lucaweihs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really like this but do have a few suggestions here and there that I think could improve usability. I think the main question for us is if we see people wanting to mix and match callbacks across repeated runs of the same experiment. If not then maybe its sufficient to define callbacks at the experiment config level. One important thing to remember for this discussion is that it need not be the same person running the experiment and so if we "hard code" certain callbacks (e.g. wandb logging) then this will make it a bit more annoying for people who don't have all the appropriate permissions set up.


import torch
import torch.distributed as dist # type: ignore
import torch.distributions # type: ignore
import torch.multiprocessing as mp # type: ignore
import torch.nn as nn
import torch.optim as optim
from allenact.algorithms.onpolicy_sync.misc import TrackingInfo, TrackingInfoType
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're using PyCharm I think this would be handled by running Code -> Optimize Imports.

@@ -1736,7 +1734,8 @@ def run_eval(
lengths: List[int]
if self.num_active_samplers > 0:
lengths = self.vector_tasks.command(
"sampler_attr", ["length"] * self.num_active_samplers,
"sampler_attr",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, note that we're using version 19.10b0 of black.

allenact/algorithms/onpolicy_sync/runner.py Outdated Show resolved Hide resolved
allenact/algorithms/onpolicy_sync/runner.py Outdated Show resolved Hide resolved
allenact/algorithms/onpolicy_sync/runner.py Outdated Show resolved Hide resolved
@@ -926,6 +926,12 @@ def _task_sampling_loop_generator_fn(
step_result = step_result.clone({"info": {}})
step_result.info[COMPLETE_TASK_METRICS_KEY] = metrics

task_callback_data = current_task.task_callback_data()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer if tasks didn't know about callbacks as this creates a bidirectional dependency between the tasks and callbacks. Could we instead pass a function (or callable object) to the VectorSampledTask and then do something like task_callback_data = task_callback_data_fn(current_task). This task_callback_data_fn could be returned by the experiment config similarly as to how make_sampler_fn is.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually even better would be if each Callback was required to have this function defined on itself as a static method and we just passed in a list of these methods (that way the experiment config wouldn't need to know which callbacks were going to be used with it making it easier to mix and match from the command line).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That’s an interesting idea! Like an on_task_end(task: Task) callback.

allenact/algorithms/onpolicy_sync/vector_sampled_tasks.py Outdated Show resolved Hide resolved
allenact/base_abstractions/callbacks.py Show resolved Hide resolved
Comment on lines 232 to 234
def task_callback_data(self) -> Optional[Any]:
"""Returns any data that should be passed to the log callback function."""
return None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If doing my above suggestion then this would be removed.

allenact/main.py Show resolved Hide resolved
Copy link
Collaborator

@jordis-ai2 jordis-ai2 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a minor comment and a short question about usage, bot other than that LGTM.

@@ -86,14 +89,18 @@ def __init__(
disable_tensorboard: bool = False,
disable_config_saving: bool = False,
distributed_ip_and_port: str = "127.0.0.1:0",
distributed_preemption_threshold: float = 0.7,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is nice - not sure why it was never added as an argument, but good you did. I guess it would also make sense to make it an arg in main?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, have just been periodically adding to this branch for any issues I’ve come across. I think I’m just about ready for it, other than addressing the existing comments above. I’ve been waiting a bit since it’s nice to sometimes add arguments to the callbacks, but if it’s merged into main, it’d be hard to add anything without breaking backwards compatibility.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant main.py :)


message.append(f"tasks {num_tasks} checkpoint {checkpoint_file_name[0]}")
get_logger().info(" ".join(message))

for callback in self.callbacks:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it's also possible to make the callback start a thread and return immediately to allow logger to keep showing stats for train in "real time". Is that right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be possible, but one of the issues is that if wandb is initialized in one process, you cannot log from a thread, and hence why logging from a thread and then destroying it wouldn’t work as one would expect. If you try to log from all threads and processes, logging becomes prohibitively expensive.

See their notes on distributed training: https://docs.wandb.ai/guides/track/advanced/distributed-training

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I didn't express correctly what I meant. Currently, if the callback function takes a long time to process, packages in the runner queue for logging wait for a long time. When they're read, we quickly flush that queue, resulting in e.g. spikes of FPS (note that there's no timestamp in the logging packages sent from the trainers). I thought that passing the data for the callback to a thread (in the same process) could do the job, but if using an open wandb session from a thread in the same process is a no go, then there's no discussion. 👍

Copy link
Collaborator

@Lucaweihs Lucaweihs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having merged the callback sensor API, LGTM.

@Lucaweihs Lucaweihs merged commit b5c7192 into main Aug 16, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants