Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,6 @@ data/
/.vscode
/data
/env

.env
logs/
13 changes: 7 additions & 6 deletions configs/config_gemma.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
model: google/gemma-2b
model: google/gemma-2b-it
enable_wandb_logging: False

wandb_config:
Expand All @@ -7,9 +7,9 @@ wandb_config:
# tags: ["20240418-1a-preemption"]

train_parameters:
output_dir: weights
max_seq_len: 128
epochs: 10
output_dir: /network/scratch/j/jacob-junqi.tian/vectorlm/weights
max_seq_len: 1024
epochs: 100000
seed: 11

# Sharding strategy
Expand All @@ -33,10 +33,11 @@ train_parameters:
# Gradient norm clipping
max_grad_norm: 1
gradient_accumulation_steps: 4
batch_size: 2

# Optimizer
optimizer:
lr: 1.0e-4
lr: 2.0e-5
weight_decay: 0.1
betas: [0.9, 0.95]
eps: 1.0e-5
Expand All @@ -47,7 +48,7 @@ train_parameters:

# Checkpointing
checkpointing_enabled: False
logging_steps: 10
logging_steps: 100
save_frequency: 0.10

# Sampling during training
Expand Down
49 changes: 21 additions & 28 deletions examples/llama_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from transformers import set_seed

