Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
FairScale integration and T5-11B fine-tuning (#271)
Browse files Browse the repository at this point in the history
* pass ddpwrapper

* add options to T5 model

* add weights_path param

* beam search as a parameter

* fix

* CHANGELOG

* add checkpoint_wrapper arg

* ignore missing weights in state dict if tied

* update

* add improved config

* update CHANGELOG

* address comments

* try fix dep

* try fix again

* revert

* fix config

* fix post load state dict hook

* rename 'ddp_wrapper' -> 'ddp_accelerator'

* fix

* update CHANGELOG

* revert CI patch
  • Loading branch information
epwalsh committed Jul 19, 2021
1 parent 5dc2cf6 commit db0e21a
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 40 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

Nothing so far

### Added

- Added some additional `__init__()` parameters to the `T5` model in `allennlp_models.generation` for customizing
beam search and other options.
- Added a configuration file for fine-tuning `t5-11b` on CCN-DM (requires at least 8 GPUs).


## [v2.6.0](https://github.com/allenai/allennlp-models/releases/tag/v2.6.0) - 2021-07-19

Expand Down
2 changes: 2 additions & 0 deletions allennlp_models/generation/models/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,3 +396,5 @@ def get_metrics(self, reset: bool = False) -> Dict[str, float]:
metrics.update(self._rouge.get_metric(reset=reset))
metrics.update(self._bleu.get_metric(reset=reset))
return metrics

default_predictor = "seq2seq"
41 changes: 38 additions & 3 deletions allennlp_models/generation/models/t5.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,41 @@
from typing import Optional, Dict, Any
from os import PathLike
from typing import Optional, Dict, Any, Union, List, Tuple

from overrides import overrides
import torch

from allennlp.common.lazy import Lazy
from allennlp.data import TextFieldTensors, Vocabulary
from allennlp.data.tokenizers import PretrainedTransformerTokenizer
from allennlp.models.model import Model
from allennlp.modules.transformer.t5 import T5 as T5Module, T5Output, IntT, BoolT
from allennlp.nn.beam_search import BeamSearch
from allennlp.nn.checkpoint import CheckpointWrapper
from allennlp.training.metrics import ROUGE, BLEU


@Model.register("t5")
class T5(Model):
def __init__(self, vocab: Vocabulary, model_name: str, **kwargs) -> None:
def __init__(
self,
vocab: Vocabulary,
model_name: str,
beam_search: Lazy[BeamSearch] = Lazy(BeamSearch, beam_size=3, max_steps=50),
checkpoint_wrapper: Optional[CheckpointWrapper] = None,
weights_path: Optional[Union[str, PathLike]] = None,
**kwargs
) -> None:
super().__init__(vocab, **kwargs)
self._model_name = model_name
# We only instantiate this when we need it.
self._tokenizer: Optional[PretrainedTransformerTokenizer] = None
self.t5 = T5Module.from_pretrained_module(model_name)
self.t5 = T5Module.from_pretrained_module(
model_name,
beam_search=beam_search,
ddp_accelerator=self.ddp_accelerator,
checkpoint_wrapper=checkpoint_wrapper,
weights_path=weights_path,
)

exclude_indices = {
self.t5.pad_token_id,
Expand All @@ -29,6 +47,21 @@ def __init__(self, vocab: Vocabulary, model_name: str, **kwargs) -> None:
BLEU(exclude_indices=exclude_indices),
]

@overrides
def _post_load_state_dict(
self, missing_keys: List[str], unexpected_keys: List[str]
) -> Tuple[List[str], List[str]]:
missing_keys_to_ignore = [
"t5.encoder.token_embeddings.weight",
"t5.decoder.token_embeddings.weight",
]
if self.t5._tie_word_embeddings:
missing_keys_to_ignore.append("t5.lm_head.weight")
for key in missing_keys_to_ignore:
if key in missing_keys:
missing_keys.remove(key)
return missing_keys, unexpected_keys

@property
def tokenizer(self) -> PretrainedTransformerTokenizer:
if self._tokenizer is None:
Expand Down Expand Up @@ -117,3 +150,5 @@ def get_metrics(self, reset: bool = False) -> Dict[str, float]:
for metric in self._metrics:
metrics.update(metric.get_metric(reset=reset))
return metrics

default_predictor = "seq2seq"
12 changes: 6 additions & 6 deletions tests/generation/modules/seq_decoders/auto_regressive_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_auto_regressive_seq_decoder_init(self):
vocab,
decoder_net,
Embedding(num_embeddings=vocab.get_vocab_size(), embedding_dim=decoder_inout_dim),
beam_search=Lazy(BeamSearch, contructor_extras={"max_steps": 10}),
beam_search=Lazy(BeamSearch, constructor_extras={"max_steps": 10}),
)

