Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated notebook to fix batch configuration and precision bugs #4447

Merged
merged 6 commits into from
Jun 24, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 69 additions & 17 deletions tutorials/nlp/Multitask_Prompt_and_PTuning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"metadata": {},
"outputs": [],
"source": [
"BRANCH=\"main\""
"BRANCH=\"r1.10.0\""
]
},
{
Expand Down Expand Up @@ -848,7 +848,7 @@
"os.environ[\"RANK\"] = '0'\n",
"os.environ[\"WORLD_SIZE\"] = '1'\n",
"\n",
"plugins = [NLPDDPPlugin(find_unused_parameters=False), TorchElasticEnvironment()]\n",
"plugins = [NLPDDPPlugin(find_unused_parameters=False, no_ddp_communication_hook=True), TorchElasticEnvironment()]\n",
"trainer = pl.Trainer(plugins=plugins, **config.trainer)\n",
"\n",
"print(\"Trainer config - \\n\")\n",
Expand Down Expand Up @@ -901,7 +901,7 @@
"source": [
"# Set some of the learning parameters\n",
"config.model.optim.lr = 1e-4\n",
"config.model.batch_size = 16"
"config.model.precision = config.trainer.precision"
]
},
{
Expand Down Expand Up @@ -1009,7 +1009,9 @@
"cell_type": "code",
"execution_count": null,
"id": "74a5a358",
"metadata": {},
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"response = model.generate(inputs=test_examples, length_params=None)\n",
Expand All @@ -1032,15 +1034,27 @@
"We need to update:\n",
"\n",
"1. `name`\n",
"3. `model.restore_path`\n",
"5. `model.existing_tasks`\n",
"6. `model.new_tasks`\n",
"7. `model.data.train_ds`\n",
"8. `model.data.validation_ds`\n",
"2. `model.restore_path`\n",
"3. `model.existing_tasks`\n",
"4. `model.new_tasks`\n",
"5. `model.virtual_prompt_style`\n",
"6. `model.data.train_ds`\n",
"7. `model.data.validation_ds`\n",
"\n",
"Remember that we already set `task_templates` for SQuAD when we were defining the task template for the other two tasks. We would add it here if we had not already set it above."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b5ec279d",
"metadata": {},
"outputs": [],
"source": [
"# Change the experiment name\n",
"config.name = 'squad_p_tuning'"
]
},
{
"cell_type": "markdown",
"id": "6adb09a3",
Expand All @@ -1052,13 +1066,10 @@
{
"cell_type": "code",
"execution_count": null,
"id": "b5ec279d",
"id": "2e196967",
"metadata": {},
"outputs": [],
"source": [
"# Change the experiment name\n",
"config.name = 'squad_p_tuning'\n",
"\n",
"# Change restore path from null to the p-tuned model we just finished training\n",
"config.model.restore_path = \"multitask_p_tuned_gpt.nemo\"\n",
"\n",
Expand All @@ -1067,6 +1078,25 @@
"config.model.new_tasks = [\"squad\"]"
]
},
{
"cell_type": "markdown",
"id": "4dc088ec",
"metadata": {},
"source": [
"After the first round of p-tuning finished, the ``virtual_prompt_style`` got automatically set to ``inference`` at the end of training. This was done to make the prompt learning model ready as soon as training is complete. For the second round of p-tuning, we need to set ``virtual_prompt_style`` to ``p-tuning`` again."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c49128a1",
"metadata": {},
"outputs": [],
"source": [
"# Reset virtual prompt style to \"p-tuning\" from \"inference\"\n",
"config.model.virtual_prompt_style = \"p-tuning\""
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -1102,18 +1132,40 @@
"# Limiting the number of validation batches for sake of time\n",
"config.trainer.limit_val_batches = 100\n",
"\n",
"# Adjust learning rate for the task\n",
"config.model.optim.lr = 5e-4\n",
"config.model.optim.sched.min_lr = 1e-5\n",
"config.model.batch_size = 4\n",
"\n",
"# Reset the trainer\n",
"plugins = [NLPDDPPlugin(find_unused_parameters=False), TorchElasticEnvironment()]\n",
"plugins = [NLPDDPPlugin(find_unused_parameters=False, no_ddp_communication_hook=True), TorchElasticEnvironment()]\n",
"trainer = pl.Trainer(plugins=plugins, **config.trainer)\n",
"\n",
"print(\"Trainer config - \\n\")\n",
"print(OmegaConf.to_yaml(config.trainer))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0ac21b0c",
"metadata": {},
"outputs": [],
"source": [
"from apex.transformer import parallel_state\n",
"from apex.transformer.pipeline_parallel.utils import _reconfigure_microbatch_calculator\n",
"from nemo.utils import AppState\n",
"\n",
"app_state = AppState()\n",
"\n",
"# Need to reconfigure micro batch calculator with apex for new p-tuning session\n",
"_reconfigure_microbatch_calculator(\n",
" rank=app_state.global_rank,\n",
" rampup_batch_size=None,\n",
" global_batch_size=config.model.global_batch_size,\n",
" micro_batch_size=config.model.micro_batch_size,\n",
" data_parallel_size=parallel_state.get_data_parallel_world_size(),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -1145,7 +1197,7 @@
"execution_count": null,
"id": "1b3d95f1",
"metadata": {
"scrolled": true
"scrolled": false
},
"outputs": [],
"source": [
Expand Down