Skip to content

Commit

Permalink
(4/n) Data Refactor - Finetuning Scripts (#950)
Browse files Browse the repository at this point in the history
Co-authored-by: rasbt <mail@sebastianraschka.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
  • Loading branch information
3 people committed Mar 18, 2024
1 parent 3d77652 commit 63f162f
Show file tree
Hide file tree
Showing 31 changed files with 345 additions and 2,559 deletions.
117 changes: 49 additions & 68 deletions finetune/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import lightning as L
import torch
from torch.utils.data import DataLoader
from lightning.fabric.loggers import CSVLogger
from lightning.fabric.plugins import BitsandbytesPrecision
from lightning.fabric.strategies import FSDPStrategy
Expand All @@ -20,6 +21,7 @@
from generate.base import generate
from lit_gpt.adapter import GPT, Block, Config, adapter_filter, mark_only_adapter_as_trainable
from lit_gpt.args import EvalArgs, IOArgs, TrainArgs
from lit_gpt.data import Alpaca, LitDataModule, apply_prompt_template
from lit_gpt.tokenizer import Tokenizer
from lit_gpt.utils import (
CLI,
Expand All @@ -28,20 +30,19 @@
get_default_supported_precision,
load_checkpoint,
num_parameters,
CycleIterator,
)
from scripts.prepare_alpaca import generate_prompt


def setup(
precision: Optional[str] = None,
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8-training"]] = None,
devices: int = 1,
seed: int = 1337,
data: Optional[LitDataModule] = None,
io: IOArgs = IOArgs(
train_data_dir=Path("data/alpaca"),
val_data_dir=Path("data/alpaca"),
checkpoint_dir=Path("checkpoints/stabilityai/stablelm-base-alpha-3b"),
out_dir=Path("out/adapter/alpaca"),
out_dir=Path("out/adapter"),
),
train: TrainArgs = TrainArgs(
save_interval=1000,
Expand All @@ -50,13 +51,16 @@ def setup(
micro_batch_size=4,
lr_warmup_steps=100,
epochs=5,
epoch_size=50000,
learning_rate=1e-3,
max_seq_length=None,
),
eval: EvalArgs = EvalArgs(interval=600, max_new_tokens=100, max_iters=100),
) -> None:

print(locals())
if data is None:
data = Alpaca()

precision = precision or get_default_supported_precision(training=True)

plugins = None
Expand Down Expand Up @@ -85,25 +89,24 @@ def setup(

logger = CSVLogger(io.out_dir.parent, io.out_dir.name, flush_logs_every_n_steps=train.log_interval)
fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=logger, plugins=plugins)
fabric.launch(main, devices, seed, Config.from_name(name=io.checkpoint_dir.name), io, train, eval)
fabric.launch(main, devices, seed, Config.from_name(name=io.checkpoint_dir.name), data, io, train, eval)


def main(fabric: L.Fabric, devices: int, seed: int, config: Config, io: IOArgs, train: TrainArgs, eval: EvalArgs) -> None:
def main(fabric: L.Fabric, devices: int, seed: int, config: Config, data: LitDataModule, io: IOArgs, train: TrainArgs, eval: EvalArgs) -> None:
validate_args(io, train, eval)

steps_per_epoch = train.epoch_size // devices // train.batch_size(devices)
lr_max_steps = train.epochs * steps_per_epoch

check_valid_checkpoint_dir(io.checkpoint_dir)

tokenizer = Tokenizer(io.checkpoint_dir)
train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train)
steps_per_epoch = len(train_dataloader) // train.gradient_accumulation_iters(devices)
lr_max_steps = min(train.epochs * steps_per_epoch, (train.max_steps or float("inf")))

fabric.seed_everything(seed) # same seed for every process to init model (FSDP)

if fabric.global_rank == 0:
os.makedirs(io.out_dir, exist_ok=True)

train_data = torch.load(io.train_data_dir / "train.pt")
val_data = torch.load(io.val_data_dir / "test.pt")

checkpoint_path = io.checkpoint_dir / "lit_model.pth"
fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}")
with fabric.init_module(empty_init=(devices > 1)):
Expand Down Expand Up @@ -131,10 +134,8 @@ def main(fabric: L.Fabric, devices: int, seed: int, config: Config, io: IOArgs,
# strict=False because missing keys due to Adapter weights not contained in state dict
load_checkpoint(fabric, model, checkpoint_path, strict=False)

fabric.seed_everything(1337 + fabric.global_rank)

train_time = time.perf_counter()
fit(fabric, model, optimizer, scheduler, train_data, val_data, devices, io, train, eval)
fit(fabric, model, optimizer, scheduler, train_dataloader, val_dataloader, devices, io, train, eval)
fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s")
if fabric.device.type == "cuda":
fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")
Expand All @@ -149,34 +150,37 @@ def fit(
model: GPT,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler,
train_data: List[Dict],
val_data: List[Dict],
train_dataloader: DataLoader,
val_dataloader: DataLoader,
devices: int,
io: IOArgs,
train: TrainArgs,
eval: EvalArgs,
) -> None:
tokenizer = Tokenizer(io.checkpoint_dir)
longest_seq_length, longest_seq_ix = get_longest_seq_length(train_data)
longest_seq_length, longest_seq_ix = get_longest_seq_length(train_dataloader.dataset)
model.max_seq_length = min(longest_seq_length, train.max_seq_length or float("inf"))
fabric.print(
f"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is"
f" {model.max_seq_length} and context length is {model.config.block_size}"
)

validate(fabric, model, val_data, tokenizer, dataclasses.replace(eval, max_iters=2), train) # sanity check
validate(fabric, model, val_dataloader, tokenizer, dataclasses.replace(eval, max_iters=2)) # sanity check

train_iterator = CycleIterator(train_dataloader)
throughput = ThroughputMonitor(fabric, window_size=50)
max_steps = train.max_steps or float("inf")
step_count = 0
iter_num = 0
total_lengths = 0
total_t0 = time.perf_counter()

for iter_num in range(1, train.max_iters(devices) + 1):
while step_count < max_steps and train_iterator.epoch < train.epochs:
iter_num += 1
iter_t0 = time.perf_counter()

input_ids, targets = get_batch(
fabric, train_data, train.micro_batch_size, train.max_seq_length, longest_seq_ix if iter_num == 1 else None
)
batch = next(train_iterator)
input_ids, targets = batch["input_ids"], batch["labels"]

is_accumulating = iter_num % train.gradient_accumulation_iters(devices) != 0
with fabric.no_backward_sync(model, enabled=is_accumulating):
Expand Down Expand Up @@ -207,7 +211,7 @@ def fit(

if not is_accumulating and step_count % eval.interval == 0:
t0 = time.perf_counter()
val_loss = validate(fabric, model, val_data, tokenizer, eval, train)
val_loss = validate(fabric, model, val_dataloader, tokenizer, eval)
t1 = time.perf_counter() - t0
fabric.print(f"iter {iter_num}: val loss {val_loss.item():.4f}, val time: {t1 * 1000:.2f} ms")
fabric.barrier()
Expand All @@ -219,13 +223,15 @@ def fit(
# the adapter "kv cache" cannot be initialized under `inference_mode`
@torch.no_grad()
def validate(
fabric: L.Fabric, model: GPT, val_data: List[Dict], tokenizer: Tokenizer, eval: EvalArgs, train: TrainArgs
fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, tokenizer: Tokenizer, eval: EvalArgs,
) -> torch.Tensor:
fabric.print("Validating ...")
model.eval()
losses = torch.zeros(eval.max_iters)
val_iterator = iter(val_dataloader)
for k in range(eval.max_iters):
input_ids, targets = get_batch(fabric, val_data, train.micro_batch_size, train.max_seq_length)
batch = next(val_iterator)
input_ids, targets = batch["input_ids"], batch["labels"]
logits = model(input_ids)
losses[k] = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:], chunk_size=0)
val_loss = losses.mean()
Expand All @@ -234,7 +240,7 @@ def validate(
instruction = "Recommend a movie for me to watch during the weekend and explain the reason."
fabric.print(instruction)
sample = {"instruction": instruction, "input": ""}
prompt = generate_prompt(sample)
prompt = apply_prompt_template(val_dataloader.dataset.prompt_template, sample)
encoded = tokenizer.encode(prompt, device=fabric.device)
with fabric.init_tensor():
# do not set `max_seq_length=max_returned_token` because memory is not a concern here
Expand All @@ -250,51 +256,24 @@ def validate(
return val_loss


def get_batch(
fabric: L.Fabric,
data: List[Dict],
micro_batch_size: int,
max_seq_length: Optional[int],
longest_seq_ix: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
ix = torch.randint(len(data), (micro_batch_size,))
if longest_seq_ix is not None:
# force the longest sample at the beginning so potential OOMs happen right away
ix[0] = longest_seq_ix

input_ids = [data[i]["input_ids"].type(torch.int64) for i in ix]
labels = [data[i]["labels"].type(torch.int64) for i in ix]

# this could be `longest_seq_length` to have a fixed size for all batches
max_len = max(len(s) for s in input_ids)

def pad_right(x, pad_id):
# pad right based on the longest sequence
n = max_len - len(x)
return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype)))

x = torch.stack([pad_right(x, pad_id=0) for x in input_ids])
y = torch.stack([pad_right(x, pad_id=-1) for x in labels])

# Truncate if needed
if max_seq_length:
x = x[:, :max_seq_length]
y = y[:, :max_seq_length]

if fabric.device.type == "cuda" and x.device.type == "cpu":
x, y = fabric.to_device((x.pin_memory(), y.pin_memory()))
else:
x, y = fabric.to_device((x, y))
return x, y


def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int):
# linear warmup followed by cosine annealing
scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps)
scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(max_steps - warmup_steps))
return torch.optim.lr_scheduler.SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[warmup_steps])


