diff --git a/model/model_training/custom_datasets/entities.py b/model/model_training/custom_datasets/entities.py new file mode 100644 index 0000000000..6aaa63c6a1 --- /dev/null +++ b/model/model_training/custom_datasets/entities.py @@ -0,0 +1,7 @@ +from enum import Enum + + +class Mode(Enum): + sft = "sft" + rm = "rm" + rl = "rl" diff --git a/model/model_training/custom_datasets/formatting.py b/model/model_training/custom_datasets/formatting.py index 633c1e8306..91e25b513e 100644 --- a/model/model_training/custom_datasets/formatting.py +++ b/model/model_training/custom_datasets/formatting.py @@ -1,7 +1,8 @@ from itertools import zip_longest from random import shuffle -from model_training.custom_datasets.entities import Language, Mode +from langcodes import Language +from model_training.custom_datasets.entities import Mode from pydantic import BaseModel, validator from pydantic.fields import ModelField @@ -25,12 +26,19 @@ def format_system_prefix(prefix, eos_token): class DatasetEntry(BaseModel): questions: list[str] answers: list[str] - context: str | None - lang: Language | None - length: int | None - quality: float | None - humor: float | None - creativity: float | None + context: str | None = None + lang: str | None = None + length: int | None = None + quality: float | None = None + humor: float | None = None + creativity: float | None = None + + @validator("lang") + def valid_lang(cls, v) -> str | None: + if v is not None: + if not (lang := Language.get(v)).is_valid(): + raise ValueError(f"Language {v} is not valid. Please provide BCP 47 compatible language codes.") + return str(lang) @validator("length") def above_zero(cls, v) -> int: diff --git a/model/model_training/tests/test_formatting.py b/model/model_training/tests/test_formatting.py index 01eb3e6e55..21c9e3ab3f 100644 --- a/model/model_training/tests/test_formatting.py +++ b/model/model_training/tests/test_formatting.py @@ -1,5 +1,6 @@ import pytest -from model_training.custom_datasets.entities import Language, Mode +from langcodes import Language +from model_training.custom_datasets.entities import Mode from model_training.custom_datasets.formatting import QA_SPECIAL_TOKENS, DatasetEntry @@ -29,7 +30,7 @@ def test_dataset_entry(): questions=["What is the capital of France?"], answers=["The capital of France is Paris."], context="Some context", - lang=Language("en"), + lang="en", length=100, quality=1.0, humor=0.0, diff --git a/model/pyproject.toml b/model/pyproject.toml index bc13ff6f81..bca221403a 100644 --- a/model/pyproject.toml +++ b/model/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "sentencepiece==0.1.97", "scikit-learn==1.2.0", "tokenizers==0.13.2", + "langcodes==3.3.0", "torch==1.13.1", "tqdm==4.65.0", "pydantic==1.10.7",