with pytest.raises(ConfigurationError):
Expand All @@ -94,7 +94,7 @@ def test_auto_regressive_seq_decoder_init(self):
Embedding(
num_embeddings=vocab.get_vocab_size(), embedding_dim=decoder_inout_dim + 1
),
beam_search=Lazy(BeamSearch, contructor_extras={"max_steps": 10}),
beam_search=Lazy(BeamSearch, constructor_extras={"max_steps": 10}),
)

def test_auto_regressive_seq_decoder_forward(self):
Expand All @@ -105,7 +105,7 @@ def test_auto_regressive_seq_decoder_forward(self):
vocab,
decoder_net,
Embedding(num_embeddings=vocab.get_vocab_size(), embedding_dim=decoder_inout_dim),
beam_search=Lazy(BeamSearch, contructor_extras={"max_steps": 10, "beam_size": 4}),
beam_search=Lazy(BeamSearch, constructor_extras={"max_steps": 10, "beam_size": 4}),
)

encoded_state = torch.rand(batch_size, time_steps, decoder_inout_dim)
Expand All @@ -128,7 +128,7 @@ def test_auto_regressive_seq_decoder_indices_to_tokens(self):
vocab,
decoder_net,
Embedding(num_embeddings=vocab.get_vocab_size(), embedding_dim=decoder_inout_dim),
beam_search=Lazy(BeamSearch, contructor_extras={"max_steps": 10}),
beam_search=Lazy(BeamSearch, constructor_extras={"max_steps": 10}),
)

predictions = torch.tensor([[3, 2, 5, 0, 0], [2, 2, 3, 5, 0]])
Expand All @@ -145,7 +145,7 @@ def test_auto_regressive_seq_decoder_post_process(self):
vocab,
decoder_net,
Embedding(num_embeddings=vocab.get_vocab_size(), embedding_dim=decoder_inout_dim),
beam_search=Lazy(BeamSearch, contructor_extras={"max_steps": 10}),
beam_search=Lazy(BeamSearch, constructor_extras={"max_steps": 10}),
)

predictions = torch.tensor([[3, 2, 5, 0, 0], [2, 2, 3, 5, 0]])
Expand All @@ -169,7 +169,7 @@ def test_auto_regressive_seq_decoder_tensor_and_token_based_metric(self):
vocab,
decoder_net,
Embedding(num_embeddings=vocab.get_vocab_size(), embedding_dim=decoder_inout_dim),
beam_search=Lazy(BeamSearch, contructor_extras={"max_steps": 10, "beam_size": 4}),
beam_search=Lazy(BeamSearch, constructor_extras={"max_steps": 10, "beam_size": 4}),
tensor_based_metric=BLEU(),
token_based_metric=DummyMetric(),
).eval()
Expand Down
130 changes: 99 additions & 31 deletions training_config/generation/t5_cnn_dm.jsonnet
Original file line number Diff line number Diff line change
@@ -1,49 +1,117 @@
local model_name = "t5-small"; // TODO: change to large model
// =================== Configurable Settings ======================

local debug = true;

local model_name = if debug then "t5-small" else "t5-11b";

local batch_size_per_gpu = if debug then 4 else 1;

// To train "t5-11b" you will probably need 8 GPUs.
local num_gpus = 8;

