Skip to content

Commit

Permalink
Add dataset adapter for loading dolly15k_multilingual dataset (#3660)
Browse files Browse the repository at this point in the history
Dataset:
[argilla/databricks-dolly-15k-curated-multilingual](https://huggingface.co/datasets/argilla/databricks-dolly-15k-curated-multilingual)
  • Loading branch information
andreaskoepf committed Aug 18, 2023
1 parent 2cc90ff commit cf166c4
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 33 deletions.
3 changes: 3 additions & 0 deletions model/model_training/custom_datasets/__init__.py
Expand Up @@ -18,6 +18,7 @@
SODA,
AlpacaGpt4,
DatabricksDolly15k,
Dolly15kMultilingual,
GPTeacher_Roleplay,
JokeExplaination,
QADataset,
Expand Down Expand Up @@ -175,6 +176,8 @@ def get_one_dataset(
train, eval = load_hellaswag()
elif dataset_name == "dolly15k":
dataset = DatabricksDolly15k(cache_dir=data_path, mode=mode, **kwargs)
elif dataset_name == "dolly15k_multilingual":
dataset = Dolly15kMultilingual(cache_dir=data_path, mode=mode, **kwargs)
elif dataset_name == "alpaca_gpt4":
dataset = AlpacaGpt4(cache_dir=data_path, mode=mode, **kwargs)
elif dataset_name == "red_pajama":
Expand Down
37 changes: 37 additions & 0 deletions model/model_training/custom_datasets/qa_datasets.py
Expand Up @@ -601,6 +601,43 @@ def __getitem__(self, index: int) -> DatasetEntry:
return dialogue


class Dolly15kMultilingual(Dataset):
def __init__(self, cache_dir: str | Path, mode: str = "sft") -> None:
super().__init__()
self.rows = []
self.citation_regex = re.compile(r"\[[a-zA-Z]\]") # removes citations in the form of e.g. [a] or [A]
if mode not in ("sft", "rl"):
raise NotImplementedError(f"Currently only the modes 'sft' and 'rl' are implemented. Received {mode}.")
self.mode = mode
splits = load_dataset("argilla/databricks-dolly-15k-curated-multilingual", cache_dir=cache_dir)
for lang in ("en", "de", "es", "fr"):
data = splits[lang]
for line in data:
if (c := self._process_instruction(line, lang=lang)) is not None:
self.rows.append(c)

def _process_instruction(self, row: dict[str, str], lang: str) -> DatasetEntry | None:
context = re_reference_remove.sub("", row["context"])
# further remove references
context = context.replace("[citation needed]", "")
context = self.citation_regex.sub("", context)
if _filter_by_words(row["instruction"]) and _filter_by_words(row["response"]):
return create_dataset_entry_qa(
mode=self.mode,
questions=[row["instruction"]],
answers=[row["response"]],
context=context,
lang=lang,
)

def __len__(self) -> int:
return len(self.rows)

def __getitem__(self, index: int) -> DatasetEntry:
dialogue = self.rows[index]
return dialogue


class AlpacaGpt4(Dataset):
def __init__(self, cache_dir: str | Path, mode: str = "sft") -> None:
super().__init__()
Expand Down
76 changes: 43 additions & 33 deletions model/pretokenizer/pretokenize.py
Expand Up @@ -19,6 +19,7 @@ class IntRole(IntEnum):
System = 0
Prompter = 1
Assistant = 2
Context = 3


class Encoder(object):
Expand Down Expand Up @@ -72,6 +73,9 @@ def format_sft_entry(entry: DatasetEntrySft) -> tuple[list[str], list[int]]:
turns.append(f"<|im_start|>system\n{entry.system_message}<|im_end|>\n")
roles.append(IntRole.System.value) # 0
for m in entry.conversation:
if m.context:
turns.append(f"<|im_start|>context\n{m.context}<|im_end|>\n")
roles.append(IntRole.Context.value) # 3
if m.role == Role.prompter:
turns.append(f"<|im_start|>user\n{m.text}<|im_end|>\n")
roles.append(IntRole.Prompter.value) # 1
Expand All @@ -90,6 +94,21 @@ def format_conversation(messages) -> str:
return format_pairs(messages)


def get_dataset_name(d: Dataset):
if isinstance(d, Subset):
inner = d
while isinstance(inner, Subset):
inner = inner.dataset
name = f"Subset of {type(inner).__name__}"
if hasattr(inner, "name"):
name += f" ({inner.name})"
else:
name = type(d).__name__
if hasattr(d, "name"):
name += f" ({d.name})"
return name


class TokenStats:
def __init__(self, name: str, total_samples: int, fraction: float = 1):
self.name = name
Expand Down Expand Up @@ -156,17 +175,7 @@ def tokenize_dataset(

for i in range(len(datasets)):
d = datasets[i]
if isinstance(d, Subset):
if hasattr(d.dataset, "name"):
name = d.dataset.name
else:
name = f"Subset of {type(d.dataset).__name__}"
else:
if hasattr(d, "name"):
name = d.name
else:
name = type(d).__name__

name = get_dataset_name(d)
frac = 1
if dataset_target_sizes:
frac = fractions[i]
Expand Down Expand Up @@ -257,20 +266,28 @@ def tokenize_dataset(
if jsonl_file:
jsonl_file.close()

print(f"\n# Stats for {full_prefix}*\n")
per_dataset_stats.append(total_stats)

for stats in per_dataset_stats:
print(f"## Stats for '{stats.name}' ({stats.total_samples} samples ({stats.fraction:.1%}))")
print("-----------------")
print(
f" Accepted: {stats.accepted_samples}/{stats.processed_samples} ({stats.accepted_samples/stats.processed_samples:.1%})"
)
print(f" Accepted tokens: {stats.accepted_tokens}")
print(f" Skipped: {stats.skipped_samples} ({stats.skipped_samples/stats.processed_samples:.1%})")
print(f" Min tokens per sample: {stats.min_tokens}")
print(f" Max tokens per sample: {stats.max_tokens}")
print(f" Avg tokens per sample: {stats.accepted_tokens/stats.accepted_samples}")
print("-----------------\n")
stats_path = Path(full_prefix + "_stats.txt")
with stats_path.open("w", encoding="UTF-8") as stats_file:
for f in (None, stats_file):
print(f"\n# Stats for {full_prefix}*\n", file=f)

for stats in per_dataset_stats:
print(f"## Stats for '{stats.name}' ({stats.total_samples} samples ({stats.fraction:.1%}))", file=f)
print("-----------------", file=f)
print(
f" Accepted: {stats.accepted_samples}/{stats.processed_samples} ({stats.accepted_samples/stats.processed_samples:.1%})",
file=f,
)
print(f" Accepted tokens: {stats.accepted_tokens}", file=f)
print(
f" Skipped: {stats.skipped_samples} ({stats.skipped_samples/stats.processed_samples:.1%})", file=f
)
print(f" Min tokens per sample: {stats.min_tokens}", file=f)
print(f" Max tokens per sample: {stats.max_tokens}", file=f)
print(f" Avg tokens per sample: {stats.accepted_tokens/stats.accepted_samples}", file=f)
print("-----------------\n", file=f)


def parse_args():
Expand Down Expand Up @@ -381,20 +398,13 @@ def main():
print("Training dataset sizes (before sampling):")
total = len(train)
for d in train.datasets:
if isinstance(d, Subset):
name = f"Subset of {type(d.dataset).__name__}"
if hasattr(d.dataset, "name"):
name += f" ({d.dataset.name})"
else:
name = type(d).__name__
if hasattr(d, "name"):
name += f" ({d.name})"
name = get_dataset_name(d)
print(f"{name}: {len(d)} ({len(d) / total:.2%})")

output_dir.mkdir(parents=True, exist_ok=True)

fn = output_dir / "special_tokens.json"
with fn.open("w") as f:
with fn.open("w", encoding="UTF-8") as f:
json.dump(encoder.special_tokens, f)

val = ConcatDataset(evals.values())
Expand Down

0 comments on commit cf166c4

Please sign in to comment.