15
15
from typing import List , Optional
16
16
17
17
from sagemaker .jumpstart import artifacts , utils as jumpstart_utils
18
+ from sagemaker .jumpstart .constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
19
+ from sagemaker .session import Session
18
20
19
21
20
22
def retrieve_options (
@@ -23,6 +25,7 @@ def retrieve_options(
23
25
model_version : Optional [str ] = None ,
24
26
tolerate_vulnerable_model : bool = False ,
25
27
tolerate_deprecated_model : bool = False ,
28
+ sagemaker_session : Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
26
29
) -> List [str ]:
27
30
"""Retrieves the supported content types for the model matching the given arguments.
28
31
@@ -40,6 +43,10 @@ def retrieve_options(
40
43
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
41
44
(exception not raised). False if these models should raise an exception.
42
45
(Default: False).
46
+ sagemaker_session (sagemaker.session.Session): A SageMaker Session
47
+ object, used for SageMaker interactions. If not
48
+ specified, one is created using the default AWS configuration
49
+ chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
43
50
Returns:
44
51
list: The supported content types to use for the model.
45
52
@@ -57,6 +64,7 @@ def retrieve_options(
57
64
region ,
58
65
tolerate_vulnerable_model ,
59
66
tolerate_deprecated_model ,
67
+ sagemaker_session = sagemaker_session ,
60
68
)
61
69
62
70
@@ -66,6 +74,7 @@ def retrieve_default(
66
74
model_version : Optional [str ] = None ,
67
75
tolerate_vulnerable_model : bool = False ,
68
76
tolerate_deprecated_model : bool = False ,
77
+ sagemaker_session : Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
69
78
) -> str :
70
79
"""Retrieves the default content type for the model matching the given arguments.
71
80
@@ -83,6 +92,10 @@ def retrieve_default(
83
92
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
84
93
(exception not raised). False if these models should raise an exception.
85
94
(Default: False).
95
+ sagemaker_session (sagemaker.session.Session): A SageMaker Session
96
+ object, used for SageMaker interactions. If not
97
+ specified, one is created using the default AWS configuration
98
+ chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
86
99
Returns:
87
100
str: The default content type to use for the model.
88
101
@@ -100,6 +113,7 @@ def retrieve_default(
100
113
region ,
101
114
tolerate_vulnerable_model ,
102
115
tolerate_deprecated_model ,
116
+ sagemaker_session = sagemaker_session ,
103
117
)
104
118
105
119
0 commit comments