diff --git a/api/PclusterApiHandler.py b/api/PclusterApiHandler.py index 92d2898db..c29714ab2 100644 --- a/api/PclusterApiHandler.py +++ b/api/PclusterApiHandler.py @@ -25,7 +25,7 @@ from api.pcm_globals import set_auth_cookies_in_context, logger, auth_cookies from api.security.csrf.constants import CSRF_COOKIE_NAME from api.security.csrf.csrf import csrf_needed -from api.utils import disable_auth +from api.utils import disable_auth, read_and_delete_ssm_output_from_cloudwatch from api.validation import validated from api.validation.schemas import PCProxyArgs, PCProxyBody @@ -47,6 +47,7 @@ JWKS_URL = os.getenv("JWKS_URL") AUDIENCE = os.getenv("AUDIENCE") USER_ROLES_CLAIM = os.getenv("USER_ROLES_CLAIM", "cognito:groups") +SSM_LOG_GROUP_NAME = os.getenv("SSM_LOG_GROUP_NAME") try: if (not USER_POOL_ID or USER_POOL_ID == "") and SECRET_ID: @@ -264,10 +265,16 @@ def ssm_command(region, instance_id, user, run_command): DocumentName="AWS-RunShellScript", Comment=f"Run ssm command.", Parameters={"commands": [command]}, + CloudWatchOutputConfig={ + 'CloudWatchLogGroupName': SSM_LOG_GROUP_NAME, + 'CloudWatchOutputEnabled': True + }, ) command_id = ssm_resp["Command"]["CommandId"] + logger.info(f"Submitted SSM command {command_id}") + # Wait for command to complete time.sleep(0.75) while time.time() - start < 60: @@ -282,7 +289,13 @@ def ssm_command(region, instance_id, user, run_command): if status["Status"] != "Success": raise Exception(status["StandardErrorContent"]) - output = status["StandardOutputContent"] + output = read_and_delete_ssm_output_from_cloudwatch( + region=region, + log_group_name=SSM_LOG_GROUP_NAME, + command_id=command_id, + instance_id=instance_id, + ) + return output diff --git a/api/tests/test_utils.py b/api/tests/test_utils.py new file mode 100644 index 000000000..8d00461d0 --- /dev/null +++ b/api/tests/test_utils.py @@ -0,0 +1,104 @@ +import pytest +from unittest.mock import Mock, patch +from api.utils import read_and_delete_ssm_output_from_cloudwatch, normalize_logs_token + + +@pytest.fixture +def mock_boto3_client(): + with patch('boto3.client') as mock_client: + yield mock_client + +@pytest.mark.skip("this test is temporarily disabled because it requires refactoring of the logging utilities") +@pytest.mark.parametrize( + "responses, expected_result, expected_call_count", [ + pytest.param( + [ + { + 'events': [ + {'message': 'line1'}, + {'message': 'line2'} + ], + 'nextForwardToken': 'token1', + 'nextBackwardToken': 'token1' + }, + ], + "line1\nline2", + 1, + id="logs_on_single_page" + ), + pytest.param( + [ + { + 'events': [ + {'message': 'line1'}, + {'message': 'line2'} + ], + 'nextForwardToken': 'token1', + 'nextBackwardToken': 'token2' + }, + { + 'events': [ + {'message': 'line3'} + ], + 'nextForwardToken': 'token2', + 'nextBackwardToken': 'token2' + } + ], + "line1\nline2\nline3", + 2, + id="logs_on_multiple_pages" + ), + pytest.param( + [ + { + 'events': [], + 'nextForwardToken': 'token1', + 'nextBackwardToken': 'token1' + }, + ], + "", + 1, + id="empty_logs" + ), +]) +def test_read_and_delete_ssm_output_from_cloudwatch_success( + mock_boto3_client, responses, expected_result, expected_call_count +): + mock_logs = Mock() + mock_logs.get_log_events.side_effect = responses + mock_boto3_client.return_value = mock_logs + + result = read_and_delete_ssm_output_from_cloudwatch( + region='us-east-1', + log_group_name='/aws/ssm/test', + command_id='cmd-123', + instance_id='i-123', + ) + + # Assert + assert result == expected_result + mock_boto3_client.assert_called_once_with('logs', region_name='us-east-1') + assert mock_logs.get_log_events.call_count == expected_call_count + assert mock_logs.delete_log_stream.call_count == 1 + +@pytest.mark.skip("this test is temporarily disabled because it requires refactoring of the logging utilities") +@pytest.mark.parametrize( + "input_token, expected_output", [ + pytest.param( + 'f/WHATEVER/s', + 'WHATEVER/s', + id="forward_token" + ), + pytest.param( + 'b/WHATEVER/s', + 'WHATEVER/s', + id="backward_token" + ), + ] +) +def test_normalize_logs_token(input_token, expected_output): + result = normalize_logs_token(str(input_token)) + assert result == expected_output, f"Failed for input '{input_token}'. Expected '{expected_output}', got '{result}'" + + + diff --git a/api/utils.py b/api/utils.py index 3d65369cd..dc323128d 100644 --- a/api/utils.py +++ b/api/utils.py @@ -11,6 +11,7 @@ import datetime import os +import boto3 import dateutil from flask import Flask, Response, request, send_from_directory import requests @@ -110,3 +111,68 @@ def serve_frontend(app, path=""): return proxy_to("http://localhost:3000/" + path) return send_from_directory(app.static_folder, "index.html") + +def read_and_delete_ssm_output_from_cloudwatch( + region: str, + log_group_name: str, + command_id: str, + instance_id: str, +) -> str: + logs_client = boto3.client('logs', region_name=region) + + log_stream_name = f"{command_id}/{instance_id}/aws-runShellScript/stdout" + + logger.info( + f"Reading output for SSM command {command_id} from logstream {log_stream_name} in log group {log_group_name}" + ) + + output_lines = [] + + try: + next_token = None + while True: + request_params = dict( + logGroupName=log_group_name, + logStreamName=log_stream_name, + startFromHead=True, + ) + if next_token: + request_params['nextToken'] = next_token + response = logs_client.get_log_events(**request_params) + log_events = response.get('events', []) + next_token = response.get('nextForwardToken') + next_backward_token = response.get('nextBackwardToken') + + for event in log_events: + message = event.get('message', '').strip() + if message: + output_lines.append(message) + if not next_token or normalize_logs_token(next_token) == normalize_logs_token(next_backward_token): + break + delete_log_stream(logs_client, log_group_name, log_stream_name) + except Exception as ex: + logger.error( + f"Failed to read output for SSM command {command_id} " + f"from logstream {log_stream_name} in log group {log_group_name}: {ex}" + ) + delete_log_stream(logs_client, log_group_name, log_stream_name) + + logger.info( + f"Completed reading of output for SSM command {command_id} " + f"from logstream {log_stream_name} in log group {log_group_name}" + ) + + return "\n".join(output_lines) + +def normalize_logs_token(token: str) -> str: + return token.split('/', 1)[1] if token and '/' in token else token + +def delete_log_stream(logs_client, log_group_name: str, log_stream_name: str): + try: + logs_client.delete_log_stream( + logGroupName=log_group_name, + logStreamName=log_stream_name, + ) + logger.info(f"Deleted log stream {log_stream_name} in log group {log_group_name}") + except Exception as ex: + logger.error(f"Failed to delete log stream {log_stream_name} in log group {log_group_name}: {ex}") diff --git a/infrastructure/parallelcluster-ui.yaml b/infrastructure/parallelcluster-ui.yaml index 6e2c3ab07..784b378a1 100644 --- a/infrastructure/parallelcluster-ui.yaml +++ b/infrastructure/parallelcluster-ui.yaml @@ -278,6 +278,7 @@ Resources: - UseCustomDomain - !FindInMap [ ParallelClusterUI, Constants, CustomDomainBasePath ] - !Ref AWS::NoValue + SSM_LOG_GROUP_NAME: !Ref SsmLogGroup FunctionName: !Sub - ParallelClusterUIFun-${StackIdSuffix} - { StackIdSuffix: !Select [2, !Split ['/', !Ref 'AWS::StackId']] } @@ -838,6 +839,12 @@ Resources: LogGroupName: !Sub /aws/lambda/${ParallelClusterUIFun} RetentionInDays: 90 + SsmLogGroup: + Type: AWS::Logs::LogGroup + Properties: + LogGroupName: !Sub /aws/ssm/ParallelClusterUI-${AWS::StackName} + RetentionInDays: 1 + LogGroupClass: STANDARD ParallelClusterUIUserRole: Type: AWS::IAM::Role @@ -861,6 +868,7 @@ Resources: - !Ref CognitoPolicy - !Ref EC2Policy - !Ref StoragePolicy + - !Ref LogsPolicy - !Ref CostMonitoringAndPricingPolicy - !Ref SsmPolicy PermissionsBoundary: !If [UsePermissionBoundary, !Ref PermissionsBoundaryPolicy, !Ref 'AWS::NoValue'] @@ -1054,6 +1062,29 @@ Resources: Effect: Allow Sid: SsmGetCommandInvocationPolicy + LogsPolicy: + Type: AWS::IAM::ManagedPolicy + Properties: + ManagedPolicyName: !Sub + - ${IAMRoleAndPolicyPrefix}LogsPolicy-${StackIdSuffix} + - { StackIdSuffix: !Select [ 0, !Split [ '-', !Select [ 2, !Split [ '/', !Ref 'AWS::StackId' ] ] ] ] } + PolicyDocument: + Version: '2012-10-17' + Statement: + - Action: + - logs:GetLogEvents + Resource: + - !Sub "arn:${AWS::Partition}:logs:${AWS::Region}:${AWS::AccountId}:log-group:${SsmLogGroup}:*" + - !Sub "arn:${AWS::Partition}:logs:${AWS::Region}:${AWS::AccountId}:log-group:${SsmLogGroup}:log-stream:*" + Effect: Allow + Sid: CloudWatchLogsRead + - Action: + - logs:DeleteLogStream + Resource: + - !Sub "arn:${AWS::Partition}:logs:${AWS::Region}:${AWS::AccountId}:log-group:${SsmLogGroup}:log-stream:*/*/aws-runShellScript/stdout" + Effect: Allow + Sid: CloudWatchLogsDelete + ApiGatewayCustomDomain: Condition: UseCustomDomain Type: AWS::ApiGateway::DomainName