From b4546e1f740d1bfbdb17c680afe1860c3a5eb610 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Thu, 18 Jan 2024 12:20:05 -0800 Subject: [PATCH 1/2] fix: parameters for chat create --- ai21/clients/studio/resources/studio_chat.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/ai21/clients/studio/resources/studio_chat.py b/ai21/clients/studio/resources/studio_chat.py index 710ed308..2fbfcac9 100644 --- a/ai21/clients/studio/resources/studio_chat.py +++ b/ai21/clients/studio/resources/studio_chat.py @@ -1,7 +1,8 @@ -from typing import List, Any, Optional, Dict +from typing import List, Optional from ai21.clients.common.chat_base import Chat from ai21.clients.studio.resources.studio_resource import StudioResource +from ai21.models import Penalty from ai21.models.chat_message import ChatMessage from ai21.models.responses.chat_response import ChatResponse @@ -20,9 +21,9 @@ def create( top_p: Optional[float] = 1.0, top_k_returns: Optional[int] = 0, stop_sequences: Optional[List[str]] = None, - frequency_penalty: Optional[Dict[str, Any]] = None, - presence_penalty: Optional[Dict[str, Any]] = None, - count_penalty: Optional[Dict[str, Any]] = None, + frequency_penalty: Optional[Penalty] = None, + presence_penalty: Optional[Penalty] = None, + count_penalty: Optional[Penalty] = None, **kwargs, ) -> ChatResponse: body = self._create_body( From 984e780b650289d865b16afcb693d543ace64330 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Thu, 18 Jan 2024 12:20:28 -0800 Subject: [PATCH 2/2] fix: imports --- ai21/clients/studio/resources/studio_chat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ai21/clients/studio/resources/studio_chat.py b/ai21/clients/studio/resources/studio_chat.py index 2fbfcac9..4cbbcf8e 100644 --- a/ai21/clients/studio/resources/studio_chat.py +++ b/ai21/clients/studio/resources/studio_chat.py @@ -2,8 +2,8 @@ from ai21.clients.common.chat_base import Chat from ai21.clients.studio.resources.studio_resource import StudioResource -from ai21.models import Penalty from ai21.models.chat_message import ChatMessage +from ai21.models.penalty import Penalty from ai21.models.responses.chat_response import ChatResponse