Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add script to prepare dataset from csv #462

Merged
merged 28 commits into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
213b084
Add script to prepare dataset from csv
Anindyadeep Aug 24, 2023
83e9aa3
Merge branch 'main' into anindya/add_csv_script
aniketmaurya Sep 11, 2023
0e8a604
Merge branch 'main' into anindya/add_csv_script
aniketmaurya Sep 11, 2023
ad0a7a4
Merge branch 'main' into anindya/add_csv_script
aniketmaurya Sep 11, 2023
ce2ff3e
update
aniketmaurya Sep 11, 2023
6a58378
Merge branch 'main' into anindya/add_csv_script
aniketmaurya Sep 11, 2023
10c4cf9
fixes
aniketmaurya Sep 11, 2023
2f13cb0
Merge branch 'main' into anindya/add_csv_script
aniketmaurya Sep 11, 2023
69e4f49
fix
aniketmaurya Sep 11, 2023
f2611cb
formatting
aniketmaurya Sep 11, 2023
a723a8a
format
aniketmaurya Sep 11, 2023
1c716ee
update requirements file
aniketmaurya Sep 11, 2023
41bed21
Merge branch 'main' into anindya/add_csv_script
aniketmaurya Sep 11, 2023
ef06941
add check
aniketmaurya Sep 11, 2023
a474126
Fixes and test
carmocca Sep 11, 2023
7798cbf
Forgot to commit the test
carmocca Sep 11, 2023
5ed85ac
Set encoding for windows
carmocca Sep 11, 2023
1b753e7
add tutorial section
rasbt Sep 12, 2023
6e6dd25
Merge branch 'main' into anindya/add_csv_script
rasbt Sep 12, 2023
853cfad
Add documentation to prepare dataset from csv.
Anindyadeep Sep 12, 2023
2f2043b
Merge branch 'main' into anindya/add_csv_script
aniketmaurya Sep 13, 2023
59b2af1
Fix: Small changes in documentation with rewordings
Anindyadeep Sep 13, 2023
a0182f0
Update scripts/prepare_csv.py
Anindyadeep Sep 14, 2023
8cd00c9
Update tutorials/prepare_dataset.md
Anindyadeep Sep 14, 2023
1113cad
Update tutorials/prepare_dataset.md
Anindyadeep Sep 14, 2023
2aa6b1e
Update tutorials/prepare_dataset.md
Anindyadeep Sep 14, 2023
fc73906
Merge branch 'main' into anindya/add_csv_script
aniketmaurya Sep 14, 2023
2edc7b1
Apply suggestions from code review
carmocca Sep 14, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/azure-gpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ jobs:
python -c "import torch ; mgpu = torch.cuda.device_count() ; assert mgpu == 2, f'GPU: {mgpu}'"
displayName: 'Image info & NVIDIA'

- script: pip install pytest pytest-rerunfailures -r requirements.txt transformers einops bitsandbytes scipy tokenizers zstandard
- script: pip install pytest pytest-rerunfailures -r requirements.txt transformers einops bitsandbytes scipy tokenizers zstandard pandas
displayName: 'Install dependencies'

- bash: pytest -v --durations=10 --disable-pytest-warnings --strict-markers --color=yes
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/cpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ jobs:

- name: Run tests without the package installed
run: |
pip install pytest pytest-rerunfailures transformers einops bitsandbytes scipy tokenizers zstandard
pip install pytest pytest-rerunfailures transformers einops bitsandbytes scipy tokenizers zstandard pandas
pip list

pytest --disable-pytest-warnings --strict-markers --color=yes
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ jsonargparse[signatures] # CLI
# datasets # quantize/gptq.py
# zstandard # scripts/prepare_redpajama.py
# git+https://github.com/EleutherAI/lm-evaluation-harness.git@master # eval
# pandas
141 changes: 141 additions & 0 deletions scripts/prepare_csv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import json
import logging
import sys
from pathlib import Path

import torch
from torch.utils.data import random_split
from tqdm import tqdm

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
logger = logging.getLogger(__name__)
sys.path.append(str(wd))

from lit_gpt.tokenizer import Tokenizer

COLUMNS = ("instruction", "input", "output")


def prepare(
csv_path: Path,
destination_path: Path = Path("data/csv"),
checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"),
test_split_fraction: float = 0.1,
seed: int = 42,
mask_inputs: bool = False,
ignore_index: int = -1,
) -> None:
"""Prepare a CSV dataset for instruction tuning.

The output is a training and test dataset saved as `train.pt` and `test.pt`,
which stores the preprocessed and tokenized prompts and labels.
"""
with open(checkpoint_dir / "lit_config.json", "r") as file:
config = json.load(file)
max_seq_length = config["block_size"]

destination_path.mkdir(parents=True, exist_ok=True)
logger.info("Loading data file ...")
import pandas as pd

