Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
*.out
*.err
*.egg-info
**/*.sh
__pycache__/
wandb/
build/
2 changes: 1 addition & 1 deletion docs/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ Similar to the wandb config above, these keyword parameters are fed directly int
* `overlap`: When we chunk a data point during packing, we can choose to have some overlap between the current chunk and the next chunk. This might help the model understand surrounding context during training (although this isn't something we have empirically investigated, we keep this option available to users).
* `add_bos_eos_tokens`: Whether to add `BOS` and `EOS` tokens as defined by the respective HuggingFace tokenizer. If using packing, these will be added after packing is done, so that each chunk of size `max_seq_len` has these tokens.
* `from_disk`: Whether we are going to be loading the dataset to preprocess from disk (the other option is to download straight from HuggingFace).
* `seperator`: If using conditional finetuning (i.e. in a given data point, everything before `separator` will not be used for calculating the loss and its labels will be `ignore_index`).
* `seperator`: If using conditional finetuning (i.e. in a given data point, everything before `separator` will not be used for calculating the loss and its labels will be `ignore_index`). **Note:** if `separator` is not found in a given sequence, the default behavior is that datapoint will be skipped and not be a part of the final set.
* `load_path`: The directory containing the HuggingFace dataset we are loading to preprocess.
* `split`: If `load_path` is a dataset dictionary, `split` specifies which key in this dictionary contains the dataset we are preprocessing.
* `save_path`: The directory we will be saving the processed dataset to.
Expand Down
84 changes: 62 additions & 22 deletions preprocess_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,21 +93,28 @@ def tokenize_dataset(
all_input_ids = []
all_attention_mask = []
all_labels = []
# Adding bos/eos
if add_bos_eos:
bos, eos = tokenizer.bos_token, tokenizer.eos_token
else:
bos, eos = "", ""
for example in examples[data_field]:
# If we want to include a prepended prompt to each datapoint
if pre_pend:
prompt = f"{bos}{pre_pend}{example}{eos}"
else:
prompt = f"{bos}{example}{eos}"
# If we've specified a separator present in each sequence
if not separator:
tokenized = tokenizer.encode(prompt, add_special_tokens=False)
if truncate and len(tokenized) > tokenizer.max_model_length:
tokenized = tokenized[:tokenizer.max_model_length]
if truncate and len(tokenized) > tokenizer.model_max_length:
tokenized = tokenized[:tokenizer.model_max_length - 1]
tokenized.append(tokenizer.eos_token_id)
all_labels.append(deepcopy(tokenized))
else:
if separator not in prompt:
continue
# Perform tokenization separately to allow for conditional prompting
separation_idx = prompt.find(separator) + len(separator)
prefix, postfix = prompt[:separation_idx], prompt[separation_idx:]
tokenized_prefix = tokenizer.encode(
Expand All @@ -117,12 +124,27 @@ def tokenize_dataset(
postfix, add_special_tokens=False,
)
tokenized = tokenized_prefix + tokenized_postfix
if truncate and len(tokenized) > tokenizer.max_model_length:
tokenized = tokenized[:tokenizer.max_model_length]
tokenized = tokenized_prefix + tokenized_postfix
all_labels.append(
[-100] * len(tokenized_prefix) + deepcopy(tokenized_postfix),
)
if truncate and len(tokenized) > tokenizer.model_max_length:
tokenized = tokenized[:tokenizer.model_max_length - 1]
tokenized.append(tokenizer.eos_token_id)
# We need to address this separately, because labels need to
# backprop on bos/eos tokens
if add_bos_eos:
label = (
[tokenizer.bos_token_id]
+ ([-100] * (len(tokenized_prefix) - 1))
+ deepcopy(tokenized_postfix)
)
else:
label = (
[-100] * len(tokenized_prefix)
+ deepcopy(tokenized_postfix)
)
# If truncated, labels should be the same.
if truncate and len(label) > tokenizer.model_max_length:
label = label[:tokenizer.model_max_length - 1]
label.append(tokenizer.eos_token_id)
all_labels.append(label)
all_input_ids.append(tokenized)
all_attention_mask.append([1] * len(tokenized))

Expand Down Expand Up @@ -160,7 +182,8 @@ def pack_examples(
"""
chunk_size = tokenizer.model_max_length
if add_bos_eos:
chunk_size -= 2 # For BOS and EOS tokens.
# For BOS and EOS tokens.
chunk_size -= 2
bos, eos = [tokenizer.bos_token_id], [tokenizer.eos_token_id]
else:
bos, eos = [], []
Expand All @@ -169,25 +192,42 @@ def pack_examples(
if packing_type == "full":
joined_examples = {k: sum(examples[k], []) for k in all_keys}
total_length = len(joined_examples["input_ids"])
result = {
k: [
bos + v[i:i + chunk_size] + eos for i in range(
0, total_length, stride,
)
] for k, v in joined_examples.items()
}
result = {}
for k, v in joined_examples.items():
value_chunked_lst = []
for i in range(0, total_length, stride):
if k != "attention_mask":
value_chunked_lst.append(bos + v[i:i + chunk_size] + eos)
else:
if add_bos_eos:
# Need to do this explicitly because attention mask
# is just 1s or 0s.
value_chunked_lst.append(
[1] + v[i:i + chunk_size] + [1]
)
else:
value_chunked_lst.append(v[i:i + chunk_size])
elif packing_type == "partial":
result = {k:[] for k in examples}
_key = all_keys[0]
for idx in range(len(examples[_key])):
total_length = len(examples[_key][idx])
for key in all_keys:
sliced_example = [
(
bos + examples[key][idx][i:i + chunk_size] + eos
) for i in range(0, total_length, stride)
]
result[key].extend(sliced_example)
for i in range(0, total_length, stride):
if key != "attention_mask":
sliced_example = [
bos + examples[key][idx][i:i + chunk_size] + eos
]
else:
if add_bos_eos:
sliced_example = [
[1] + examples[key][idx][i:i + chunk_size] + [1]
]
else:
sliced_example = [
examples[key][idx][i:i + chunk_size]
]
result[key].extend(sliced_example)
else:
msg = "`packing_type` needs to either be `full` or `partial`."
raise ValueError(msg)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name="vectorlm",
version="1.0",
version="0.1.0",
packages=find_packages(),
install_requires=requirements,
python_requires=">=3.10",
Expand Down