Skip to content

Commit

Permalink
Gsq signal (#31998)
Browse files Browse the repository at this point in the history
* gsq signal updates

* gsq signal updates

* docstring

* fixes

* pylint
  • Loading branch information
nemanjarajic committed Sep 18, 2023
1 parent 31e3de9 commit 358a32e
Show file tree
Hide file tree
Showing 12 changed files with 501 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
PredictionDriftSignalSchema,
FeatureAttributionDriftSignalSchema,
CustomMonitoringSignalSchema,
GenerationSafetyQualitySchema,
)
from azure.ai.ml._schema.monitoring.alert_notification import AlertNotificationSchema
from azure.ai.ml._schema.core.fields import NestedField, UnionField, StringTransformedEnum
Expand All @@ -33,6 +34,7 @@ class MonitorDefinitionSchema(metaclass=PatchedSchemaMeta):
NestedField(PredictionDriftSignalSchema),
NestedField(FeatureAttributionDriftSignalSchema),
NestedField(CustomMonitoringSignalSchema),
NestedField(GenerationSafetyQualitySchema),
]
),
)
Expand Down
38 changes: 38 additions & 0 deletions sdk/ml/azure-ai-ml/azure/ai/ml/_schema/monitoring/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
FeatureAttributionDriftMetricThresholdSchema,
ModelPerformanceMetricThresholdSchema,
CustomMonitoringMetricThresholdSchema,
GenerationSafetyQualityMetricThresholdSchema,
)


Expand Down Expand Up @@ -277,3 +278,40 @@ def make(self, data, **kwargs):

data.pop("type", None)
return CustomMonitoringSignal(**data)


class LlmRequestResponseDataSchema(metaclass=PatchedSchemaMeta):
input_data = UnionField(union_fields=[NestedField(DataInputSchema), NestedField(MLTableInputSchema)])
data_column_names = fields.Dict()
data_window_size = fields.Str()

@post_load
def make(self, data, **kwargs):
from azure.ai.ml.entities._monitoring.signals import LlmRequestResponseData

return LlmRequestResponseData(**data)


class GenerationSafetyQualitySchema(metaclass=PatchedSchemaMeta):
type = StringTransformedEnum(allowed_values=MonitorSignalType.GENERATION_SAFETY_QUALITY, required=True)
production_data = fields.List(NestedField(LlmRequestResponseDataSchema))
workspace_connection_id = fields.Str()
metric_thresholds = NestedField(GenerationSafetyQualityMetricThresholdSchema)
alert_enabled = fields.Bool()
properties = fields.Dict()
sampling_rate = fields.Int()

@pre_dump
def predump(self, data, **kwargs):
from azure.ai.ml.entities._monitoring.signals import GenerationSafetyQualitySignal

if not isinstance(data, GenerationSafetyQualitySignal):
raise ValidationError("Cannot dump non-GenerationSafetyQuality object into GenerationSafetyQuality")
return data

@post_load
def make(self, data, **kwargs):
from azure.ai.ml.entities._monitoring.signals import GenerationSafetyQualitySignal

data.pop("type", None)
return GenerationSafetyQualitySignal(**data)
24 changes: 24 additions & 0 deletions sdk/ml/azure-ai-ml/azure/ai/ml/_schema/monitoring/thresholds.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,27 @@ def make(self, data, **kwargs):
from azure.ai.ml.entities._monitoring.thresholds import CustomMonitoringMetricThreshold

return CustomMonitoringMetricThreshold(**data)


class GenerationSafetyQualityMetricThresholdSchema(metaclass=PatchedSchemaMeta): # pylint: disable=name-too-long
groundedness = fields.Dict(
keys=StringTransformedEnum(allowed_values=["aggregated_groundedness_pass_rate"]), values=fields.Number()
)
relevance = fields.Dict(
keys=StringTransformedEnum(allowed_values=["aggregated_relevance_pass_rate"]), values=fields.Number()
)
coherence = fields.Dict(
keys=StringTransformedEnum(allowed_values=["aggregated_coherence_pass_rate"]), values=fields.Number()
)
fluency = fields.Dict(
keys=StringTransformedEnum(allowed_values=["aggregated_fluency_pass_rate"]), values=fields.Number()
)
similarity = fields.Dict(
keys=StringTransformedEnum(allowed_values=["aggregated_similarity_pass_rate"]), values=fields.Number()
)

@post_load
def make(self, data, **kwargs):
from azure.ai.ml.entities._monitoring.thresholds import GenerationSafetyQualityMonitoringMetricThreshold

return GenerationSafetyQualityMonitoringMetricThreshold(**data)
1 change: 1 addition & 0 deletions sdk/ml/azure-ai-ml/azure/ai/ml/constants/_monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class MonitorSignalType(str, Enum, metaclass=CaseInsensitiveEnumMeta):
MODEL_PERFORMANCE = "model_performance"
FEATURE_ATTRIBUTION_DRIFT = "feature_attribution_drift"
CUSTOM = "custom"
GENERATION_SAFETY_QUALITY = "generation_safety_quality"


