From 4282dba51d28f575e52a88178383da4691da998c Mon Sep 17 00:00:00 2001 From: arendu Date: Mon, 13 May 2024 18:28:37 +0000 Subject: [PATCH] Apply isort and black reformatting Signed-off-by: arendu --- .../megatron_gpt_embedding_model.py | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/nemo/collections/nlp/models/information_retrieval/megatron_gpt_embedding_model.py b/nemo/collections/nlp/models/information_retrieval/megatron_gpt_embedding_model.py index dd701e7d1cfc..389c90d7f97c 100644 --- a/nemo/collections/nlp/models/information_retrieval/megatron_gpt_embedding_model.py +++ b/nemo/collections/nlp/models/information_retrieval/megatron_gpt_embedding_model.py @@ -61,7 +61,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): if self.cfg.get("do_mrl", False): min_mrl = self.cfg.get("min_mrl_dim", int(np.log2(32))) - 1 max_mrl = int(np.log2(self.cfg.hidden_size // 2)) - self.mrl_dims = [2 ** i for i in range(max_mrl, min_mrl, -1)] + self.mrl_dims = [2**i for i in range(max_mrl, min_mrl, -1)] else: self.mrl_dims = [] @@ -262,7 +262,14 @@ def gather_and_maybe_write_predictions(self, output, data_cfg, mode, averaged_me gathered_output_batches = [None for _ in range(parallel_state.get_data_parallel_world_size())] torch.distributed.all_gather_object( gathered_output_batches, - [{'q_hs': batch['q_hs'], 'd_hs': batch['d_hs'], 'metadata': batch['metadata'],} for batch in output], + [ + { + 'q_hs': batch['q_hs'], + 'd_hs': batch['d_hs'], + 'metadata': batch['metadata'], + } + for batch in output + ], group=parallel_state.get_data_parallel_group(), ) @@ -279,7 +286,11 @@ def gather_and_maybe_write_predictions(self, output, data_cfg, mode, averaged_me l_d_hs = listify(batch['d_hs']) l_m = batch['metadata'] assert len(l_m) == len(l_q_hs) == len(l_d_hs) - for q_hs, d_hs, metadata in zip(l_q_hs, l_d_hs, l_m,): + for q_hs, d_hs, metadata in zip( + l_q_hs, + l_d_hs, + l_m, + ): total_size += 1 if not metadata.get("__AUTOGENERATED__", False): deduplicated_outputs['q_hs'].append(q_hs) @@ -333,10 +344,10 @@ def write_embeddings_to_file(self, outputs, output_file_path, d_idx): def local_validation_step(self, dataloader_iter): """ - Our dataloaders produce a micro-batch and then we fetch - a number of microbatches depending on the global batch size and model parallel size - from the dataloader to produce a list of microbatches. - The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions. + Our dataloaders produce a micro-batch and then we fetch + a number of microbatches depending on the global batch size and model parallel size + from the dataloader to produce a list of microbatches. + The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions. """ # Check if iterator is exhausted # dataloader_iter, done = self._val_iterator_done(dataloader_iter)