Skip to content

Commit

Permalink
remove mandatory index key from output of metric_function in `Dat…
Browse files Browse the repository at this point in the history
…aAnalysis` map operation (microsoft#5112)

When performing the map operation required for the curriculum learning,
the output of `metric_function` requires an `index` field:
```
    def update_metric_results(self, data, metric_types, metric_dtypes, metric_functions, metric_results):
        for m_idx in range(len(metric_types)):
            [...]
            if metric_type == 'single_value_per_sample':
                for row in range(metric_values.size()[0]):
                    metric_result["sample_to_metric_builder"].add_item(metric_values[row].reshape(-1))
                    metric_result["metric_to_sample_dict"][metric_values[row].item()].append(
                        data['index'][row][0].item()). ##<------- data['index']??
```

There is no mention to this in the documentation, where it specifies
that the output of `metric_function` should be a dict/DataFrame (?) with
an `index` key/column. To makes things worse, on top of that, there is
no way for an user to be able to specify a proper `index` value for each
sample, because the distribution of samples across workers/threads is
not know, as it's done inside `DataAnalysis`:
```
    def run_map_helper(self, thread_id):
        start_idx, end_idx = self.thread_splits[thread_id][0], \
            self.thread_splits[thread_id][1]
        logger.info(f"worker {self.worker_id} thread {thread_id}: start working " \
            f"on data subset {start_idx} to {end_idx}")
        thread_dataset = Subset(self.dataset, list(range(start_idx, end_idx)))
        sampler = BatchSampler(SequentialSampler(thread_dataset), batch_size=self.batch_size, drop_last=False)
```

Since by design you picked a `SequentialSampler`, then you know
beforehand the global id of each each sample of each batch of each
thread of each worker by looking at
```
self.worker_splits, self.thread_splits = split_dataset(self.dataset, self.num_workers, self.worker_id,
                                                               self.num_threads)
start_idx, end_idx = thread_splits[t_idx_reduce][0], thread_splits[t_idx_reduce][1]
```
and you can populate that index value correctly, instead of asking the
user to provide it.

This PR removes the need for `'index'` key in `data` and uses instead
the batch, thread, and worker ids to compute the global index of each
sample.
  • Loading branch information
bm-synth authored and rraminen committed May 9, 2024
1 parent f148a25 commit f3059c7
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py
Expand Up @@ -36,7 +36,8 @@ def __init__(self,
custom_map_init=None,
custom_map_update=None,
custom_map_finalize=None,
custom_reduce=None):
custom_reduce=None,
sample_indices=None):
super().__init__()
self.dataset = dataset
self.num_workers = num_workers
Expand All @@ -55,6 +56,7 @@ def __init__(self,
self.custom_map_update = custom_map_update
self.custom_map_finalize = custom_map_finalize
self.custom_reduce = custom_reduce
self.sample_indices = sample_indices

def init_metric_results(self, thread_id, metric_names, metric_types, metric_dtypes, save_path, worker_id):
metric_results = []
Expand Down Expand Up @@ -82,7 +84,13 @@ def init_metric_results(self, thread_id, metric_names, metric_types, metric_dtyp
metric_results.append({"metric_value": metric_value, "metric_value_fname": metric_value_fname})
return metric_results

def update_metric_results(self, data, metric_types, metric_dtypes, metric_functions, metric_results):
def update_metric_results(self,
data,
metric_types,
metric_dtypes,
metric_functions,
metric_results,
batch_start_idx=0):
for m_idx in range(len(metric_types)):
metric_type, metric_dtype, metric_function, metric_result = metric_types[m_idx], \
metric_dtypes[m_idx], metric_functions[m_idx], metric_results[m_idx]
Expand All @@ -97,9 +105,13 @@ def update_metric_results(self, data, metric_types, metric_dtypes, metric_functi

if metric_type == 'single_value_per_sample':
for row in range(metric_values.size()[0]):
sample_idx = batch_start_idx + row # sample idx following dataset iteration order
if 'index' in data: # Megatron use case, sample idx provided in 'index' field
sample_idx = data['index'][row][0].item()
elif self.sample_indices is not None: # user defined shuffling of indices
sample_idx = self.sample_indices[sample_idx]
metric_result["sample_to_metric_builder"].add_item(metric_values[row].reshape(-1))
metric_result["metric_to_sample_dict"][metric_values[row].item()].append(
data['index'][row][0].item())
metric_result["metric_to_sample_dict"][metric_values[row].item()].append(sample_idx)
for m_value in metric_result["metric_to_sample_dict"]:
if len(metric_result["metric_to_sample_dict"][m_value]) > 100:
metric_fname = metric_result["metric_to_sample_fname"]
Expand Down Expand Up @@ -159,12 +171,13 @@ def run_map_helper(self, thread_id):
while True:
try:
data = next(iterator)
batch_start_idx = start_idx + processed_sample
if self.custom_map_update is None:
self.update_metric_results(data, self.metric_types, self.metric_dtypes, self.metric_functions,
metric_results)
metric_results, batch_start_idx)
else:
self.custom_map_update(data, self.metric_types, self.metric_dtypes, self.metric_functions,
metric_results)
metric_results, batch_start_idx)
processed_sample += self.batch_size
duration = (time.time() - start) / 3600.0
remain_duration = duration * total_sample / processed_sample - duration
Expand Down

0 comments on commit f3059c7

Please sign in to comment.