Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(4/n) Data Refactor - Finetuning Scripts #950

Merged
merged 76 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
1655751
alpaca
awaelchli Feb 15, 2024
0b65b09
fixes
awaelchli Feb 23, 2024
1e273fe
fixes
awaelchli Feb 23, 2024
719ac2f
separate
awaelchli Feb 23, 2024
8b31c77
lima
awaelchli Feb 23, 2024
e4daef9
Merge branch 'main' into refactor/data
awaelchli Feb 24, 2024
a2047f6
integrate
awaelchli Feb 24, 2024
815e03b
remove converted datasets
awaelchli Feb 24, 2024
34bf71c
tinyllama
awaelchli Feb 24, 2024
1768d56
update
awaelchli Feb 24, 2024
0b9eca8
small typo fix: laoder -> loader
rasbt Feb 26, 2024
c83f468
refactor base class
awaelchli Feb 26, 2024
1f4a4ee
args stuff
awaelchli Feb 26, 2024
a8e6ce4
Merge branch 'refactor/data' of ssh://github.com/Lightning-AI/lit-gpt…
awaelchli Feb 26, 2024
65d4a21
max_seq_length needs to be specified differently
awaelchli Feb 27, 2024
19cbc37
fix for max steps
awaelchli Feb 27, 2024
095cc7b
fix init
awaelchli Feb 27, 2024
35537c0
tinyllama
awaelchli Feb 27, 2024
2f5658e
model config
awaelchli Feb 27, 2024
7f81bbe
remove epoch size
awaelchli Feb 27, 2024
da0710f
simplify
awaelchli Feb 27, 2024
9638b0b
fix
awaelchli Feb 27, 2024
edbd4e0
refactor
awaelchli Feb 27, 2024
ff0ba0e
init
awaelchli Feb 27, 2024
0d36d09
revert
awaelchli Feb 27, 2024
eae63aa
docs
awaelchli Feb 27, 2024
e475962
docs
awaelchli Feb 27, 2024
b759d48
fix test
awaelchli Feb 27, 2024
0c18563
Update tests/test_pretrain_tinyllama.py
awaelchli Feb 27, 2024
a5d1ae5
Update pretrain/tinyllama.py
awaelchli Feb 27, 2024
bb250f2
update gitnignore
awaelchli Feb 27, 2024
066b776
tests
awaelchli Feb 27, 2024
097a58b
no test loader
awaelchli Feb 27, 2024
7edd888
rename base
awaelchli Feb 27, 2024
175f223
remove name arg
awaelchli Feb 27, 2024
6bb4ec3
datasets collides with hf datasets import :(
awaelchli Feb 27, 2024
463cb56
Merge branch 'refactor/data-tinyllama' into refactor/data
awaelchli Feb 27, 2024
a13dfb6
move
awaelchli Feb 27, 2024
b30771c
Merge branch 'main' into refactor/data
awaelchli Feb 27, 2024
ec45af5
restore
awaelchli Feb 28, 2024
abce4e0
restore
awaelchli Feb 28, 2024
1e5bf65
tests
awaelchli Feb 28, 2024
7835dea
test
awaelchli Feb 28, 2024
c990ce5
test
awaelchli Feb 28, 2024
006c09d
update
awaelchli Feb 28, 2024
adc99b5
csv
awaelchli Feb 28, 2024
b4c9b71
test csv
awaelchli Feb 28, 2024
63ba730
remove old test
awaelchli Feb 28, 2024
216da43
dolly
awaelchli Feb 28, 2024
f28ffd7
longform
awaelchli Feb 28, 2024
49b5b0a
fixes
awaelchli Feb 28, 2024
5b99057
flan
awaelchli Feb 28, 2024
c47785f
fix
awaelchli Feb 28, 2024
d9e035f
update
awaelchli Feb 28, 2024
e1c2766
optional data
awaelchli Feb 28, 2024
3e90c4a
fix test split
awaelchli Feb 28, 2024
0b0fa20
todos
awaelchli Feb 28, 2024
1315b3f
Merge branch 'main' into refactor/data
awaelchli Feb 28, 2024
e8a7677
update test
awaelchli Feb 28, 2024
efd5b7e
tinyllama
awaelchli Feb 28, 2024
0cd28d6
update
awaelchli Feb 29, 2024
9c6135c
lora
awaelchli Feb 29, 2024
2014b18
adapter
awaelchli Feb 29, 2024
e4c6396
adapter v2
awaelchli Feb 29, 2024
8fcfe26
update
awaelchli Feb 29, 2024
eac6bd6
update tests
awaelchli Feb 29, 2024
563c580
update
awaelchli Feb 29, 2024
ebb5b7b
update
awaelchli Feb 29, 2024
2053cc0
tests
awaelchli Feb 29, 2024
e5637e3
tests
awaelchli Feb 29, 2024
d69c73a
reset
awaelchli Feb 29, 2024
490c51e
Merge branch 'main' into refactor/data
awaelchli Feb 29, 2024
6bcff6a
Run CI on wip branch
carmocca Feb 29, 2024
8d7a2b3
Merge branch 'main' into refactor/data
awaelchli Feb 29, 2024
7215ac5
require either epochs or max_steps to be set
awaelchli Feb 29, 2024
5e10c9e
don't inline max_steps redefinition
awaelchli Feb 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ build
# data
data
datasets
!lit_gpt/datasets
checkpoints
out
wandb
Expand Down
104 changes: 36 additions & 68 deletions finetune/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union

import lightning as L
import torch
from torch.utils.data import DataLoader

import lightning as L
from lightning.fabric.loggers import CSVLogger
from lightning.fabric.strategies import FSDPStrategy
from torchmetrics import RunningMean
Expand All @@ -21,24 +23,23 @@
from lit_gpt.args import EvalArgs, IOArgs, TrainArgs
from lit_gpt.model import GPT, Block, Config
from lit_gpt.tokenizer import Tokenizer
from lit_gpt.datasets import Alpaca, apply_prompt_template
from lit_gpt.utils import (
CLI,
check_valid_checkpoint_dir,
chunked_cross_entropy,
get_default_supported_precision,
load_checkpoint,
num_parameters,
CycleIterator,
)
from scripts.prepare_alpaca import generate_prompt


def setup(
precision: Optional[str] = None,
devices: int = 1,
resume: Union[bool, Path] = False,
io: IOArgs = IOArgs(
train_data_dir=Path("data/alpaca"),
val_data_dir=Path("data/alpaca"),
checkpoint_dir=Path("checkpoints/stabilityai/stablelm-base-alpha-3b"),
out_dir=Path("out/full/alpaca"),
),
Expand All @@ -49,7 +50,6 @@ def setup(
micro_batch_size=1,
lr_warmup_steps=100,
epochs=5,
epoch_size=50000,
learning_rate=3e-3,
max_seq_length=None,
),
Expand Down Expand Up @@ -85,7 +85,10 @@ def main(
) -> None:
validate_args(io, train, eval)

steps_per_epoch = train.epoch_size // devices // train.batch_size(devices)
datamodule = Alpaca(io.checkpoint_dir)
train_dataloader, val_dataloader = get_dataloaders(fabric, datamodule)

steps_per_epoch = len(train_dataloader) // train.gradient_accumulation_iters(devices)
lr_max_steps = train.epochs * steps_per_epoch

check_valid_checkpoint_dir(io.checkpoint_dir)
Expand All @@ -95,9 +98,6 @@ def main(
if fabric.global_rank == 0:
os.makedirs(io.out_dir, exist_ok=True)

train_data = torch.load(io.train_data_dir / "train.pt")
val_data = torch.load(io.val_data_dir / "test.pt")

checkpoint_path = io.checkpoint_dir / "lit_model.pth"
fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}")
with fabric.init_module(empty_init=(devices > 1)):
Expand All @@ -121,10 +121,8 @@ def main(
else:
load_checkpoint(fabric, state["model"], checkpoint_path)

fabric.seed_everything(1337 + fabric.global_rank)

train_time = time.perf_counter()
fit(fabric, state, train_data, val_data, devices, resume, io, train, eval)
fit(fabric, state, train_dataloader, val_dataloader, devices, resume, io, train, eval)
fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s")
if fabric.device.type == "cuda":
fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")
Expand All @@ -136,8 +134,8 @@ def main(
def fit(
fabric: L.Fabric,
state: Dict,
train_data: List[Dict],
val_data: List[Dict],
train_dataloader: DataLoader,
val_dataloader: DataLoader,
devices: int,
resume: Union[bool, Path],
io: IOArgs,
Expand All @@ -155,14 +153,15 @@ def fit(
f" {model.max_seq_length} and context length is {model.config.block_size}"
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
)

validate(fabric, model, val_data, tokenizer, dataclasses.replace(eval, max_iters=2), train) # sanity check
validate(fabric, model, val_dataloader, tokenizer, dataclasses.replace(eval, max_iters=2), train) # sanity check
initial_iter = state["iter_num"]
train_iterator = CycleIterator(train_dataloader)

# resume data loader state by fast-forwarding through all seen batches
if resume:
resume_t0 = time.perf_counter()
for resume_iter in range(initial_iter):
get_batch(fabric, train_data, None)
next(train_iterator)
if resume_iter % 1000 == 0:
fabric.print(f"Resuming dataset: {resume_iter} / {initial_iter}")
fabric.barrier()
Expand All @@ -176,16 +175,11 @@ def fit(
)
fabric.barrier()

for state["iter_num"] in range(state["iter_num"] + 1, train.max_iters(devices) + 1):
while state["iter_num"] <= train.max_iters(devices) and train_iterator.epoch < train.epochs:
state["iter_num"] += 1
iter_t0 = time.perf_counter()

input_ids, targets = get_batch(
fabric,
train_data,
train.micro_batch_size,
train.max_seq_length,
longest_seq_ix if state["iter_num"] == 1 else None,
)
batch = next(train_iterator)
input_ids, targets = batch["input_ids"], batch["labels"]

is_accumulating = state["iter_num"] % train.gradient_accumulation_iters(devices) != 0
with fabric.no_backward_sync(model, enabled=is_accumulating):
Expand Down Expand Up @@ -224,7 +218,7 @@ def fit(

if not is_accumulating and state["step_count"] % eval.interval == 0:
t0 = time.perf_counter()
val_loss = validate(fabric, model, val_data, tokenizer, eval, train)
val_loss = validate(fabric, model, val_dataloader, tokenizer, eval, train)
t1 = time.perf_counter() - t0
fabric.print(f"iter {state['iter_num']}: val loss {val_loss.item():.4f}, val time: {t1 * 1000:.2f} ms")
metrics = {"val_loss": val_loss, "val_ppl": math.exp(val_loss)}
Expand All @@ -239,22 +233,23 @@ def fit(
# FSDP has issues with `inference_mode`
@torch.no_grad()
def validate(
fabric: L.Fabric, model: GPT, val_data: List[Dict], tokenizer: Tokenizer, eval: EvalArgs, train: TrainArgs
fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, tokenizer: Tokenizer, eval: EvalArgs, train: TrainArgs
) -> torch.Tensor:
fabric.print("Validating ...")
model.eval()
losses = torch.zeros(eval.max_iters)
val_iterator = iter(val_dataloader)
for k in range(eval.max_iters):
input_ids, targets = get_batch(fabric, val_data, train.micro_batch_size, train.max_seq_length)
batch = next(val_iterator)
input_ids, targets = batch["input_ids"], batch["labels"]
logits = model(input_ids)
losses[k] = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:], chunk_size=0)
val_loss = losses.mean()

# produce an example:
instruction = "Recommend a movie for me to watch during the weekend and explain the reason."
fabric.print(instruction)
sample = {"instruction": instruction, "input": ""}
prompt = generate_prompt(sample)
prompt = apply_prompt_template(val_dataloader.dataset.prompt_template, {"instruction": instruction})
encoded = tokenizer.encode(prompt, device=fabric.device)
with fabric.init_tensor():
# do not set `max_seq_length=max_returned_token` because memory is not a concern here
Expand All @@ -270,51 +265,24 @@ def validate(
return val_loss


def get_batch(
fabric: L.Fabric,
data: List[Dict],
micro_batch_size: int,
max_seq_length: Optional[int],
longest_seq_ix: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
ix = torch.randint(len(data), (micro_batch_size,))
if longest_seq_ix is not None:
# force the longest sample at the beginning so potential OOMs happen right away
ix[0] = longest_seq_ix

input_ids = [data[i]["input_ids"].type(torch.int64) for i in ix]
labels = [data[i]["labels"].type(torch.int64) for i in ix]

# this could be `longest_seq_length` to have a fixed size for all batches
max_len = max(len(s) for s in input_ids)

def pad_right(x, pad_id):
# pad right based on the longest sequence
n = max_len - len(x)
return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype)))

x = torch.stack([pad_right(x, pad_id=0) for x in input_ids])
y = torch.stack([pad_right(x, pad_id=-1) for x in labels])

# Truncate if needed
if max_seq_length:
x = x[:, :max_seq_length]
y = y[:, :max_seq_length]

if fabric.device.type == "cuda" and x.device.type == "cpu":
x, y = fabric.to_device((x.pin_memory(), y.pin_memory()))
else:
x, y = fabric.to_device((x, y))
return x, y


def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int):
# linear warmup followed by cosine annealing
scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps)
scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(max_steps - warmup_steps))
return torch.optim.lr_scheduler.SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[warmup_steps])


def get_dataloaders(fabric: L.Fabric, datamodule: L.LightningDataModule) -> Tuple[DataLoader, DataLoader]:
if fabric.global_rank == 0:
datamodule.prepare_data()
fabric.barrier()
datamodule.setup()
train_dataloader = datamodule.train_dataloader()
val_dataloader = datamodule.val_dataloader()
train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader)
return train_dataloader, val_dataloader


def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]:
# find out the minimum max_seq_length required during fine-tuning (saves memory!)
lengths = [len(d["input_ids"]) for d in data]
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
5 changes: 5 additions & 0 deletions lit_gpt/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

from lit_gpt.datasets.base import SFTDataset, apply_prompt_template, sft_collate_fn
from lit_gpt.datasets.alpaca import Alpaca
from lit_gpt.datasets.lima import LIMA
Loading
Loading