-
Notifications
You must be signed in to change notification settings - Fork 99
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit b91e375
Showing
144 changed files
with
11,712 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,151 @@ | ||
| outputs | ||
| tables/*/*.csv | ||
| tables/*/*.csv# | ||
| tables/*.csv | ||
| tables/*.csv# | ||
| tables/*.ods | ||
|
|
||
| wandb-metadata.json | ||
|
|
||
| dedup | ||
|
|
||
| .vs/ | ||
|
|
||
| images | ||
|
|
||
| *.temp.sh | ||
|
|
||
| # Byte-compiled / optimized / DLL files | ||
| __pycache__/ | ||
| *.py[cod] | ||
| *$py.class | ||
|
|
||
| # C extensions | ||
| *.so | ||
|
|
||
| # Distribution / packaging | ||
| .Python | ||
| build/ | ||
| develop-eggs/ | ||
| dist/ | ||
| downloads/ | ||
| eggs/ | ||
| .eggs/ | ||
| lib/ | ||
| lib64/ | ||
| parts/ | ||
| sdist/ | ||
| var/ | ||
| wheels/ | ||
| pip-wheel-metadata/ | ||
| share/python-wheels/ | ||
| *.egg-info/ | ||
| .installed.cfg | ||
| *.egg | ||
| MANIFEST | ||
|
|
||
| # PyInstaller | ||
| # Usually these files are written by a python script from a template | ||
| # before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
| *.manifest | ||
| *.spec | ||
|
|
||
| # Installer logs | ||
| pip-log.txt | ||
| pip-delete-this-directory.txt | ||
|
|
||
| # Unit test / coverage reports | ||
| htmlcov/ | ||
| .tox/ | ||
| .nox/ | ||
| .coverage | ||
| .coverage.* | ||
| .cache | ||
| nosetests.xml | ||
| coverage.xml | ||
| *.cover | ||
| *.py,cover | ||
| .hypothesis/ | ||
| .pytest_cache/ | ||
|
|
||
| # Translations | ||
| *.mo | ||
| *.pot | ||
|
|
||
| # Django stuff: | ||
| *.log | ||
| local_settings.py | ||
| db.sqlite3 | ||
| db.sqlite3-journal | ||
|
|
||
| # Flask stuff: | ||
| instance/ | ||
| .webassets-cache | ||
|
|
||
| # Scrapy stuff: | ||
| .scrapy | ||
|
|
||
| # Sphinx documentation | ||
| docs/_build/ | ||
|
|
||
| # PyBuilder | ||
| target/ | ||
|
|
||
| # Jupyter Notebook | ||
| .ipynb_checkpoints | ||
|
|
||
| # IPython | ||
| profile_default/ | ||
| ipython_config.py | ||
|
|
||
| # pyenv | ||
| .python-version | ||
|
|
||
| # pipenv | ||
| # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | ||
| # However, in case of collaboration, if having platform-specific dependencies or dependencies | ||
| # having no cross-platform support, pipenv may install dependencies that don't work, or not | ||
| # install all needed dependencies. | ||
| #Pipfile.lock | ||
|
|
||
| # PEP 582; used by e.g. github.com/David-OConnor/pyflow | ||
| __pypackages__/ | ||
|
|
||
| # Celery stuff | ||
| celerybeat-schedule | ||
| celerybeat.pid | ||
|
|
||
| # SageMath parsed files | ||
| *.sage.py | ||
|
|
||
| # Environments | ||
| .env | ||
| .venv | ||
| env/ | ||
| venv/ | ||
| ENV/ | ||
| env.bak/ | ||
| venv.bak/ | ||
|
|
||
| # Spyder project settings | ||
| .spyderproject | ||
| .spyproject | ||
|
|
||
| # Rope project settings | ||
| .ropeproject | ||
|
|
||
| # mkdocs documentation | ||
| /site | ||
|
|
||
| # mypy | ||
| .mypy_cache/ | ||
| .dmypy.json | ||
| dmypy.json | ||
|
|
||
| # Pyre type checker | ||
| .pyre/ | ||
|
|
||
| *.csv | ||
| *.txt | ||
| *.pth |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,32 @@ | ||
| # precommit hooks from https://github.com/ashleve/lightning-hydra-template | ||
| repos: | ||
| - repo: https://github.com/pre-commit/pre-commit-hooks | ||
| rev: v3.4.0 | ||
| hooks: | ||
| # list of supported hooks: https://pre-commit.com/hooks.html | ||
| - id: trailing-whitespace | ||
| - id: end-of-file-fixer | ||
| - id: check-yaml | ||
| - id: check-added-large-files | ||
| - id: debug-statements | ||
| - id: detect-private-key | ||
|
|
||
| # python code formatting | ||
| - repo: https://github.com/psf/black | ||
| rev: 22.3.0 | ||
| hooks: | ||
| - id: black | ||
| args: [--line-length, "140", "--fast"] # ;> | ||
|
|
||
| # yaml formatting | ||
| - repo: https://github.com/pre-commit/mirrors-prettier | ||
| rev: v2.3.0 | ||
| hooks: | ||
| - id: prettier | ||
| types: [yaml] | ||
|
|
||
| # python code analysis | ||
| - repo: https://github.com/PyCQA/flake8 | ||
| rev: 4.0.1 | ||
| hooks: | ||
| - id: flake8 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| # added by check-manifest | ||
| include *.py | ||
| include *.yaml | ||
| recursive-include cramming *.md | ||
| recursive-include cramming *.yaml | ||
| global-exclude *.pyc | ||
| global-exclude __pycache__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,82 @@ | ||
| # Cramming Language Model (Pretraining) | ||
|
|
||
| This repository contains code to replicate our research described in "Cramming: Training a Language Model on a Single GPU in One Day". We experiment with language model pretraining a BERT-type model with limited compute, wondering "how bad can it really be"? | ||
|
|
||
|
|
||
| You can find our abstract below: | ||
|
|
||
| > Recent trends in language modeling have focused on increasing performance through scaling, and have resulted in an environment where training language models is out of reach for most researchers and practitioners. While most in the community are asking how to push the limits of extreme computation, we ask the opposite question: | ||
| How far can we get with a single GPU in just one day? | ||
|
|
||
| > We investigate the downstream performance achievable with a transformer-based language model trained completely from scratch with masked language modeling for a *single* day on a *single consumer* GPU. | ||
| Aside from re-analyzing nearly all components of the pretraining pipeline for this scenario and providing a modified pipeline with performance close to BERT, we investigate why scaling down is hard, and which modifications actually improve performance in this scenario. We provide evidence that even in this constrained setting, performance closely follows scaling laws observed in large-compute settings. Through the lens of scaling laws, we categorize a range of recent improvements to training and architecture and discuss their merit and practical applicability (or lack thereof) for the limited compute setting. | ||
|
|
||
| ## The Rules for Cramming | ||
| Setting: | ||
| * A transformer-based language model of arbitrary size is trained with masked-language modeling, completely from scratch. | ||
| * Existing pretrained models cannot be included in any part of the pipeline. | ||
| * Any raw text (excluding downstream data) can be included for training. This means that one can achieve speedups by making judicious choices about how and when to sample data, provided the sampling mechanism does not require a pre-trained model. | ||
| * The downloading and pre-processing of raw data is exempted from the total compute budget. Pre-processing may include CPU-based tokenizer construction, tokenization, and filtering, but cannot include representation learning (e.g. pre-training a word embedding is not allowed, unless it is counted towards the final runtime). | ||
| * Training proceeds on a single GPU for 24 hours. | ||
| * Downstream performance is evaluated on GLUE \citep{wang_glue_2018}. Downstream finetuning on GLUE is limited to brief training with only the training data of the downstream task (we consider 5 epochs or less) and needs to work with hyperparameters set globally for all GLUE tasks. Downstream finetuning is excluded from the total compute budget. | ||
|
|
||
|
|
||
| # How to run the code | ||
|
|
||
| ## Requirements: | ||
| * PyTorch: `torch` | ||
| * huggingface: `transformers`, `tokenizers`, `datasets` | ||
| * `hydra-core` | ||
| * [OPTIONAL]`deepspeed` | ||
| * [OPTIONAL] `flash-attention` | ||
| * `psutil` | ||
| * `einops` | ||
| * [OPTIONAL] For The-Pile data, install `zstandard` | ||
|
|
||
| ## Installation | ||
| * Just clone for now, and install packages as described.s | ||
| * [Optional] Follow the instructions at https://pre-commit.com/ to install the pre-commit hooks. | ||
| * [Optional] For deduplication, first install rust `curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh `, then | ||
| `git clone https://github.com/google-research/deduplicate-text-datasets/tree/dev-v1` and then run `cargo install --target-dir ../cramming/dedup` | ||
| * [Optional] For FlashAttention, install package as instructed at https://github.com/HazyResearch/flash-attention | ||
|
|
||
| ## Replicate the final recipe | ||
|
|
||
| To replicate the final recipe discussed in the paper, run | ||
| ``` | ||
| python pretrain.py name=amp_b4096_c5_o3_final arch=bert-c5 train=bert-o3 train.batch_size=4096 data=c4-subset-processed | ||
| ``` | ||
| to pretrain and | ||
| ``` | ||
| python eval.py eval=GLUE_sane name=amp_b4096_c5_o3_final eval.checkpoint=latest impl.microbatch_size=16 impl.shuffle_in_dataloader=True | ||
| ``` | ||
| to evaluate the model. The recipe called "crammed BERT" in the paper corresponds to the architecture called `bert-c5` trained with training setup `bert-o3` on data `c4-subset-processed`. | ||
|
|
||
| ## Additional Recipes | ||
| Pretraining: | ||
| Single GPU: | ||
| ``` | ||
| python pretrain.py name=bert data=bookcorpus-wikipedia arch=bert-original train=bert-original | ||
| ``` | ||
| Multi-GPU: | ||
| ``` | ||
| torchrun --nproc_per_node=4 --standalone pretrain.py name=bert4gpu data=bookcorpus-wikipedia arch=bert-original train=bert-original | ||
| ``` | ||
|
|
||
| Eval a huggingface checkpoint: | ||
| ``` | ||
| python eval.py dryrun=True eval=rte name=bert-finetuning eval.checkpoint=hf://bert-base-uncased | ||
| ``` | ||
|
|
||
| Sanity check for distributed code on CPU: | ||
| ``` | ||
| torchrun --nproc_per_node=4 --standalone pretrain.py name=speedtest1 dryrun=True data=sanity-check-2 impl.backend=gloo | ||
| ``` | ||
|
|
||
| Additional examples for recipes can be found in the `/scripts` folder. | ||
|
|
||
| # Contact | ||
|
|
||
| Please, feel free to contact us with any questions, or open an issue on github. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,43 @@ | ||
| """Initialize cramming""" | ||
|
|
||
| from cramming.architectures import construct_model | ||
| from cramming.backend import load_backend | ||
| from cramming.data import load_pretraining_corpus, prepare_task_dataloaders | ||
| from cramming import utils | ||
|
|
||
| __all__ = [ | ||
| "construct_model", | ||
| "load_backend", | ||
| "load_pretraining_corpus", | ||
| "prepare_task_dataloaders", | ||
| "utils", | ||
| ] | ||
|
|
||
|
|
||
| import hydra | ||
|
|
||
| """Construct interfaces to some cfg folders for use in packaged installations:""" | ||
|
|
||
|
|
||
| def get_config(overrides=[]): | ||
| """Return default hydra config.""" | ||
| with hydra.initialize(config_path="config"): | ||
| cfg = hydra.compose(config_name="cfg", overrides=overrides) | ||
| print(f"Loading default config {cfg.name}.") | ||
| return cfg | ||
|
|
||
|
|
||
| def get_model_config(arch="hf-bert-tiny", overrides=[]): | ||
| """Return default hydra config for a given attack.""" | ||
| with hydra.initialize(config_path="config/arch"): | ||
| cfg = hydra.compose(config_name=arch, overrides=overrides) | ||
| print(f"Loading model configuration {cfg.architecture}.") | ||
| return cfg | ||
|
|
||
|
|
||
| def get_backend_config(backend="torch-default", overrides=[]): | ||
| """Return default hydra config for a given attack.""" | ||
| with hydra.initialize(config_path="config/impl"): | ||
| cfg = hydra.compose(config_name=backend, overrides=overrides) | ||
| print(f"Loading backend {cfg.name}.") | ||
| return cfg |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| """This module handles all questions of model architecture.""" | ||
|
|
||
| from .construction import construct_model | ||
|
|
||
| __all__ = ["construct_model"] |
Oops, something went wrong.