Skip to content

Commit

Permalink
Remove callback_metric conversion to numpy by cloning tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
Peiffap committed Jun 4, 2024
1 parent 3af297b commit e5007d6
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/lightning/pytorch/loggers/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@

import functools
import operator
import statistics
from abc import ABC
from collections import defaultdict
from typing import Any, Callable, Dict, Mapping, Optional, Sequence

import statistics
from typing_extensions import override

from lightning.fabric.loggers import Logger as FabricLogger
Expand Down
12 changes: 5 additions & 7 deletions src/lightning/pytorch/strategies/launchers/multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,7 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Opt
return _WorkerOutput(best_model_path, weights_path, trainer.state, results, extra)

def get_extra_results(self, trainer: "pl.Trainer") -> Dict[str, Any]:
"""Gather extra state from the Trainer and return it as a dictionary for sending back to the main process. To
avoid issues with memory sharing, we cast the data to numpy.
"""Gather extra state from the Trainer and return it as a dictionary for sending back to the main process.
Args:
trainer: reference to the Trainer.
Expand All @@ -236,13 +235,12 @@ def get_extra_results(self, trainer: "pl.Trainer") -> Dict[str, Any]:
"""
callback_metrics: dict = apply_to_collection(
trainer.callback_metrics, Tensor, lambda x: x.cpu().numpy()
) # send as numpy to avoid issues with memory sharing
trainer.callback_metrics, Tensor, lambda x: x.cpu().detach().clone()
)
return {"callback_metrics": callback_metrics}

def update_main_process_results(self, trainer: "pl.Trainer", extra: Dict[str, Any]) -> None:
"""Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, we
cast back the data to ``torch.Tensor``.
"""Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue.
Args:
trainer: reference to the Trainer.
Expand All @@ -252,7 +250,7 @@ def update_main_process_results(self, trainer: "pl.Trainer", extra: Dict[str, An
"""
# NOTE: `get_extra_results` needs to be called before
callback_metrics = extra["callback_metrics"]
trainer.callback_metrics.update(apply_to_collection(callback_metrics, type(Tensor().numpy()), lambda x: torch.tensor(x)))
trainer.callback_metrics.update(callback_metrics)

@override
def kill(self, signum: _SIGNUM) -> None:
Expand Down

0 comments on commit e5007d6

Please sign in to comment.