df = pd.read_csv(csv_path, dtype=str).fillna("")
if not (df.columns.values == COLUMNS).all():
raise ValueError(f"CSV columns must be {COLUMNS}, found {df.columns.values}")
data = json.loads(df.to_json(orient="records", indent=4))

print("Loading tokenizer...")
tokenizer = Tokenizer(checkpoint_dir)

# Partition the dataset into train and test
train_set, test_set = random_split(
data, [1.0 - test_split_fraction, test_split_fraction], generator=torch.Generator().manual_seed(seed)
)
train_set, test_set = list(train_set), list(test_set)

print(f"train has {len(train_set):,} samples")
print(f"test has {len(test_set):,} samples")

print("Processing train split ...")
train_set = [
prepare_sample(
example=sample,
tokenizer=tokenizer,
max_length=max_seq_length,
mask_inputs=mask_inputs,
ignore_index=ignore_index,
)
for sample in tqdm(train_set)
]
torch.save(train_set, destination_path / "train.pt")

print("Processing test split ...")
test_set = [
prepare_sample(
example=sample,
tokenizer=tokenizer,
max_length=max_seq_length,
mask_inputs=mask_inputs,
ignore_index=ignore_index,
)
for sample in tqdm(test_set)
]
torch.save(test_set, destination_path / "test.pt")


def prepare_sample(example: dict, tokenizer: Tokenizer, max_length: int, mask_inputs: bool, ignore_index: int):
"""Processes a single sample.

Each sample in the dataset consists of:
- instruction: A string describing the task
- input: A string holding a special input value for the instruction.
This only applies to some samples, and in others this is empty.
- output: The response string

This function processes this data to produce a prompt text and a label for
supervised training. The prompt text is formed as a single message including both
the instruction and the input. The label/target is the same message but with the
response attached.

Finally, both the prompt and the label get tokenized. If desired, all tokens
in the label that correspond to the original input prompt get masked out (default).
"""
full_prompt = generate_prompt(example)
full_prompt_and_response = full_prompt + example["output"]
encoded_full_prompt = tokenizer.encode(full_prompt, max_length=max_length)
encoded_full_prompt_and_response = tokenizer.encode(full_prompt_and_response, eos=True, max_length=max_length)

# The labels are the full prompt with response, but with the prompt masked out
labels = encoded_full_prompt_and_response.clone()
if mask_inputs:
labels[: len(encoded_full_prompt)] = ignore_index

return {
**example,
"input_ids": encoded_full_prompt_and_response,
"input_ids_no_response": encoded_full_prompt,
"labels": labels,
}


def generate_prompt(example):
"""Generates a standardized message to prompt the model with an instruction, optional input and a
'response' field."""

if example["input"]:
return (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
f"### Instruction:\n{example['instruction']}\n\n### Input:\n{example['input']}\n\n### Response:"
)
return (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
f"### Instruction:\n{example['instruction']}\n\n### Response:"
)


if __name__ == "__main__":
from jsonargparse import CLI

CLI(prepare, as_positional=False)
99 changes: 99 additions & 0 deletions tests/test_prepare_csv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import json
import subprocess
import sys
from pathlib import Path
from unittest import mock
from unittest.mock import ANY, call


def test_prepare_csv(tmp_path, fake_checkpoint_dir):
with mock.patch("lit_gpt.tokenizer.Tokenizer"):
from scripts.prepare_csv import prepare

# create fake data
config = dict(block_size=128, padded_vocab_size=256, n_layer=3, n_head=8, n_embd=16)
with open(fake_checkpoint_dir / "lit_config.json", "w") as fp:
json.dump(config, fp)
csv_path = tmp_path / "data.csv"
mock_data = (
"instruction,input,output\n"
"Add,2+2,4\n"
"Subtract,5-3,2\n"
"Multiply,6*4,24\n"
"Divide,10/2,5\n"
"Exponentiate,2^3,8\n"
"Square root,√9,3\n"
)
with open(csv_path, "w", encoding="utf-8") as fp:
fp.write(mock_data)

with mock.patch("torch.save") as save_mock:
prepare(csv_path, destination_path=tmp_path, checkpoint_dir=fake_checkpoint_dir, test_split_fraction=0.5)