@experimental
Expand Down
6 changes: 6 additions & 0 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,11 @@
PredictionDriftSignal,
FeatureAttributionDriftSignal,
CustomMonitoringSignal,
GenerationSafetyQualitySignal,
MonitorFeatureFilter,
DataSegment,
FADProductionData,
LlmRequestResponseData,
ProductionData,
ReferenceData,
BaselineDataRange,
Expand All @@ -199,6 +201,7 @@
NumericalDriftMetrics,
DataQualityMetricsNumerical,
DataQualityMetricsCategorical,
GenerationSafetyQualityMonitoringMetricThreshold,
)

from ._workspace_hub.workspace_hub import WorkspaceHub, WorkspaceHubConfig
Expand Down Expand Up @@ -424,9 +427,11 @@
"PredictionDriftSignal",
"FeatureAttributionDriftSignal",
"CustomMonitoringSignal",
"GenerationSafetyQualitySignal",
"MonitorFeatureFilter",
"DataSegment",
"FADProductionData",
"LlmRequestResponseData",
"ProductionData",
"ReferenceData",
"BaselineDataRange",
Expand All @@ -439,6 +444,7 @@
"PredictionDriftMetricThreshold",
"FeatureAttributionDriftMetricThreshold",
"CustomMonitoringMetricThreshold",
"GenerationSafetyQualityMonitoringMetricThreshold",
"CategoricalDriftMetrics",
"NumericalDriftMetrics",
"DataQualityMetricsNumerical",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
FeatureAttributionDriftSignal,
MonitoringSignal,
PredictionDriftSignal,
GenerationSafetyQualitySignal,
)
from azure.ai.ml.entities._monitoring.target import MonitoringTarget

Expand All @@ -45,7 +46,8 @@ class MonitorDefinition(RestTranslatableMixin):
:paramtype monitoring_signals: Optional[Dict[str, Union[~azure.ai.ml.entities.DataDriftSignal
, ~azure.ai.ml.entities.DataQualitySignal, ~azure.ai.ml.entities.PredictionDriftSignal
, ~azure.ai.ml.entities.FeatureAttributionDriftSignal
, ~azure.ai.ml.entities.CustomMonitoringSignal]]]
, ~azure.ai.ml.entities.CustomMonitoringSignal
, ~azure.ai.ml.entities.GenerationSafetyQualitySignal]]]
:keyword alert_notification: The alert configuration for the monitor.
:paramtype alert_notification: Optional[Union[Literal['azmonitoring'], ~azure.ai.ml.entities.AlertNotification]]
Expand All @@ -72,6 +74,7 @@ def __init__(
PredictionDriftSignal,
FeatureAttributionDriftSignal,
CustomMonitoringSignal,
GenerationSafetyQualitySignal,
],
] = None,
alert_notification: Optional[Union[Literal[AZMONITORING], AlertNotification]] = None,
Expand Down
120 changes: 119 additions & 1 deletion sdk/ml/azure-ai-ml/azure/ai/ml/entities/_monitoring/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

# pylint: disable=protected-access
# pylint: disable=protected-access, too-many-lines

