From db0e21ae116e5d241252e4482bb8b9514645b979 Mon Sep 17 00:00:00 2001 From: Pete Date: Mon, 19 Jul 2021 16:39:21 -0700 Subject: [PATCH] FairScale integration and T5-11B fine-tuning (#271) * pass ddpwrapper * add options to T5 model * add weights_path param * beam search as a parameter * fix * CHANGELOG * add checkpoint_wrapper arg * ignore missing weights in state dict if tied * update * add improved config * update CHANGELOG * address comments * try fix dep * try fix again * revert * fix config * fix post load state dict hook * rename 'ddp_wrapper' -> 'ddp_accelerator' * fix * update CHANGELOG * revert CI patch --- CHANGELOG.md | 6 + allennlp_models/generation/models/bart.py | 2 + allennlp_models/generation/models/t5.py | 41 +++++- .../seq_decoders/auto_regressive_test.py | 12 +- training_config/generation/t5_cnn_dm.jsonnet | 130 +++++++++++++----- 5 files changed, 151 insertions(+), 40 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c4630fe8a..356672f24 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Nothing so far +### Added + +- Added some additional `__init__()` parameters to the `T5` model in `allennlp_models.generation` for customizing + beam search and other options. +- Added a configuration file for fine-tuning `t5-11b` on CCN-DM (requires at least 8 GPUs). + ## [v2.6.0](https://github.com/allenai/allennlp-models/releases/tag/v2.6.0) - 2021-07-19 diff --git a/allennlp_models/generation/models/bart.py b/allennlp_models/generation/models/bart.py index 304a1d3e0..398b0091a 100644 --- a/allennlp_models/generation/models/bart.py +++ b/allennlp_models/generation/models/bart.py @@ -396,3 +396,5 @@ def get_metrics(self, reset: bool = False) -> Dict[str, float]: metrics.update(self._rouge.get_metric(reset=reset)) metrics.update(self._bleu.get_metric(reset=reset)) return metrics + + default_predictor = "seq2seq" diff --git a/allennlp_models/generation/models/t5.py b/allennlp_models/generation/models/t5.py index 60eca585a..5f028c1df 100644 --- a/allennlp_models/generation/models/t5.py +++ b/allennlp_models/generation/models/t5.py @@ -1,23 +1,41 @@ -from typing import Optional, Dict, Any +from os import PathLike +from typing import Optional, Dict, Any, Union, List, Tuple from overrides import overrides import torch +from allennlp.common.lazy import Lazy from allennlp.data import TextFieldTensors, Vocabulary from allennlp.data.tokenizers import PretrainedTransformerTokenizer from allennlp.models.model import Model from allennlp.modules.transformer.t5 import T5 as T5Module, T5Output, IntT, BoolT +from allennlp.nn.beam_search import BeamSearch +from allennlp.nn.checkpoint import CheckpointWrapper from allennlp.training.metrics import ROUGE, BLEU @Model.register("t5") class T5(Model): - def __init__(self, vocab: Vocabulary, model_name: str, **kwargs) -> None: + def __init__( + self, + vocab: Vocabulary, + model_name: str, + beam_search: Lazy[BeamSearch] = Lazy(BeamSearch, beam_size=3, max_steps=50), + checkpoint_wrapper: Optional[CheckpointWrapper] = None, + weights_path: Optional[Union[str, PathLike]] = None, + **kwargs + ) -> None: super().__init__(vocab, **kwargs) self._model_name = model_name # We only instantiate this when we need it. self._tokenizer: Optional[PretrainedTransformerTokenizer] = None - self.t5 = T5Module.from_pretrained_module(model_name) + self.t5 = T5Module.from_pretrained_module( + model_name, + beam_search=beam_search, + ddp_accelerator=self.ddp_accelerator, + checkpoint_wrapper=checkpoint_wrapper, + weights_path=weights_path, + ) exclude_indices = { self.t5.pad_token_id, @@ -29,6 +47,21 @@ def __init__(self, vocab: Vocabulary, model_name: str, **kwargs) -> None: BLEU(exclude_indices=exclude_indices), ] + @overrides + def _post_load_state_dict( + self, missing_keys: List[str], unexpected_keys: List[str] + ) -> Tuple[List[str], List[str]]: + missing_keys_to_ignore = [ + "t5.encoder.token_embeddings.weight", + "t5.decoder.token_embeddings.weight", + ] + if self.t5._tie_word_embeddings: + missing_keys_to_ignore.append("t5.lm_head.weight") + for key in missing_keys_to_ignore: + if key in missing_keys: + missing_keys.remove(key) + return missing_keys, unexpected_keys + @property def tokenizer(self) -> PretrainedTransformerTokenizer: if self._tokenizer is None: @@ -117,3 +150,5 @@ def get_metrics(self, reset: bool = False) -> Dict[str, float]: for metric in self._metrics: metrics.update(metric.get_metric(reset=reset)) return metrics + + default_predictor = "seq2seq" diff --git a/tests/generation/modules/seq_decoders/auto_regressive_test.py b/tests/generation/modules/seq_decoders/auto_regressive_test.py index 413712c61..962252cc5 100644 --- a/tests/generation/modules/seq_decoders/auto_regressive_test.py +++ b/tests/generation/modules/seq_decoders/auto_regressive_test.py @@ -84,7 +84,7 @@ def test_auto_regressive_seq_decoder_init(self): vocab, decoder_net, Embedding(num_embeddings=vocab.get_vocab_size(), embedding_dim=decoder_inout_dim), - beam_search=Lazy(BeamSearch, contructor_extras={"max_steps": 10}), + beam_search=Lazy(BeamSearch, constructor_extras={"max_steps": 10}), ) with pytest.raises(ConfigurationError): @@ -94,7 +94,7 @@ def test_auto_regressive_seq_decoder_init(self): Embedding( num_embeddings=vocab.get_vocab_size(), embedding_dim=decoder_inout_dim + 1 ), - beam_search=Lazy(BeamSearch, contructor_extras={"max_steps": 10}), + beam_search=Lazy(BeamSearch, constructor_extras={"max_steps": 10}), ) def test_auto_regressive_seq_decoder_forward(self): @@ -105,7 +105,7 @@ def test_auto_regressive_seq_decoder_forward(self): vocab, decoder_net, Embedding(num_embeddings=vocab.get_vocab_size(), embedding_dim=decoder_inout_dim), - beam_search=Lazy(BeamSearch, contructor_extras={"max_steps": 10, "beam_size": 4}), + beam_search=Lazy(BeamSearch, constructor_extras={"max_steps": 10, "beam_size": 4}), ) encoded_state = torch.rand(batch_size, time_steps, decoder_inout_dim) @@ -128,7 +128,7 @@ def test_auto_regressive_seq_decoder_indices_to_tokens(self): vocab, decoder_net, Embedding(num_embeddings=vocab.get_vocab_size(), embedding_dim=decoder_inout_dim), - beam_search=Lazy(BeamSearch, contructor_extras={"max_steps": 10}), + beam_search=Lazy(BeamSearch, constructor_extras={"max_steps": 10}), ) predictions = torch.tensor([[3, 2, 5, 0, 0], [2, 2, 3, 5, 0]]) @@ -145,7 +145,7 @@ def test_auto_regressive_seq_decoder_post_process(self): vocab, decoder_net, Embedding(num_embeddings=vocab.get_vocab_size(), embedding_dim=decoder_inout_dim), - beam_search=Lazy(BeamSearch, contructor_extras={"max_steps": 10}), + beam_search=Lazy(BeamSearch, constructor_extras={"max_steps": 10}), ) predictions = torch.tensor([[3, 2, 5, 0, 0], [2, 2, 3, 5, 0]]) @@ -169,7 +169,7 @@ def test_auto_regressive_seq_decoder_tensor_and_token_based_metric(self): vocab, decoder_net, Embedding(num_embeddings=vocab.get_vocab_size(), embedding_dim=decoder_inout_dim), - beam_search=Lazy(BeamSearch, contructor_extras={"max_steps": 10, "beam_size": 4}), + beam_search=Lazy(BeamSearch, constructor_extras={"max_steps": 10, "beam_size": 4}), tensor_based_metric=BLEU(), token_based_metric=DummyMetric(), ).eval() diff --git a/training_config/generation/t5_cnn_dm.jsonnet b/training_config/generation/t5_cnn_dm.jsonnet index 19f0c19b6..4364ccb68 100644 --- a/training_config/generation/t5_cnn_dm.jsonnet +++ b/training_config/generation/t5_cnn_dm.jsonnet @@ -1,49 +1,117 @@ -local model_name = "t5-small"; // TODO: change to large model +// =================== Configurable Settings ====================== + +local debug = true; + +local model_name = if debug then "t5-small" else "t5-11b"; + +local batch_size_per_gpu = if debug then 4 else 1; + +// To train "t5-11b" you will probably need 8 GPUs. +local num_gpus = 8; + +// This is probably necessary for t5-11b unless you have more than 8 GPUs. +local activation_checkpointing = true; + +// Set to `false` if you want to skip validation. +local validate = true; + +// AMP is currently unusably slow with t5-11b, which may be due to a bug bug within +// FairScale, but I'm not sure yet. +local use_amp = false; + +// These are reasonable defaults. +local source_length = 512; +local target_length = 54; + +// Set to `true` to log to Weights & Biases. +local use_wandb = false; + +// ================================================================ + +// ------ !! You probably don't need to edit below here !! -------- + local data_base_url = "https://storage.googleapis.com/allennlp-public-data/cnndm-combined-data-2020.07.13.tar.gz"; local train_data = data_base_url + "!cnndm-combined-data-2020.07.13/url_lists/all_train.txt"; local dev_data = data_base_url + "!cnndm-combined-data-2020.07.13/url_lists/all_val.txt"; -{ - "train_data_path": train_data, - "validation_data_path": dev_data, - "dataset_reader": { - "type": "cnn_dm", - "source_tokenizer": { +local dataset_reader = { + "type": "cnn_dm", + "source_tokenizer": { + "type": "pretrained_transformer", + "model_name": model_name, + }, + "source_token_indexers": { + "tokens": { "type": "pretrained_transformer", "model_name": model_name, - }, - "source_token_indexers": { - "tokens": { - "type": "pretrained_transformer", - "model_name": model_name, - "namespace": "tokens", - } - }, - "source_max_tokens": 512, - "target_max_tokens": 54, - "source_prefix": "summarize: ", - "max_instances": 1000 // DEBUG setting + "namespace": "tokens", + } + }, + "source_max_tokens": source_length, + "target_max_tokens": target_length, + "source_prefix": "summarize: ", +}; + +local data_loader = { + "batch_size": batch_size_per_gpu, + "shuffle": true, +}; + +local wandb_callback = { + "type": "wandb", + "project": "allennlp-t5", + "entity": "allenai-team1", + "watch_model": false, + "summary_interval": 1, + "should_log_parameter_statistics": false, + "should_log_learning_rate": false, +}; + +{ + "train_data_path": train_data, + [if validate then "validation_data_path"]: dev_data, + "dataset_reader": dataset_reader + { + [if debug then "max_instances"]: batch_size_per_gpu * 40, + }, + "validation_dataset_reader": dataset_reader + { + "max_instances": if debug then batch_size_per_gpu * 4 else batch_size_per_gpu * 10, }, "model": { "type": "t5", "model_name": model_name, + "beam_search": { + "beam_size": 3, + "max_steps": if debug then 5 else 50, + }, + [if activation_checkpointing then "checkpoint_wrapper"]: { + "type": "fairscale", + "offload_to_cpu": true, + "maintain_forward_counter": true, + }, + }, + "data_loader": data_loader + { + [if !debug then "max_instances_in_memory"]: batch_size_per_gpu * 128, + [if !debug then "num_workers"]: 1, }, - "data_loader": { - "batch_size": 4, - "shuffle": true, + "validation_data_loader": data_loader, + "vocabulary": { + "type": "empty", }, "trainer": { + "use_amp": use_amp, + [if use_amp then "grad_scaling"]: false, # TODO: use grad scaling once it's fixed in FairScale. "num_epochs": 3, "optimizer": { - "type": "huggingface_adamw", - "lr": 3e-5, - "betas": [0.9, 0.999], - "eps": 1e-8, - "correct_bias": true, - }, - "learning_rate_scheduler": { - "type": "polynomial_decay", + "type": "huggingface_adafactor", }, "grad_norm": 1.0, - } + [if use_wandb then "callbacks"]: [wandb_callback], + }, + [if num_gpus > 1 then "distributed"]: { + "cuda_devices": std.range(0, num_gpus - 1), + "ddp_accelerator": { + "type": "fairscale_fsdp", + "mixed_precision": use_amp, + }, + }, }