from vectorlm.dataset import Dataset
from vectorlm.trainer import Trainer
from vectorlm.trice import ICEMTrainer
from vectorlm.utils.data_utils import Config
from vectorlm.utils.misc_utils import cleanup, setup, wandb_setup
from vectorlm.utils.model_utils import (
Expand Down Expand Up @@ -92,7 +92,6 @@ def main(
# setup wandb
if rank == 0 and config.enable_wandb_logging:
wandb_setup(config, **config.wandb_config)
dist.barrier()

# load model and tokenizer
model, tokenizer = load_model_and_tokenizer(
Expand Down Expand Up @@ -152,7 +151,7 @@ def main(
)

# instantiate trainer
trainer = Trainer(
trainer = ICEMTrainer(
config=training_args,
enable_wandb_logging=config.enable_wandb_logging,
original_dataset_length=dataset.original_length,
Expand Down Expand Up @@ -186,33 +185,27 @@ def main(

# Checkpoint check. Always call before training.
# If no checkpoint, it returns 0.
checkpointed_epoch = trainer.find_checkpoint(training_args.output_dir)

for epoch in range(checkpointed_epoch, training_args.epochs):
train_dl_iterator = iter(dataset.train_dataloader)
for _ in tqdm(
range(len(dataset.train_dataloader)),
disable=rank != 0,
file=sys.__stdout__,
):
batch = next(train_dl_iterator)
trainer.step(batch, epoch)
trainer.model.train()
trainer.find_checkpoint(training_args.output_dir)
eval_acc = 0

pbar = tqdm(
range(config.train_parameters.epochs),
disable=rank != 0,
file=sys.__stdout__,
ncols=75,
)
for index in pbar:
train_loss, eval_output = trainer.step({}, index)
eval_acc = eval_output if eval_output is not None else eval_acc

if epoch == training_args.epochs - 1:
hf_save_dir = os.path.join(training_args.output_dir, "final-model")
else:
hf_save_dir = os.path.join(
training_args.output_dir,
"checkpoints",
f"epoch_{epoch}",
"end-epoch-model",
)
pbar.set_description(f"{train_loss:.3e}, {eval_acc * 100:.0f}%")

if is_lora_enabled:
save_peft_adapter(trainer.model, hf_save_dir)
else:
save_consolidated_model(trainer.model, hf_save_dir, rank)
dataset.reset_dataloaders()
if is_lora_enabled:
save_peft_adapter(trainer.model, hf_save_dir)
else:
save_consolidated_model(trainer.model, hf_save_dir, rank)
dataset.reset_dataloaders()

sys.exit(0)

Expand Down
1 change: 1 addition & 0 deletions vectorlm/sampling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
ManagedMultiProcGPUExecutor,
SamplingEngineProvider,
SynchronizationBarriers,
batch_process,
handle_sample,
multiprocess_wrap,
)
27 changes: 25 additions & 2 deletions vectorlm/sampling/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

if TYPE_CHECKING:
import torch
from vectorlm.trainer import Trainer
from transformers import PreTrainedTokenizer

from .utils import SynchronizationBarriers

Expand Down Expand Up @@ -38,12 +38,18 @@ def __init__(
self.vllm_train_step = -1

@abstractmethod
def update(self, model: torch.nn.Module, train_step: int) -> None:
def update(
self,
model: torch.nn.Module,
train_step: int,
tokenizer: PreTrainedTokenizer | None = None,
) -> None:
"""Update model in sampling engine if the current copy is stale.

Params:
model: PeftModel, up-to-date model
train_step: int, train step of the given model.
tokenizer: optionally, provide updated copy of tokenizer.
"""
if self.vllm_train_step != train_step:
# Update parameters of self.vllm_llm using the given `model``.
Expand All @@ -54,11 +60,16 @@ def generate(
self,
prompts: list[str],
sampling_params: vllm.SamplingParams | None = None,
use_tqdm: bool = False,
) -> list[vllm.RequestOutput]:
"""Generate continuation for the given prompts synchronously.

Invoke at all ranks. Output will be broadcasted to all ranks.

Only one thread should execute this method at a time. For performance,
supply a large number of prompts at a time instead of one prompt
at a time.

Params:
------
prompts: List of input prompts.
Expand All @@ -70,3 +81,15 @@ def generate(
Output from vllm: list[vllm.RequestOutput], one for each prompt.

"""

def generate_text_only(
self,
prompts: list[str],
sampling_params: vllm.SamplingParams | None = None,
use_tqdm: bool = False,
) -> list[str]:
"""Generate and return text only."""
return [
response.outputs[0].text
for response in self.generate(prompts, sampling_params, use_tqdm)
]
18 changes: 14 additions & 4 deletions vectorlm/sampling/sampling_lora.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from __future__ import annotations

import os
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING

import torch
import torch.distributed as dist
import vllm
from vllm.lora.request import LoRARequest
Expand All @@ -15,6 +14,7 @@

if TYPE_CHECKING:
from peft.peft_model import PeftModel
from transformers import PreTrainedTokenizer


class LoRASamplingEngine(AbstractSamplingEngine):
Expand Down Expand Up @@ -63,15 +63,24 @@ def __init__(
self.generate_fn = multiprocess_wrap(generate_fn_raw, self.barriers)
self.vllm_train_step = -1

def update(self, model: PeftModel, train_step: int) -> None:
def update(
self,
model: PeftModel,
train_step: int,
tokenizer: PreTrainedTokenizer | None = None,
) -> None:
"""Update model in sampling engine if the current copy is stale.

Params:
model: PeftModel, up-to-date model
train_step: int, train step of the given model.
tokenizer: optionally, provide updated copy of tokenizer.
"""
self.barriers.before_generation.wait()
if self.vllm_train_step != train_step:
if tokenizer is not None:
tokenizer.save_pretrained(self.adapter_temp_folder)

save_peft_adapter(model, self.adapter_temp_folder)
self.vllm_train_step = train_step
self.lora_request = LoRARequest(
Expand All @@ -86,6 +95,7 @@ def generate(
self,
prompts: list[str],
sampling_params: vllm.SamplingParams | None = None,
use_tqdm: bool = False,
) -> list[vllm.RequestOutput]:
"""Generate continuation for the given prompts. Invoke at all ranks.

Expand All @@ -106,7 +116,7 @@ def generate(
prompts,
sampling_params,
lora_request=self.lora_request,
use_tqdm=False,
use_tqdm=use_tqdm,
)
assert len(return_value) == len(prompts)

Expand Down
41 changes: 40 additions & 1 deletion vectorlm/sampling/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,15 @@
import threading
import time
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Iterable, NamedTuple, TypeVar
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
Iterator,
NamedTuple,
TypeVar,
)

from vllm import LLM, LLMEngine, SamplingParams
from vllm.engine.arg_utils import EngineConfig
Expand Down Expand Up @@ -59,6 +67,8 @@ class SynchronizationBarriers(NamedTuple):


Fn = TypeVar("Fn", bound=Callable[..., Any])
InputItem = TypeVar("InputItem")
OutputItem = TypeVar("OutputItem")


def multiprocess_wrap(fn: Fn | None, barriers: SynchronizationBarriers) -> Fn:
Expand Down Expand Up @@ -116,6 +126,35 @@ def _wrapped_fn(*args, **kwargs) -> ...: # noqa: ANN002,ANN003
return _wrapped_fn # type: ignore[reportReturnType]


def batch_process(
input_data: Iterable[InputItem],
fn: Callable[[Iterable[InputItem]], Iterable[OutputItem]],
max_batch_size: int,
) -> Iterator[OutputItem]:
"""Process input data one batch at a time.

Params:
------
input_data: iterator of data to enter into fn.
fn: function that accepts a batch of data and produces
an output of equal length.
max_batch_size: maximum size of a batch.

Yields
------
Iterator of output.

"""
input_batch: list[InputItem] = []
for input_item in input_data:
input_batch.append(input_item)
if len(input_batch) == max_batch_size:
yield from fn(input_batch)
input_batch = []

yield from fn(input_batch)


class ManagedMultiProcGPUExecutor(MultiprocessingGPUExecutor):
"""MultiProcGPUExecutor, but with VectorLM launched alongside vLLM.

Expand Down
15 changes: 15 additions & 0 deletions vectorlm/tests/test_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from vectorlm.sampling import batch_process

RATIONALE_ANSWER_REGEXP = r"(.+)\s\(([A-C])\)[^\(\)]*$"


def test_batch_process() -> None:
"""Test batch_process."""
example_input = list("banana")
output = []
for output_item in batch_process(example_input, lambda x: x, 5):
print(output_item)
output.append(output_item)


assert output == example_input
Loading