Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions api/PclusterApiHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from api.pcm_globals import set_auth_cookies_in_context, logger, auth_cookies
from api.security.csrf.constants import CSRF_COOKIE_NAME
from api.security.csrf.csrf import csrf_needed
from api.utils import disable_auth
from api.utils import disable_auth, read_and_delete_ssm_output_from_cloudwatch
from api.validation import validated
from api.validation.schemas import PCProxyArgs, PCProxyBody

Expand All @@ -47,6 +47,7 @@
JWKS_URL = os.getenv("JWKS_URL")
AUDIENCE = os.getenv("AUDIENCE")
USER_ROLES_CLAIM = os.getenv("USER_ROLES_CLAIM", "cognito:groups")
SSM_LOG_GROUP_NAME = os.getenv("SSM_LOG_GROUP_NAME")

try:
if (not USER_POOL_ID or USER_POOL_ID == "") and SECRET_ID:
Expand Down Expand Up @@ -264,10 +265,16 @@ def ssm_command(region, instance_id, user, run_command):
DocumentName="AWS-RunShellScript",
Comment=f"Run ssm command.",
Parameters={"commands": [command]},
CloudWatchOutputConfig={
'CloudWatchLogGroupName': SSM_LOG_GROUP_NAME,
'CloudWatchOutputEnabled': True
},
)

command_id = ssm_resp["Command"]["CommandId"]

logger.info(f"Submitted SSM command {command_id}")

# Wait for command to complete
time.sleep(0.75)
while time.time() - start < 60:
Expand All @@ -282,7 +289,13 @@ def ssm_command(region, instance_id, user, run_command):
if status["Status"] != "Success":
raise Exception(status["StandardErrorContent"])

output = status["StandardOutputContent"]
output = read_and_delete_ssm_output_from_cloudwatch(
region=region,
log_group_name=SSM_LOG_GROUP_NAME,
command_id=command_id,
instance_id=instance_id,
)

return output


Expand Down
104 changes: 104 additions & 0 deletions api/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import pytest
from unittest.mock import Mock, patch
from api.utils import read_and_delete_ssm_output_from_cloudwatch, normalize_logs_token


@pytest.fixture
def mock_boto3_client():
with patch('boto3.client') as mock_client:
yield mock_client

@pytest.mark.skip("this test is temporarily disabled because it requires refactoring of the logging utilities")
@pytest.mark.parametrize(
"responses, expected_result, expected_call_count", [
pytest.param(
[
{
'events': [
{'message': 'line1'},
{'message': 'line2'}
],
'nextForwardToken': 'token1',
'nextBackwardToken': 'token1'
},
],
"line1\nline2",
1,
id="logs_on_single_page"
),
pytest.param(
[
{
'events': [
{'message': 'line1'},
{'message': 'line2'}
],
'nextForwardToken': 'token1',
'nextBackwardToken': 'token2'
},
{
'events': [
{'message': 'line3'}
],
'nextForwardToken': 'token2',
'nextBackwardToken': 'token2'
}
],
"line1\nline2\nline3",
2,
id="logs_on_multiple_pages"
),
pytest.param(
[
{
'events': [],
'nextForwardToken': 'token1',
'nextBackwardToken': 'token1'
},
],
"",
1,
id="empty_logs"
),
])
def test_read_and_delete_ssm_output_from_cloudwatch_success(
mock_boto3_client, responses, expected_result, expected_call_count
):
mock_logs = Mock()
mock_logs.get_log_events.side_effect = responses
mock_boto3_client.return_value = mock_logs

result = read_and_delete_ssm_output_from_cloudwatch(
region='us-east-1',
log_group_name='/aws/ssm/test',
command_id='cmd-123',
instance_id='i-123',
)

# Assert
assert result == expected_result
mock_boto3_client.assert_called_once_with('logs', region_name='us-east-1')
assert mock_logs.get_log_events.call_count == expected_call_count
assert mock_logs.delete_log_stream.call_count == 1

@pytest.mark.skip("this test is temporarily disabled because it requires refactoring of the logging utilities")
@pytest.mark.parametrize(
"input_token, expected_output", [
pytest.param(
'f/WHATEVER/s',
'WHATEVER/s',
id="forward_token"
),
pytest.param(
'b/WHATEVER/s',
'WHATEVER/s',
id="backward_token"
),
]
)
def test_normalize_logs_token(input_token, expected_output):
result = normalize_logs_token(str(input_token))
assert result == expected_output, f"Failed for input '{input_token}'. Expected '{expected_output}', got '{result}'"



66 changes: 66 additions & 0 deletions api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import datetime
import os

import boto3
import dateutil
from flask import Flask, Response, request, send_from_directory
import requests
Expand Down Expand Up @@ -110,3 +111,68 @@ def serve_frontend(app, path=""):
return proxy_to("http://localhost:3000/" + path)

return send_from_directory(app.static_folder, "index.html")

