Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions configs/config.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
model: /model-weights/Llama-2-7b-chat-hf
model: /model-weights/Llama-2-7b-chat-hf/
enable_wandb_logging: True

wandb_config:
Expand All @@ -7,7 +7,7 @@ wandb_config:

train_parameters:
output_dir: your/output/dir
max_seq_len: 1024
max_seq_len: 4096
epochs: 1
seed: 11

Expand All @@ -18,6 +18,7 @@ train_parameters:
use_mp: True
use_activation_checkpointing: True
use_flash_attention: True
low_cpu_mem_usage: True

# Gradient norm clipping
max_grad_norm: 1
Expand Down
1 change: 1 addition & 0 deletions docs/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ The key-value pairs stored under `wandb_config` are directly passed into the [`w
* `use_mp`: Whether to use mixed precision. This is done using bf16.
* `use_activation_checkpointing`: Whether to use activation checkpointing. This greatly reduces memory footprint as only a few intermediate activations as saved during the forward pass, and are then recomputed for the backward pass on the fly. However, the tradeoff between compute vs. memory usually makes this worth it.
* `use_flash_attention`: Whether to use Flash Attention. If it is supported for your model in HuggingFace, you can enable this option.
* `low_cpu_mem_usage`: Whether to efficiently load the model. If enabled, the model weights are only loaded once on rank 0 and are broadcasted to the rest of the world from the main rank. It will prevent the CPU memory from exploding when loading large models (e.g. LLaMa-70B).

### Gradient

Expand Down
5 changes: 5 additions & 0 deletions examples/llama_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def parse_args() -> Namespace:
Returns
-------
The parsed arguments.

"""
parser = argparse.ArgumentParser()
parser.add_argument(
Expand Down Expand Up @@ -62,6 +63,8 @@ def main(config: Config) -> None:
training_args.use_mp,
training_args.use_flash_attention,
training_args.max_seq_len,
local_rank,
training_args.low_cpu_mem_usage,
)

model = shard_model(
Expand All @@ -70,6 +73,8 @@ def main(config: Config) -> None:
training_args.use_mp,
training_args.use_activation_checkpointing,
training_args.sharding_strategy,
local_rank,
training_args.low_cpu_mem_usage,
)

# load dataset
Expand Down
9 changes: 5 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ build-backend = "setuptools.build_meta"

[tool.ruff]
line-length = 80
select = ["ALL"]
ignore = [
lint.select = ["ALL"]
lint.ignore = [
"ANN101",
"FBT",
"D100",
Expand All @@ -15,9 +15,10 @@ ignore = [
"N817",
"TCH001",
"E731",
"PLR0913"
"PLR0913",
"T201"
]

[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
# Ignore `F401` (import violations) in all `__init__.py` files.
"__init__.py" = ["F401", "D104"]
2 changes: 2 additions & 0 deletions vectorlm/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class Dataset:
train_bs: A per-device batch size for training.
eval_bs: A per-device batch size for evaluating.
_processed_ids: A tensor of already trained examples.

"""

def __init__(
Expand All @@ -41,6 +42,7 @@ def __init__(
----
config: The dataset config.
tokenizer: The input tokenizer.

"""
self.config = config
self._processed_ids = torch.tensor([]).to(torch.cuda.current_device())
Expand Down
12 changes: 11 additions & 1 deletion vectorlm/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@

import torch
import torch.distributed as dist
import wandb
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler, ReduceLROnPlateau
from transformers import PreTrainedTokenizer

import wandb
from vectorlm.dataset import Dataset
from vectorlm.utils.data_utils import Config
from vectorlm.utils.save_utils import (
Expand Down Expand Up @@ -47,6 +47,7 @@ class Trainer:
epoch.
max_steps: An integer maximum number of training steps for this run.
saving_steps: An integer for how often we save.

"""

def __init__(self,
Expand All @@ -62,6 +63,7 @@ def __init__(self,
enable_wandb_logging: Whether to enable wandb logging.
original_dataset_length: The length of the original dataset
(divided by the batch size).

"""
self.config = config
self.gas = config.gradient_accumulation_steps
Expand Down Expand Up @@ -114,6 +116,7 @@ def prepare_trainer(
dataset: The `Dataset` class.
optimizer: The training optimizer.
lr_scheduler: The LR scheduler.

"""
self.model = model
self.tokenizer = tokenizer
Expand All @@ -127,6 +130,7 @@ def save_checkpoint(self, epoch: int) -> None:
Args:
----
epoch: The current training epoch.

"""
rank = dist.get_rank()
gathered_processed_ids = _gather(
Expand Down Expand Up @@ -161,6 +165,7 @@ def load_checkpoint(self, checkpoint_dir: str) -> int:
Returns:
-------
The checkpointed epoch to be used by the outer loop.

"""
rank = dist.get_rank()
step, epoch, ids = load_metadata(checkpoint_dir)
Expand All @@ -184,6 +189,7 @@ def find_checkpoint(self, checkpoint_dir: str) -> int:
-------
The checkpointed epoch. If no checkpoint exists, it returns a
default value of 0.

"""
checkpoint = checkpoint_exists(checkpoint_dir)
if checkpoint:
Expand All @@ -210,6 +216,7 @@ def step(
----
train_batch: The training batch.
epoch: The current training epoch.

"""
if (
self.config.checkpointing_enabled
Expand All @@ -234,6 +241,7 @@ def train_step(self, batch: dict[str, torch.Tensor], epoch: int) -> float:
----
batch: The training batch.
epoch: The current training epoch.

"""
ids = batch.pop("id").to(torch.cuda.current_device())
batch["input_ids"] = batch["input_ids"].type(torch.LongTensor)
Expand Down Expand Up @@ -274,6 +282,7 @@ def eval_step(self, epoch: int) -> float:
Args:
----
epoch: The current training epoch.

"""
print_main("Evaluating")
self.model.eval()
Expand Down Expand Up @@ -305,6 +314,7 @@ def log(self, loss: float, epoch: int, mode: str = "train") -> None:
loss: The loss being logged.
epoch: The current training epoch.
mode: One of `train` or `eval`.

"""
if mode not in {"train", "eval"}:
msg = "`mode` argument needs to be 'train' or 'eval'."
Expand Down
2 changes: 2 additions & 0 deletions vectorlm/utils/convert_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def parse_args() -> Namespace:
Returns
-------
The parsed arguments.

"""
parser = argparse.ArgumentParser()
parser.add_argument("--config_path", default="configs/config.yaml")
Expand All @@ -28,6 +29,7 @@ def converter(config: Config) -> None:
Args:
----
config: The full config.

"""
state_dict = torch.load(
os.path.join(
Expand Down
5 changes: 5 additions & 0 deletions vectorlm/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class Config:
----------
yaml_path: A path to the yaml file that stores our config.
to_box: A boolean indicating whether to box our config.

"""

def __init__(self, yaml_path: str, to_box: bool = True) -> None:
Expand All @@ -21,6 +22,7 @@ def __init__(self, yaml_path: str, to_box: bool = True) -> None:
----
yaml_path: The string path to the config yaml.
to_box: Defines whether this initialization will use dot notation.

"""
self.yaml_path = yaml_path
self.to_box = to_box
Expand Down Expand Up @@ -55,6 +57,7 @@ class DataCollatorWithPadding:
ignore_index: A value used for ignoring a given token in labels.
max_seq_len: An integer denoting the maximum sequence length.
padding_side: A side of the sequence that gets padded.

"""

def __init__(
Expand All @@ -73,6 +76,7 @@ def __init__(
loss.
max_seq_len: The maximum sequence length to expect.
padding_side: The side of the sequence which is padded.

"""
self.pad_token_id = pad_token_id
self.ignore_index = ignore_index
Expand All @@ -99,6 +103,7 @@ def __call__(
Returns:
-------
A dictionary containing a batch that we can input to our model.

"""
batch = {}
keys = ["input_ids", "labels"]
Expand Down
1 change: 1 addition & 0 deletions vectorlm/utils/misc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch.distributed as dist
import wandb

from vectorlm.utils.data_utils import Config


Expand Down
Loading