Skip to content

A repository for research on medium sized language models.

License

Notifications You must be signed in to change notification settings

TRI-ML/linear_open_lm

 
 

Repository files navigation

Linear OpenLM

This repository contains the code for Linearizing Large Language Models. This is a fork of the original OpenLM repository.

Linearizing Large Language Models

Quickstart

Our Mistral-SUPRA model is publicly available on HuggingFace!

Detailed instructions on how to run the model can be found on the Mistral-SUPRA HF page. We also recommend you check out our Mamba-7B model. If you want to simply use the models for inference/generation, you can do the following:

First pip install our fork of OpenLM.

pip install git+https://github.com/tri-ml/linear_open_lm.git

Import the OpenLM classes with

from open_lm.open_lm_hf import *

The model can then be loaded normally using AutoTokenizer and AutoModelForCausalLM as follows:

from open_lm.open_lm_hf import *
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("tri-ml/mistral-supra")
model = AutoModelForCausalLM.from_pretrained("tri-ml/mistral-supra")

inputs = tokenizer(["Machine learning is"], return_tensors="pt")
gen_kwargs = {"max_new_tokens": 50, "top_p": 0.8, "temperature": 0.8, "do_sample": True, "repetition_penalty": 1.1}
output = model.generate(inputs['input_ids'], **gen_kwargs)
output = tokenizer.decode(output[0].tolist(), skip_special_tokens=True)
print(output)

If you are interested in further training this model or in training another linear model, we recommend you use this repo. Our models were trained with OpenLM, then the weights were copied over to HuggingFace. We have not tested training directly using HuggingFace.

How to train a linear model

See the Run training section below for the original OpenLM instructions on how to train a model. The only difference is that you should use linear models instead of the open_lm models. The available linear models are:

Model Name
linear_tiny
linear_1b
linear_7b
mistral_7b_linear
llama2_7b_linear

How to uptrain a linear model

To uptrain a linear model, you can use the same training script as for pre-training a linear model from scratch. The only difference is that you should use the --pretrained flag to specify the checkpoint you want to start from. For example, to uptrain a linear model from the checkpoint.pt checkpoint, you can use the following command:

>>> export CUDA_VISIBLE_DEVICES=0,1,2,3
>>> torchrun --nproc-per-node 4 -m open_lm.main   \
 --model linear_tiny \
 --dataset-manifest refined_web_tokenized/manifest.jsonl \
 --train-num-samples 1_000_000 \
 --precision "amp_bfloat16" \
 --fsdp-amp \
 --fsdp-pure-bf16 \
 --workers 1 \
 --global-batch-size 9 \
 --log-every-n-steps 100 \
 --grad-clip-norm 1 \
 --data-key json.gz \
 --lr 3e-4 \
 --accum-freq 1 \
 --warmup 10 \
 --wd 0.1 \
 --beta2 0.98 \
 --epochs 10 \
 --report-to wandb \
 --wandb-project-name linear_open_lm \
 --name linear_tiny_example \
 --logs logs \
 --z-loss-coefficient 1e-4 \
 --load-not-strict \
 --pretrained checkpoint.pt

How to evaluate a linear model

See the Evaluate Model section below for the original instructions on how to evaluate a model. The only difference is that you should use linear models instead of the open_lm models. Note that for the reference paper, we used the EleutherAI LM Harness evaluation suite, which is not available in this repository.

How to generate text from a linear model

An example of how to generate text from a linear model is shown below. The only difference is that you should use linear models instead of the open_lm models.

python scripts/generate.py \
--model linear_1b \
--checkpoint /path/to/linear_checkpoint.pt \
--input-text "Are you conscious, can you talk to me?" \
--tokenizer EleutherAI/gpt-neox-20b \
--use-cache

Pre-trained Models

We provide the following pre-trained models:

  • Mistral-SUPRA-7B: Our method uptraining the Mistral-7B model on 100B tokens of refined web into an RNN.
  • Mamba-7B: Our baseline model that we trained on 1.2T tokens of refined web (2 epochs) in the present repository. See the Mamba repo for the original code.
