Skip to content

Commit

Permalink
feat: rename dataclasses
Browse files Browse the repository at this point in the history
Ref: #145

Signed-off-by: Tomas Dvorak <toomas2d@gmail.com>
  • Loading branch information
Tomas2D committed Sep 25, 2023
1 parent e64c798 commit dafceb0
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
7 changes: 4 additions & 3 deletions examples/user/generate_with_moderation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from genai.credentials import Credentials
from genai.model import Model
from genai.schemas import GenerateParams
from genai.schemas.generate_params import ModerationHAPOptions, ModerationOptions
from genai.schemas.generate_params import HAPOptions, ModerationsOptions

# make sure you have a .env file under genai root with
# GENAI_KEY=<your-genai-key>
Expand All @@ -21,9 +21,10 @@
min_new_tokens=10,
max_new_tokens=20,
stream=True,
moderations=ModerationOptions(
moderations=ModerationsOptions(
# Threshold is set to very low level to flag everything (testing purposes)
hap=ModerationHAPOptions(input=True, output=True, threshold=0.01)
# or set to True to enable HAP with default settings
hap=HAPOptions(input=True, output=True, threshold=0.01)
),
)

Expand Down
8 changes: 4 additions & 4 deletions src/genai/schemas/generate_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,18 @@ def __init__(self, *args, **kwargs):
# Link to doc : https://workbench.res.ibm.com/docs/api-reference#generate


class ModerationHAPOptions(BaseModel):
class HAPOptions(BaseModel):
input: bool = Field(description=tx.HAP_INPUT, default=True)
output: bool = Field(description=tx.HAP_OUTPUT, default=True)
threshold: float = Field(description=tx.HAP_THRESHOLD, ge=0, le=1, multiple_of=0.01, default=0.75)


class ModerationOptions(BaseModel):
class ModerationsOptions(BaseModel):
class Config:
extra = Extra.allow
allow_population_by_field_name = True

hap: Union[bool, ModerationHAPOptions] = Field(description=tx.HAP, default=False)
hap: Union[bool, HAPOptions] = Field(description=tx.HAP, default=False)


class GenerateParams(BaseModel):
Expand Down Expand Up @@ -84,4 +84,4 @@ class Config:
beam_width: Optional[int] = Field(None, description=tx.BEAM_WIDTH, ge=0)
return_options: Optional[ReturnOptions] = Field(None, description=tx.RETURN)
returns: Optional[Return] = Field(None, description=tx.RETURN, alias="return", deprecated=True)
moderations: Optional[ModerationOptions] = Field(None, description=tx.MODERATIONS)
moderations: Optional[ModerationsOptions] = Field(None, description=tx.MODERATIONS)

0 comments on commit dafceb0

Please sign in to comment.