Skip to content

Commit

Permalink
feat: update generate params
Browse files Browse the repository at this point in the history
Ref: #112
  • Loading branch information
Tomas2D committed Jul 31, 2023
1 parent e05e32a commit 11a0f9f
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 13 deletions.
1 change: 1 addition & 0 deletions src/genai/schemas/descriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class Descriptions:
TOP_P = "If set to value < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation. The range is 0.00 to 1.00. Valid only with decoding_method=sample."
REPETITION_PENALTY = "The parameter for repetition penalty. 1.0 means no penalty."
TRUNCATE_INPUT_TOKENS = "Truncate to this many input tokens. Can be used to avoid requests failing due to input being longer than configured limits. Zero means don't truncate."
BEAM_WIDTH = "Multiple output sequences of tokens are generated, using your decoding selection, and then the output sequence with the highest overall probability is returned. When beam search is enabled, there will be a performance penalty, and Stop sequences will not be available." # noqa

# Params.Token
RETURN_TOKEN = "Return tokens with the response. Defaults to false."
Expand Down
17 changes: 10 additions & 7 deletions src/genai/schemas/generate_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,19 @@ class Config:

decoding_method: Optional[Literal["greedy", "sample"]] = Field(None, description=tx.DECODING_METHOD)
length_penalty: Optional[LengthPenalty] = Field(None, description=tx.LENGTH_PENALTY)
max_new_tokens: Optional[int] = Field(None, description=tx.MAX_NEW_TOKEN)
min_new_tokens: Optional[int] = Field(None, description=tx.MIN_NEW_TOKEN)
random_seed: Optional[int] = Field(None, description=tx.RANDOM_SEED, ge=1, le=9999)
stop_sequences: Optional[list[str]] = Field(None, description=tx.STOP_SQUENCES)
max_new_tokens: Optional[int] = Field(None, description=tx.MAX_NEW_TOKEN, ge=1)
min_new_tokens: Optional[int] = Field(None, description=tx.MIN_NEW_TOKEN, ge=0)
random_seed: Optional[int] = Field(None, description=tx.RANDOM_SEED, ge=1)
stop_sequences: Optional[list[str]] = Field(None, description=tx.STOP_SQUENCES, min_length=1)
stream: Optional[bool] = Field(None, description=tx.STREAM)
temperature: Optional[float] = Field(None, description=tx.TEMPERATURE, ge=0.00, le=2.00)
temperature: Optional[float] = Field(None, description=tx.TEMPERATURE, ge=0.05, le=2.00)
time_limit: Optional[int] = Field(None, description=tx.TIME_LIMIT)
top_k: Optional[int] = Field(None, description=tx.TOP_K, ge=1)
top_p: Optional[float] = Field(None, description=tx.TOP_P, ge=0.00, le=1.00)
repetition_penalty: Optional[float] = Field(None, description=tx.REPETITION_PENALTY)
truncate_input_tokens: Optional[int] = Field(None, description=tx.TRUNCATE_INPUT_TOKENS)
repetition_penalty: Optional[float] = Field(
None, description=tx.REPETITION_PENALTY, multiple_of=0.01, ge=1.00, le=2.00
)
truncate_input_tokens: Optional[int] = Field(None, description=tx.TRUNCATE_INPUT_TOKENS, ge=0)
beam_width: Optional[int] = Field(None, description=tx.BEAM_WIDTH, ge=0)
return_options: Optional[ReturnOptions] = Field(None, description=tx.RETURN)
returns: Optional[Return] = Field(None, description=tx.RETURN, alias="return", deprecated=True)
2 changes: 1 addition & 1 deletion tests/test_concurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def mock_generate_json(self, mocker):

@pytest.fixture
def generate_params(self):
return GenerateParams(temperature=0, max_new_tokens=3, return_options=ReturnOptions(input_text=True))
return GenerateParams(temperature=0.05, max_new_tokens=3, return_options=ReturnOptions(input_text=True))

@pytest.fixture
def mock_tokenize_json(self, mocker):
Expand Down
20 changes: 15 additions & 5 deletions tests/test_generate_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def setup_method(self):
top_p=0.7,
repetition_penalty=1.2,
truncate_input_tokens=2,
beam_width=1,
return_options=ReturnOptions(
input_text=True,
generated_tokens=True,
Expand Down Expand Up @@ -122,8 +123,6 @@ def test_random_seed_invalid_type(self):
GenerateParams(random_seed="dummy")
with pytest.raises(ValidationError):
GenerateParams(random_seed=0)
with pytest.raises(ValidationError):
GenerateParams(random_seed=10000)

def test_random_seed_valid_type(self, request_body):
# test that random_seed must be an integer between 1 and 9999
Expand Down Expand Up @@ -162,7 +161,7 @@ def test_stream_valid_type(self, request_body):
assert isinstance(params.stream, bool)

def test_temperature_invalid_type(self):
# test that temperature must be a float between 0 and 1
# test that temperature must be a float between 0.05 and 1
with pytest.raises(ValidationError):
GenerateParams(temperature="")
with pytest.raises(ValidationError):
Expand Down Expand Up @@ -238,7 +237,7 @@ def test_top_p_valid_type(self, request_body):

def test_repetition_penalty_invalid_type(self):
# test that repetition_penalty must be a float
# NOTE: repetition_penalty can be 0 or less then 0?
# NOTE: repetition_penalty can be 0 or less than 0?
with pytest.raises(ValidationError):
GenerateParams(repetition_penalty="")
with pytest.raises(ValidationError):
Expand All @@ -253,14 +252,25 @@ def test_repetition_penalty_valid_type(self, request_body):
assert isinstance(params.repetition_penalty, float)

def test_truncate_input_tokens_invalid_type(self):
# test that truncate_input_tokens must be a interger
# test that truncate_input_tokens must be an integer
with pytest.raises(ValidationError):
GenerateParams(truncate_input_tokens="")
with pytest.raises(ValidationError):
GenerateParams(truncate_input_tokens=[0, 1, 2])
with pytest.raises(ValidationError):
GenerateParams(truncate_input_tokens="dummy")

def test_beam_width_valid_type(self, request_body):
params = request_body["params"]
assert isinstance(params.beam_width, int)

def test_beam_width_invalid_type(self):
# test that beam_width must be an non-negative integer
with pytest.raises(ValidationError):
GenerateParams(beam_width="")
with pytest.raises(ValidationError):
GenerateParams(beam_width=-100)

def test_truncate_input_tokens_valid_type(self, request_body):
params = request_body["params"]
assert isinstance(params.truncate_input_tokens, int)
Expand Down

0 comments on commit 11a0f9f

Please sign in to comment.