Model Size Tokens HellaSwag PIQA WG ARC-E ARC-C MMLU Average
StableLM2 1.6B 2000 69.0 76.7 63.6 68.6 38.9 38.4 59.2
StableLM 3B 1000 73.8 79.3 65.8 72.1 40.0 44.2 62.5
Gemma 2B 2000 71.4 78.6 64.4 74.0 41.5 41.2 61.9
Mamba 1.4B 600 59.0 73.9 61.4 65.5 32.9 25.2 53.0
RWKV-5 1.5B 1100 53.1 71.6 59.0 62.2 32.7 26.2 50.8
Mamba 2.8B 600 66.2 75.8 63.4 69.7 36.3 26.3 56.3
Llama2 7B 2000 76.0 79.1 69.1 76.3 46.3 45.9 65.4
Gemma 7B 6000 80.7 81.9 73.7 81.1 53.2 62.9 72.2
Mistral 7B 8000(?) 81.0 82.1 74.0 80.9 53.8 62.4 72.4
RetNet 6.7B 200 60.7 75.4 58.1 -- -- -- --
RWKV-5 7B 1100 70.9 77.2 67.4 71.8 43.6 31.0 60.3
RWKV-5-1.7T 7B 1700 73.0 78.6 72.9 75.8 45.6 34.9 63.5
Mamba (ours) 7B 1200 77.9 81.0 71.8 77.5 46.7 33.3 64.7
Llama2-SUPRA 7B +20 71.8 78.6 65.8 71.1 39.5 24.9 58.6
Mistral-SUPRA 7B +20 74.8 80.1 67.4 74.6 42.3 28.0 61.2
Mistral-SUPRA 7B +100 77.1 80.4 70.3 75.9 45.8 34.2 64.0

Last 7 rows are linear models. 5-shot results are used for MMLU. Norm results are used for PIQA, HellaSwag, ARC-C. RetNet results taken from RetNet paper.

Citation

@article{Mercat2024Linearizing,
  title={Linearizing Large Language Models},
  author={Jean Mercat and Igor Vasiljevic and Sedrick Keh and Kushal Arora and Achal Dave and Adrien Gaidon and Thomas Kollar},
  journal={ArXiv},
  year={2024},
  volume={},
}

OpenLM

This part is copied from the original OpenLM repository, only the paragraph used in the linear models are kept. Refer to the original repository for more information.

Quickstart

Here we'll go over a basic example where we start from a fresh install, download and preprocess some training data, and train a model.

Setup

We require python >=3.9, and a current installation of pyTorch, as well as several other packages. The full list of requirements is contained in requirements.txt and can be installed in your python enviornment via >>> pip install -r requirements.txt Next, to access open_lm everywhere in your virtual environment, install it using pip (from within the top level github repo) >>> pip install --editable . Some considerations:

  • We like WandB and tensorboard for logging. We specify how to use these during training below.

Process Training Data

Next you must specify a collection of tokenized data. For the purposes of this example, we will use a recent dump of english Wikipedia, available on HuggingFace. To download this locally, we've included a script located at open_lm/datapreprocess/wiki_download.py. All you have to do is specify an output directory for where the raw data should be stored:

python open_lm/datapreprocess/wiki_download.py --output-dir path/to/raw_data

Next we process our training data by running it through a BPE tokenizer and chunk it into chunks of appropriate length. By default we use the tokenizer attached with GPT-NeoX-20B. To do this, use the script datapreprocess/make_2048.py:

>>> python open_lm/datapreprocess/make_2048.py \
    --input-files path_to_raw_data/*.jsonl
    --output-dir preproc_data
    --num-workers 32
    --num-consumers 1

Where input-files passes all of its (possibly many) arguments through the python glob module, allowing for wildcards. Optionally, data can be stored in S3 by setting the environment variables: S3_BASE, and passing the flag --upload-to-s3 to the script. This saves sharded data to the given bucket with prefix of S3_BASE. E.g.

>>> export S3_BASE=preproc_data-v1/
>>> python open_lm/datapreprocess/make2048.py --upload-to-s3 ... # same arguments as before

Run Training

Tokenized data can now be passed to the main training script, open_lm/main.py. Distributed computatation is handled via torchrun, and hyperparameters are specified by a variety of keyword arguments. We highlight several of the most important ones here:

  • train-data: location of the sharded tokenized training data. If locally generated and stored, this will point to a directory containing files like preproc_data/2048-v1/0/XXXXXXX.tar. Data are processed using the webdataset package where wildcards are supported like preproc_data/2048-v1/0/{0000000..0000099}.tar to select the first 100 .tar files.
  • model: Which model to use. See the table below to see valid options and parameter sizes for each.
  • train-num-samples: how many samples to use from the specified training dataset
  • name: name of this particular training run for logging purposes
  • report-to: if present, can be wandb, tensorboard, or all to stash logging information on WandB or Tensorboard.

Model choices are contained in the following table, where, for instance 11m indicates an 11 million parameter model and 1b indicates a 1 billion parameter model.

Model Name
open_lm_11m
open_lm_25m
open_lm_87m
open_lm_160m
open_lm_411m
open_lm_830m
open_lm_1b
open_lm_3b
open_lm_7b

An example training run can be called as follows:

>>> export CUDA_VISIBLE_DEVICES=0,1,2,3
>>> torchrun --nproc-per-node 4 -m open_lm.main   \
 --model open_lm_3b \
 --train-data /preproc_data/shard-{0000000..0000099}.tar \
 --train-num-samples 1000000000 \
 --workers 8 \
 --dataset-resampled \
 --precision amp_bfloat16 \
 --batch-size 8 \
 --grad-checkpointing \
 --log-every-n-steps 100 \
 --grad-clip-norm 1 \
 --data-key txt \
 --lr 3e-4 \
 --fsdp --fsdp-amp \
 --warmup 2000 \
 --wd 0.1 \
 --beta2 0.95 \
 --epochs 100 \
 --report-to wandb \
 --wandb-project-name open_lm_example \
 --name open_lm_ex_$RANDOM \
 --resume latest \
 --logs path/to/logging/dir/

Checkpoints and final model weights will be saved to the specified logs directory.

During training, the above command will pick shards to train on via sampling with replacement. Training can also be done by picking shards via sampling without replacement. To do this, the input dataset(s) must first be preprocessed using the following command:

python -m open_lm.utils.make_wds_manifest --data-dir /preproc_data/

This will create a file called manifest.jsonl under /preproc_data. Training can then be done by sampling wihout replacement via the following example commands:

>>> export CUDA_VISIBLE_DEVICES=0,1,2,3
>>> torchrun --nproc-per-node 4 -m open_lm.main   \
 --model open_lm_3b \
 --dataset-manifest /preproc_data/manifest.jsonl \
 --train-num-samples 1000000000 \
 --workers 8 \
 --precision amp_bfloat16 \
 --batch-size 8 \
 --grad-checkpointing \
 --log-every-n-steps 100 \
 --grad-clip-norm 1 \
 --data-key txt \
 --lr 3e-4 \
 --fsdp --fsdp-amp \
 --warmup 2000 \
 --wd 0.1 \
 --beta2 0.95 \
 --epochs 100 \
 --report-to wandb \
 --wandb-project-name open_lm_example \
 --name open_lm_ex_$RANDOM \
 --resume latest \
 --logs path/to/logging/dir/

Dataset manifest

The manifest created with open_lm/utils/make_wds_manifest.py is a jsonl file describing the dataset. Each line in this file corresponds to a shard of the dataset and is a json object containing two fields:

  • "shard": the name of a shard in the dataset.
  • "num_sequences": the number of sequences contained in the shards. Each sequence contains a set length of tokens.

This manifest file provides auxiliary information about the dataset, and is assumed to be found within the same directory as the shards.

Evaluate Model

Once trained, we can evaluate the model. This requires LLM Foundry, which can be installed via pip install llm-foundry. Next some configurations are required to pass to the evaluator: a skeleton of these parameters is located at eval/in_memory_hf_eval.yaml. Then just run the following script, making sure to point it at the checkpoint of your trained model (and it's correspending config .json file):

cd eval

python eval_openlm_ckpt.py \
--eval-yaml in_memory_hf_eval.yaml \
--model open_lm_1b  \
--checkpoint /path/to/openlm_checkpoint.pt
--positional_embedding_type head_rotary

Note that --positional-embedding-type head_rotary is only necessary if using the pretrained open_lm_1b model hosted below. See discussion in the next section about this.

Generate Text

One can also use a trained model to generate text. This is accessible via the script located at scripts/generate.py. The parameters are similar to those used in evaluation:

cd scripts

python generate.py \
--model open_lm_1b \
--checkpoint /path/to/openlm_checkpoint.pt \
--positional-embedding-type head_rotary \
--input-text "Please give me a recipe for chocolate chip cookies"

Citations

If you use this model in your work, please use the following BibTeX citations:

@inproceedings{Mercat2024LinearizingLL,
  title={Linearizing Large Language Models},
  author={Jean Mercat and Igor Vasiljevic and Sedrick Scott Keh and Kushal Arora and Achal Dave and Adrien Gaidon and Thomas Kollar},
  year={2024},
  url={https://arxiv.org/abs/2405.06640}
}
@misc{open_lm,
  author = {Gururangan, Suchin and Wortsman, Mitchell and Gadre, Samir Yitzhak and Dave, Achal and Kilian, Maciej and Shi, Weijia and Mercat, Jean and Smyrnis, Georgios and Ilharco, Gabriel and Jordan, Matt and Heckel, Reinhard and Dimakis, Alex and Farhadi, Ali and Shankar, Vaishaal and Schmidt, Ludwig},
  title = {{open_lm}:  a minimal but performative language modeling (LM) repository},
  year = {2023},
  note = {GitHub repository},
  url = {https://github.com/mlfoundations/open_lm/}
}

About

A repository for research on medium sized language models.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 97.8%
  • Jupyter Notebook 1.4%
  • Other 0.8%