def read_and_delete_ssm_output_from_cloudwatch(
region: str,
log_group_name: str,
command_id: str,
instance_id: str,
) -> str:
logs_client = boto3.client('logs', region_name=region)

log_stream_name = f"{command_id}/{instance_id}/aws-runShellScript/stdout"

logger.info(
f"Reading output for SSM command {command_id} from logstream {log_stream_name} in log group {log_group_name}"
)

output_lines = []

try:
next_token = None
while True:
request_params = dict(
logGroupName=log_group_name,
logStreamName=log_stream_name,
startFromHead=True,
)
if next_token:
request_params['nextToken'] = next_token
response = logs_client.get_log_events(**request_params)
log_events = response.get('events', [])
next_token = response.get('nextForwardToken')
next_backward_token = response.get('nextBackwardToken')

for event in log_events:
message = event.get('message', '').strip()
if message:
output_lines.append(message)
if not next_token or normalize_logs_token(next_token) == normalize_logs_token(next_backward_token):
break
delete_log_stream(logs_client, log_group_name, log_stream_name)
except Exception as ex:
logger.error(
f"Failed to read output for SSM command {command_id} "
f"from logstream {log_stream_name} in log group {log_group_name}: {ex}"
)
delete_log_stream(logs_client, log_group_name, log_stream_name)

logger.info(
f"Completed reading of output for SSM command {command_id} "
f"from logstream {log_stream_name} in log group {log_group_name}"
)

return "\n".join(output_lines)

def normalize_logs_token(token: str) -> str:
return token.split('/', 1)[1] if token and '/' in token else token

def delete_log_stream(logs_client, log_group_name: str, log_stream_name: str):
try:
logs_client.delete_log_stream(
logGroupName=log_group_name,
logStreamName=log_stream_name,
)
logger.info(f"Deleted log stream {log_stream_name} in log group {log_group_name}")
except Exception as ex:
logger.error(f"Failed to delete log stream {log_stream_name} in log group {log_group_name}: {ex}")
31 changes: 31 additions & 0 deletions infrastructure/parallelcluster-ui.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ Resources:
- UseCustomDomain
- !FindInMap [ ParallelClusterUI, Constants, CustomDomainBasePath ]
- !Ref AWS::NoValue
SSM_LOG_GROUP_NAME: !Ref SsmLogGroup
FunctionName: !Sub
- ParallelClusterUIFun-${StackIdSuffix}
- { StackIdSuffix: !Select [2, !Split ['/', !Ref 'AWS::StackId']] }
Expand Down Expand Up @@ -838,6 +839,12 @@ Resources:
LogGroupName: !Sub /aws/lambda/${ParallelClusterUIFun}
RetentionInDays: 90

SsmLogGroup:
Type: AWS::Logs::LogGroup
Properties:
LogGroupName: !Sub /aws/ssm/ParallelClusterUI-${AWS::StackName}
RetentionInDays: 1
LogGroupClass: STANDARD

ParallelClusterUIUserRole:
Type: AWS::IAM::Role
Expand All @@ -861,6 +868,7 @@ Resources:
- !Ref CognitoPolicy
- !Ref EC2Policy
- !Ref StoragePolicy
- !Ref LogsPolicy
- !Ref CostMonitoringAndPricingPolicy
- !Ref SsmPolicy
PermissionsBoundary: !If [UsePermissionBoundary, !Ref PermissionsBoundaryPolicy, !Ref 'AWS::NoValue']
Expand Down Expand Up @@ -1054,6 +1062,29 @@ Resources:
Effect: Allow
Sid: SsmGetCommandInvocationPolicy

LogsPolicy:
Type: AWS::IAM::ManagedPolicy
Properties:
ManagedPolicyName: !Sub
- ${IAMRoleAndPolicyPrefix}LogsPolicy-${StackIdSuffix}
- { StackIdSuffix: !Select [ 0, !Split [ '-', !Select [ 2, !Split [ '/', !Ref 'AWS::StackId' ] ] ] ] }
PolicyDocument:
Version: '2012-10-17'
Statement:
- Action:
- logs:GetLogEvents
Resource:
- !Sub "arn:${AWS::Partition}:logs:${AWS::Region}:${AWS::AccountId}:log-group:${SsmLogGroup}:*"
- !Sub "arn:${AWS::Partition}:logs:${AWS::Region}:${AWS::AccountId}:log-group:${SsmLogGroup}:log-stream:*"
Effect: Allow
Sid: CloudWatchLogsRead
- Action:
- logs:DeleteLogStream
Resource:
- !Sub "arn:${AWS::Partition}:logs:${AWS::Region}:${AWS::AccountId}:log-group:${SsmLogGroup}:log-stream:*/*/aws-runShellScript/stdout"
Effect: Allow
Sid: CloudWatchLogsDelete

ApiGatewayCustomDomain:
Condition: UseCustomDomain
Type: AWS::ApiGateway::DomainName
Expand Down