Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: arendu <arendu@users.noreply.github.com>
  • Loading branch information
arendu committed May 13, 2024
1 parent 27ee06e commit 4282dba
Showing 1 changed file with 18 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
if self.cfg.get("do_mrl", False):
min_mrl = self.cfg.get("min_mrl_dim", int(np.log2(32))) - 1
max_mrl = int(np.log2(self.cfg.hidden_size // 2))
self.mrl_dims = [2 ** i for i in range(max_mrl, min_mrl, -1)]
self.mrl_dims = [2**i for i in range(max_mrl, min_mrl, -1)]
else:
self.mrl_dims = []

Expand Down Expand Up @@ -262,7 +262,14 @@ def gather_and_maybe_write_predictions(self, output, data_cfg, mode, averaged_me
gathered_output_batches = [None for _ in range(parallel_state.get_data_parallel_world_size())]
torch.distributed.all_gather_object(
gathered_output_batches,
[{'q_hs': batch['q_hs'], 'd_hs': batch['d_hs'], 'metadata': batch['metadata'],} for batch in output],
[
{
'q_hs': batch['q_hs'],
'd_hs': batch['d_hs'],
'metadata': batch['metadata'],
}
for batch in output
],
group=parallel_state.get_data_parallel_group(),
)

Expand All @@ -279,7 +286,11 @@ def gather_and_maybe_write_predictions(self, output, data_cfg, mode, averaged_me
l_d_hs = listify(batch['d_hs'])
l_m = batch['metadata']
assert len(l_m) == len(l_q_hs) == len(l_d_hs)
for q_hs, d_hs, metadata in zip(l_q_hs, l_d_hs, l_m,):
for q_hs, d_hs, metadata in zip(
l_q_hs,
l_d_hs,
l_m,
):
total_size += 1
if not metadata.get("__AUTOGENERATED__", False):
deduplicated_outputs['q_hs'].append(q_hs)
Expand Down Expand Up @@ -333,10 +344,10 @@ def write_embeddings_to_file(self, outputs, output_file_path, d_idx):

def local_validation_step(self, dataloader_iter):
"""
Our dataloaders produce a micro-batch and then we fetch
a number of microbatches depending on the global batch size and model parallel size
from the dataloader to produce a list of microbatches.
The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions.
Our dataloaders produce a micro-batch and then we fetch
a number of microbatches depending on the global batch size and model parallel size
from the dataloader to produce a list of microbatches.
The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions.
"""
# Check if iterator is exhausted
# dataloader_iter, done = self._val_iterator_done(dataloader_iter)
Expand Down

0 comments on commit 4282dba

Please sign in to comment.