assert len(save_mock.mock_calls) == 2
train_calls, test_calls = save_mock.mock_calls
assert train_calls == call(
[
{
"instruction": "Add",
"input": "2+2",
"output": "4",
"input_ids": ANY,
"input_ids_no_response": ANY,
"labels": ANY,
},
{
"instruction": "Divide",
"input": "10/2",
"output": "5",
"input_ids": ANY,
"input_ids_no_response": ANY,
"labels": ANY,
},
{
"instruction": "Multiply",
"input": "6*4",
"output": "24",
"input_ids": ANY,
"input_ids_no_response": ANY,
"labels": ANY,
},
],
tmp_path / "train.pt",
)
assert test_calls == call(
[
{
"instruction": "Exponentiate",
"input": "2^3",
"output": "8",
"input_ids": ANY,
"input_ids_no_response": ANY,
"labels": ANY,
},
{
"instruction": "Subtract",
"input": "5-3",
"output": "2",
"input_ids": ANY,
"input_ids_no_response": ANY,
"labels": ANY,
},
{
"instruction": "Square root",
"input": "√9",
"output": "3",
"input_ids": ANY,
"input_ids_no_response": ANY,
"labels": ANY,
},
],
tmp_path / "test.pt",
)


def test_cli():
cli_path = Path(__file__).parent.parent / "scripts" / "prepare_csv.py"
output = subprocess.check_output([sys.executable, cli_path, "-h"])
output = str(output.decode())
assert "Prepare a CSV dataset" in output
64 changes: 57 additions & 7 deletions tutorials/prepare_dataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,19 +121,19 @@ Please read the [tutorials/finetune_*.md](../tutorials) documents for more infor
The models in Lit-GPT expect datasets for instruction finetuning in the following format:

```
[
[
{
"instruction": "Write a limerick about a
"instruction": "Write a limerick about a
pelican.”,
"input": "",
"output": "There once was a pelican so fine,
\nHis beak was as colorful as
sunshine,\nHe would fish all day,\nIn
a very unique way,\nThis pelican was
\nHis beak was as colorful as
sunshine,\nHe would fish all day,\nIn
a very unique way,\nThis pelican was
truly divine!\n\n\n"
},
{
"instruction": "Identify the odd one out from
"instruction": "Identify the odd one out from
the group.",
"input": "Carrot, Apple, Banana, Grape",
"output": "Carrot\n\n"
Expand All @@ -142,8 +142,58 @@ The models in Lit-GPT expect datasets for instruction finetuning in the followin
```
(Note that epending on the task, the `"input"` text can be an empty string, as shown above.)

Custom datasets can be prepared by either creating a new `scripts/prepare_dataset.py` script or reading the dataset
from a CSV file.

The easiest way to prepare a new dataset is to copy and modify one of the existing dataset preparation scripts:
 

### Preparing Custom Datasets From a CSV File

If you have a CSV file containing the following columns

- `instruction`: Column which will describe the task.
- `input`: A string holding a special input value for the instruction. This applies to some samples, and in others, this is empty (empty string).
- `output`: The expected response string.

If any of the columns is missing, then the script will fail to create the dataset.

Before starting to finetune, you need to read, tokenize, and write the data converted from the CSV in a binary format. The simplest way to prepare the dataset is by simply running:

```bash
python scripts/prepare_csv.py path/to/the/file.csv
carmocca marked this conversation as resolved.
Show resolved Hide resolved
```
You can also customize the dataset generation by using these additional parameters

- `destination_path`: The folder where the binary data will be saved. By default, it is saved inside `data/csv`

- `checkpoint_dir`: The model checkpoint dir. It will use the model's tokenizer to load and convert the string to input ids. Defaults to `"checkpoints/stabilityai/stablelm-base-alpha-3b"`

- `test_split_fraction`: The fraction of the data to split. Defaults to `0.1`

- `seed`: The seed value to reproduce the same random splits for train and test data.

- `mask_inputs`: Whether we require any masking or not.

- `ignore_index`: Mask out all the tokens after this index when preparing the dataset.

To use the the settings described above, you can add the respective command line arguments when calling `prepare_csv.py` as shown in the example below:

```bash
python scripts/prepare_csv.py test_data.csv \
carmocca marked this conversation as resolved.
Show resolved Hide resolved
--destination_path data/csv \
--checkpoint_dir checkpoints/stabilityai/stablelm-base-alpha-3b \
--test_split_fraction 0.1 \
--seed 42 \
--mask_inputs false \
--ignore_index -1
```
Replace `test_data.csv` with your CSV path and the other additional parameters accordingly. Executing the command above will create two binary files, `train.pt` and `test.pt`, inside `data/csv`. Now you can use this to finetune your model.

 

### Preparing Custom Datasets Using a Dataset Prepration Script

If you don't have a CSV file following the format described in the previous section, the easiest way to prepare a new dataset is to copy and modify one of the existing dataset preparation scripts:

- [`scripts/prepare_alpaca.py`](https://github.com/Lightning-AI/lit-gpt/blob/main/scripts/prepare_alpaca.py) (if you plan to load a dataset from a JSON file);
- [`scripts/prepare_lima.py`](https://github.com/Lightning-AI/lit-gpt/blob/main/scripts/prepare_lima.py) (if you plan to load a dataset using the `datasets` Python library).
Expand Down