Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More modular fit() and support for progress and logging callbacks #1021

Open
chiragjn opened this issue Jun 23, 2021 · 9 comments
Open

More modular fit() and support for progress and logging callbacks #1021

chiragjn opened this issue Jun 23, 2021 · 9 comments

Comments

@chiragjn
Copy link

Hello,
Is there a plan to make .fit more modular?

For context, I am integrating the library in an async worker and I want to use python/tensorboard/wandb logging to log metrics, losses, etc every n steps or epochs. The fit function at the moment is not modular enough for me to inherit and override the right points. There is callback fn support but that only works if some evaluator is provided.
For e.g. libraries like fastai, pytorch lightning provide callbacks for before/after batch/epoch etc

[1] https://docs.fast.ai/callback.progress.html
[2] https://pytorch-lightning.readthedocs.io/en/latest/extensions/callbacks.html


If acceptable I can help work on a PR for this :D

@nreimers
Copy link
Member

Yes, it is planned to integrate a more modular fit (and evaluators) in version 2.1. They should provide options to log to tensorboard / wandb or to your own custom code.

A PR on this would be really nice. I haven't started yet on this topic.

@chiragjn
Copy link
Author

chiragjn commented Jun 30, 2021

Just saw 2.0 released a few days ago! :D I forked it and have started refactoring code to add events via callbacks. I'll update when I have something working. In the meantime, I kinda hacked the loss function to be able to do some logging

import torch
from sentence_transformers import losses as st_losses
from timeit import default_timer as timer

class LoggedDenoisingAutoEncoderLoss(st_losses.DenoisingAutoEncoderLoss):
    def __init__(self,
                 model: st.SentenceTransformer,
                 decoder_name_or_path: str = None,
                 tie_encoder_decoder: bool = True,
                 record_every_n_steps: int = 1000):
        super().__init__(model=model, decoder_name_or_path=decoder_name_or_path, tie_encoder_decoder=tie_encoder_decoder)
        self._record_every_n_steps = max(1, record_every_n_steps)
        self._step_counter = 0
        self._mtime = 0
        self._running_loss = torch.tensor(0.0)

    def _pre_forward(self):
        if self.encoder.training:
            if self._step_counter == 0 or self._step_counter % self._record_every_n_steps == 0:
                self._mtime = timer()
                self._running_loss = torch.tensor(0.0)

            self._step_counter += 1

    def _post_forward(self, loss: torch.Tensor):
        if self.encoder.training:
            self._running_loss += loss.item()

            if self._step_counter % self._record_every_n_steps == 0:
                metrics = {
                    'step_range': f'{self._step_counter - self._record_every_n_steps + 1}-{self._step_counter}',
                    'time_elapsed': timer() - self._mtime,
                    'total_train_loss': self._running_loss.item(),
                    'avg_train_loss_per_step': self._running_loss.item() / self._record_every_n_steps,
                }
                logger.info(f'{type(self).__name__}: {metrics}')

    def forward(self, sentence_features: Iterable[Dict[str, torch.Tensor]], labels: torch.Tensor) -> torch.Tensor:
        self._pre_forward()
        loss: torch.Tensor = super().forward(sentence_features=sentence_features, labels=labels)
        self._post_forward(loss=loss.clone().detach().cpu())
        return loss

@skewwhiff
Copy link

If there's no open PR for this, I'll be happy to take it up. I currently subclass SentenceTransformer and CrossEncoder and implement hard-coded wandb logging. I think it can be generalized to custom callbacks for tensorboard/wandb.

@chiragjn
Copy link
Author

chiragjn commented Nov 2, 2021

@skewwhiff Please feel free to take this up, sorry for the lack of updates. I did some prototyping back then but I was not happy with the architecture myself. I think API like fastai or pytorch-lightning would be pretty good. I'll be happy to help in any way I can

@skewwhiff
Copy link

All right. Will open a new PR for generic logging. @nreimers . I have a broad overview of what to do.

  • Add a backward-compatible log_step and a logger object to SentenceTransformers and CrossEncoder 's fit function.
  • Add an evaluation logger parameter to all Evaluators, along with a callable for users for further data-wrangling.
  • Probably example codes to integrate Wandb and Tensorboard?

Do you have anything else in mind wrt logging?

@Exr0n
Copy link

Exr0n commented Dec 29, 2021

chiragjn's loss function looks great, but I was turned away by how complicated it looked. For anyone else looking for a quick hack, I think this is the basic idea:

class LoggingLoss:
    def __init__(self, loss_fn, wandb):
        self.loss_fn = loss_fn
        self.wandb = wandb

    def __call__(self, logits, labels):
        loss = self.loss_fn(logits, labels)
        self.wandb.log({ 'train_loss': loss })
        return loss

# ...

wandb.init()
wandb.watch(model.model)
model.fit(
    # ...
    loss_fct=LoggingLoss(torch.nn.BCEWithLogitsLoss(), wandb),
    # ...
)

Looking forward to hooks and proper integration! :)

@SnoozingSimian
Copy link

Commenting to keep an eye on this issue, I did it by hacking the fit() function from the SentenceTransformer.py file but I guess it makes more sense to use the Losses to do the logging.

@minhrongcon2000
Copy link

Hi! Are there any updates on this issue?

@johnsonice
Copy link

any updates ? it is a bit crazy that we can't get training loss ?!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants