Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WandbLogger callback for customizing checkpoint artifact logging #17913

Open
schmidt-ai opened this issue Jun 24, 2023 · 1 comment · May be fixed by #18253
Open

Add WandbLogger callback for customizing checkpoint artifact logging #17913

schmidt-ai opened this issue Jun 24, 2023 · 1 comment · May be fixed by #18253
Labels
needs triage Waiting to be triaged by maintainers refactor

Comments

@schmidt-ai
Copy link
Contributor

schmidt-ai commented Jun 24, 2023

Outline & Motivation

It could be useful to add a callback to WandbLogger to allow custom handling of checkpoint artifacts. Examples of use cases:

  1. I'm already writing checkpoints to persistent storage (e.g. using a ModelCheckpoint writing to S3), so I just want WandbLogger to log reference artifacts to them.
  2. I want to add additional files or metadata to my WandB checkpoint artifacts.

We could refactor WandbLogger slightly:

class WandbLogger:
    def on_log_checkpoint_artifact(self, artifact, checkpoint_timestamp, path, score, tag):
        artifact.add_file(path, name="model.ckpt")
        return artifact
    
    def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None:
        # get checkpoints to be saved with associated score
        checkpoints = _scan_checkpoints(checkpoint_callback, self._logged_model_time)

        # log iteratively all new checkpoints
        for t, p, s, tag in checkpoints:
            metadata = (
                {
                    "score": s.item() if isinstance(s, Tensor) else s,
                    "original_filename": Path(p).name,
                    checkpoint_callback.__class__.__name__: {
                        k: getattr(checkpoint_callback, k)
                        for k in [
                            "monitor",
                            "mode",
                            "save_last",
                            "save_top_k",
                            "save_weights_only",
                            "_every_n_train_steps",
                        ]
                        # ensure it does not break if `ModelCheckpoint` args change
                        if hasattr(checkpoint_callback, k)
                    },
                }
                if _WANDB_GREATER_EQUAL_0_10_22
                else None
            )
            if not self._checkpoint_name:
                self._checkpoint_name = f"model-{self.experiment.id}"
            artifact = wandb.Artifact(name=self._checkpoint_name, type="model", metadata=metadata)

            # Handle artifact logic here
            artifact = self.on_log_checkpoint_artifact(artifact, t, p, s, tag)

            aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"]
            self.experiment.log_artifact(artifact, aliases=aliases)
            # remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name)
            self._logged_model_time[p] = t

Then, if users want custom artifact logging, they can subclass WandbLogger and override on_log_checkpoint_artifact:

class ReferenceArtifactLogger(WandbLogger):
    def on_log_checkpoint_artifact(self, artifact, checkpoint_timestamp, path, score, tag):
        artifact.add_reference(path)
        return artifact

### Pitch

_No response_

### Additional context

_No response_

cc @justusschock @awaelchli
@schmidt-ai schmidt-ai added needs triage Waiting to be triaged by maintainers refactor labels Jun 24, 2023
@schmidt-ai schmidt-ai linked a pull request Aug 7, 2023 that will close this issue
@noamsgl
Copy link

noamsgl commented Oct 4, 2023

Hi, I also wish for this feature request, my use case is identical to your #1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs triage Waiting to be triaged by maintainers refactor
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants