Skip to content

Commit

Permalink
Fix multiprocess download (#383)
Browse files Browse the repository at this point in the history
* fix multiprocess download

* simplify code

* another idea

* update

* Fix multiprocess _get_data using status file and atexit hook.

* Register remove hook of get data to all processes to make sure the hook execution.

* use try except

* Move atexit register hook to before.

* move register hook to first and register togather.

* Fix move register hook to first and register togather.

* Move lock files into a list to remove togather.

* Fix OSError beark loop.

* Refine nested try-catch.

Co-authored-by: guosheng <guosheng@baidu.com>
  • Loading branch information
FrostML and guoshengCS authored May 20, 2021
1 parent b7dd5ce commit 7ac9971
Showing 1 changed file with 49 additions and 7 deletions.
56 changes: 49 additions & 7 deletions paddlenlp/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import atexit
import collections
import io
import math
Expand All @@ -20,11 +21,12 @@
import sys
import inspect
from multiprocess import Pool, RLock
import time

import paddle.distributed as dist
from paddle.io import Dataset, IterableDataset
from paddle.dataset.common import md5file
from paddle.utils.download import get_path_from_url
from paddle.utils.download import get_path_from_url, _get_unique_endpoints
from paddlenlp.utils.env import DATA_HOME
from typing import Iterable, Iterator, Optional, List, Any, Callable, Union
import importlib
Expand Down Expand Up @@ -494,19 +496,59 @@ def read_datasets(self, splits=None, data_files=None):
for split, filename in data_files.items()
]

def remove_if_exit(filepath):
if isinstance(filepath, (list, tuple)):
for file in filepath:
try:
os.remove(file)
except OSError:
pass
else:
try:
os.remove(filepath)
except OSError:
pass

if splits:
assert isinstance(splits, str) or (
isinstance(splits, list) and isinstance(splits[0], str)
) or (
isinstance(splits, tuple) and isinstance(splits[0], str)
), "`splits` should be a string or list of string or a tuple of string."
if isinstance(splits, str):
filename = self._get_data(splits)
datasets.append(self.read(filename=filename, split=splits))
else:
for split in splits:
filename = self._get_data(split)
datasets.append(self.read(filename=filename, split=split))
splits = [splits]
parallel_env = dist.ParallelEnv()
unique_endpoints = _get_unique_endpoints(
parallel_env.trainer_endpoints[:])
# move register hook to first and register togather
lock_files = []
for split in splits:
lock_file = os.path.join(DATA_HOME, self.__class__.__name__)
if self.name is not None:
lock_file = lock_file + "." + self.name
lock_file += "." + split + ".done" + "." + str(os.getppid())
lock_files.append(lock_file)
# Must register to all procs to make the lock file can be removed
# when any proc breaks. Otherwise, the single registered proc may
# not receive proper singal send by the parent proc to exit.
atexit.register(lambda: remove_if_exit(lock_files))
for split in splits:
filename = self._get_data(split)
lock_file = os.path.join(DATA_HOME, self.__class__.__name__)
if self.name is not None:
lock_file = lock_file + "." + self.name
lock_file += "." + split + ".done" + "." + str(os.getppid())
# `lock_file` indicates the finished status of`_get_data`.
# `_get_data` only works in the `unique_endpoints` specified
# proc since `get_path_from_url` only work for it. The other
# procs wait `_get_data` to be finished.
if parallel_env.current_endpoint in unique_endpoints:
f = open(lock_file, "w")
f.close()
else:
while not os.path.exists(lock_file):
time.sleep(1)
datasets.append(self.read(filename=filename, split=split))

return datasets if len(datasets) > 1 else datasets[0]

Expand Down

0 comments on commit 7ac9971

Please sign in to comment.