Skip to content

Commit

Permalink
Convert tensors to bytes instead of numpy in multiprocessing result-q…
Browse files Browse the repository at this point in the history
…ueue (#20005)
  • Loading branch information
awaelchli committed Jun 23, 2024
1 parent e330da5 commit 9304a2c
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions src/lightning/pytorch/strategies/launchers/multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import logging
import os
import queue
Expand All @@ -19,7 +20,6 @@
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Literal, NamedTuple, Optional, Union

import numpy as np
import torch
import torch.backends.cudnn
import torch.multiprocessing as mp
Expand Down Expand Up @@ -226,7 +226,7 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Opt

def get_extra_results(self, trainer: "pl.Trainer") -> Dict[str, Any]:
"""Gather extra state from the Trainer and return it as a dictionary for sending back to the main process. To
avoid issues with memory sharing, we cast the data to numpy.
avoid issues with memory sharing, we convert tensors to bytes.
Args:
trainer: reference to the Trainer.
Expand All @@ -236,14 +236,15 @@ def get_extra_results(self, trainer: "pl.Trainer") -> Dict[str, Any]:
process this output.
"""
callback_metrics: dict = apply_to_collection(
trainer.callback_metrics, Tensor, lambda x: x.cpu().numpy()
) # send as numpy to avoid issues with memory sharing
return {"callback_metrics": callback_metrics}
callback_metrics = apply_to_collection(trainer.callback_metrics, Tensor, lambda t: t.cpu())
buffer = io.BytesIO()
torch.save(callback_metrics, buffer)
# send tensors as bytes to avoid issues with memory sharing
return {"callback_metrics_bytes": buffer.getvalue()}

def update_main_process_results(self, trainer: "pl.Trainer", extra: Dict[str, Any]) -> None:
"""Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, we
cast back the data to ``torch.Tensor``.
convert bytes back to ``torch.Tensor``.
Args:
trainer: reference to the Trainer.
Expand All @@ -252,8 +253,9 @@ def update_main_process_results(self, trainer: "pl.Trainer", extra: Dict[str, An
"""
# NOTE: `get_extra_results` needs to be called before
callback_metrics = extra["callback_metrics"]
trainer.callback_metrics.update(apply_to_collection(callback_metrics, np.ndarray, lambda x: torch.tensor(x)))
callback_metrics_bytes = extra["callback_metrics_bytes"]
callback_metrics = torch.load(io.BytesIO(callback_metrics_bytes))
trainer.callback_metrics.update(callback_metrics)

@override
def kill(self, signum: _SIGNUM) -> None:
Expand Down

0 comments on commit 9304a2c

Please sign in to comment.