def get_dataloaders(fabric: L.Fabric, data: LitDataModule, tokenizer: Tokenizer, train: TrainArgs) -> Tuple[DataLoader, DataLoader]:
data.connect(tokenizer=tokenizer, batch_size=train.micro_batch_size, max_seq_length=train.max_seq_length)
with fabric.rank_zero_first():
data.prepare_data()
data.setup()
train_dataloader = data.train_dataloader()
val_dataloader = data.val_dataloader()
train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader)
return train_dataloader, val_dataloader


def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]:
# find out the minimum max_seq_length required during fine-tuning (saves memory!)
lengths = [len(d["input_ids"]) for d in data]
Expand All @@ -316,14 +295,16 @@ def validate_args(io: IOArgs, train: TrainArgs, eval: EvalArgs) -> None:
if getattr(args, name) is not None:
issues.append(f"{__file__} doesn't support the {name!r} argument. This is set in {args}")
required = [
(io, ["checkpoint_dir", "train_data_dir", "val_data_dir"]),
(train, ["epoch_size", "epochs"]),
(io, ["checkpoint_dir"]),
(train, ["epochs"]),
(eval, ["max_new_tokens"]),
]
for args, names in required:
for name in names:
if getattr(args, name) is None:
issues.append(f"{__file__} requires the {name!r} argument. This is set in {args}")
if not train.epochs and not train.max_steps:
issues.append(f"{__file__} requires either epochs or max_steps to be set. This is set in {train}")
if issues:
raise ValueError("\n".join(issues))

Expand Down
Loading

0 comments on commit 63f162f

Please sign in to comment.