diff --git a/configs/config.yaml b/configs/config.yaml index 2778f9c..4ce2dde 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -1,4 +1,4 @@ -model: /model-weights/Llama-2-7b-chat-hf +model: /model-weights/Llama-2-7b-chat-hf/ enable_wandb_logging: True wandb_config: @@ -7,7 +7,7 @@ wandb_config: train_parameters: output_dir: your/output/dir - max_seq_len: 1024 + max_seq_len: 4096 epochs: 1 seed: 11 @@ -18,6 +18,7 @@ train_parameters: use_mp: True use_activation_checkpointing: True use_flash_attention: True + low_cpu_mem_usage: True # Gradient norm clipping max_grad_norm: 1 diff --git a/docs/config.md b/docs/config.md index db56806..0541c10 100644 --- a/docs/config.md +++ b/docs/config.md @@ -28,6 +28,7 @@ The key-value pairs stored under `wandb_config` are directly passed into the [`w * `use_mp`: Whether to use mixed precision. This is done using bf16. * `use_activation_checkpointing`: Whether to use activation checkpointing. This greatly reduces memory footprint as only a few intermediate activations as saved during the forward pass, and are then recomputed for the backward pass on the fly. However, the tradeoff between compute vs. memory usually makes this worth it. * `use_flash_attention`: Whether to use Flash Attention. If it is supported for your model in HuggingFace, you can enable this option. +* `low_cpu_mem_usage`: Whether to efficiently load the model. If enabled, the model weights are only loaded once on rank 0 and are broadcasted to the rest of the world from the main rank. It will prevent the CPU memory from exploding when loading large models (e.g. LLaMa-70B). ### Gradient diff --git a/examples/llama_example.py b/examples/llama_example.py index 9e9f4eb..d794740 100644 --- a/examples/llama_example.py +++ b/examples/llama_example.py @@ -26,6 +26,7 @@ def parse_args() -> Namespace: Returns ------- The parsed arguments. + """ parser = argparse.ArgumentParser() parser.add_argument( @@ -62,6 +63,8 @@ def main(config: Config) -> None: training_args.use_mp, training_args.use_flash_attention, training_args.max_seq_len, + local_rank, + training_args.low_cpu_mem_usage, ) model = shard_model( @@ -70,6 +73,8 @@ def main(config: Config) -> None: training_args.use_mp, training_args.use_activation_checkpointing, training_args.sharding_strategy, + local_rank, + training_args.low_cpu_mem_usage, ) # load dataset diff --git a/pyproject.toml b/pyproject.toml index 3cf671e..c9ca70d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,8 +4,8 @@ build-backend = "setuptools.build_meta" [tool.ruff] line-length = 80 -select = ["ALL"] -ignore = [ +lint.select = ["ALL"] +lint.ignore = [ "ANN101", "FBT", "D100", @@ -15,9 +15,10 @@ ignore = [ "N817", "TCH001", "E731", - "PLR0913" + "PLR0913", + "T201" ] -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] # Ignore `F401` (import violations) in all `__init__.py` files. "__init__.py" = ["F401", "D104"] diff --git a/vectorlm/dataset.py b/vectorlm/dataset.py index 768ef8f..b383803 100644 --- a/vectorlm/dataset.py +++ b/vectorlm/dataset.py @@ -28,6 +28,7 @@ class Dataset: train_bs: A per-device batch size for training. eval_bs: A per-device batch size for evaluating. _processed_ids: A tensor of already trained examples. + """ def __init__( @@ -41,6 +42,7 @@ def __init__( ---- config: The dataset config. tokenizer: The input tokenizer. + """ self.config = config self._processed_ids = torch.tensor([]).to(torch.cuda.current_device()) diff --git a/vectorlm/trainer.py b/vectorlm/trainer.py index 173c892..70ac1da 100644 --- a/vectorlm/trainer.py +++ b/vectorlm/trainer.py @@ -6,11 +6,11 @@ import torch import torch.distributed as dist +import wandb from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler, ReduceLROnPlateau from transformers import PreTrainedTokenizer -import wandb from vectorlm.dataset import Dataset from vectorlm.utils.data_utils import Config from vectorlm.utils.save_utils import ( @@ -47,6 +47,7 @@ class Trainer: epoch. max_steps: An integer maximum number of training steps for this run. saving_steps: An integer for how often we save. + """ def __init__(self, @@ -62,6 +63,7 @@ def __init__(self, enable_wandb_logging: Whether to enable wandb logging. original_dataset_length: The length of the original dataset (divided by the batch size). + """ self.config = config self.gas = config.gradient_accumulation_steps @@ -114,6 +116,7 @@ def prepare_trainer( dataset: The `Dataset` class. optimizer: The training optimizer. lr_scheduler: The LR scheduler. + """ self.model = model self.tokenizer = tokenizer @@ -127,6 +130,7 @@ def save_checkpoint(self, epoch: int) -> None: Args: ---- epoch: The current training epoch. + """ rank = dist.get_rank() gathered_processed_ids = _gather( @@ -161,6 +165,7 @@ def load_checkpoint(self, checkpoint_dir: str) -> int: Returns: ------- The checkpointed epoch to be used by the outer loop. + """ rank = dist.get_rank() step, epoch, ids = load_metadata(checkpoint_dir) @@ -184,6 +189,7 @@ def find_checkpoint(self, checkpoint_dir: str) -> int: ------- The checkpointed epoch. If no checkpoint exists, it returns a default value of 0. + """ checkpoint = checkpoint_exists(checkpoint_dir) if checkpoint: @@ -210,6 +216,7 @@ def step( ---- train_batch: The training batch. epoch: The current training epoch. + """ if ( self.config.checkpointing_enabled @@ -234,6 +241,7 @@ def train_step(self, batch: dict[str, torch.Tensor], epoch: int) -> float: ---- batch: The training batch. epoch: The current training epoch. + """ ids = batch.pop("id").to(torch.cuda.current_device()) batch["input_ids"] = batch["input_ids"].type(torch.LongTensor) @@ -274,6 +282,7 @@ def eval_step(self, epoch: int) -> float: Args: ---- epoch: The current training epoch. + """ print_main("Evaluating") self.model.eval() @@ -305,6 +314,7 @@ def log(self, loss: float, epoch: int, mode: str = "train") -> None: loss: The loss being logged. epoch: The current training epoch. mode: One of `train` or `eval`. + """ if mode not in {"train", "eval"}: msg = "`mode` argument needs to be 'train' or 'eval'." diff --git a/vectorlm/utils/convert_to_hf.py b/vectorlm/utils/convert_to_hf.py index 3e66dce..c4f2405 100644 --- a/vectorlm/utils/convert_to_hf.py +++ b/vectorlm/utils/convert_to_hf.py @@ -14,6 +14,7 @@ def parse_args() -> Namespace: Returns ------- The parsed arguments. + """ parser = argparse.ArgumentParser() parser.add_argument("--config_path", default="configs/config.yaml") @@ -28,6 +29,7 @@ def converter(config: Config) -> None: Args: ---- config: The full config. + """ state_dict = torch.load( os.path.join( diff --git a/vectorlm/utils/data_utils.py b/vectorlm/utils/data_utils.py index 4e14dd4..ddf021f 100644 --- a/vectorlm/utils/data_utils.py +++ b/vectorlm/utils/data_utils.py @@ -12,6 +12,7 @@ class Config: ---------- yaml_path: A path to the yaml file that stores our config. to_box: A boolean indicating whether to box our config. + """ def __init__(self, yaml_path: str, to_box: bool = True) -> None: @@ -21,6 +22,7 @@ def __init__(self, yaml_path: str, to_box: bool = True) -> None: ---- yaml_path: The string path to the config yaml. to_box: Defines whether this initialization will use dot notation. + """ self.yaml_path = yaml_path self.to_box = to_box @@ -55,6 +57,7 @@ class DataCollatorWithPadding: ignore_index: A value used for ignoring a given token in labels. max_seq_len: An integer denoting the maximum sequence length. padding_side: A side of the sequence that gets padded. + """ def __init__( @@ -73,6 +76,7 @@ def __init__( loss. max_seq_len: The maximum sequence length to expect. padding_side: The side of the sequence which is padded. + """ self.pad_token_id = pad_token_id self.ignore_index = ignore_index @@ -99,6 +103,7 @@ def __call__( Returns: ------- A dictionary containing a batch that we can input to our model. + """ batch = {} keys = ["input_ids", "labels"] diff --git a/vectorlm/utils/misc_utils.py b/vectorlm/utils/misc_utils.py index 30c1d67..6e8cc1a 100644 --- a/vectorlm/utils/misc_utils.py +++ b/vectorlm/utils/misc_utils.py @@ -5,6 +5,7 @@ import torch.distributed as dist import wandb + from vectorlm.utils.data_utils import Config diff --git a/vectorlm/utils/model_utils.py b/vectorlm/utils/model_utils.py index 09d5c27..e4d7950 100644 --- a/vectorlm/utils/model_utils.py +++ b/vectorlm/utils/model_utils.py @@ -1,7 +1,7 @@ from __future__ import annotations import functools -from typing import Any +from typing import Any, Callable import torch import torch.distributed as dist @@ -56,6 +56,7 @@ def load_peft_model_and_tokenizer( Returns: ------- The PEFT model and tokenizer. + """ model, tokenizer = load_model_and_tokenizer( path, @@ -72,11 +73,15 @@ def load_peft_model_and_tokenizer( ) return peft_model, tokenizer + def load_model_and_tokenizer( path: str, use_mp: bool, use_fa: bool, max_seq_len: int, + local_rank: int, + low_cpu_mem_usage: bool, + use_safetensors: bool = True, ) -> tuple[PreTrainedModel, PreTrainedTokenizer]: """Load the model and tokenizer. @@ -86,13 +91,19 @@ def load_model_and_tokenizer( use_mp: Whether to use mixed-precision. use_fa: Whether to use Flash Attention 2. max_seq_len: The maximum sequence length. + local_rank: The local rank of the current worker. + low_cpu_mem_usage: Whether to only load model weights on main rank, and + then scatter them to the other workers. + use_safetensors: Whether to use HF safe tensors. Note that this format + loads significantly faster. Returns: ------- The model and tokenizer. + """ # load model - model_args = {"use_cache": False} + model_args = {"use_cache": False, "use_safetensors": use_safetensors} if use_mp: model_args["torch_dtype"] = torch.bfloat16 @@ -101,10 +112,18 @@ def load_model_and_tokenizer( msg = "Use FA with bf16 (mixed precision)" raise ValueError(msg) model_args["attn_implementation"] = "flash_attention_2" - model = AutoModelForCausalLM.from_pretrained( - path, - **model_args, - ) + + if not low_cpu_mem_usage or local_rank == 0: + model = AutoModelForCausalLM.from_pretrained( + path, + **model_args, + ) + else: + with torch.device("meta"): + model = AutoModelForCausalLM.from_pretrained( + path, + **model_args, + ) # load tokenizer tokenizer = AutoTokenizer.from_pretrained(path) @@ -125,6 +144,8 @@ def fsdp_config( use_mp: bool, layer_to_wrap: nn.Module, strategy: str, + local_rank: int, + low_cpu_mem_usage: bool, ) -> dict[str, Any]: """Get FSDP config. @@ -133,11 +154,23 @@ def fsdp_config( use_mp: Whether to use mixed-precision. layer_to_wrap: The layer we are wrapping using FSDP. strategy: The sharding strategy to use. + local_rank: The local rank of the current worker. + low_cpu_mem_usage: Whether to only load model weights on main rank, and + then scatter them to the other workers. Returns: ------- A dictionary containing the configurations. + """ + + def _module_init_fn(module: nn.Module) -> Callable: + """Return the function used for initializing modules on FSDP workers.""" + return module.to_empty( + device=torch.cuda.current_device(), + recurse=False, + ) + strategy_exists = hasattr(ShardingStrategy, strategy) if not strategy_exists: msg = f"The sharding strategy {strategy} does not exist." @@ -161,6 +194,9 @@ def fsdp_config( ret_dict["auto_wrap_policy"] = auto_wrap_policy ret_dict["sharding_strategy"] = sharding_strategy ret_dict["device_id"] = torch.cuda.current_device() + if low_cpu_mem_usage: + ret_dict["param_init_fn"] = _module_init_fn if local_rank != 0 else None + ret_dict["sync_module_states"] = True return ret_dict @@ -170,6 +206,8 @@ def shard_model( use_mp: bool, use_activation_checkpointing: bool, strategy: str, + local_rank: int, + low_cpu_mem_usage: bool, ) -> nn.Module: """Shard the model to workers using FSDP. @@ -180,12 +218,18 @@ def shard_model( use_mp: Whether to use mixed-precision. use_activation_checkpointing: Whether to use activation checkpointing. strategy: The sharding strategy to use. + local_rank: The local rank of the current worker. + low_cpu_mem_usage: Whether to only load model weights on main rank, and + then scatter them to the other workers. Returns: ------- The sharded module with the requested configurations. + """ - fsdp_cfg = fsdp_config(use_mp, layer_to_wrap, strategy) + fsdp_cfg = fsdp_config( + use_mp, layer_to_wrap, strategy, local_rank, low_cpu_mem_usage, + ) if dist.get_rank() == 0: print(f"FSDP config: {fsdp_cfg}") model = FSDP(model, **fsdp_cfg) @@ -209,6 +253,7 @@ def hook_activation_checkpointing( ---- model: The model we are using. layer: The layer to which we hook activation checkpointing to. + """ non_reentrant_wrapper = functools.partial( checkpoint_wrapper, diff --git a/vectorlm/utils/optimizer_utils.py b/vectorlm/utils/optimizer_utils.py index ff479a9..2d99d04 100644 --- a/vectorlm/utils/optimizer_utils.py +++ b/vectorlm/utils/optimizer_utils.py @@ -31,6 +31,7 @@ class PlateaeuWithWarmup(ReduceLROnPlateau): The maximum LR is determined by the number of warmup steps and the current step. base_lrs: A list of base LRs for the optimizer's param groups. + """ def __init__( @@ -63,6 +64,7 @@ def __init__( otherwise the update is ignored. verbose: Whether to print messages to stdout every LR update. num_warmup_steps: The number of steps we warmup the LR for. + """ super().__init__( optimizer=optimizer, @@ -85,6 +87,7 @@ def step(self, metrics: float, epoch: int | None = None) -> None: --------- metrics: The metric we are using to measure change in LR. epoch: The current step. + """ if epoch is None: epoch = self.last_epoch + 1 @@ -159,9 +162,11 @@ def get_custom_scheduler( name: The name of the scheduler args: The scheduler specific args. kwargs: The scheduler specific kwargs. - + Returns: + ------- The scheduler. + """ if name == "plataeu-with-warmup": scheduler = PlateaeuWithWarmup(*args, **kwargs) diff --git a/vectorlm/utils/save_utils.py b/vectorlm/utils/save_utils.py index aef9993..cd6e8df 100644 --- a/vectorlm/utils/save_utils.py +++ b/vectorlm/utils/save_utils.py @@ -28,6 +28,7 @@ def checkpoint_exists(output_dir: str) -> bool: Returns: ------- Returns whether a checkpoint exists. + """ if os.path.isdir(os.path.join(output_dir, "checkpoints")): return True @@ -44,6 +45,7 @@ def save_metadata( ---- out_dir: The directory to save to. meta_dict: The dictionary containing the meta data. + """ os.makedirs(out_dir, exist_ok=True) torch.save(meta_dict, os.path.join(out_dir, "meta_data.pkl")) @@ -62,6 +64,7 @@ def load_metadata( ------- A tuple containing the checkpointed step, epoch, and the processed training dataset ids. + """ save_path = os.path.join(in_dir, "meta_data.pkl") meta_dict = torch.load(save_path) @@ -81,6 +84,7 @@ def get_latest_checkpoint_dir(folder_path: str) -> str: Returns: ------- The subpath (i.e. two levels) of the latest checkpoint's directory. + """ epoch_pattern = re.compile(r"^epoch_(\d+)$") folder_pattern = re.compile(r"^checkpoint_(\d+)$") @@ -112,6 +116,7 @@ def save_model(model: nn.Module, output_dir: str, rank: int) -> None: model: The sharded model. output_dir: The checkpointing directory. rank: The worker's rank. + """ os.makedirs(output_dir, exist_ok=True) weights_name = f"model_rank{rank}.bin" @@ -131,6 +136,7 @@ def load_model(model: nn.Module, input_dir: str, rank: int) -> None: model: The sharded model. input_dir: The checkpointing directory. rank: The worker's rank. + """ weights_name = f"model_rank{rank}.bin" input_model_file = os.path.join(input_dir, weights_name) @@ -154,6 +160,7 @@ def save_consolidated_model( model: The sharded model. save_dir: The checkpointing directory. rank: The worker's rank. + """ os.makedirs(save_dir, exist_ok=True) cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) @@ -178,6 +185,7 @@ def save_optimizer( model: The sharded model. output_dir: The checkpointing directory. rank: The worker's rank. + """ opt_name = f"optimizer_rank{rank}.bin" output_optimizer_file = os.path.join(output_dir, opt_name) @@ -207,6 +215,7 @@ def load_optimizer( model: The sharded model. input_dir: The checkpointing directory. rank: The worker's rank. + """ opt_name = f"optimizer_rank{rank}.bin" input_optimizer_file = os.path.join(input_dir, opt_name) @@ -237,6 +246,7 @@ def save_scheduler( scheduler: The LR scheduler. output_dir: The checkpointing directory. rank: The worker's rank. + """ sched_name = f"scheduler_rank{rank}.bin" output_scheduler_file = os.path.join(output_dir, sched_name) @@ -258,6 +268,7 @@ def load_scheduler( scheduler: The LR scheduler. input_dir: The checkpointing directory. rank: The worker's rank. + """ sched_name = f"scheduler_rank{rank}.bin" input_scheduler_file = os.path.join(input_dir, sched_name)