// This is probably necessary for t5-11b unless you have more than 8 GPUs.
local activation_checkpointing = true;

// Set to `false` if you want to skip validation.
local validate = true;

// AMP is currently unusably slow with t5-11b, which may be due to a bug bug within
// FairScale, but I'm not sure yet.
local use_amp = false;

// These are reasonable defaults.
local source_length = 512;
local target_length = 54;

// Set to `true` to log to Weights & Biases.
local use_wandb = false;

// ================================================================

// ------ !! You probably don't need to edit below here !! --------

local data_base_url = "https://storage.googleapis.com/allennlp-public-data/cnndm-combined-data-2020.07.13.tar.gz";
local train_data = data_base_url + "!cnndm-combined-data-2020.07.13/url_lists/all_train.txt";
local dev_data = data_base_url + "!cnndm-combined-data-2020.07.13/url_lists/all_val.txt";

{
"train_data_path": train_data,
"validation_data_path": dev_data,
"dataset_reader": {
"type": "cnn_dm",
"source_tokenizer": {
local dataset_reader = {
"type": "cnn_dm",
"source_tokenizer": {
"type": "pretrained_transformer",
"model_name": model_name,
},
"source_token_indexers": {
"tokens": {
"type": "pretrained_transformer",
"model_name": model_name,
},
"source_token_indexers": {
"tokens": {
"type": "pretrained_transformer",
"model_name": model_name,
"namespace": "tokens",
}
},
"source_max_tokens": 512,
"target_max_tokens": 54,
"source_prefix": "summarize: ",
"max_instances": 1000 // DEBUG setting
"namespace": "tokens",
}
},
"source_max_tokens": source_length,
"target_max_tokens": target_length,
"source_prefix": "summarize: ",
};

local data_loader = {
"batch_size": batch_size_per_gpu,
"shuffle": true,
};

local wandb_callback = {
"type": "wandb",
"project": "allennlp-t5",
"entity": "allenai-team1",
"watch_model": false,
"summary_interval": 1,
"should_log_parameter_statistics": false,
"should_log_learning_rate": false,
};

{
"train_data_path": train_data,
[if validate then "validation_data_path"]: dev_data,
"dataset_reader": dataset_reader + {
[if debug then "max_instances"]: batch_size_per_gpu * 40,
},
"validation_dataset_reader": dataset_reader + {
"max_instances": if debug then batch_size_per_gpu * 4 else batch_size_per_gpu * 10,
},
"model": {
"type": "t5",
"model_name": model_name,
"beam_search": {
"beam_size": 3,
"max_steps": if debug then 5 else 50,
},
[if activation_checkpointing then "checkpoint_wrapper"]: {
"type": "fairscale",
"offload_to_cpu": true,
"maintain_forward_counter": true,
},
},
"data_loader": data_loader + {
[if !debug then "max_instances_in_memory"]: batch_size_per_gpu * 128,
[if !debug then "num_workers"]: 1,
},
"data_loader": {
"batch_size": 4,
"shuffle": true,
"validation_data_loader": data_loader,
"vocabulary": {
"type": "empty",
},
"trainer": {
"use_amp": use_amp,
[if use_amp then "grad_scaling"]: false, # TODO: use grad scaling once it's fixed in FairScale.
"num_epochs": 3,
"optimizer": {
"type": "huggingface_adamw",
"lr": 3e-5,
"betas": [0.9, 0.999],
"eps": 1e-8,
"correct_bias": true,
},
"learning_rate_scheduler": {
"type": "polynomial_decay",
"type": "huggingface_adafactor",
},
"grad_norm": 1.0,
}
[if use_wandb then "callbacks"]: [wandb_callback],
},
[if num_gpus > 1 then "distributed"]: {
"cuda_devices": std.range(0, num_gpus - 1),
"ddp_accelerator": {
"type": "fairscale_fsdp",
"mixed_precision": use_amp,
},
},
}

0 comments on commit db0e21a

Please sign in to comment.