@@ -75,3 +75,61 @@ def _retrieve_model_package_arn(
7575 return regional_arn
7676
7777 raise NotImplementedError (f"Model Package ARN not supported for scope: '{ scope } '" )
78+
79+
80+ def _retrieve_model_package_model_artifact_s3_uri (
81+ model_id : str ,
82+ model_version : str ,
83+ region : Optional [str ],
84+ scope : Optional [str ] = None ,
85+ tolerate_vulnerable_model : bool = False ,
86+ tolerate_deprecated_model : bool = False ,
87+ ) -> Optional [str ]:
88+ """Retrieves s3 artifact uri associated with model package.
89+
90+ Args:
91+ model_id (str): JumpStart model ID of the JumpStart model for which to
92+ retrieve the model package artifact.
93+ model_version (str): Version of the JumpStart model for which to retrieve the
94+ model package artifact.
95+ region (Optional[str]): Region for which to retrieve the model package artifact.
96+ (Default: None).
97+ scope (Optional[str]): Scope for which to retrieve the model package artifact.
98+ (Default: None).
99+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
100+ specifications should be tolerated (exception not raised). If False, raises an
101+ exception if the script used by this version of the model has dependencies with known
102+ security vulnerabilities. (Default: False).
103+ tolerate_deprecated_model (bool): True if deprecated versions of model
104+ specifications should be tolerated (exception not raised). If False, raises
105+ an exception if the version of the model is deprecated. (Default: False).
106+
107+ Returns:
108+ str: the model package artifact uri to use for the model or None.
109+
110+ Raises:
111+ NotImplementedError: If an unsupported script is used.
112+ """
113+
114+ if scope == JumpStartScriptScope .TRAINING :
115+
116+ if region is None :
117+ region = JUMPSTART_DEFAULT_REGION_NAME
118+
119+ model_specs = verify_model_region_and_return_specs (
120+ model_id = model_id ,
121+ version = model_version ,
122+ scope = scope ,
123+ region = region ,
124+ tolerate_vulnerable_model = tolerate_vulnerable_model ,
125+ tolerate_deprecated_model = tolerate_deprecated_model ,
126+ )
127+
128+ if model_specs .training_model_package_artifact_uris is None :
129+ return None
130+
131+ model_s3_uri = model_specs .training_model_package_artifact_uris .get (region )
132+
133+ return model_s3_uri
134+
135+ raise NotImplementedError (f"Model Package Artifact URI not supported for scope: '{ scope } '" )
0 commit comments