-
Notifications
You must be signed in to change notification settings - Fork 451
/
__init__.py
2498 lines (2086 loc) · 96 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import json
import logging
import os
import re
import subprocess
import sys
import time
import pprint
from enum import Enum
import boto3
import requests
import boto3
from botocore.config import Config
from botocore.exceptions import ClientError
from glob import glob
from invoke import run
from invoke.context import Context
from packaging.version import InvalidVersion, Version, parse
from packaging.specifiers import SpecifierSet
from datetime import date, datetime, timedelta
from retrying import retry
from pathlib import Path
import dataclasses
import uuid
# from security import EnhancedJSONEncoder
from src import config
LOGGER = logging.getLogger(__name__)
LOGGER.setLevel(logging.INFO)
LOGGER.addHandler(logging.StreamHandler(sys.stderr))
# Constant to represent default region for boto3 commands
DEFAULT_REGION = "us-west-2"
# Constant to represent region where p3dn tests can be run
P3DN_REGION = "us-east-1"
# Constant to represent region where p4de tests can be run
P4DE_REGION = "us-east-1"
def get_ami_id_boto3(region_name, ami_name_pattern, IncludeDeprecated=False):
"""
For a given region and ami name pattern, return the latest ami-id
"""
# Use max_attempts=10 because this function is used in global context, and all test jobs
# get AMI IDs for tests regardless of whether they are used in that job.
ec2_client = boto3.client(
"ec2",
region_name=region_name,
config=Config(retries={"max_attempts": 10, "mode": "standard"}),
)
ami_list = ec2_client.describe_images(
Filters=[{"Name": "name", "Values": [ami_name_pattern]}],
Owners=["amazon"],
IncludeDeprecated=IncludeDeprecated,
)
# NOTE: Hotfix for fetching latest DLAMI before certain creation date.
# replace `ami_list["Images"]` with `filtered_images` in max() if needed.
# filtered_images = [
# element
# for element in ami_list["Images"]
# if datetime.strptime(element["CreationDate"], "%Y-%m-%dT%H:%M:%S.%fZ")
# < datetime.strptime("2024-05-02", "%Y-%m-%d")
# ]
ami = max(ami_list["Images"], key=lambda x: x["CreationDate"])
return ami["ImageId"]
def get_ami_id_ssm(region_name, parameter_path):
"""
For a given region and parameter path, return the latest ami-id
"""
# Use max_attempts=10 because this function is used in global context, and all test jobs
# get AMI IDs for tests regardless of whether they are used in that job.
ssm_client = boto3.client(
"ssm",
region_name=region_name,
config=Config(retries={"max_attempts": 10, "mode": "standard"}),
)
ami = ssm_client.get_parameter(Name=parameter_path)
ami_id = eval(ami["Parameter"]["Value"])["image_id"]
return ami_id
# DLAMI Base is split between OSS Nvidia Driver and Propietary Nvidia Driver. see https://docs.aws.amazon.com/dlami/latest/devguide/important-changes.html
UBUNTU_20_BASE_OSS_DLAMI_US_WEST_2 = get_ami_id_boto3(
region_name="us-west-2",
ami_name_pattern="Deep Learning Base OSS Nvidia Driver GPU AMI (Ubuntu 20.04) ????????",
)
UBUNTU_20_BASE_OSS_DLAMI_US_EAST_1 = get_ami_id_boto3(
region_name="us-east-1",
ami_name_pattern="Deep Learning Base OSS Nvidia Driver GPU AMI (Ubuntu 20.04) ????????",
)
UBUNTU_20_BASE_PROPRIETARY_DLAMI_US_WEST_2 = get_ami_id_boto3(
region_name="us-west-2",
ami_name_pattern="Deep Learning Base Proprietary Nvidia Driver GPU AMI (Ubuntu 20.04) ????????",
)
UBUNTU_20_BASE_PROPRIETARY_DLAMI_US_EAST_1 = get_ami_id_boto3(
region_name="us-east-1",
ami_name_pattern="Deep Learning Base Proprietary Nvidia Driver GPU AMI (Ubuntu 20.04) ????????",
)
AML2_BASE_OSS_DLAMI_US_WEST_2 = get_ami_id_boto3(
region_name="us-west-2",
ami_name_pattern="Deep Learning Base OSS Nvidia Driver AMI (Amazon Linux 2) Version ??.?",
)
AML2_BASE_OSS_DLAMI_US_EAST_1 = get_ami_id_boto3(
region_name="us-east-1",
ami_name_pattern="Deep Learning Base OSS Nvidia Driver AMI (Amazon Linux 2) Version ??.?",
)
AML2_BASE_PROPRIETARY_DLAMI_US_WEST_2 = get_ami_id_boto3(
region_name="us-west-2",
ami_name_pattern="Deep Learning Base Proprietary Nvidia Driver AMI (Amazon Linux 2) Version ??.?",
)
AML2_BASE_PROPRIETARY_DLAMI_US_EAST_1 = get_ami_id_boto3(
region_name="us-east-1",
ami_name_pattern="Deep Learning Base Proprietary Nvidia Driver AMI (Amazon Linux 2) Version ??.?",
)
# We use the following DLAMI for MXNet and TensorFlow tests as well, but this is ok since we use custom DLC Graviton containers on top. We just need an ARM base DLAMI.
UL20_CPU_ARM64_US_WEST_2 = get_ami_id_boto3(
region_name="us-west-2",
ami_name_pattern="Deep Learning ARM64 AMI OSS Nvidia Driver GPU PyTorch 2.2.? (Ubuntu 20.04) ????????",
IncludeDeprecated=True,
)
UL20_CPU_ARM64_US_EAST_1 = get_ami_id_boto3(
region_name="us-east-1",
ami_name_pattern="Deep Learning ARM64 AMI OSS Nvidia Driver GPU PyTorch 2.2.? (Ubuntu 20.04) ????????",
IncludeDeprecated=True,
)
# Using latest ARM64 AMI (pytorch) - however, this will fail for TF benchmarks, so TF benchmarks are currently
# disabled for Graviton.
UL20_BENCHMARK_CPU_ARM64_US_WEST_2 = get_ami_id_boto3(
region_name="us-west-2",
ami_name_pattern="Deep Learning ARM64 AMI OSS Nvidia Driver GPU PyTorch 2.2.? (Ubuntu 20.04) ????????",
IncludeDeprecated=True,
)
AML2_CPU_ARM64_US_EAST_1 = get_ami_id_boto3(
region_name="us-east-1", ami_name_pattern="Deep Learning Base AMI (Amazon Linux 2) Version ??.?"
)
PT_GPU_PY3_BENCHMARK_IMAGENET_AMI_US_EAST_1 = "ami-0673bb31cc62485dd"
PT_GPU_PY3_BENCHMARK_IMAGENET_AMI_US_WEST_2 = "ami-02d9a47bc61a31d43"
# Since latest driver is not in public DLAMI yet, using a custom one
NEURON_UBUNTU_18_BASE_DLAMI_US_WEST_2 = get_ami_id_boto3(
region_name="us-west-2", ami_name_pattern="Deep Learning Base AMI (Ubuntu 18.04) Version ??.?"
)
UL20_PT_NEURON_US_WEST_2 = get_ami_id_boto3(
region_name="us-west-2",
ami_name_pattern="Deep Learning AMI Neuron PyTorch 1.11.0 (Ubuntu 20.04) ????????",
)
UL20_TF_NEURON_US_WEST_2 = get_ami_id_boto3(
region_name="us-west-2",
ami_name_pattern="Deep Learning AMI Neuron TensorFlow 2.10.? (Ubuntu 20.04) ????????",
)
# Since NEURON TRN1 DLAMI is not released yet use a custom AMI
NEURON_INF1_AMI_US_WEST_2 = "ami-06a5a60d3801a57b7"
# Habana Base v0.15.4 ami
# UBUNTU_18_HPU_DLAMI_US_WEST_2 = "ami-0f051d0c1a667a106"
# UBUNTU_18_HPU_DLAMI_US_EAST_1 = "ami-04c47cb3d4fdaa874"
# Habana Base v1.2 ami
# UBUNTU_18_HPU_DLAMI_US_WEST_2 = "ami-047fd74c001116366"
# UBUNTU_18_HPU_DLAMI_US_EAST_1 = "ami-04c47cb3d4fdaa874"
# Habana Base v1.3 ami
# UBUNTU_18_HPU_DLAMI_US_WEST_2 = "ami-0ef18b1906e7010fb"
# UBUNTU_18_HPU_DLAMI_US_EAST_1 = "ami-040ef14d634e727a2"
# Habana Base v1.4.1 ami
# UBUNTU_18_HPU_DLAMI_US_WEST_2 = "ami-08e564663ef2e761c"
# UBUNTU_18_HPU_DLAMI_US_EAST_1 = "ami-06a0a1e2c90bfc1c8"
# Habana Base v1.5 ami
# UBUNTU_18_HPU_DLAMI_US_WEST_2 = "ami-06bb08c4a3c5ba3bb"
# UBUNTU_18_HPU_DLAMI_US_EAST_1 = "ami-009bbfadb94835957"
# Habana Base v1.6 ami
UBUNTU_18_HPU_DLAMI_US_WEST_2 = "ami-03cdcfc91a96a8f92"
UBUNTU_18_HPU_DLAMI_US_EAST_1 = "ami-0d83d7487f322545a"
UL_AMI_LIST = [
UBUNTU_20_BASE_OSS_DLAMI_US_WEST_2,
UBUNTU_20_BASE_OSS_DLAMI_US_EAST_1,
UBUNTU_20_BASE_PROPRIETARY_DLAMI_US_WEST_2,
UBUNTU_20_BASE_PROPRIETARY_DLAMI_US_EAST_1,
UBUNTU_18_HPU_DLAMI_US_WEST_2,
UBUNTU_18_HPU_DLAMI_US_EAST_1,
PT_GPU_PY3_BENCHMARK_IMAGENET_AMI_US_EAST_1,
PT_GPU_PY3_BENCHMARK_IMAGENET_AMI_US_WEST_2,
NEURON_UBUNTU_18_BASE_DLAMI_US_WEST_2,
UL20_PT_NEURON_US_WEST_2,
UL20_TF_NEURON_US_WEST_2,
NEURON_INF1_AMI_US_WEST_2,
UL20_CPU_ARM64_US_EAST_1,
UL20_CPU_ARM64_US_WEST_2,
UL20_BENCHMARK_CPU_ARM64_US_WEST_2,
]
# ECS images are maintained here: https://docs.aws.amazon.com/AmazonECS/latest/developerguide/ecs-optimized_AMI.html
ECS_AML2_GPU_USWEST2 = get_ami_id_ssm(
region_name="us-west-2",
parameter_path="/aws/service/ecs/optimized-ami/amazon-linux-2/gpu/recommended",
)
ECS_AML2_CPU_USWEST2 = get_ami_id_ssm(
region_name="us-west-2",
parameter_path="/aws/service/ecs/optimized-ami/amazon-linux-2/recommended",
)
ECS_AML2_NEURON_USWEST2 = get_ami_id_ssm(
region_name="us-west-2",
parameter_path="/aws/service/ecs/optimized-ami/amazon-linux-2/inf/recommended",
)
ECS_AML2_GRAVITON_CPU_USWEST2 = get_ami_id_ssm(
region_name="us-west-2",
parameter_path="/aws/service/ecs/optimized-ami/amazon-linux-2/arm64/recommended",
)
NEURON_AL2_DLAMI = get_ami_id_boto3(
region_name="us-west-2", ami_name_pattern="Deep Learning AMI (Amazon Linux 2) Version ??.?"
)
# Account ID of test executor
ACCOUNT_ID = boto3.client("sts", region_name=DEFAULT_REGION).get_caller_identity().get("Account")
# S3 bucket for TensorFlow models
TENSORFLOW_MODELS_BUCKET = "s3://tensoflow-trained-models"
# Used for referencing tests scripts from container_tests directory (i.e. from ECS cluster)
CONTAINER_TESTS_PREFIX = os.path.join(os.sep, "test", "bin")
# S3 Bucket to use to transfer tests into an EC2 instance
TEST_TRANSFER_S3_BUCKET = f"s3://dlinfra-tests-transfer-bucket-{ACCOUNT_ID}"
# S3 Bucket to use to record benchmark results for further retrieving
BENCHMARK_RESULTS_S3_BUCKET = "s3://dlinfra-dlc-cicd-performance"
# Ubuntu ami home dir
UBUNTU_HOME_DIR = "/home/ubuntu"
# Reason string for skipping tests in PR context
SKIP_PR_REASON = "Skipping test in PR context to speed up iteration time. Test will be run in nightly/release pipeline."
# Reason string for skipping tests in non-PR context
PR_ONLY_REASON = "Skipping test that doesn't need to be run outside of PR context."
KEYS_TO_DESTROY_FILE = os.path.join(os.sep, "tmp", "keys_to_destroy.txt")
# Sagemaker test types
SAGEMAKER_LOCAL_TEST_TYPE = "local"
SAGEMAKER_REMOTE_TEST_TYPE = "sagemaker"
PUBLIC_DLC_REGISTRY = "763104351884"
SAGEMAKER_EXECUTION_REGIONS = ["us-west-2", "us-east-1", "eu-west-1"]
# Before SM GA with Trn1, they support launch of ml.trn1 instance only in us-east-1. After SM GA this can be removed
SAGEMAKER_NEURON_EXECUTION_REGIONS = ["us-west-2"]
SAGEMAKER_NEURONX_EXECUTION_REGIONS = ["us-east-1"]
UPGRADE_ECR_REPO_NAME = "upgraded-image-ecr-scan-repo"
ECR_SCAN_HELPER_BUCKET = f"ecr-scan-helper-{ACCOUNT_ID}"
ECR_SCAN_FAILURE_ROUTINE_LAMBDA = "ecr-scan-failure-routine-lambda"
## Note that the region for the repo used for conducting ecr enhanced scans should be different from other
## repos since ecr enhanced scanning is activated in all the repos of a region and does not allow one to
## conduct basic scanning on some repos whereas enhanced scanning on others within the same region.
ECR_ENHANCED_SCANNING_REPO_NAME = "ecr-enhanced-scanning-dlc-repo"
ECR_ENHANCED_REPO_REGION = "us-west-1"
class NightlyFeatureLabel(Enum):
AWS_FRAMEWORK_INSTALLED = "aws_framework_installed"
AWS_SMDEBUG_INSTALLED = "aws_smdebug_installed"
AWS_SMDDP_INSTALLED = "aws_smddp_installed"
AWS_SMMP_INSTALLED = "aws_smmp_installed"
PYTORCH_INSTALLED = "pytorch_installed"
AWS_S3_PLUGIN_INSTALLED = "aws_s3_plugin_installed"
TORCHAUDIO_INSTALLED = "torchaudio_installed"
TORCHVISION_INSTALLED = "torchvision_installed"
TORCHDATA_INSTALLED = "torchdata_installed"
class MissingPythonVersionException(Exception):
"""
When the Python Version is missing from an image_uri where it is expected to exist
"""
pass
class CudaVersionTagNotFoundException(Exception):
"""
When none of the tags of a GPU image have a Cuda version in them
"""
pass
class DockerImagePullException(Exception):
"""
When a docker image could not be pulled from ECR
"""
pass
class SerialTestCaseExecutorException(Exception):
"""
Raise for execute_serial_test_cases function
"""
pass
class EnhancedJSONEncoder(json.JSONEncoder):
"""
EnhancedJSONEncoder is required to dump dataclass objects as JSON.
"""
def default(self, o):
if dataclasses.is_dataclass(o):
return dataclasses.asdict(o)
if isinstance(o, (datetime, date)):
return o.isoformat()
return super().default(o)
def execute_serial_test_cases(test_cases, test_description="test"):
"""
Helper function to execute tests in serial
Args:
test_cases (List): list of test cases, formatted as [(test_fn, (fn_arg1, fn_arg2 ..., fn_argN))]
test_description (str, optional): Describe test for custom error message. Defaults to "test".
bins (int, optional): If interested in optimizing the test across bins, use this feature
"""
exceptions = []
logging_stack = []
times = {}
for fn, args in test_cases:
log_stack = []
fn_name = fn.__name__
start_time = datetime.now()
log_stack.append(f"*********\nStarting {fn_name} at {start_time}\n")
try:
fn(*args)
end_time = datetime.now()
log_stack.append(f"\nEnding {fn_name} at {end_time}\n")
except Exception as e:
exceptions.append(f"{fn_name} FAILED WITH {type(e).__name__}:\n{e}")
end_time = datetime.now()
log_stack.append(f"\nFailing {fn_name} at {end_time}\n")
finally:
log_stack.append(
f"Total execution time for {fn_name} {end_time - start_time}\n*********"
)
times[fn_name] = end_time - start_time
logging_stack.append(log_stack)
# Save logging to the end, as there may be other conccurent jobs
for log_case in logging_stack:
for line in log_case:
LOGGER.info(line)
pretty_times = pprint.pformat(times)
LOGGER.info(pretty_times)
if exceptions:
raise SerialTestCaseExecutorException(
f"Found {len(exceptions)} errors in {test_description}\n" + "\n\n".join(exceptions)
)
def get_dockerfile_path_for_image(image_uri, python_version=None):
"""
For a given image_uri, find the path within the repository to its corresponding dockerfile
:param image_uri: str Image URI
:return: str Absolute path to dockerfile
"""
github_repo_path = os.path.abspath(os.path.curdir).split("test", 1)[0]
framework, framework_version = get_framework_and_version_from_tag(image_uri)
if "trcomp" in framework:
# Replace the trcomp string as it is extracted from ECR repo name
framework = framework.replace("_trcomp", "")
framework_path = framework.replace("_", os.path.sep)
elif "huggingface" in framework:
framework_path = framework.replace("_", os.path.sep)
elif "habana" in image_uri:
framework_path = os.path.join("habana", framework)
elif "stabilityai" in framework:
framework_path = framework.replace("_", os.path.sep)
else:
framework_path = framework
job_type = get_job_type_from_image(image_uri)
short_framework_version = re.search(r"(\d+\.\d+)", image_uri).group(1)
framework_version_path = os.path.join(
github_repo_path, framework_path, job_type, "docker", short_framework_version
)
if not os.path.isdir(framework_version_path):
long_framework_version = re.search(r"\d+(\.\d+){2}", image_uri).group()
framework_version_path = os.path.join(
github_repo_path, framework_path, job_type, "docker", long_framework_version
)
# While using the released images, they do not have python version at times
# Hence, we want to allow a parameter that can pass the Python version externally in case it is not in the tag.
if not python_version:
python_version = re.search(r"py\d+", image_uri).group()
python_version_path = os.path.join(framework_version_path, python_version)
if not os.path.isdir(python_version_path):
python_version_path = os.path.join(framework_version_path, "py3")
device_type = get_processor_from_image_uri(image_uri)
cuda_version = get_cuda_version_from_tag(image_uri)
synapseai_version = get_synapseai_version_from_tag(image_uri)
neuron_sdk_version = get_neuron_sdk_version_from_tag(image_uri)
dockerfile_name = get_expected_dockerfile_filename(device_type, image_uri)
dockerfiles_list = [
path
for path in glob(os.path.join(python_version_path, "**", dockerfile_name), recursive=True)
if "example" not in path
]
if device_type in ["gpu", "hpu", "neuron", "neuronx"]:
if len(dockerfiles_list) > 1:
if device_type == "gpu" and not cuda_version:
raise LookupError(
f"dockerfiles_list has more than one result, and needs cuda_version to be in image_uri to "
f"uniquely identify the right dockerfile:\n"
f"{dockerfiles_list}"
)
if device_type == "hpu" and not synapseai_version:
raise LookupError(
f"dockerfiles_list has more than one result, and needs synapseai_version to be in image_uri to "
f"uniquely identify the right dockerfile:\n"
f"{dockerfiles_list}"
)
if "neuron" in device_type and not neuron_sdk_version:
raise LookupError(
f"dockerfiles_list has more than one result, and needs neuron_sdk_version to be in image_uri to "
f"uniquely identify the right dockerfile:\n"
f"{dockerfiles_list}"
)
for dockerfile_path in dockerfiles_list:
if cuda_version:
if cuda_version in dockerfile_path:
return dockerfile_path
elif synapseai_version:
if synapseai_version in dockerfile_path:
return dockerfile_path
elif neuron_sdk_version:
if neuron_sdk_version in dockerfile_path:
return dockerfile_path
raise LookupError(
f"Failed to find a dockerfile path for {cuda_version} in:\n{dockerfiles_list}"
)
assert (
len(dockerfiles_list) == 1
), f"No unique dockerfile path in:\n{dockerfiles_list}\nfor image: {image_uri}"
return dockerfiles_list[0]
def get_expected_dockerfile_filename(device_type, image_uri):
if is_covered_by_ec2_sm_split(image_uri):
if "graviton" in image_uri:
return f"Dockerfile.graviton.{device_type}"
elif is_ec2_sm_in_same_dockerfile(image_uri):
if "pytorch-trcomp-training" in image_uri:
return f"Dockerfile.trcomp.{device_type}"
else:
return f"Dockerfile.{device_type}"
elif is_ec2_image(image_uri):
return f"Dockerfile.ec2.{device_type}"
else:
return f"Dockerfile.sagemaker.{device_type}"
## TODO: Keeping here for backward compatibility, should be removed in future when the
## functions is_covered_by_ec2_sm_split and is_ec2_sm_in_same_dockerfile are made exhaustive
if is_ec2_image(image_uri):
return f"Dockerfile.ec2.{device_type}"
if is_sagemaker_image(image_uri):
return f"Dockerfile.sagemaker.{device_type}"
if is_trcomp_image(image_uri):
return f"Dockerfile.trcomp.{device_type}"
return f"Dockerfile.{device_type}"
def get_canary_helper_bucket_name():
bucket_name = os.getenv("CANARY_HELPER_BUCKET")
assert bucket_name, "Unable to find bucket name in CANARY_HELPER_BUCKET env variable"
return bucket_name
def get_customer_type():
return os.getenv("CUSTOMER_TYPE")
def get_image_type():
"""
Env variable should return training or inference
"""
return os.getenv("IMAGE_TYPE")
def get_test_job_arch_type():
"""
Env variable should return graviton, x86, or None
"""
return os.getenv("ARCH_TYPE", "x86")
def get_ecr_repo_name(image_uri):
"""
Retrieve ECR repository name from image URI
:param image_uri: str ECR Image URI
:return: str ECR repository name
"""
ecr_repo_name = image_uri.split("/")[-1].split(":")[0]
return ecr_repo_name
def is_tf_version(required_version, image_uri):
"""
Validate that image_uri has framework version equal to required_version
Relying on current convention to include TF version into an image tag for all
TF based frameworks
:param required_version: str Framework version which is required from the image_uri
:param image_uri: str ECR Image URI for the image to be validated
:return: bool True if image_uri has same framework version as required_version, else False
"""
image_framework_name, image_framework_version = get_framework_and_version_from_tag(image_uri)
required_version_specifier_set = SpecifierSet(f"=={required_version}.*")
return (
is_tf_based_framework(image_framework_name)
and image_framework_version in required_version_specifier_set
)
def is_tf_based_framework(name):
"""
Checks whether framework is TF based.
Relying on current convention to include "tensorflow" into TF based names
E.g. "huggingface-tensorflow" or "huggingface-tensorflow-trcomp"
"""
return "tensorflow" in name
def is_equal_to_framework_version(version_required, image_uri, framework):
"""
Validate that image_uri has framework version exactly equal to version_required
:param version_required: str Framework version that image_uri is required to be at
:param image_uri: str ECR Image URI for the image to be validated
:param framework: str Framework installed in image
:return: bool True if image_uri has framework version equal to version_required, else False
"""
image_framework_name, image_framework_version = get_framework_and_version_from_tag(image_uri)
return image_framework_name == framework and Version(image_framework_version) in SpecifierSet(
f"=={version_required}"
)
def is_above_framework_version(version_lower_bound, image_uri, framework):
"""
Validate that image_uri has framework version strictly less than version_upper_bound
:param version_lower_bound: str Framework version that image_uri is required to be above
:param image_uri: str ECR Image URI for the image to be validated
:param framework: str Framework installed in image
:return: bool True if image_uri has framework version more than version_lower_bound, else False
"""
image_framework_name, image_framework_version = get_framework_and_version_from_tag(image_uri)
required_version_specifier_set = SpecifierSet(f">{version_lower_bound}")
return (
image_framework_name == framework
and image_framework_version in required_version_specifier_set
)
def is_below_framework_version(version_upper_bound, image_uri, framework):
"""
Validate that image_uri has framework version strictly less than version_upper_bound
:param version_upper_bound: str Framework version that image_uri is required to be below
:param image_uri: str ECR Image URI for the image to be validated
:return: bool True if image_uri has framework version less than version_upper_bound, else False
"""
image_framework_name, image_framework_version = get_framework_and_version_from_tag(image_uri)
required_version_specifier_set = SpecifierSet(f"<{version_upper_bound}")
return (
image_framework_name == framework
and image_framework_version in required_version_specifier_set
)
def is_image_incompatible_with_instance_type(image_uri, ec2_instance_type):
"""
Check for all compatibility issues between DLC Image Types and EC2 Instance Types.
Currently configured to fail on the following checks:
1. p4d.24xlarge instance type is used with a cuda<11.0 image
2. p3.8xlarge instance type is used with a cuda=11.0 image for MXNET framework
:param image_uri: ECR Image URI in valid DLC-format
:param ec2_instance_type: EC2 Instance Type
:return: bool True if there are incompatibilities, False if there aren't
"""
incompatible_conditions = []
framework, framework_version = get_framework_and_version_from_tag(image_uri)
image_is_cuda10_on_incompatible_p4d_instance = (
get_processor_from_image_uri(image_uri) == "gpu"
and get_cuda_version_from_tag(image_uri).startswith("cu10")
and ec2_instance_type in ["p4d.24xlarge"]
)
incompatible_conditions.append(image_is_cuda10_on_incompatible_p4d_instance)
image_is_cuda11_on_incompatible_p2_instance_mxnet = (
framework == "mxnet"
and get_processor_from_image_uri(image_uri) == "gpu"
and get_cuda_version_from_tag(image_uri).startswith("cu11")
and ec2_instance_type in ["p3.8xlarge"]
)
incompatible_conditions.append(image_is_cuda11_on_incompatible_p2_instance_mxnet)
image_is_pytorch_1_11_on_incompatible_p2_instance_pytorch = (
framework == "pytorch"
and Version(framework_version) in SpecifierSet("==1.11.*")
and get_processor_from_image_uri(image_uri) == "gpu"
and ec2_instance_type in ["p3.8xlarge"]
)
incompatible_conditions.append(image_is_pytorch_1_11_on_incompatible_p2_instance_pytorch)
return any(incompatible_conditions)
def get_repository_local_path():
git_repo_path = os.getcwd().split("/test/")[0]
return git_repo_path
def get_inference_server_type(image_uri):
if "pytorch" not in image_uri:
return "mms"
if "neuron" in image_uri:
return "ts"
image_tag = image_uri.split(":")[1]
# recent changes to the packaging package
# updated parse function to return Version type
# and deprecated LegacyVersion
# attempt to parse pytorch version would raise an InvalidVersion exception
# return that as "mms"
try:
pytorch_ver = parse(image_tag.split("-")[0])
if pytorch_ver < Version("1.6"):
return "mms"
except InvalidVersion as e:
return "mms"
return "ts"
def get_build_context():
return os.getenv("BUILD_CONTEXT")
def is_pr_context():
return os.getenv("BUILD_CONTEXT") == "PR"
def is_canary_context():
return os.getenv("BUILD_CONTEXT") == "CANARY"
def is_mainline_context():
return os.getenv("BUILD_CONTEXT") == "MAINLINE"
def is_deep_canary_context():
return os.getenv("BUILD_CONTEXT") == "DEEP_CANARY" or (
os.getenv("BUILD_CONTEXT") == "PR"
and os.getenv("DEEP_CANARY_MODE", "false").lower() == "true"
)
def is_nightly_context():
return (
os.getenv("BUILD_CONTEXT") == "NIGHTLY"
or os.getenv("NIGHTLY_PR_TEST_MODE", "false").lower() == "true"
)
def is_empty_build_context():
return not os.getenv("BUILD_CONTEXT")
def is_graviton_architecture():
return os.getenv("ARCH_TYPE") == "graviton"
def is_dlc_cicd_context():
return os.getenv("BUILD_CONTEXT") in ["PR", "CANARY", "NIGHTLY", "MAINLINE"]
def is_efa_dedicated():
return os.getenv("EFA_DEDICATED", "False").lower() == "true"
def are_heavy_instance_ec2_tests_enabled():
return os.getenv("HEAVY_INSTANCE_EC2_TESTS_ENABLED", "False").lower() == "true"
def is_generic_image():
return os.getenv("IS_GENERIC_IMAGE", "false").lower() == "true"
def get_allowlist_path_for_enhanced_scan_from_env_variable():
return os.getenv("ALLOWLIST_PATH_ENHSCAN")
def is_rc_test_context():
return config.is_sm_rc_test_enabled()
def is_covered_by_ec2_sm_split(image_uri):
ec2_sm_split_images = {
"pytorch": SpecifierSet(">=1.10.0"),
"tensorflow": SpecifierSet(">=2.7.0"),
"pytorch_trcomp": SpecifierSet(">=1.12.0"),
"mxnet": SpecifierSet(">=1.9.0"),
}
framework, version = get_framework_and_version_from_tag(image_uri)
return framework in ec2_sm_split_images and Version(version) in ec2_sm_split_images[framework]
def is_ec2_sm_in_same_dockerfile(image_uri):
same_sm_ec2_dockerfile_record = {
"pytorch": SpecifierSet(">=1.11.0"),
"tensorflow": SpecifierSet(">=2.8.0"),
"pytorch_trcomp": SpecifierSet(">=1.12.0"),
"mxnet": SpecifierSet(">=1.9.0"),
}
framework, version = get_framework_and_version_from_tag(image_uri)
return (
framework in same_sm_ec2_dockerfile_record
and Version(version) in same_sm_ec2_dockerfile_record[framework]
)
def is_ec2_image(image_uri):
return "-ec2" in image_uri
def is_sagemaker_image(image_uri):
return "-sagemaker" in image_uri
def is_trcomp_image(image_uri):
return "-trcomp" in image_uri
def is_time_for_canary_safety_scan():
"""
Canary tests run every 15 minutes.
Using a 20 minutes interval to make tests run only once a day around 9 am PST (10 am during winter time).
"""
current_utc_time = time.gmtime()
return current_utc_time.tm_hour == 16 and (0 < current_utc_time.tm_min < 20)
def is_time_for_invoking_ecr_scan_failure_routine_lambda():
"""
Canary tests run every 15 minutes.
Using a 20 minutes interval to make tests run only once a day around 9 am PST (10 am during winter time).
"""
current_utc_time = time.gmtime()
return current_utc_time.tm_hour == 16 and (0 < current_utc_time.tm_min < 20)
def _get_remote_override_flags():
try:
s3_client = boto3.client("s3")
sts_client = boto3.client("sts")
account_id = sts_client.get_caller_identity().get("Account")
result = s3_client.get_object(
Bucket=f"dlc-cicd-helper-{account_id}", Key="override_tests_flags.json"
)
json_content = json.loads(result["Body"].read().decode("utf-8"))
except ClientError as e:
LOGGER.warning("ClientError when performing S3/STS operation: {}".format(e))
json_content = {}
return json_content
# Now we can skip EFA tests on pipeline without making any source code change
def are_efa_tests_disabled():
disable_efa_tests = (
is_pr_context() and os.getenv("DISABLE_EFA_TESTS", "False").lower() == "true"
)
remote_override_flags = _get_remote_override_flags()
override_disable_efa_tests = (
remote_override_flags.get("disable_efa_tests", "false").lower() == "true"
)
return disable_efa_tests or override_disable_efa_tests
def is_safety_test_context():
return config.is_safety_check_test_enabled()
def is_test_disabled(test_name, build_name, version):
"""
Expected format of remote_override_flags:
{
"CB Project Name for Test Type A": {
"CodeBuild Resolved Source Version": ["test_type_A_test_function_1", "test_type_A_test_function_2"]
},
"CB Project Name for Test Type B": {
"CodeBuild Resolved Source Version": ["test_type_B_test_function_1", "test_type_B_test_function_2"]
}
}
:param test_name: str Test Function node name (includes parametrized values in string)
:param build_name: str Build Project name of current execution
:param version: str Source Version of current execution
:return: bool True if test is disabled as per remote override, False otherwise
"""
remote_override_flags = _get_remote_override_flags()
remote_override_build = remote_override_flags.get(build_name, {})
if version in remote_override_build:
return not remote_override_build[version] or any(
[test_keyword in test_name for test_keyword in remote_override_build[version]]
)
return False
def run_subprocess_cmd(cmd, failure="Command failed"):
import pytest
command = subprocess.run(cmd, stdout=subprocess.PIPE, shell=True)
if command.returncode:
pytest.fail(f"{failure}. Error log:\n{command.stdout.decode()}")
return command
def login_to_ecr_registry(context, account_id, region):
"""
Function to log into an ecr registry
:param context: either invoke context object or fabric connection object
:param account_id: Account ID with the desired ecr registry
:param region: i.e. us-west-2
"""
context.run(
f"aws ecr get-login-password --region {region} | docker login --username AWS "
f"--password-stdin {account_id}.dkr.ecr.{region}.amazonaws.com"
)
def retry_if_result_is_false(result):
"""Return True if we should retry (in this case retry if the result is False), False otherwise"""
return result is False
@retry(
stop_max_attempt_number=10,
wait_fixed=10000,
retry_on_result=retry_if_result_is_false,
)
def request_mxnet_inference(ip_address="127.0.0.1", port="80", connection=None, model="squeezenet"):
"""
Send request to container to test inference on kitten.jpg
:param ip_address:
:param port:
:connection: ec2_connection object to run the commands remotely over ssh
:return: <bool> True/False based on result of inference
"""
conn_run = connection.run if connection is not None else run
# Check if image already exists
run_out = conn_run("[ -f kitten.jpg ]", warn=True)
if run_out.return_code != 0:
conn_run("curl -O https://s3.amazonaws.com/model-server/inputs/kitten.jpg", hide=True)
run_out = conn_run(
f"curl -X POST http://{ip_address}:{port}/predictions/{model} -T kitten.jpg", warn=True
)
# The run_out.return_code is not reliable, since sometimes predict request may succeed but the returned result
# is 404. Hence the extra check.
if run_out.return_code != 0 or "probability" not in run_out.stdout:
return False
return True
@retry(stop_max_attempt_number=10, wait_fixed=10000, retry_on_result=retry_if_result_is_false)
def request_mxnet_inference_gluonnlp(ip_address="127.0.0.1", port="80", connection=None):
"""
Send request to container to test inference for predicting sentiments.
:param ip_address:
:param port:
:connection: ec2_connection object to run the commands remotely over ssh
:return: <bool> True/False based on result of inference
"""
conn_run = connection.run if connection is not None else run
run_out = conn_run(
(
f"curl -X POST http://{ip_address}:{port}/predictions/bert_sst/predict -F "
'\'data=["Positive sentiment", "Negative sentiment"]\''
),
warn=True,
)
# The run_out.return_code is not reliable, since sometimes predict request may succeed but the returned result
# is 404. Hence the extra check.
if run_out.return_code != 0 or "1" not in run_out.stdout:
return False
return True
@retry(
stop_max_attempt_number=10,
wait_fixed=10000,
retry_on_result=retry_if_result_is_false,
)
def request_pytorch_inference_densenet(
ip_address="127.0.0.1",
port="80",
connection=None,
model_name="pytorch-densenet",
server_type="ts",
):
"""
Send request to container to test inference on flower.jpg
:param ip_address: str
:param port: str
:param connection: obj
:param model_name: str
:return: <bool> True/False based on result of inference
"""
conn_run = connection.run if connection is not None else run
# Check if image already exists
run_out = conn_run("[ -f flower.jpg ]", warn=True)
if run_out.return_code != 0:
conn_run("curl -O https://s3.amazonaws.com/model-server/inputs/flower.jpg", hide=True)
run_out = conn_run(
f"curl -X POST http://{ip_address}:{port}/predictions/{model_name} -T flower.jpg",
hide=True,
warn=True,
)
# The run_out.return_code is not reliable, since sometimes predict request may succeed but the returned result
# is 404. Hence the extra check.
if run_out.return_code != 0:
LOGGER.error(
f"run_out.return_code is not reliable. Predict requests may succeed but return a 404 error instead.\n",
f"Return Code: {run_out.return_code=}\n",
f"Error: {run_out.stderr=}",
)
return False
else:
inference_output = json.loads(run_out.stdout.strip("\n"))
if not (
(
"neuron" in model_name
and isinstance(inference_output, list)
and len(inference_output) == 3
)
or (
server_type == "ts"
and isinstance(inference_output, dict)
and len(inference_output) == 5
)
or (
server_type == "mms"
and isinstance(inference_output, list)
and len(inference_output) == 5
)
):
return False
LOGGER.info(f"Inference Output = {json.dumps(inference_output, indent=4)}")
return True
@retry(stop_max_attempt_number=20, wait_fixed=15000, retry_on_result=retry_if_result_is_false)
def request_tensorflow_inference(
model_name,
ip_address="127.0.0.1",
port="8501",
inference_string="'{\"instances\": [1.0, 2.0, 5.0]}'",