Skip to content

Commit c20d89e

Browse files
committed
Add logs for subset of jobs
1 parent ff04765 commit c20d89e

File tree

8 files changed

+405
-165
lines changed

8 files changed

+405
-165
lines changed

example_notebooks/sagemaker_core_overview.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@
426426
" stopping_condition=StoppingCondition(\n",
427427
" max_runtime_in_seconds=600\n",
428428
" )\n",
429-
")\n"
429+
")"
430430
]
431431
},
432432
{
@@ -445,7 +445,7 @@
445445
"source": [
446446
"# Wait for TrainingJob with SageMakerCore\n",
447447
"\n",
448-
"training_job.wait()"
448+
"training_job.wait(logs=True)"
449449
]
450450
},
451451
{

src/sagemaker_core/main/logs.py

Lines changed: 55 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -2,49 +2,53 @@
22
import botocore
33

44
from boto3.session import Session
5+
import botocore.client
56
from botocore.config import Config
6-
from typing import Generator
7-
from sagemaker_core.generated.utils import SingletonMeta, logger
7+
from typing import Generator, Tuple, List
8+
from sagemaker_core.main.utils import SingletonMeta
89

910

1011
class CloudWatchLogsClient(metaclass=SingletonMeta):
1112
"""
1213
A singleton class for creating a SageMaker client.
1314
"""
15+
1416
def __init__(self):
1517
session = Session()
1618
self.client = session.client(
17-
"logs",
18-
session.region_name,
19-
config=Config(retries={"max_attempts": 10, "mode": "standard"})
19+
"logs",
20+
session.region_name,
21+
config=Config(retries={"max_attempts": 10, "mode": "standard"}),
2022
)
2123

2224

2325
class LogStreamHandler:
2426
log_groupt_name: str = None
2527
log_stream_name: str = None
28+
stream_id: int = None
2629
next_token: str = None
27-
cw_client: CloudWatchLogsClient = None
28-
29-
def __init__(self, log_group_name: str, log_stream_name: str):
30+
cw_client = None
31+
32+
def __init__(self, log_group_name: str, log_stream_name: str, stream_id: int):
3033
self.log_group_name = log_group_name
3134
self.log_stream_name = log_stream_name
32-
self.cw_client = CloudWatchLogsClient()
35+
self.cw_client = CloudWatchLogsClient().client
36+
self.stream_id = stream_id
3337

34-
def get_latest_log_events(self) -> Generator[tuple[str, dict], None, None]:
38+
def get_latest_log_events(self) -> Generator[Tuple[str, dict], None, None]:
3539
"""
3640
This method gets all the latest log events for this stream that exist at this moment in time.
37-
38-
cw_client.get_log_events() always returns a nextForwardToken even if the current batch of events is empty.
41+
42+
cw_client.get_log_events() always returns a nextForwardToken even if the current batch of events is empty.
3943
You can keep calling cw_client.get_log_events() with the same token until a new batch of log events exist.
40-
44+
4145
API Reference: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/logs/client/get_log_events.html
42-
46+
4347
Returns:
4448
Generator[tuple[str, dict], None, None]: Generator that yields a tuple that consists for two values
4549
str: stream_name,
4650
dict: event dict in format
47-
{
51+
{
4852
"ingestionTime": number,
4953
"message": "string",
5054
"timestamp": number
@@ -54,109 +58,107 @@ def get_latest_log_events(self) -> Generator[tuple[str, dict], None, None]:
5458
if not self.next_token:
5559
token_args = {}
5660
else:
57-
token_args = {
58-
"nextToken": self.next_token
59-
}
60-
61+
token_args = {"nextToken": self.next_token}
62+
6163
response = self.cw_client.get_log_events(
6264
logGroupName=self.log_group_name,
6365
logStreamName=self.log_stream_name,
6466
startFromHead=True,
65-
**token_args
67+
**token_args,
6668
)
6769

6870
self.next_token = response["nextForwardToken"]
6971
if not response["events"]:
7072
break
71-
73+
7274
for event in response["events"]:
73-
yield self.log_stream_name, event
75+
yield self.stream_id, event
7476

7577

7678
class MultiLogStreamHandler:
7779
log_group_name: str = None
7880
log_stream_name_prefix: str = None
7981
expected_stream_count: int = None
80-
streams: list[LogStreamHandler] = []
81-
cw_client: CloudWatchLogsClient = None
82-
83-
def __init__(self, log_group_name: str, log_stream_name_prefix: str, expected_stream_count: int):
82+
streams: List[LogStreamHandler] = []
83+
cw_client = None
84+
85+
def __init__(
86+
self, log_group_name: str, log_stream_name_prefix: str, expected_stream_count: int
87+
):
8488
self.log_group_name = log_group_name
8589
self.log_stream_name_prefix = log_stream_name_prefix
8690
self.expected_stream_count = expected_stream_count
87-
self.cw_client = CloudWatchLogsClient()
91+
self.cw_client = CloudWatchLogsClient().client
8892

89-
def get_latest_log_events(self) -> Generator[tuple[str, dict], None, None]:
93+
def get_latest_log_events(self) -> Generator[Tuple[str, dict], None, None]:
9094
"""
9195
This method gets all the latest log events from each stream that exist at this moment.
9296
9397
Returns:
9498
Generator[tuple[str, dict], None, None]: Generator that yields a tuple that consists for two values
9599
str: stream_name,
96-
dict: event dict in format -
97-
{
100+
dict: event dict in format -
101+
{
98102
"ingestionTime": number,
99103
"message": "string",
100104
"timestamp": number
101105
}
102106
"""
103107
if not self.ready():
104108
return []
105-
106-
for stream in self.streams:
107-
yield from stream.get_latest_log_events()
108-
109109

110+
for stream in self.streams:
111+
yield from stream.get_latest_log_events()
110112

111113
def ready(self) -> bool:
112114
"""
113115
Checks whether or not MultiLogStreamHandler is ready to serve new log events at this moment.
114-
116+
115117
If self.streams is already set, return True.
116-
Otherwise, check if the current number of log streams in the log group match the exptected stream count.
118+
Otherwise, check if the current number of log streams in the log group match the exptected stream count.
117119
118120
Returns:
119121
bool: Whether or not MultiLogStreamHandler is ready to serve new log events.
120122
"""
121-
123+
122124
if len(self.streams) >= self.expected_stream_count:
123125
return True
124-
126+
125127
try:
126128
response = self.cw_client.describe_log_streams(
127129
logGroupName=self.log_group_name,
128130
logStreamNamePrefix=self.log_stream_name_prefix + "/",
129131
orderBy="LogStreamName",
130132
)
131-
stream_names = [stream['logStreamName'] for stream in response['logStreams']]
132-
133-
next_token = response.get('nextToken')
133+
stream_names = [stream["logStreamName"] for stream in response["logStreams"]]
134+
135+
next_token = response.get("nextToken")
134136
while next_token:
135137
response = self.cw_client.describe_log_streams(
136-
logGroupName=self.log_group_name,
138+
logGroupName=self.log_group_name,
137139
logStreamNamePrefix=self.log_stream_name_prefix + "/",
138140
orderBy="LogStreamName",
139-
nextToken=next_token
141+
nextToken=next_token,
140142
)
141-
stream_names.extend([stream['logStreamName'] for stream in response['logStreams']])
142-
next_token = response.get('nextToken', None)
143-
143+
stream_names.extend([stream["logStreamName"] for stream in response["logStreams"]])
144+
next_token = response.get("nextToken", None)
145+
144146
if len(stream_names) >= self.expected_stream_count:
145147
self.streams = [
146-
LogStreamHandler(self.log_group_name, log_stream_name)
147-
for log_stream_name in stream_names
148+
LogStreamHandler(self.log_group_name, log_stream_name, index)
149+
for index, log_stream_name in enumerate(stream_names)
148150
]
149-
151+
150152
return True
151153
else:
152154
# Log streams are created whenever a container starts writing to stdout/err,
153155
# so if the stream count is less than the expected number, return False
154-
return False
155-
156+
return False
157+
156158
except botocore.exceptions.ClientError as e:
157159
# On the very first training job run on an account, there's no log group until
158160
# the container starts logging, so ignore any errors thrown about that
159-
if e.response['Error']['Code'] == 'ResourceNotFoundException':
161+
if e.response["Error"]["Code"] == "ResourceNotFoundException":
160162
return False
161163
else:
162-
raise
164+
raise

0 commit comments

Comments
 (0)