diff --git a/aws_xray_sdk/__init__.py b/aws_xray_sdk/__init__.py index e69de29b..79ee3c82 100644 --- a/aws_xray_sdk/__init__.py +++ b/aws_xray_sdk/__init__.py @@ -0,0 +1,3 @@ +from .sdk_config import SDKConfig + +global_sdk_config = SDKConfig() diff --git a/aws_xray_sdk/core/lambda_launcher.py b/aws_xray_sdk/core/lambda_launcher.py index 0b0558be..79ab8338 100644 --- a/aws_xray_sdk/core/lambda_launcher.py +++ b/aws_xray_sdk/core/lambda_launcher.py @@ -2,11 +2,11 @@ import logging import threading +from aws_xray_sdk import global_sdk_config from .models.facade_segment import FacadeSegment from .models.trace_header import TraceHeader from .context import Context - log = logging.getLogger(__name__) @@ -71,7 +71,8 @@ def put_subsegment(self, subsegment): current_entity = self.get_trace_entity() if not self._is_subsegment(current_entity) and current_entity.initializing: - log.warning("Subsegment %s discarded due to Lambda worker still initializing" % subsegment.name) + if sdk_config_module.sdk_enabled(): + log.warning("Subsegment %s discarded due to Lambda worker still initializing" % subsegment.name) return current_entity.add_subsegment(subsegment) @@ -93,6 +94,9 @@ def _refresh_context(self): """ header_str = os.getenv(LAMBDA_TRACE_HEADER_KEY) trace_header = TraceHeader.from_header_str(header_str) + if not global_sdk_config.sdk_enabled(): + trace_header._sampled = False + segment = getattr(self._local, 'segment', None) if segment: @@ -124,7 +128,10 @@ def _initialize_context(self, trace_header): set by AWS Lambda and initialize storage for subsegments. """ sampled = None - if trace_header.sampled == 0: + if not global_sdk_config.sdk_enabled(): + # Force subsequent subsegments to be disabled and turned into DummySegments. + sampled = False + elif trace_header.sampled == 0: sampled = False elif trace_header.sampled == 1: sampled = True diff --git a/aws_xray_sdk/core/patcher.py b/aws_xray_sdk/core/patcher.py index 9d6caff2..622e2e03 100644 --- a/aws_xray_sdk/core/patcher.py +++ b/aws_xray_sdk/core/patcher.py @@ -7,6 +7,7 @@ import sys import wrapt +from aws_xray_sdk import global_sdk_config from .utils.compat import PY2, is_classmethod, is_instance_method log = logging.getLogger(__name__) @@ -62,6 +63,10 @@ def _is_valid_import(module): def patch(modules_to_patch, raise_errors=True, ignore_module_patterns=None): + enabled = global_sdk_config.sdk_enabled() + if not enabled: + log.debug("Skipped patching modules %s because the SDK is currently disabled." % ', '.join(modules_to_patch)) + return # Disable module patching if the SDK is disabled. modules = set() for module_to_patch in modules_to_patch: # boto3 depends on botocore and patching botocore is sufficient diff --git a/aws_xray_sdk/core/recorder.py b/aws_xray_sdk/core/recorder.py index 4e94b60c..b953cece 100644 --- a/aws_xray_sdk/core/recorder.py +++ b/aws_xray_sdk/core/recorder.py @@ -5,6 +5,7 @@ import platform import time +from aws_xray_sdk import global_sdk_config from aws_xray_sdk.version import VERSION from .models.segment import Segment, SegmentContextManager from .models.subsegment import Subsegment, SubsegmentContextManager @@ -18,7 +19,7 @@ from .daemon_config import DaemonConfig from .plugins.utils import get_plugin_modules from .lambda_launcher import check_in_lambda -from .exceptions.exceptions import SegmentNameMissingException +from .exceptions.exceptions import SegmentNameMissingException, SegmentNotFoundException from .utils.compat import string_types from .utils import stacktrace @@ -88,7 +89,6 @@ def configure(self, sampling=None, plugins=None, Configure needs to run before patching thrid party libraries to avoid creating dangling subsegment. - :param bool sampling: If sampling is enabled, every time the recorder creates a segment it decides whether to send this segment to the X-Ray daemon. This setting is not used if the recorder @@ -138,6 +138,7 @@ class to have your own implementation of the streaming process. and AWS_XRAY_TRACING_NAME respectively overrides arguments daemon_address, context_missing and service. """ + if sampling is not None: self.sampling = sampling if sampler: @@ -219,6 +220,12 @@ def begin_segment(self, name=None, traceid=None, # depending on if centralized or local sampling rule takes effect. decision = True + # To disable the recorder, we set the sampling decision to always be false. + # This way, when segments are generated, they become dummy segments and are ultimately never sent. + # The call to self._sampler.should_trace() is never called either so the poller threads are never started. + if not global_sdk_config.sdk_enabled(): + sampling = 0 + # we respect the input sampling decision # regardless of recorder configuration. if sampling == 0: @@ -273,6 +280,7 @@ def begin_subsegment(self, name, namespace='local'): :param str name: the name of the subsegment. :param str namespace: currently can only be 'local', 'remote', 'aws'. """ + segment = self.current_segment() if not segment: log.warning("No segment found, cannot begin subsegment %s." % name) @@ -396,6 +404,16 @@ def capture(self, name=None): def record_subsegment(self, wrapped, instance, args, kwargs, name, namespace, meta_processor): + # In the case when the SDK is disabled, we ensure that a parent segment exists, because this is usually + # handled by the middleware. We generate a dummy segment as the parent segment if one doesn't exist. + # This is to allow potential segment method calls to not throw exceptions in the captured method. + if not global_sdk_config.sdk_enabled(): + try: + self.current_segment() + except SegmentNotFoundException: + segment = DummySegment(name) + self.context.put_segment(segment) + subsegment = self.begin_subsegment(name, namespace) exception = None @@ -473,6 +491,14 @@ def _is_subsegment(self, entity): return (hasattr(entity, 'type') and entity.type == 'subsegment') + @property + def enabled(self): + return self._enabled + + @enabled.setter + def enabled(self, value): + self._enabled = value + @property def sampling(self): return self._sampling diff --git a/aws_xray_sdk/core/sampling/sampler.py b/aws_xray_sdk/core/sampling/sampler.py index d5d03818..8c09d59b 100644 --- a/aws_xray_sdk/core/sampling/sampler.py +++ b/aws_xray_sdk/core/sampling/sampler.py @@ -9,6 +9,7 @@ from .target_poller import TargetPoller from .connector import ServiceConnector from .reservoir import ReservoirDecision +from aws_xray_sdk import global_sdk_config log = logging.getLogger(__name__) @@ -37,6 +38,9 @@ def start(self): Start rule poller and target poller once X-Ray daemon address and context manager is in place. """ + if not global_sdk_config.sdk_enabled(): + return + with self._lock: if not self._started: self._rule_poller.start() @@ -51,6 +55,9 @@ def should_trace(self, sampling_req=None): All optional arguments are extracted from incoming requests by X-Ray middleware to perform path based sampling. """ + if not global_sdk_config.sdk_enabled(): + return False + if not self._started: self.start() # only front-end that actually uses the sampler spawns poller threads diff --git a/aws_xray_sdk/sdk_config.py b/aws_xray_sdk/sdk_config.py new file mode 100644 index 00000000..350ad5fa --- /dev/null +++ b/aws_xray_sdk/sdk_config.py @@ -0,0 +1,58 @@ +import os +import logging + +log = logging.getLogger(__name__) + + +class SDKConfig(object): + """ + Global Configuration Class that defines SDK-level configuration properties. + + Enabling/Disabling the SDK: + By default, the SDK is enabled unless if an environment variable AWS_XRAY_SDK_ENABLED + is set. If it is set, it needs to be a valid string boolean, otherwise, it will default + to true. If the environment variable is set, all calls to set_sdk_enabled() will + prioritize the value of the environment variable. + Disabling the SDK affects the recorder, patcher, and middlewares in the following ways: + For the recorder, disabling automatically generates DummySegments for subsequent segments + and DummySubsegments for subsegments created and thus not send any traces to the daemon. + For the patcher, module patching will automatically be disabled. The SDK must be disabled + before calling patcher.patch() method in order for this to function properly. + For the middleware, no modification is made on them, but since the recorder automatically + generates DummySegments for all subsequent calls, they will not generate segments/subsegments + to be sent. + + Environment variables: + "AWS_XRAY_SDK_ENABLED" - If set to 'false' disables the SDK and causes the explained above + to occur. + """ + XRAY_ENABLED_KEY = 'AWS_XRAY_SDK_ENABLED' + __SDK_ENABLED = str(os.getenv(XRAY_ENABLED_KEY, 'true')).lower() != 'false' + + @classmethod + def sdk_enabled(cls): + """ + Returns whether the SDK is enabled or not. + """ + return cls.__SDK_ENABLED + + @classmethod + def set_sdk_enabled(cls, value): + """ + Modifies the enabled flag if the "AWS_XRAY_SDK_ENABLED" environment variable is not set, + otherwise, set the enabled flag to be equal to the environment variable. If the + env variable is an invalid string boolean, it will default to true. + + :param bool value: Flag to set whether the SDK is enabled or disabled. + + Environment variables AWS_XRAY_SDK_ENABLED overrides argument value. + """ + # Environment Variables take precedence over hardcoded configurations. + if cls.XRAY_ENABLED_KEY in os.environ: + cls.__SDK_ENABLED = str(os.getenv(cls.XRAY_ENABLED_KEY, 'true')).lower() != 'false' + else: + if type(value) == bool: + cls.__SDK_ENABLED = value + else: + cls.__SDK_ENABLED = True + log.warning("Invalid parameter type passed into set_sdk_enabled(). Defaulting to True...") diff --git a/tests/ext/aiohttp/test_middleware.py b/tests/ext/aiohttp/test_middleware.py index e58848d8..c8b23335 100644 --- a/tests/ext/aiohttp/test_middleware.py +++ b/tests/ext/aiohttp/test_middleware.py @@ -4,6 +4,7 @@ Expects pytest-aiohttp """ import asyncio +from aws_xray_sdk import global_sdk_config from unittest.mock import patch from aiohttp import web @@ -109,6 +110,7 @@ def recorder(loop): xray_recorder.clear_trace_entities() yield xray_recorder + global_sdk_config.set_sdk_enabled(True) xray_recorder.clear_trace_entities() patcher.stop() @@ -283,3 +285,21 @@ async def get_delay(): # Ensure all ID's are different ids = [item.id for item in recorder.emitter.local] assert len(ids) == len(set(ids)) + + +async def test_disabled_sdk(test_client, loop, recorder): + """ + Test a normal response when the SDK is disabled. + + :param test_client: AioHttp test client fixture + :param loop: Eventloop fixture + :param recorder: X-Ray recorder fixture + """ + global_sdk_config.set_sdk_enabled(False) + client = await test_client(ServerTest.app(loop=loop)) + + resp = await client.get('/') + assert resp.status == 200 + + segment = recorder.emitter.pop() + assert not segment diff --git a/tests/ext/django/test_middleware.py b/tests/ext/django/test_middleware.py index 66e96488..a0128b7c 100644 --- a/tests/ext/django/test_middleware.py +++ b/tests/ext/django/test_middleware.py @@ -1,4 +1,5 @@ import django +from aws_xray_sdk import global_sdk_config from django.core.urlresolvers import reverse from django.test import TestCase @@ -14,6 +15,7 @@ def setUp(self): xray_recorder.configure(context=Context(), context_missing='LOG_ERROR') xray_recorder.clear_trace_entities() + global_sdk_config.set_sdk_enabled(True) def tearDown(self): xray_recorder.clear_trace_entities() @@ -102,3 +104,10 @@ def test_response_header(self): assert 'Sampled=1' in trace_header assert segment.trace_id in trace_header + + def test_disabled_sdk(self): + global_sdk_config.set_sdk_enabled(False) + url = reverse('200ok') + self.client.get(url) + segment = xray_recorder.emitter.pop() + assert not segment diff --git a/tests/ext/flask/test_flask.py b/tests/ext/flask/test_flask.py index 3b435028..07c8d42c 100644 --- a/tests/ext/flask/test_flask.py +++ b/tests/ext/flask/test_flask.py @@ -1,6 +1,7 @@ import pytest from flask import Flask, render_template_string +from aws_xray_sdk import global_sdk_config from aws_xray_sdk.ext.flask.middleware import XRayMiddleware from aws_xray_sdk.core.context import Context from aws_xray_sdk.core.models import http @@ -51,6 +52,7 @@ def cleanup(): recorder.clear_trace_entities() yield recorder.clear_trace_entities() + global_sdk_config.set_sdk_enabled(True) def test_ok(): @@ -143,3 +145,11 @@ def test_sampled_response_header(): resp_header = resp.headers[http.XRAY_HEADER] assert segment.trace_id in resp_header assert 'Sampled=1' in resp_header + + +def test_disabled_sdk(): + global_sdk_config.set_sdk_enabled(False) + path = '/ok' + app.get(path) + segment = recorder.emitter.pop() + assert not segment diff --git a/tests/test_lambda_context.py b/tests/test_lambda_context.py index 0bfec7b4..90f405e8 100644 --- a/tests/test_lambda_context.py +++ b/tests/test_lambda_context.py @@ -1,5 +1,7 @@ import os +from aws_xray_sdk import global_sdk_config +import pytest from aws_xray_sdk.core import lambda_launcher from aws_xray_sdk.core.models.subsegment import Subsegment @@ -12,6 +14,12 @@ context = lambda_launcher.LambdaContext() +@pytest.fixture(autouse=True) +def setup(): + yield + global_sdk_config.set_sdk_enabled(True) + + def test_facade_segment_generation(): segment = context.get_trace_entity() @@ -41,3 +49,14 @@ def test_put_subsegment(): context.end_subsegment() assert context.get_trace_entity().id == segment.id + + +def test_disable(): + context.clear_trace_entities() + segment = context.get_trace_entity() + assert segment.sampled + + context.clear_trace_entities() + global_sdk_config.set_sdk_enabled(False) + segment = context.get_trace_entity() + assert not segment.sampled diff --git a/tests/test_patcher.py b/tests/test_patcher.py index 944d1aad..f9651401 100644 --- a/tests/test_patcher.py +++ b/tests/test_patcher.py @@ -13,6 +13,7 @@ # Python versions < 3 have reload built-in pass +from aws_xray_sdk import global_sdk_config from aws_xray_sdk.core import patcher, xray_recorder from aws_xray_sdk.core.context import Context @@ -40,6 +41,7 @@ def construct_ctx(): yield xray_recorder.end_segment() xray_recorder.clear_trace_entities() + global_sdk_config.set_sdk_enabled(True) # Reload wrapt.importer references to modules to start off clean reload(wrapt) @@ -172,3 +174,11 @@ def test_external_submodules_ignores_module(): assert xray_recorder.current_segment().subsegments[0].name == 'mock_init' assert xray_recorder.current_segment().subsegments[1].name == 'mock_func' assert xray_recorder.current_segment().subsegments[2].name == 'mock_no_doublepatch' # It is patched with decorator + + +def test_disable_sdk_disables_patching(): + global_sdk_config.set_sdk_enabled(False) + patcher.patch(['tests.mock_module']) + imported_modules = [module for module in TEST_MODULES if module in sys.modules] + assert not imported_modules + assert len(xray_recorder.current_segment().subsegments) == 0 diff --git a/tests/test_recorder.py b/tests/test_recorder.py index ebe5a1c9..76c9415c 100644 --- a/tests/test_recorder.py +++ b/tests/test_recorder.py @@ -5,6 +5,11 @@ from aws_xray_sdk.version import VERSION from .util import get_new_stubbed_recorder +from aws_xray_sdk import global_sdk_config +from aws_xray_sdk.core.models.segment import Segment +from aws_xray_sdk.core.models.subsegment import Subsegment +from aws_xray_sdk.core.models.dummy_entities import DummySegment, DummySubsegment + xray_recorder = get_new_stubbed_recorder() @@ -17,6 +22,7 @@ def construct_ctx(): xray_recorder.clear_trace_entities() yield xray_recorder.clear_trace_entities() + global_sdk_config.set_sdk_enabled(True) def test_default_runtime_context(): @@ -168,6 +174,22 @@ def test_in_segment_exception(): assert len(subsegment.cause['exceptions']) == 1 +def test_default_enabled(): + assert global_sdk_config.sdk_enabled() + segment = xray_recorder.begin_segment('name') + subsegment = xray_recorder.begin_subsegment('name') + assert type(xray_recorder.current_segment()) is Segment + assert type(xray_recorder.current_subsegment()) is Subsegment + + +def test_disable_is_dummy(): + global_sdk_config.set_sdk_enabled(False) + segment = xray_recorder.begin_segment('name') + subsegment = xray_recorder.begin_subsegment('name') + assert type(xray_recorder.current_segment()) is DummySegment + assert type(xray_recorder.current_subsegment()) is DummySubsegment + + def test_max_stack_trace_zero(): xray_recorder.configure(max_trace_back=1) with pytest.raises(Exception): diff --git a/tests/test_sdk_config.py b/tests/test_sdk_config.py new file mode 100644 index 00000000..81569456 --- /dev/null +++ b/tests/test_sdk_config.py @@ -0,0 +1,80 @@ +from aws_xray_sdk import global_sdk_config +import os +import pytest + + +XRAY_ENABLED_KEY = "AWS_XRAY_SDK_ENABLED" + + +@pytest.fixture(autouse=True) +def cleanup(): + """ + Clean up Environmental Variable for enable before and after tests + """ + if XRAY_ENABLED_KEY in os.environ: + del os.environ[XRAY_ENABLED_KEY] + yield + if XRAY_ENABLED_KEY in os.environ: + del os.environ[XRAY_ENABLED_KEY] + global_sdk_config.set_sdk_enabled(True) + + +def test_enable_key(): + assert global_sdk_config.XRAY_ENABLED_KEY == XRAY_ENABLED_KEY + + +def test_default_enabled(): + assert global_sdk_config.sdk_enabled() is True + + +def test_env_var_precedence(): + os.environ[XRAY_ENABLED_KEY] = "true" + # Env Variable takes precedence. This is called to activate the internal check + global_sdk_config.set_sdk_enabled(False) + assert global_sdk_config.sdk_enabled() is True + os.environ[XRAY_ENABLED_KEY] = "false" + global_sdk_config.set_sdk_enabled(False) + assert global_sdk_config.sdk_enabled() is False + os.environ[XRAY_ENABLED_KEY] = "false" + global_sdk_config.set_sdk_enabled(True) + assert global_sdk_config.sdk_enabled() is False + os.environ[XRAY_ENABLED_KEY] = "true" + global_sdk_config.set_sdk_enabled(True) + assert global_sdk_config.sdk_enabled() is True + os.environ[XRAY_ENABLED_KEY] = "true" + global_sdk_config.set_sdk_enabled(None) + assert global_sdk_config.sdk_enabled() is True + + +def test_env_enable_case(): + os.environ[XRAY_ENABLED_KEY] = "TrUE" + # Env Variable takes precedence. This is called to activate the internal check + global_sdk_config.set_sdk_enabled(True) + assert global_sdk_config.sdk_enabled() is True + + os.environ[XRAY_ENABLED_KEY] = "true" + global_sdk_config.set_sdk_enabled(True) + assert global_sdk_config.sdk_enabled() is True + + os.environ[XRAY_ENABLED_KEY] = "False" + global_sdk_config.set_sdk_enabled(True) + assert global_sdk_config.sdk_enabled() is False + + os.environ[XRAY_ENABLED_KEY] = "falSE" + global_sdk_config.set_sdk_enabled(True) + assert global_sdk_config.sdk_enabled() is False + + +def test_invalid_env_string(): + os.environ[XRAY_ENABLED_KEY] = "INVALID" + # Env Variable takes precedence. This is called to activate the internal check + global_sdk_config.set_sdk_enabled(True) + assert global_sdk_config.sdk_enabled() is True + + os.environ[XRAY_ENABLED_KEY] = "1.0" + global_sdk_config.set_sdk_enabled(True) + assert global_sdk_config.sdk_enabled() is True + + os.environ[XRAY_ENABLED_KEY] = "1-.0" + global_sdk_config.set_sdk_enabled(False) + assert global_sdk_config.sdk_enabled() is True