Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
remove mandatory
index
key from output of metric_function
in `Dat…
…aAnalysis` map operation (microsoft#5112) When performing the map operation required for the curriculum learning, the output of `metric_function` requires an `index` field: ``` def update_metric_results(self, data, metric_types, metric_dtypes, metric_functions, metric_results): for m_idx in range(len(metric_types)): [...] if metric_type == 'single_value_per_sample': for row in range(metric_values.size()[0]): metric_result["sample_to_metric_builder"].add_item(metric_values[row].reshape(-1)) metric_result["metric_to_sample_dict"][metric_values[row].item()].append( data['index'][row][0].item()). ##<------- data['index']?? ``` There is no mention to this in the documentation, where it specifies that the output of `metric_function` should be a dict/DataFrame (?) with an `index` key/column. To makes things worse, on top of that, there is no way for an user to be able to specify a proper `index` value for each sample, because the distribution of samples across workers/threads is not know, as it's done inside `DataAnalysis`: ``` def run_map_helper(self, thread_id): start_idx, end_idx = self.thread_splits[thread_id][0], \ self.thread_splits[thread_id][1] logger.info(f"worker {self.worker_id} thread {thread_id}: start working " \ f"on data subset {start_idx} to {end_idx}") thread_dataset = Subset(self.dataset, list(range(start_idx, end_idx))) sampler = BatchSampler(SequentialSampler(thread_dataset), batch_size=self.batch_size, drop_last=False) ``` Since by design you picked a `SequentialSampler`, then you know beforehand the global id of each each sample of each batch of each thread of each worker by looking at ``` self.worker_splits, self.thread_splits = split_dataset(self.dataset, self.num_workers, self.worker_id, self.num_threads) start_idx, end_idx = thread_splits[t_idx_reduce][0], thread_splits[t_idx_reduce][1] ``` and you can populate that index value correctly, instead of asking the user to provide it. This PR removes the need for `'index'` key in `data` and uses instead the batch, thread, and worker ids to compute the global index of each sample.
- Loading branch information