import datetime
from typing import Dict, List, Optional, Union
Expand All @@ -18,6 +18,9 @@
from azure.ai.ml._restclient.v2023_06_01_preview.models import (
DataQualityMonitoringSignal as RestMonitoringDataQualitySignal,
)
from azure.ai.ml._restclient.v2023_06_01_preview.models import (
GenerationSafetyQualityMonitoringSignal as RestGenerationSafetyQualityMonitoringSignal,
)
from azure.ai.ml._restclient.v2023_06_01_preview.models import (
FeatureAttributionDriftMonitoringSignal as RestFeatureAttributionDriftMonitoringSignal,
)
Expand Down Expand Up @@ -63,6 +66,7 @@
MetricThreshold,
ModelPerformanceMetricThreshold,
PredictionDriftMetricThreshold,
GenerationSafetyQualityMonitoringMetricThreshold,
)
from azure.ai.ml.entities._job._input_output_helpers import (
to_rest_dataset_literal_inputs,
Expand Down Expand Up @@ -347,6 +351,8 @@ def _from_rest_object( # pylint: disable=too-many-return-statements
return FeatureAttributionDriftSignal._from_rest_object(obj)
if obj.signal_type == MonitoringSignalType.CUSTOM:
return CustomMonitoringSignal._from_rest_object(obj)
if obj.signal_type == MonitoringSignalType.GENERATION_SAFETY_QUALITY:
return GenerationSafetyQualitySignal._from_rest_object(obj)

return None

Expand Down Expand Up @@ -958,6 +964,118 @@ def _from_rest_object(cls, obj: RestCustomMonitoringSignal) -> "CustomMonitoring
)


@experimental
class LlmRequestResponseData(RestTranslatableMixin):
"""LLM Request Response Data
:keyword input_data: Input data used by the monitor.
:paramtype input_data: ~azure.ai.ml.entities.Input
:keyword data_column_names: The names of columns in the input data.
:paramtype data_column_names: Dict[str, str]
:keyword data_window_size: The number of days a single monitor looks back
over the target
:paramtype data_window_size: Optional[int]
"""

def __init__(
self,
*,
input_data: Input,
data_column_names: Dict[str, str] = None,
data_window_size: str = None,
):
self.input_data = input_data
self.data_column_names = data_column_names
self.data_window_size = data_window_size

def _to_rest_object(self, **kwargs) -> RestMonitoringInputData:
if self.data_window_size is None:
self.data_window_size = kwargs.get("default")
return TrailingInputData(
target_columns=self.data_column_names,
job_type=self.input_data.type,
uri=self.input_data.path,
window_size=self.data_window_size,
window_offset=self.data_window_size,
)._to_rest_object()

@classmethod
def _from_rest_object(cls, obj: RestMonitoringInputData) -> "LlmRequestResponseData":
return cls(
input_data=Input(
path=obj.uri,
type=obj.job_input_type,
),
data_column_names=obj.columns,
data_window_size=isodate.duration_isoformat(obj.window_size),
)


@experimental
class GenerationSafetyQualitySignal(RestTranslatableMixin):
"""Generation Safety Quality monitoring signal.
:ivar type: The type of the signal. Set to "generationsafetyquality" for this class.
:vartype type: str
:keyword production_data: A list of input datasets for monitoring.
:paramtype input_datasets: Optional[dict[str, ~azure.ai.ml.entities.LlmRequestResponseData]]
:keyword metric_thresholds: Metrics to calculate and their associated thresholds.
:paramtype metric_thresholds: ~azure.ai.ml.entities.GenerationSafetyQualityMonitoringMetricThreshold
:keyword alert_enabled: Whether or not to enable alerts for the signal. Defaults to True.
:paramtype alert_enabled: bool
:keyword workspace_connection_id: Gets or sets the workspace connection ID used to connect to the
content generation endpoint.
:paramtype workspace_connection_id: str
:keyword properties: The properties of the signal
:paramtype properties: Dict[str, str]
:keyword sampling_rate: The sample rate of the target data, should be greater
than 0 and at most 1.
:paramtype sampling_rate: float
"""

def __init__(
self,
*,
production_data: List[LlmRequestResponseData],
workspace_connection_id: str,
metric_thresholds: GenerationSafetyQualityMonitoringMetricThreshold,
alert_enabled: bool = True,
properties: Optional[Dict[str, str]] = None,
sampling_rate: Optional[float] = None,
):
self.type = MonitorSignalType.GENERATION_SAFETY_QUALITY
self.production_data = production_data
self.workspace_connection_id = workspace_connection_id
self.metric_thresholds = metric_thresholds
self.alert_enabled = alert_enabled
self.properties = properties
self.sampling_rate = sampling_rate

def _to_rest_object(self, **kwargs) -> RestGenerationSafetyQualityMonitoringSignal:
data_window_size = kwargs.get("default_data_window_size")
return RestGenerationSafetyQualityMonitoringSignal(
production_data=[data._to_rest_object(default=data_window_size) for data in self.production_data],
workspace_connection_id=self.workspace_connection_id,
metric_thresholds=self.metric_thresholds._to_rest_object(),
mode=MonitoringNotificationMode.ENABLED if self.alert_enabled else MonitoringNotificationMode.DISABLED,
properties=self.properties,
sampling_rate=self.sampling_rate,
)

@classmethod
def _from_rest_object(cls, obj: RestGenerationSafetyQualityMonitoringSignal) -> "GenerationSafetyQualitySignal":
return cls(
production_data=[LlmRequestResponseData._from_rest_object(data) for data in obj.production_data],
workspace_connection_id=obj.workspace_connection_id,
metric_thresholds=GenerationSafetyQualityMonitoringMetricThreshold._from_rest_object(obj.metric_thresholds),
alert_enabled=False
if not obj.mode or (obj.mode and obj.mode == MonitoringNotificationMode.DISABLED)
else MonitoringNotificationMode.ENABLED,
properties=obj.properties,
sampling_rate=obj.sampling_rate,
)


def _from_rest_features(
obj: RestMonitoringFeatureFilterBase,
) -> Optional[Union[List[str], MonitorFeatureFilter, Literal[ALL_FEATURES]]]:
Expand Down
Loading

0 comments on commit 358a32e

Please sign in to comment.