Skip to content

Commit

Permalink
Merge branch 'main' into christophe-papazian/exploit_prevention_stack…
Browse files Browse the repository at this point in the history
…_traces_support
  • Loading branch information
christophe-papazian committed Mar 26, 2024
2 parents e658581 + 9707da1 commit b7a98ab
Show file tree
Hide file tree
Showing 97 changed files with 1,046 additions and 273 deletions.
70 changes: 31 additions & 39 deletions ddtrace/_trace/processor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,30 +135,55 @@ def unregister(self):

@attr.s
class TraceSamplingProcessor(TraceProcessor):
"""Processor that keeps traces that have sampled spans. If all spans
are unsampled then ``None`` is returned.
"""Processor that runs both trace and span sampling rules.
Note that this processor is only effective if complete traces are sent. If
the spans of a trace are divided in separate lists then it's possible that
parts of the trace are unsampled when the whole trace should be sampled.
* Span sampling must be applied after trace sampling priority has been set.
* Span sampling rules are specified with a sample rate or rate limit as well as glob patterns
for matching spans on service and name.
* If the span sampling decision is to keep the span, then span sampling metrics are added to the span.
* If a dropped trace includes a span that had been kept by a span sampling rule, then the span is sent to the
Agent even if the dropped trace is not (as is the case when trace stats computation is enabled).
"""

_compute_stats_enabled = attr.ib(type=bool)
sampler = attr.ib()
single_span_rules = attr.ib(type=List[SpanSamplingRule])

def process_trace(self, trace):
# type: (List[Span]) -> Optional[List[Span]]

if trace:
chunk_root = trace[0]
root_ctx = chunk_root._context

# only trace sample if we haven't already sampled
if root_ctx and root_ctx.sampling_priority is None:
self.sampler.sample(trace[0])
# When stats computation is enabled in the tracer then we can
# safely drop the traces.
if self._compute_stats_enabled:
priority = trace[0]._context.sampling_priority if trace[0]._context is not None else None
priority = root_ctx.sampling_priority if root_ctx is not None else None
if priority is not None and priority <= 0:
# When any span is marked as keep by a single span sampling
# decision then we still send all and only those spans.
single_spans = [_ for _ in trace if is_single_span_sampled(_)]

return single_spans or None

# single span sampling rules are applied after trace sampling
if self.single_span_rules:
for span in trace:
if span.context.sampling_priority is not None and span.context.sampling_priority <= 0:
for rule in self.single_span_rules:
if rule.match(span):
rule.sample(span)
# If stats computation is enabled, we won't send all spans to the agent.
# In order to ensure that the agent does not update priority sampling rates
# due to single spans sampling, we set all of these spans to manual keep.
if config._trace_compute_stats:
span.set_metric(SAMPLING_PRIORITY_KEY, USER_KEEP)
break

return trace

log.debug("dropping trace %d with %d spans", trace[0].trace_id, len(trace))
Expand Down Expand Up @@ -360,36 +385,3 @@ def _queue_span_count_metrics(self, metric_name, tag_name, min_count=100):
TELEMETRY_NAMESPACE_TAG_TRACER, metric_name, count, tags=((tag_name, tag_value),)
)
self._span_metrics[metric_name] = defaultdict(int)


@attr.s
class SpanSamplingProcessor(SpanProcessor):
"""SpanProcessor for sampling single spans:
* Span sampling must be applied after trace sampling priority has been set.
* Span sampling rules are specified with a sample rate or rate limit as well as glob patterns
for matching spans on service and name.
* If the span sampling decision is to keep the span, then span sampling metrics are added to the span.
* If a dropped trace includes a span that had been kept by a span sampling rule, then the span is sent to the
Agent even if the dropped trace is not (as is the case when trace stats computation is enabled).
"""

rules = attr.ib(type=List[SpanSamplingRule])

def on_span_start(self, span):
# type: (Span) -> None
pass

def on_span_finish(self, span):
# type: (Span) -> None
# only sample if the span isn't already going to be sampled by trace sampler
if span.context.sampling_priority is not None and span.context.sampling_priority <= 0:
for rule in self.rules:
if rule.match(span):
rule.sample(span)
# If stats computation is enabled, we won't send all spans to the agent.
# In order to ensure that the agent does not update priority sampling rates
# due to single spans sampling, we set all of these spans to manual keep.
if config._trace_compute_stats:
span.set_metric(SAMPLING_PRIORITY_KEY, USER_KEEP)
break
50 changes: 40 additions & 10 deletions ddtrace/_trace/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from ddtrace._trace.context import Context
from ddtrace._trace.processor import SpanAggregator
from ddtrace._trace.processor import SpanProcessor
from ddtrace._trace.processor import SpanSamplingProcessor
from ddtrace._trace.processor import TopLevelSpanProcessor
from ddtrace._trace.processor import TraceProcessor
from ddtrace._trace.processor import TraceSamplingProcessor
Expand Down Expand Up @@ -114,13 +113,18 @@ def _default_span_processors_factory(
compute_stats_enabled: bool,
single_span_sampling_rules: List[SpanSamplingRule],
agent_url: str,
trace_sampler: BaseSampler,
profiling_span_processor: EndpointCallCounterProcessor,
) -> Tuple[List[SpanProcessor], Optional[Any], List[SpanProcessor]]:
# FIXME: type should be AppsecSpanProcessor but we have a cyclic import here
"""Construct the default list of span processors to use."""
trace_processors: List[TraceProcessor] = []
trace_processors += [TraceTagsProcessor(), PeerServiceProcessor(_ps_config), BaseServiceProcessor()]
trace_processors += [TraceSamplingProcessor(compute_stats_enabled)]
trace_processors += [
PeerServiceProcessor(_ps_config),
BaseServiceProcessor(),
TraceSamplingProcessor(compute_stats_enabled, trace_sampler, single_span_sampling_rules),
TraceTagsProcessor(),
]
trace_processors += trace_filters

span_processors: List[SpanProcessor] = []
Expand Down Expand Up @@ -167,9 +171,6 @@ def _default_span_processors_factory(

span_processors.append(profiling_span_processor)

if single_span_sampling_rules:
span_processors.append(SpanSamplingProcessor(single_span_sampling_rules))

# These need to run after all the other processors
deferred_processors: List[SpanProcessor] = [
SpanAggregator(
Expand Down Expand Up @@ -266,6 +267,7 @@ def __init__(
self._compute_stats,
self._single_span_sampling_rules,
self._agent_url,
self._sampler,
self._endpoint_call_counter_span_processor,
)
if config._data_streams_enabled:
Expand All @@ -278,6 +280,7 @@ def __init__(

self._hooks = _hooks.Hooks()
atexit.register(self._atexit)
forksafe.register_before_fork(self._sample_before_fork)
forksafe.register(self._child_after_fork)

self._shutdown_lock = RLock()
Expand All @@ -298,7 +301,10 @@ def _atexit(self) -> None:
self.shutdown(timeout=self.SHUTDOWN_TIMEOUT)

def sample(self, span):
self._sampler.sample(span)
if self._sampler is not None:
self._sampler.sample(span)
else:
log.error("No sampler available to sample span")

@property
def sampler(self):
Expand Down Expand Up @@ -341,6 +347,29 @@ def deregister_on_start_span(self, func: Callable) -> Callable:
self._hooks.deregister(self.__class__.start_span, func)
return func

def _sample_before_fork(self) -> None:
span = self.current_root_span()
if span is not None and span.context.sampling_priority is None:
self.sample(span)

@property
def _sampler(self):
return self._sampler_current

@_sampler.setter
def _sampler(self, value):
self._sampler_current = value
# we need to update the processor that uses the sampler
if getattr(self, "_deferred_processors", None):
for aggregator in self._deferred_processors:
if type(aggregator) == SpanAggregator:
for processor in aggregator._trace_processors:
if type(processor) == TraceSamplingProcessor:
processor.sampler = value
break
else:
log.debug("No TraceSamplingProcessor available to update sampling rate")

@property
def debug_logging(self):
return log.isEnabledFor(logging.DEBUG)
Expand Down Expand Up @@ -525,6 +554,7 @@ def configure(
self._compute_stats,
self._single_span_sampling_rules,
self._agent_url,
self._sampler,
self._endpoint_call_counter_span_processor,
)

Expand Down Expand Up @@ -590,6 +620,7 @@ def _child_after_fork(self):
self._compute_stats,
self._single_span_sampling_rules,
self._agent_url,
self._sampler,
self._endpoint_call_counter_span_processor,
)

Expand Down Expand Up @@ -775,9 +806,6 @@ def _start_span(
if service and service not in self._services and self._is_span_internal(span):
self._services.add(service)

if not trace_id:
self.sample(span)

# Only call span processors if the tracer is enabled
if self.enabled:
for p in chain(self._span_processors, SpanProcessor.__processors__, self._deferred_processors):
Expand Down Expand Up @@ -1042,6 +1070,7 @@ def shutdown(self, timeout: Optional[float] = None) -> None:

atexit.unregister(self._atexit)
forksafe.unregister(self._child_after_fork)
forksafe.unregister_before_fork(self._sample_before_fork)

self.start_span = self._start_span_after_shutdown # type: ignore[assignment]

Expand Down Expand Up @@ -1103,6 +1132,7 @@ def _on_global_config_update(self, cfg, items):
sample_rate = cfg._trace_sample_rate
else:
sample_rate = None

sampler = DatadogSampler(default_sample_rate=sample_rate)
self._sampler = sampler

Expand Down
6 changes: 5 additions & 1 deletion ddtrace/appsec/_api_security/api_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,12 @@ def _schema_callback(self, env):

try:
# check both current span and root span for sampling priority
# if sampling has not yet run for the span, we default to treating it as sampled
if root.context.sampling_priority is None and env.span.context.sampling_priority is None:
priorities = (1,)
else:
priorities = (root.context.sampling_priority or 0, env.span.context.sampling_priority or 0)
# if any of them is set to USER_KEEP or USER_REJECT, we should respect it
priorities = (root.context.sampling_priority or 0, env.span.context.sampling_priority or 0)
if constants.USER_KEEP in priorities:
priority = constants.USER_KEEP
elif constants.USER_REJECT in priorities:
Expand Down
2 changes: 1 addition & 1 deletion ddtrace/contrib/algoliasearch/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _patched_search(func, instance, wrapt_args, wrapt_kwargs):
span.set_tag_str(SPAN_KIND, SpanKind.CLIENT)

span.set_tag(SPAN_MEASURED_KEY)
if span.context.sampling_priority is None or span.context.sampling_priority <= 0:
if span.context.sampling_priority is not None and span.context.sampling_priority <= 0:
return func(*wrapt_args, **wrapt_kwargs)

if config.algoliasearch.collect_query_text:
Expand Down
2 changes: 1 addition & 1 deletion ddtrace/contrib/elasticsearch/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def _perform_request(func, instance, args, kwargs):
span.set_tag(SPAN_MEASURED_KEY)

# Only instrument if trace is sampled or if we haven't tried to sample yet
if span.context.sampling_priority is None or span.context.sampling_priority <= 0:
if span.context.sampling_priority is not None and span.context.sampling_priority <= 0:
yield func(*args, **kwargs)
return

Expand Down
45 changes: 45 additions & 0 deletions ddtrace/contrib/langchain/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from ddtrace.internal.utils.formats import asbool
from ddtrace.internal.utils.formats import deep_getattr
from ddtrace.internal.utils.version import parse_version
from ddtrace.llmobs import LLMObs
from ddtrace.llmobs._integrations import LangChainIntegration
from ddtrace.pin import Pin
from ddtrace.vendor import wrapt
Expand Down Expand Up @@ -153,6 +154,7 @@ def traced_llm_generate(langchain, pin, func, instance, args, kwargs):
span = integration.trace(
pin,
"%s.%s" % (instance.__module__, instance.__class__.__name__),
submit_to_llmobs=True,
interface_type="llm",
provider=llm_provider,
model=model,
Expand Down Expand Up @@ -194,6 +196,14 @@ def traced_llm_generate(langchain, pin, func, instance, args, kwargs):
integration.metric(span, "incr", "request.error", 1)
raise
finally:
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags(
"llm",
span,
prompts,
completions,
error=bool(span.error),
)
span.finish()
integration.metric(span, "dist", "request.duration", span.duration_ns)
if integration.is_pc_sampled_log(span):
Expand Down Expand Up @@ -224,6 +234,7 @@ async def traced_llm_agenerate(langchain, pin, func, instance, args, kwargs):
span = integration.trace(
pin,
"%s.%s" % (instance.__module__, instance.__class__.__name__),
submit_to_llmobs=True,
interface_type="llm",
provider=llm_provider,
model=model,
Expand Down Expand Up @@ -265,6 +276,14 @@ async def traced_llm_agenerate(langchain, pin, func, instance, args, kwargs):
integration.metric(span, "incr", "request.error", 1)
raise
finally:
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags(
"llm",
span,
prompts,
completions,
error=bool(span.error),
)
span.finish()
integration.metric(span, "dist", "request.duration", span.duration_ns)
if integration.is_pc_sampled_log(span):
Expand Down Expand Up @@ -294,6 +313,7 @@ def traced_chat_model_generate(langchain, pin, func, instance, args, kwargs):
span = integration.trace(
pin,
"%s.%s" % (instance.__module__, instance.__class__.__name__),
submit_to_llmobs=True,
interface_type="chat_model",
provider=llm_provider,
model=_extract_model_name(instance),
Expand Down Expand Up @@ -348,6 +368,14 @@ def traced_chat_model_generate(langchain, pin, func, instance, args, kwargs):
integration.metric(span, "incr", "request.error", 1)
raise
finally:
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags(
"chat",
span,
chat_messages,
chat_completions,
error=bool(span.error),
)
span.finish()
integration.metric(span, "dist", "request.duration", span.duration_ns)
if integration.is_pc_sampled_log(span):
Expand Down Expand Up @@ -392,6 +420,7 @@ async def traced_chat_model_agenerate(langchain, pin, func, instance, args, kwar
span = integration.trace(
pin,
"%s.%s" % (instance.__module__, instance.__class__.__name__),
submit_to_llmobs=True,
interface_type="chat_model",
provider=llm_provider,
model=_extract_model_name(instance),
Expand Down Expand Up @@ -446,6 +475,14 @@ async def traced_chat_model_agenerate(langchain, pin, func, instance, args, kwar
integration.metric(span, "incr", "request.error", 1)
raise
finally:
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags(
"chat",
span,
chat_messages,
chat_completions,
error=bool(span.error),
)
span.finish()
integration.metric(span, "dist", "request.duration", span.duration_ns)
if integration.is_pc_sampled_log(span):
Expand Down Expand Up @@ -706,6 +743,10 @@ def traced_similarity_search(langchain, pin, func, instance, args, kwargs):
def patch():
if getattr(langchain, "_datadog_patch", False):
return

if config._llmobs_enabled:
LLMObs.enable()

langchain._datadog_patch = True

Pin().onto(langchain)
Expand Down Expand Up @@ -808,6 +849,10 @@ def wrap_output_parser(module, parser):
def unpatch():
if not getattr(langchain, "_datadog_patch", False):
return

if LLMObs.enabled:
LLMObs.disable()

langchain._datadog_patch = False

if SHOULD_PATCH_LANGCHAIN_COMMUNITY:
Expand Down
Loading

0 comments on commit b7a98ab

Please sign in to comment.