Skip to content

Commit

Permalink
ddp_fix (#270)
Browse files Browse the repository at this point in the history
* 1 node 2 gpus works

* log only on rank zero

* black and remove comments

* clearml logger DDP support

* metrics - fix mypy typing and clearly state what we support in metrics ddp

---------

Co-authored-by: Daniel Shats Daniel.Shats1@ibm.com <shatz@cccxl016.pok.ibm.com>
  • Loading branch information
shatz01 and Daniel Shats Daniel.Shats1@ibm.com committed Feb 14, 2023
1 parent 18aa0d3 commit 2d47362
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 34 deletions.
47 changes: 32 additions & 15 deletions fuse/dl/lightning/pl_funcs.py
Expand Up @@ -19,6 +19,7 @@
Collection of useful functions to implement FuseMedML pytorch lightning based module and train loop
"""
import os
import traceback
from typing import Any, Dict, List, OrderedDict, Sequence, Union, Mapping, TypeVar
from statistics import mean
Expand Down Expand Up @@ -57,6 +58,7 @@ def start_clearml_logger(
) -> TaskInstance:
"""
Just a fuse function to quickly start the clearml logger. It sets up patches to pytorch lightning logging hooks so it doesn't need to be passed to any lightning logger.
This function also checks if the NODE_RANK and LOCAL_RANK env variables have been set. In which case clearml will only be initialized on global rank 0.
For information on all the arguments please see: https://clear.ml/docs/latest/docs/references/sdk/task/ or https://github.com/allegroai/clearml/blob/master/clearml/task.py
General Clearml instructions:
Expand All @@ -68,19 +70,34 @@ def start_clearml_logger(
from fuse.dl.lightning.pl_funcs import start_clearml_logger
start_clearml_logger(project_name="my_project_name", task_name="test_01")
"""
task = Task.init(
project_name=project_name,
task_name=task_name,
tags=tags,
reuse_last_task_id=reuse_last_task_id,
continue_last_task=continue_last_task,
output_uri=output_uri,
auto_connect_arg_parser=auto_connect_arg_parser,
auto_connect_frameworks=auto_connect_frameworks,
auto_resource_monitoring=auto_resource_monitoring,
auto_connect_streams=auto_connect_streams,
deferred_init=deferred_init,
)
bool_start_logger = False
task = None

# check if we are in a distributed setting (if we are, must check that we are also on global rank 0)
distributed = ("NODE_RANK" in os.environ) and ("LOCAL_RANK" in os.environ)
if distributed:
node_rank = int(os.environ["NODE_RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
if (node_rank == 0) and (local_rank == 0):
bool_start_logger = True
else:
# if not in a distributed setting, we can just start logger
bool_start_logger = True

if bool_start_logger:
task = Task.init(
project_name=project_name,
task_name=task_name,
tags=tags,
reuse_last_task_id=reuse_last_task_id,
continue_last_task=continue_last_task,
output_uri=output_uri,
auto_connect_arg_parser=auto_connect_arg_parser,
auto_connect_frameworks=auto_connect_frameworks,
auto_resource_monitoring=auto_resource_monitoring,
auto_connect_streams=auto_connect_streams,
deferred_init=deferred_init,
)
return task


Expand Down Expand Up @@ -210,7 +227,7 @@ def epoch_end_compute_and_log_losses(
else:
losses.append(elem[key])
loss = mean(losses)
pl.log(f"{mode}{sep}losses.{key}", loss, on_epoch=True, sync_dist=True)
pl.log(f"{mode}{sep}losses.{key}", loss, on_epoch=True, sync_dist=True, rank_zero_only=True)


def epoch_end_compute_and_log_metrics(
Expand Down Expand Up @@ -241,4 +258,4 @@ def epoch_end_compute_and_log_metrics(
# log metrics
for key in epoch_results.keypaths():
if epoch_results[key] is not None and not isinstance(epoch_results[key], (PerSampleData)):
pl.log(f"{mode}{sep}{key}", epoch_results[key], on_epoch=True, sync_dist=True)
pl.log(f"{mode}{sep}{key}", epoch_results[key], on_epoch=True, sync_dist=True, rank_zero_only=True)
77 changes: 58 additions & 19 deletions fuse/eval/metrics/metrics_common.py
Expand Up @@ -18,7 +18,7 @@
"""

from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Hashable, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, Hashable, Optional, Sequence, Tuple, Union, List
import copy
from fuse.utils import uncollate
import torch.distributed as dist
Expand Down Expand Up @@ -112,18 +112,6 @@ def collect(self, batch: Dict) -> None:
if not isinstance(batch, NDict):
batch = NDict(batch)

# If in distributed mode (multi gpu training) we shall gather the result from all the machine to evaluate with respect to the entire batch.
if dist.is_initialized():
world_size = dist.get_world_size() # num of gpus
samples_gather = [None for rank in range(world_size)]
# samples_gather[i] will have the 'samples' value of the i's GPU
dist.all_gather_object(samples_gather, batch)

# union all the GPU's samples into one samples list
samples = []
for rank in range(world_size):
samples += samples_gather[rank]

if self._pre_collect_process_func is not None or self._post_collect_process_func is not None:
samples = uncollate(batch)
for sample in samples:
Expand All @@ -150,8 +138,14 @@ def collect(self, batch: Dict) -> None:
if self._batch_pre_collect_process_func is not None:
batch = self._batch_pre_collect_process_func(batch)
batch_to_collect = {}

for name, key in self._keys_to_collect.items():
value = batch[key]

# collect distributed
if dist.is_initialized():
value = self.sync_tensor_data_and_concat(value)

if isinstance(value, torch.Tensor):
value = value.detach().cpu().numpy()

Expand All @@ -165,11 +159,56 @@ def collect(self, batch: Dict) -> None:
for key in self._id_keys:
if key in batch:
ids = batch[key]
# collect distributed
if dist.is_initialized():
ids = self.sync_ids(ids)
break

if ids is not None:
self._collected_ids.extend(ids)

@staticmethod
def sync_tensor_data_and_concat(data: torch.Tensor) -> torch.Tensor:
"""
Collect the arg data (which is a tensor) into a list, concat along the batch dim (assumed first dim) and return
:param data: value to collect accross gpus
"""
assert isinstance(
data, torch.Tensor
), f"ERROR, Fuse Metrics only supports gathering of torch.Tensor at this time. You tried gathering {type(data)}"

# if data is 1d, torch.vstack will add another dimension which we do not want
if len(data.shape) == 1:
data = data[:, None] # add dim to the end

# gather
world_size = dist.get_world_size()
data_list = [torch.zeros_like(data) for _ in range(world_size)]
dist.all_gather(data_list, data)

# stack
stacked_data = torch.vstack(data_list)

return stacked_data

@staticmethod
def sync_ids(ids: List[Tuple[str, int]]) -> List[Any]:
"""
Collect the arg ids into a list, flatten this list and return
:param ids: list of tuples ex [('mnist-train', 0), ('mnist-train', 1)]
"""
# gather
world_size = dist.get_world_size()
data_list = [None for _ in range(world_size)]
dist.all_gather_object(data_list, ids)

# flatten list
data_list = [item for sublist in data_list for item in sublist]

return data_list

@staticmethod
def _df_dict_apply(data: pd.Series, func: Callable) -> pd.Series:
result = func(NDict(data.to_dict()))
Expand Down Expand Up @@ -278,7 +317,7 @@ def __init__(
batch_pre_collect_process_func: Optional[Callable] = None,
external_data_collector: Optional[MetricCollector] = None,
extract_ids: bool = False,
**kwargs,
**kwargs: Any,
) -> None:
"""
:param pre_collect_process_func: Optional callable - the callable will get as an input a batch_dict or a dataframe and can preprocess it if required
Expand Down Expand Up @@ -371,7 +410,7 @@ class MetricDefault(MetricWithCollectorBase):
Can be used for any metric getting as an input list of prediction, list of targets and optionally additional parameters
"""

def __init__(self, metric_func: Callable, pred: Optional[str] = None, target: Optional[str] = None, **kwargs):
def __init__(self, metric_func: Callable, pred: Optional[str] = None, target: Optional[str] = None, **kwargs: Any):
"""
:param pred: prediction key to collect
:param target: target key to collect
Expand Down Expand Up @@ -409,7 +448,7 @@ class MetricPerSampleDefault(MetricWithCollectorBase):
"""

def __init__(
self, pred: str, target: str, metric_per_sample_func: Callable, result_aggregate_func: Callable, **kwargs
self, pred: str, target: str, metric_per_sample_func: Callable, result_aggregate_func: Callable, **kwargs: Any
):
"""
:param pred: prediction key to collect
Expand Down Expand Up @@ -442,7 +481,7 @@ class GroupAnalysis(MetricWithCollectorBase):
{'mean': <>, 'std': <>, 'median': <>, <group 0>: <>, <group 1>: <>, ...}
"""

def __init__(self, metric: MetricBase, group: str, **super_kwargs) -> None:
def __init__(self, metric: MetricBase, group: str, **super_kwargs: Any) -> None:
"""
:param metric: metric to analyze
:param group: key to extract the group from
Expand Down Expand Up @@ -524,7 +563,7 @@ class Filter(MetricWithCollectorBase):
Evaluate a sub-group of data. This utility will filter non relevant samples and will call to the given metric.
"""

def __init__(self, metric: MetricBase, filter: str, **super_kwargs) -> None:
def __init__(self, metric: MetricBase, filter: str, **super_kwargs: Any) -> None:
"""
:param metric: metric to filter samples for
:param group: key to extract filter
Expand Down Expand Up @@ -583,7 +622,7 @@ def __init__(
rnd_seed: int = 1234,
conf_interval: float = 95,
ci_method: str = "PERCENTILE",
**super_kwargs,
**super_kwargs: Any,
) -> None:
"""
:param metric: metric to compute the confidence interval for
Expand Down

0 comments on commit 2d47362

Please sign in to comment.