Skip to content

Commit

Permalink
Fix/need different cache dirs for different datasets (#4640)
Browse files Browse the repository at this point in the history
* Add path to data directory to cache dir

Signed-off-by: PeganovAnton <peganoff2@mail.ru>

* Further improvements in cache dir

Signed-off-by: PeganovAnton <peganoff2@mail.ru>

* Improve features name for token classification

Signed-off-by: PeganovAnton <peganoff2@mail.ru>

* Fix code style

Signed-off-by: PeganovAnton <peganoff2@mail.ru>

Co-authored-by: ekmb <ebakhturina@nvidia.com>
  • Loading branch information
PeganovAnton and ekmb committed Aug 2, 2022
1 parent 45bfc84 commit aaeac3c
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -958,7 +958,9 @@ def __init__(
self.batch_building_progress_queue = batch_building_progress_queue

master_device = is_global_rank_zero()
self.features_pkl = self._get_path_to_pkl_features(self.text_file, cache_dir, max_seq_length, num_samples)
self.features_pkl = self._get_path_to_pkl_features(
self.text_file, self.labels_file, cache_dir, max_seq_length, num_samples
)
features = None
if master_device and not (self.features_pkl.is_file() and use_cache):
if verbose:
Expand Down Expand Up @@ -1036,15 +1038,20 @@ def __init__(
self.capit_label_frequencies = self._calculate_and_save_label_frequencies(self.capit_labels, 'capit')

def _get_path_to_pkl_features(
self, text_file: Path, cache_dir: Optional[Union[str, os.PathLike]], max_seq_length: int, num_samples: int
self,
text_file: Path,
labels_file: Path,
cache_dir: Optional[Union[str, os.PathLike]],
max_seq_length: int,
num_samples: int,
) -> Path:
if cache_dir is None:
cache_dir = text_file.parent
else:
cache_dir = Path(cache_dir).expanduser()
vocab_size = getattr(self.tokenizer, "vocab_size", 0)
features_pkl = cache_dir / "cached.{}.{}.max_seq_length{}.vocab{}.{}.punctuation_capitalization.pkl".format(
text_file.stem,
'__' + text_file.name + '__' + labels_file.name + '__',
self.tokenizer.name,
max_seq_length,
vocab_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,17 +217,16 @@ def __init__(
""" Initializes BertTokenClassificationDataset. """

data_dir = os.path.dirname(text_file)
filename = os.path.basename(text_file)
text_filename = os.path.basename(text_file)
lbl_filename = os.path.basename(label_file)

if not filename.endswith('.txt'):
if not text_filename.endswith('.txt'):
raise ValueError("{text_file} should have extension .txt")

vocab_size = getattr(tokenizer, "vocab_size", 0)
features_pkl = os.path.join(
data_dir,
"cached_{}_{}_{}_{}_{}".format(
filename, tokenizer.name, str(max_seq_length), str(vocab_size), str(num_samples)
),
f"cached__{text_filename}__{lbl_filename}__{tokenizer.name}_{max_seq_length}_{vocab_size}_{num_samples}",
)

master_device = is_global_rank_zero()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,13 @@ def _setup_dataloader_from_config(self, cfg: DictConfig, train: bool) -> torch.u
number_of_batches_is_multiple_of = 1
else:
number_of_batches_is_multiple_of = self._trainer.num_nodes * self._trainer.num_devices
if cfg.cache_dir is None:
cache_dir = cfg.cache_dir
else:
# If pickled features are saved `cache_dir` not in the same directory with original data files, then
# a full path to data directory have to be appended to `cache_dir`. This is done to avoid collisions
# cache for different datasets is save to same `cache_dir`.
cache_dir = Path(cfg.cache_dir).joinpath('fsroot', *text_file.expanduser().resolve().parts[1:-1])
dataset = BertPunctuationCapitalizationDataset(
tokenizer=self.tokenizer,
text_file=text_file,
Expand All @@ -858,7 +865,7 @@ def _setup_dataloader_from_config(self, cfg: DictConfig, train: bool) -> torch.u
batch_shuffling_random_seed=batch_shuffling_random_seed,
verbose=cfg.verbose,
get_label_frequencies=cfg.get_label_frequences,
cache_dir=cfg.cache_dir,
cache_dir=cache_dir,
label_info_save_dir=cfg.label_info_save_dir,
)
if cfg.shuffle and cfg.use_tarred_dataset:
Expand Down

0 comments on commit aaeac3c

Please sign in to comment.