Skip to content

Commit

Permalink
Merge branch 'lmcafee/retro-jul23' into 'main'
Browse files Browse the repository at this point in the history
Retro updates

See merge request ADLR/megatron-lm!676
  • Loading branch information
jaredcasper committed Jul 14, 2023
2 parents 156533e + e4d3995 commit 040eac9
Show file tree
Hide file tree
Showing 8 changed files with 207 additions and 295 deletions.
30 changes: 20 additions & 10 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,17 +358,27 @@ def validate_args(args, defaults={}):
if not args.add_bias_linear:
args.bias_gelu_fusion = False

# Load retro args.
if args.retro_workdir:
# Retro checks.
if args.retro_add_retriever:

# Sequence parallelism unsupported.
assert not args.sequence_parallel, \
"retro currently does not support sequence parallelism."

# Pipeline parallelism unsupported.
assert args.pipeline_model_parallel_size == 1, \
"retro currently does not support pipeline parallelism."

# Load retro args.
retro_args_path = get_retro_args_path(args.retro_workdir)
if os.path.exists(retro_args_path):
with open(retro_args_path) as f:
retro_args = types.SimpleNamespace(**json.load(f))
retro_args.retro_return_doc_ids = args.retro_return_doc_ids
retro_args.retro_gpt_retrieved_length = \
args.retro_num_retrieved_chunks * \
retro_args.retro_gpt_chunk_length
set_retro_args(retro_args)
assert os.path.exists(retro_args_path), "retro workdir missing args.json"
with open(retro_args_path) as f:
retro_args = types.SimpleNamespace(**json.load(f))
retro_args.retro_return_doc_ids = args.retro_return_doc_ids
retro_args.retro_gpt_retrieved_length = \
args.retro_num_retrieved_chunks * \
retro_args.retro_gpt_chunk_length
set_retro_args(retro_args)

# Legacy RoPE arguments
if args.use_rotary_position_embeddings:
Expand Down
3 changes: 1 addition & 2 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,8 +785,7 @@ def __init__(self, config,
# Retriever (bi-directional transformer with cross attention)
if layer_type == LayerType.retro_decoder_with_retriever:
self.retriever = ParallelTransformer(
init_method,
output_layer_init_method,
config=config,
model_type=ModelType.retro_encoder,
self_attn_mask_type=AttnMaskType.padding,
pre_process=True,
Expand Down
3 changes: 3 additions & 0 deletions tools/bert_embedding/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from megatron import get_args, get_tokenizer, print_rank_0
from megatron import core
from megatron.arguments import core_transformer_config_from_args
from megatron.core.enums import ModelType
from megatron.core.pipeline_parallel import get_forward_backward_func
from megatron.model import BertModel
Expand All @@ -28,8 +29,10 @@ def model_provider(pre_process=True, post_process=True):
print_rank_0(" > build Bert model.")

args = get_args()
config = core_transformer_config_from_args(args)
num_tokentypes = 2 if args.bert_binary_head else 0
model = BertModel(
config=config,
num_tokentypes=num_tokentypes,
add_binary_head=args.bert_binary_head,
parallel_output=True,
Expand Down
13 changes: 5 additions & 8 deletions tools/retro/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,11 @@ The following overview goes into more detail on the pipeline, code structure, us
<!-- ################ quick start ################ -->
# Quick start

See `examples/get_preprocess_cmd.sh` for example arguments.

Key files:

- `main.py` : Entry point.
- `examples/get_preprocess_cmd.sh` : Build preprocessing command (for `main.py`).
- `examples/preprocess_data.sh` : Run preprocessing (calls `get_preprocess_cmd.sh`, `main.py`).
- `main.py` : Entry point for processing.
- `examples/preprocess_data.sh` : Example preprocessing launch (calls `main.py`).
- `examples/pretrain_data.sh` : Example pretraining launch (calls `pretrain_retro.py`).

Use `--retro-tasks` to move through the preprocessing pipeline.

Expand Down Expand Up @@ -86,9 +84,8 @@ Multiple tasks can be specified by separating with commas (e.g., `--retro-tasks

Example scripts for setting arguments and launch Retro preprocessing. The key files here are:

- **`get_preprocess_cmd.sh`** : Sets up arguments and command for preprocessing. **Important note**: this script assumes a few environment variables are already set before it is called. Please see the `Environment vars.` section at the top of this file. Generally, environment variables must be set to determine the location of Retro workdirs, input datasets, and GPT and Bert model information.
- **`preprocess_data.sh`** : Calls `get_preprocess_cmd.sh` to get arguments, and then calls `main.py` to launch preprocessing.
- **`pretrain_model.sh`** : Example script for pretraining on Wikipedia data, after preprocessing is complete.
- **`preprocess_data.sh`** : Example launch script for preprocessing retro data.
- **`pretrain_model.sh`** : Example launch script for pretraining a retro model.

### `tools/retro/db`

Expand Down
43 changes: 0 additions & 43 deletions tools/retro/examples/get_dataset_configs.sh

This file was deleted.

137 changes: 0 additions & 137 deletions tools/retro/examples/get_preprocess_cmd.sh

This file was deleted.

Loading

0 comments on commit 040eac9

Please sign in to comment.