/
wrapper.py
730 lines (615 loc) · 34.2 KB
/
wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
from __future__ import annotations
import logging
import os
from abc import ABC, abstractmethod
from datetime import timedelta
import boto3
import sagemaker
from sagemaker import Session
# noinspection PyProtectedMember
from sagemaker.estimator import _TrainingJob # need access to sagemaker internals to get last training job name
from sagemaker.multidatamodel import MultiDataModel
from sagemaker.processing import ProcessingInput, ScriptProcessor, FrameworkProcessor
from sagemaker.processing import ProcessingJob # Note: processing job is not marked as protected
from sagemaker.transformer import Transformer
# noinspection PyProtectedMember
from sagemaker.transformer import _TransformJob # need access to sagemaker internals to get last training job name
from sagemaker.sklearn import SKLearnProcessor
from sagemaker.spark import PySparkProcessor
from sagemaker_ssh_helper.sm_ssh import SageMakerSecureShellHelper
from sagemaker_ssh_helper.aws import AWS
from sagemaker_ssh_helper.detached_sagemaker import DetachedEstimator, DetachedProcessor
from sagemaker_ssh_helper.ide import SSHIDE, NotebookInstance
from sagemaker_ssh_helper.log import SSHLog
from sagemaker_ssh_helper.manager import SSMManager
from sagemaker_ssh_helper.proxy import SSMProxy
class SSHEnvironmentWrapper(ABC):
logger = logging.getLogger('sagemaker-ssh-helper')
def __init__(self,
ssm_iam_role: str,
bootstrap_on_start: bool = True,
connection_wait_time_seconds: int = 600,
sagemaker_session: sagemaker.Session = None,
local_user_id: str = None,
log_to_stdout: bool = False):
f"""
:param ssm_iam_role: the SSM role without prefix, e.g. 'service-role/SageMakerRole'
See https://docs.aws.amazon.com/systems-manager/latest/userguide/sysman-managed-instance-activation.html .
:param bootstrap_on_start: Kick-off connection procedure upon sagemaker_ssh_helper.setup_and_start_ssh() .
:param connection_wait_time_seconds: How long to wait before a SageMaker entry point.
Can be 0 (don't wait).
"""
self.log_to_stdout = log_to_stdout
self.local_user_id = local_user_id
self.sagemaker_session = sagemaker_session or sagemaker.Session()
self.ssm_manager = SSMManager(region_name=self.sagemaker_session.boto_region_name)
self.ssh_log = SSHLog(region_name=self.sagemaker_session.boto_region_name)
if ssm_iam_role != '':
if self._is_arn(ssm_iam_role):
raise ValueError(f"ssm_iam_role should be only the part after role/, not a full ARN. "
f"Got: {ssm_iam_role}")
self.ssm_iam_role = ssm_iam_role
self.bootstrap_on_start = bootstrap_on_start
self.connection_wait_time_seconds = connection_wait_time_seconds
self.augmented = False
@classmethod
def dependency_dir(cls):
return os.path.dirname(__file__)
def _augment(self):
self.augmented = True
def _augment_env(self, env):
if self.local_user_id is None:
region = self.sagemaker_session.boto_region_name
endpoint_url = "https://sts.{}.amazonaws.com".format(region)
caller_id = boto3.client("sts", region_name=region, endpoint_url=endpoint_url).get_caller_identity()
user_id = caller_id.get('UserId')
else:
user_id = self.local_user_id
user_id_masked = list(user_id)
for i in range(3, len(user_id_masked) - 4):
user_id_masked[i] = '*'
user_id_masked = ''.join(user_id_masked)
self.logger.info(f"Passing '{user_id_masked}' as a value of the SSHOwner tag of an SSM managed instance")
env.update({'START_SSH': str(self.bootstrap_on_start).lower(),
'SSH_SSM_ROLE': self.ssm_iam_role,
'SSH_OWNER_TAG': user_id,
'SSH_LOG_TO_STDOUT': str(self.log_to_stdout).lower(),
'SSH_WAIT_TIME_SECONDS': f"{self.connection_wait_time_seconds}"})
@classmethod
def ssm_role_from_iam_arn(cls, iam_arn: str):
if not iam_arn:
raise ValueError("iam_arn cannot be empty")
if not cls._is_arn(iam_arn):
raise ValueError(f"iam_arn should be a full ARN, got: '{iam_arn}'")
role_position = iam_arn.find(":role/")
if role_position == -1:
raise ValueError("':role/' not found in the iam_arn")
return iam_arn[role_position + 6:]
@abstractmethod
def get_instance_ids(self, retry: int = None, timeout_in_sec: int = 900):
f"""
:param timeout_in_sec:
:param retry: (deprecated, use timeout_in_sec) how many retries (each retry is 10 seconds), 360 is for 1 hour
"""
raise NotImplementedError("Abstract method")
def get_instance_id(self, retry: int = None, timeout_in_sec: int = 900, index: int = 0):
ids = self.get_instance_ids(retry, timeout_in_sec)
if not ids:
raise ValueError(f"No SSM instances found.")
return ids[index]
def retry_deprecated_warning(self, retry, timeout_in_sec):
if retry:
self.logger.warning("retry is deprecated, use timeout_in_sec instead")
timeout_in_sec = retry * 10
return timeout_in_sec
def start_ssm_connection_and_continue(self, ssh_listen_port: int, retry: int = None,
timeout_in_sec: int = 900,
timeout: timedelta = timedelta(minutes=15),
extra_args: str = ""):
if timeout_in_sec != timeout.total_seconds():
timeout_in_sec = timeout.total_seconds()
proxy = self.start_ssm_connection(ssh_listen_port, retry, timeout_in_sec=timeout_in_sec, extra_args=extra_args)
proxy.disconnect()
def start_ssm_connection(self, ssh_listen_port: int, retry: int = None,
timeout: timedelta = timedelta(minutes=15),
timeout_in_sec: int = 900,
extra_args: str = "") -> SSMProxy:
if timeout_in_sec != timeout.total_seconds():
timeout_in_sec = timeout.total_seconds()
self.logger.info(f"Starting SSM connection")
timeout_in_sec = self.retry_deprecated_warning(retry, timeout_in_sec)
instance_ids = self.get_instance_ids(timeout_in_sec=timeout_in_sec)
if not instance_ids:
raise ValueError(f"No SSM instances found. Has the SSM Agent been started? "
f"Check the remote logs: {self.get_cloudwatch_url()} "
f"AND the remote metadata: {self.get_metadata_url()}")
instance_id = instance_ids[0]
if "mi-" not in instance_id:
raise ValueError(f"instance_id doesn't start with 'mi-': {instance_id}")
ssm_proxy = SSMProxy(ssh_listen_port, extra_args, self.sagemaker_session.boto_region_name,
self.get_cloudwatch_url())
try:
ssm_proxy.connect_to_ssm_instance(instance_id)
except Exception as e:
self.logger.error(f"Failed to connect to SSM instance: {e}")
ssm_proxy.disconnect()
raise
if self.connection_wait_time_seconds > 0:
ssm_proxy.terminate_waiting_loop()
return ssm_proxy
@staticmethod
def _is_arn(arn):
return AWS.is_arn(arn)
@abstractmethod
def get_cloudwatch_url(self):
raise ValueError("Not implemented")
@abstractmethod
def get_metadata_url(self):
raise ValueError("Not implemented")
@abstractmethod
def print_ssh_info(self):
raise ValueError("Not implemented")
@classmethod
def attach_to_resource(cls, fqdn: str,
domain_id: str = '',
user_profile_name: str = '',
sagemaker_session: Session = None):
resource_type = SageMakerSecureShellHelper.fqdn_to_type(fqdn)
resource_name = SageMakerSecureShellHelper.fqdn_to_name(fqdn)
if resource_type == 'inference':
return SSHModelWrapper.attach(resource_name, sagemaker_session)
elif resource_type == 'training':
return SSHEstimatorWrapper.attach(resource_name, sagemaker_session)
elif resource_type == 'processing':
return SSHProcessorWrapper.attach(resource_name, sagemaker_session)
elif resource_type == 'transform':
return SSHTransformerWrapper.attach(resource_name, sagemaker_session)
elif resource_type == 'ide':
return SSHIDEWrapper.attach(domain_id, user_profile_name, resource_name, sagemaker_session)
elif resource_type == 'notebook':
return SSHNotebookInstanceWrapper.attach(resource_name, sagemaker_session)
else:
raise ValueError(f"Don't know how to handle this resource type: {resource_type}")
def region(self):
return self.sagemaker_session.boto_region_name
class SSHEstimatorWrapper(SSHEnvironmentWrapper):
def __init__(self, estimator: sagemaker.estimator.EstimatorBase, ssm_iam_role: str = '',
bootstrap_on_start: bool = True, connection_wait_time_seconds: int = 600,
ssh_instance_count: int = 2, local_user_id: str = None,
log_to_stdout: bool = False):
super().__init__(ssm_iam_role, bootstrap_on_start, connection_wait_time_seconds,
estimator.sagemaker_session, local_user_id, log_to_stdout)
if hasattr(estimator, 'instance_groups') and estimator.instance_groups is not None:
# TODO: add support for heterogeneous clusters
self.logger.warning("Heterogeneous clusters are not yet supported, SSH Helper will start only on one node")
self.ssh_instance_count = 1
elif ssh_instance_count <= estimator.instance_count:
self.ssh_instance_count = ssh_instance_count
else:
self.ssh_instance_count = estimator.instance_count
if self.ssm_iam_role == '':
self.ssm_iam_role = SSHEnvironmentWrapper.ssm_role_from_iam_arn(estimator.role)
self.estimator = estimator
def _augment(self):
super()._augment()
self.logger.info(f'Turning on SSH to training job for estimator {self.estimator.__class__}')
env = self.estimator.environment
if env is None:
env = {}
self._augment_env(env)
# TODO: promote ssh_instance_count to processing/inference wrappers
env.update({'SSH_INSTANCE_COUNT': str(self.ssh_instance_count)})
self.estimator.environment = env
def get_instance_ids(self, retry: int = None, timeout_in_sec: int = 900):
timeout_in_sec = self.retry_deprecated_warning(retry, timeout_in_sec)
self.logger.info("Resolving training instance IDs through SSM tags")
self.logger.info(f"Remote training logs are at {self.get_cloudwatch_url()}")
self.logger.info(f"Estimator metadata is at {self.get_metadata_url()}")
training_job = self._latest_training_job()
return self.ssm_manager.get_training_instance_ids(training_job.name, timeout_in_sec, self.ssh_instance_count)
def _latest_training_job(self):
training_job: _TrainingJob = self.estimator.latest_training_job
if training_job is None:
raise AssertionError("No training jobs found for estimator. Did you call estimator.fit() first?")
return training_job
def wait_training_job(self):
self.logger.info("Waiting for training job to complete")
training_job = self._latest_training_job()
training_job.wait()
self.logger.info("Training job is complete")
def wait_training_job_with_status(self) -> str:
self.wait_training_job()
description = self.sagemaker_session.describe_training_job(self.training_job_name())
result = description["TrainingJobStatus"]
self.logger.info(f"Training job status is '{result}'")
return result
def stop_training_job(self):
self.logger.info("Stopping training job")
training_job = self._latest_training_job()
training_job.stop()
training_job.wait()
self.logger.info("Training job is stopped")
@classmethod
def create(cls, estimator: sagemaker.estimator.EstimatorBase,
connection_wait_time_seconds: int = 600,
connection_wait_time: timedelta = timedelta(minutes=10),
ssh_instance_count: int = 2, local_user_id: str = None,
log_to_stdout: bool = False) -> SSHEstimatorWrapper:
if connection_wait_time_seconds != connection_wait_time.total_seconds():
connection_wait_time_seconds = connection_wait_time.total_seconds()
# noinspection PyProtectedMember
if estimator._current_job_name:
raise ValueError(
"You should call wrapper.create() before starting a training job with estimator.fit()."
)
result = SSHEstimatorWrapper(estimator, connection_wait_time_seconds=connection_wait_time_seconds,
ssh_instance_count=ssh_instance_count, local_user_id=local_user_id,
log_to_stdout=log_to_stdout)
result._augment()
return result
@classmethod
def attach(cls, training_job_name, sagemaker_session: Session = None) -> SSHEstimatorWrapper:
estimator = DetachedEstimator.attach(training_job_name, sagemaker_session or Session())
return SSHEstimatorWrapper(estimator)
def get_cloudwatch_url(self):
return self.ssh_log.get_training_cloudwatch_url(self.training_job_name())
def get_metadata_url(self):
return self.ssh_log.get_training_metadata_url(self.training_job_name())
def training_job_name(self):
return self._latest_training_job().name
def is_job_in_progress(self):
# TODO: extract API to the base class for all job-based resources?
describe_output = self._latest_training_job().describe()
return describe_output['TrainingJobStatus'] == 'InProgress'
def rule_job_summary(self):
return self._latest_training_job().rule_job_summary()
@classmethod
def attach_arn(cls, training_job_arn, sagemaker_session: Session = None) -> SSHEstimatorWrapper:
if ':training-job/' not in training_job_arn:
raise ValueError(f"Not a training job ARN: {training_job_arn}")
return cls.attach(training_job_arn.split('/')[1], sagemaker_session)
def print_ssh_info(self):
print(f"Remote training logs are at {self.get_cloudwatch_url()}")
print(f"Training job metadata is at {self.get_metadata_url()}")
print(f"To connect over SSM run:\n"
f"AWS_DEFAULT_REGION={self.region()} aws ssm start-session --target {self.get_instance_id()}")
print(f"To configure local host for SSH run:\n"
f"sm-local-configure")
print(f"To connect over SSH run:\n"
f"AWS_DEFAULT_REGION={self.region()} sm-ssh connect {self.training_job_name()}.training.sagemaker")
class SSHModelWrapper(SSHEnvironmentWrapper):
def __init__(self, model: sagemaker.model.Model,
ssm_iam_role: str = '',
bootstrap_on_start: bool = True, connection_wait_time_seconds: int = 600):
super().__init__(ssm_iam_role,
bootstrap_on_start, connection_wait_time_seconds, model.sagemaker_session)
if self.ssm_iam_role == '':
self.ssm_iam_role = SSHEnvironmentWrapper.ssm_role_from_iam_arn(model.role)
self.model = model
def _augment(self):
super()._augment()
self.logger.info(f'Turning on SSH to endpoint for model {self.model.__class__}')
env = self.model.env
if env is None:
env = {}
self._augment_env(env)
self.model.env = env
# noinspection DuplicatedCode
def get_instance_ids(self, retry: int = None, timeout_in_sec: int = 900):
timeout_in_sec = self.retry_deprecated_warning(retry, timeout_in_sec)
self.logger.info("Resolving endpoint instance IDs through CloudWatch logs")
self.logger.info(f"Remote endpoint logs are at {self.get_cloudwatch_url()}")
self.logger.info(f"Endpoint metadata is at {self.get_metadata_url()}")
self.logger.info(f"Endpoint config metadata is at {self.get_config_metadata_url()}")
self.logger.info(f"Model metadata is at {self.get_model_metadata_url()}")
return self.ssh_log.get_endpoint_ssm_instance_ids(self.model.endpoint_name, timeout_in_sec)
def wait_for_endpoint(self):
self.logger.info("Waiting for endpoint")
self.sagemaker_session.wait_for_endpoint(self.model.endpoint_name)
self.logger.info("Endpoint is ready")
@classmethod
def create(cls, model: sagemaker.model.Model, connection_wait_time_seconds: int = 600) -> SSHModelWrapper:
if model.endpoint_name:
raise AssertionError("You should call wrapper.create() before model.deploy().")
result: SSHModelWrapper = SSHModelWrapper(model, connection_wait_time_seconds=connection_wait_time_seconds)
result._augment()
return result
def get_cloudwatch_url(self):
return self.ssh_log.get_endpoint_cloudwatch_url(self.model.endpoint_name)
def get_metadata_url(self):
return self.ssh_log.get_endpoint_metadata_url(self.model.endpoint_name)
def get_config_metadata_url(self):
return self.ssh_log.get_endpoint_config_metadata_url(self.model.endpoint_name)
def get_model_metadata_url(self):
return self.ssh_log.get_model_metadata_url(self.model.name)
def endpoint_is_online(self):
describe_result = boto3.client('sagemaker').describe_endpoint(EndpointName=self.model.endpoint_name)
status = describe_result["EndpointStatus"]
return status == 'InService'
@classmethod
def attach(cls, endpoint_name: str, sagemaker_session):
model = sagemaker.Model(image_uri='', sagemaker_session=sagemaker_session)
model.endpoint_name = endpoint_name
return SSHModelWrapper(model)
def print_ssh_info(self):
print(f"Remote endpoint logs are at {self.get_cloudwatch_url()}")
print(f"Endpoint metadata is at {self.get_metadata_url()}")
print(f"Endpoint config metadata is at {self.get_config_metadata_url()}")
print(f"Model metadata is at {self.get_model_metadata_url()}")
print(f"To connect over SSM run:\n"
f"AWS_DEFAULT_REGION={self.region()} aws ssm start-session --target {self.get_instance_id()}")
print(f"To configure local host for SSH run:\n"
f"sm-local-configure")
print(f"To connect over SSH run:\n"
f"AWS_DEFAULT_REGION={self.region()} sm-ssh connect {self.model.endpoint_name}.inference.sagemaker")
class SSHMultiModelWrapper(SSHEnvironmentWrapper):
def __init__(self, mdm: sagemaker.multidatamodel.MultiDataModel,
ssm_iam_role: str = '',
bootstrap_on_start: bool = True, connection_wait_time_seconds: int = 600):
super().__init__(ssm_iam_role,
bootstrap_on_start, connection_wait_time_seconds, mdm.sagemaker_session)
self.mdm = mdm
if mdm.model:
self.model = mdm.model
if self.ssm_iam_role == '':
self.ssm_iam_role = SSHEnvironmentWrapper.ssm_role_from_iam_arn(mdm.model.role)
self.model_wrapper = SSHModelWrapper(mdm.model, self.ssm_iam_role,
bootstrap_on_start,
connection_wait_time_seconds)
else:
self.model = None
if self.ssm_iam_role == '':
self.ssm_iam_role = SSHEnvironmentWrapper.ssm_role_from_iam_arn(mdm.role)
def _augment(self):
super()._augment()
if self.model:
# noinspection PyProtectedMember
self.model_wrapper._augment()
else:
self.logger.info(f'Turning on SSH to endpoint for multi data model {self.mdm.__class__}')
env = self.mdm.env
if env is None:
env = {}
self._augment_env(env)
self.mdm.env = env
# noinspection DuplicatedCode
def get_instance_ids(self, retry: int = None, timeout_in_sec: int = 900):
timeout_in_sec = self.retry_deprecated_warning(retry, timeout_in_sec)
self.logger.info("Resolving multi-model endpoint instance IDs through SSM tags")
self.logger.info(f"Remote multi-model endpoint logs are at {self.get_cloudwatch_url()}")
self.logger.info(f"Multi-model endpoint metadata is at {self.get_metadata_url()}")
self.logger.info(f"Endpoint config metadata is at {self.get_config_metadata_url()}")
self.logger.info(f"Model metadata is at {self.get_model_metadata_url()}")
return self.ssh_log.get_endpoint_ssm_instance_ids(self.mdm.endpoint_name, timeout_in_sec)
def wait_for_endpoint(self):
self.logger.info("Waiting for endpoint")
self.sagemaker_session.wait_for_endpoint(self.mdm.endpoint_name)
self.logger.info("Endpoint is ready")
@classmethod
def create(cls, mdm: sagemaker.multidatamodel.MultiDataModel,
connection_wait_time_seconds: int = 600) -> SSHMultiModelWrapper:
if hasattr(mdm, 'endpoint_name') and mdm.endpoint_name:
raise AssertionError("You should call wrapper.create() before mdm.deploy().")
result = SSHMultiModelWrapper(mdm, connection_wait_time_seconds=connection_wait_time_seconds)
result._augment()
return result
def get_cloudwatch_url(self):
return self.ssh_log.get_endpoint_cloudwatch_url(self.mdm.endpoint_name)
def get_metadata_url(self):
return self.ssh_log.get_endpoint_metadata_url(self.mdm.endpoint_name)
def get_config_metadata_url(self):
return self.ssh_log.get_endpoint_config_metadata_url(self.mdm.endpoint_name)
def get_model_metadata_url(self):
return self.ssh_log.get_model_metadata_url(self.mdm.name)
def print_ssh_info(self):
print(f"Remote multi-model endpoint logs are at {self.get_cloudwatch_url()}")
print(f"Multi-model endpoint metadata is at {self.get_metadata_url()}")
print(f"Endpoint config metadata is at {self.get_config_metadata_url()}")
print(f"Model metadata is at {self.get_model_metadata_url()}")
print(f"To connect over SSM run:\n"
f"AWS_DEFAULT_REGION={self.region()} aws ssm start-session --target {self.get_instance_id()}")
print(f"To configure local host for SSH run:\n"
f"sm-local-configure")
print(f"To connect over SSH run:\n"
f"AWS_DEFAULT_REGION={self.region()} sm-ssh connect {self.mdm.endpoint_name}.inference.sagemaker")
class SSHProcessorWrapper(SSHEnvironmentWrapper):
def print_ssh_info(self):
print(f"Remote processing logs are at {self.get_cloudwatch_url()}")
print(f"Processing job metadata is at {self.get_metadata_url()}")
print(f"To connect over SSM run:\n"
f"AWS_DEFAULT_REGION={self.region()} aws ssm start-session --target {self.get_instance_id()}")
print(f"To configure local host for SSH run:\n"
f"sm-local-configure")
print(f"To connect over SSH run:\n"
f"AWS_DEFAULT_REGION={self.region()} sm-ssh connect {self.get_processor_latest_job_name()}.processing.sagemaker")
def __init__(self, processor: sagemaker.processing.Processor,
ssm_iam_role: str = '',
bootstrap_on_start: bool = True,
connection_wait_time_seconds: int = 600):
super().__init__(ssm_iam_role, bootstrap_on_start, connection_wait_time_seconds,
processor.sagemaker_session)
if self.ssm_iam_role == '':
self.ssm_iam_role = SSHEnvironmentWrapper.ssm_role_from_iam_arn(processor.role)
self.processor = processor
def _augment(self):
super()._augment()
self.logger.info(f'Turning on SSH to processor {self.processor.__class__}')
env = self.processor.env
if env is None:
env = {}
self._augment_env(env)
self.processor.env = env
def get_instance_ids(self, retry: int = None, timeout_in_sec: int = 900):
timeout_in_sec = self.retry_deprecated_warning(retry, timeout_in_sec)
self.logger.info("Resolving processing instance IDs through SSM tags")
self.logger.info(f"Remote processing logs are at {self.get_cloudwatch_url()}")
self.logger.info(f"Processor metadata is at {self.get_metadata_url()}")
return self.ssm_manager.get_processing_instance_ids(self.get_processor_latest_job_name(), timeout_in_sec)
def wait_processing_job(self):
self.logger.info("Waiting for processing job to complete")
job: ProcessingJob = self.processor.latest_job
job.wait()
self.logger.info("Processing job is complete")
def augmented_input(self):
f"""
Attaches the helper as the processing input. Required for processing jobs until the package is in PyPI.
Useful for processing jobs that don't support source_dir in run() method, e. g. {PySparkProcessor} and
{ScriptProcessor} / {SKLearnProcessor}
:return: a ProcessingInput to pass into processor#run(..., inputs=[...])
"""
if isinstance(self.processor, FrameworkProcessor):
self.logger.info("The processor {self.processor.__class__} is a subclass of FrameworkProcessor. "
"It's recommended to pass SageMaker SSH Helper as a dependency to the run() method "
"with dependencies=[SSHProcessorWrapper.dependency_dir()].")
return ProcessingInput(source=SSHProcessorWrapper.dependency_dir(),
destination='/opt/ml/processing/input/sagemaker_ssh_helper',
input_name='sagemaker_ssh_helper')
@classmethod
def create(cls, processor: sagemaker.processing.Processor,
connection_wait_time_seconds: int = 600) -> SSHProcessorWrapper:
if processor.latest_job:
raise AssertionError("You should call wrapper.create() before processor.run()")
result = SSHProcessorWrapper(processor, connection_wait_time_seconds=connection_wait_time_seconds)
result._augment()
return result
def get_cloudwatch_url(self):
return self.ssh_log.get_processing_cloudwatch_url(self.get_processor_latest_job_name())
def get_processor_latest_job_name(self):
return self.processor.latest_job.job_name
def get_metadata_url(self):
return self.ssh_log.get_processing_metadata_url(self.get_processor_latest_job_name())
@classmethod
def attach(cls, processing_job_name, sagemaker_session: Session = None) -> SSHProcessorWrapper:
processor = DetachedProcessor.attach(processing_job_name, sagemaker_session or Session())
return SSHProcessorWrapper(processor)
class SSHTransformerWrapper(SSHEnvironmentWrapper):
def __init__(self, transformer: sagemaker.transformer.Transformer, model_wrapper: SSHModelWrapper):
super().__init__('', True, model_wrapper.connection_wait_time_seconds, transformer.sagemaker_session)
self.transformer = transformer
self.model_wrapper = model_wrapper
def _augment(self):
super()._augment()
def get_instance_ids(self, retry: int = None, timeout_in_sec: int = 900):
timeout_in_sec = self.retry_deprecated_warning(retry, timeout_in_sec)
self.logger.info("Resolving transformer instance IDs through SSM tags")
self.logger.info(f"Remote transformer logs are at {self.get_cloudwatch_url()}")
self.logger.info(f"Transformer metadata is at {self.get_metadata_url()}")
return self.ssm_manager.get_transformer_instance_ids(self.get_transformer_latest_job_name(), timeout_in_sec)
def wait_transform_job(self):
self.logger.info("Waiting for transform job to complete")
job: _TransformJob = self.transformer.latest_transform_job
job.wait()
self.logger.info("Transform job is complete")
@classmethod
def create(cls, transformer: sagemaker.transformer.Transformer,
model_wrapper: SSHModelWrapper) -> SSHTransformerWrapper:
if not model_wrapper.augmented:
raise ValueError(f"Model Wrapper is not yet augmented. Consider constructing object with create().")
if model_wrapper.model.name != transformer.model_name:
raise ValueError(f"Transformer and model should have the same name, "
f"got: {transformer.model_name} and {transformer.model_name}")
if transformer.latest_transform_job:
raise AssertionError("You should call wrapper.create() before transformer.transform()")
result = SSHTransformerWrapper(transformer, model_wrapper)
result._augment()
return result
def get_cloudwatch_url(self):
return self.ssh_log.get_transform_cloudwatch_url(self.get_transformer_latest_job_name())
def get_transformer_latest_job_name(self):
return self.transformer.latest_transform_job.job_name
def get_metadata_url(self):
return self.ssh_log.get_transform_metadata_url(self.get_transformer_latest_job_name())
@classmethod
def attach(cls, transform_job_name, sagemaker_session: Session = None) -> SSHTransformerWrapper:
transformer = sagemaker.transformer.Transformer.attach(transform_job_name)
return SSHTransformerWrapper(
transformer,
SSHModelWrapper.attach(transform_job_name, sagemaker_session or Session())
)
def print_ssh_info(self):
print(f"Remote batch transform logs are at {self.get_cloudwatch_url()}")
print(f"Batch transform job metadata is at {self.get_metadata_url()}")
print(f"To connect over SSM run:\n"
f"AWS_DEFAULT_REGION={self.region()} aws ssm start-session --target {self.get_instance_id()}")
print(f"To configure local host for SSH run:\n"
f"sm-local-configure")
print(f"To connect over SSH run:\n"
f"AWS_DEFAULT_REGION={self.region()} sm-ssh connect {self.get_transformer_latest_job_name()}.transform.sagemaker")
class SSHIDEWrapper(SSHEnvironmentWrapper):
def __init__(self,
ssm_iam_role: str,
ide: SSHIDE,
bootstrap_on_start: bool = True,
connection_wait_time_seconds: int = 600,
sagemaker_session: sagemaker.Session = None,
local_user_id: str = None,
log_to_stdout: bool = False):
super().__init__(ssm_iam_role, bootstrap_on_start, connection_wait_time_seconds, sagemaker_session,
local_user_id, log_to_stdout)
self.app_name = None
self.ide = ide
self.not_earlier_than_timestamp = 0
@classmethod
def attach(cls, domain_id, user_profile_name, app_name, sagemaker_session: Session = None,
not_earlier_than_timestamp: int = 0) -> SSHIDEWrapper:
sagemaker_session = sagemaker_session or sagemaker.Session()
result = SSHIDEWrapper(
'',
SSHIDE(domain_id, user_profile_name, sagemaker_session.boto_region_name),
connection_wait_time_seconds=0
)
result.app_name = app_name
result.not_earlier_than_timestamp = not_earlier_than_timestamp
return result
def get_instance_ids(self, retry: int = None, timeout_in_sec: int = 900):
return self.ide.get_kernel_instance_ids(self.app_name, timeout_in_sec=timeout_in_sec,
not_earlier_than_timestamp=self.not_earlier_than_timestamp)
def get_cloudwatch_url(self):
return self.ide.get_cloudwatch_url(self.app_name)
def get_metadata_url(self):
return self.ide.get_user_metadata_url()
def print_ssh_info(self):
print(f"SageMaker Studio logs are at {self.get_cloudwatch_url()}")
print(f"SageMaker Studio metadata is at {self.get_metadata_url()}")
print(f"To connect over SSM run:\n"
f"AWS_DEFAULT_REGION={self.region()} aws ssm start-session --target {self.get_instance_id()}")
print(f"To configure local host for SSH run:\n"
f"sm-local-configure")
print(f"To connect over SSH run:\n"
f"AWS_DEFAULT_REGION={self.region()} sm-ssh connect {self.app_name}.studio.sagemaker")
class SSHNotebookInstanceWrapper(SSHEnvironmentWrapper):
def __init__(self,
ssm_iam_role: str,
notebook_instance: NotebookInstance,
bootstrap_on_start: bool = True,
connection_wait_time_seconds: int = 600,
sagemaker_session: sagemaker.Session = None,
local_user_id: str = None,
log_to_stdout: bool = False):
super().__init__(ssm_iam_role, bootstrap_on_start, connection_wait_time_seconds, sagemaker_session,
local_user_id, log_to_stdout)
self.notebook_instance = notebook_instance
def get_instance_ids(self, retry: int = None, timeout_in_sec: int = 900):
return self.notebook_instance.get_instance_ids()
def get_cloudwatch_url(self):
return self.notebook_instance.get_cloudwatch_url()
def get_metadata_url(self):
return self.notebook_instance.get_metadata_url()
@classmethod
def attach(cls, notebook_name, sagemaker_session):
sagemaker_session = sagemaker_session or sagemaker.Session()
result = SSHNotebookInstanceWrapper(
'',
NotebookInstance(notebook_name, sagemaker_session.boto_region_name),
connection_wait_time_seconds=0
)
return result
def print_ssh_info(self):
print(f"SageMaker Notebook Instance logs are at {self.get_cloudwatch_url()}")
print(f"SageMaker Notebook Instance metadata is at {self.get_metadata_url()}")
print(f"To connect over SSM run:\n"
f"AWS_DEFAULT_REGION={self.region()} aws ssm start-session --target {self.get_instance_id()}")
print(f"To configure local host for SSH run:\n"
f"sm-local-configure")
print(f"To connect over SSH run:\n"
f"AWS_DEFAULT_REGION={self.region()} sm-ssh connect {self.notebook_instance.notebook_name}.notebook.sagemaker")