@@ -164,7 +164,9 @@ def retrieve(
164
164
config = _config_for_framework_and_scope (_framework , final_image_scope , accelerator_type )
165
165
166
166
original_version = version
167
- version = _validate_version_and_set_if_needed (version , config , framework )
167
+ version = _validate_version_and_set_if_needed (
168
+ version , config , framework , image_scope , base_framework_version
169
+ )
168
170
version_config = config ["versions" ][_version_for_config (version , config )]
169
171
170
172
if framework == HUGGING_FACE_FRAMEWORK :
@@ -199,9 +201,8 @@ def retrieve(
199
201
container_version = sdk_version + "-" + container_version
200
202
201
203
if framework == HUGGING_FACE_FRAMEWORK :
202
- pt_or_tf_version = (
203
- re .compile ("^(pytorch|tensorflow)(.*)$" ).match (base_framework_version ).group (2 )
204
- )
204
+ hf_base_fw_regex = _get_hf_base_framework_version_regex ()
205
+ pt_or_tf_version = hf_base_fw_regex .match (base_framework_version ).group (2 )
205
206
_version = original_version
206
207
207
208
if repo in [
@@ -255,6 +256,10 @@ def retrieve(
255
256
return ECR_URI_TEMPLATE .format (registry = registry , hostname = hostname , repository = repo )
256
257
257
258
259
+ def _get_hf_base_framework_version_regex ():
260
+ return re .compile ("^(pytorch|tensorflow)(.*)$" )
261
+
262
+
258
263
def _get_instance_type_family (instance_type ):
259
264
"""Return the family of the instance type.
260
265
@@ -509,7 +514,7 @@ def _get_end_of_support_warn_message(version, config, framework):
509
514
return ""
510
515
511
516
512
- def _validate_version_and_set_if_needed (version , config , framework ):
517
+ def _validate_version_and_set_if_needed (version , config , framework , scope , base_framework_version ):
513
518
"""Checks if the framework/algorithm version is one of the supported versions."""
514
519
available_versions = list (config ["versions" ].keys ())
515
520
aliased_versions = list (config .get ("version_aliases" , {}).keys ())
@@ -527,8 +532,21 @@ def _validate_version_and_set_if_needed(version, config, framework):
527
532
528
533
_validate_arg (version , available_versions + aliased_versions , "{} version" .format (framework ))
529
534
530
- # For DLCs, warn if image is out of support. Use aliased version to grab information about the latest
531
- end_of_support_warning = _get_end_of_support_warn_message (version , config , framework )
535
+ # For DLCs, warn if image is out of support.
536
+ eos_version , eos_config , eos_framework = version , config , framework
537
+
538
+ # Default to underlying base framework for HF, except for neuron which has its own policies.
539
+ inf_or_trn = "inf" in config .get ("processors" , []) or "trn" in config .get ("processors" , [])
540
+ if framework == HUGGING_FACE_FRAMEWORK and base_framework_version and not inf_or_trn :
541
+ hf_base_fw_regex = _get_hf_base_framework_version_regex ()
542
+ hf_match = hf_base_fw_regex .search (base_framework_version )
543
+ if hf_match :
544
+ eos_framework = hf_match .group (1 )
545
+ eos_version = hf_match .group (2 )
546
+ eos_config = _config_for_framework_and_scope (eos_framework , scope )
547
+ end_of_support_warning = _get_end_of_support_warn_message (
548
+ eos_version , eos_config , eos_framework
549
+ )
532
550
if end_of_support_warning :
533
551
logger .warning (end_of_support_warning )
534
552
0 commit comments