Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/xpk/commands/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
from ..core.cluster_private import authorize_private_cluster_access_if_necessary
from ..core.commands import run_command_for_value, run_command_with_updates
from ..core.config import VERTEX_TENSORBOARD_FEATURE_FLAG
from ..core.telemetry import MetricsCollector, MetricsEventMetadataKey
from ..core.capacity import get_capacity_type
from ..core.gcloud_context import (
add_zone_and_project,
get_gke_control_plane_version,
Expand Down Expand Up @@ -263,6 +265,8 @@ def cluster_create(args) -> None:
xpk_print(f'Starting cluster create for cluster {args.cluster}:', flush=True)
add_zone_and_project(args)

_log_cluster_create_telemetry(args)

if system.device_type in cluster_gcluster.supported_device_types:
xpk_print(
'Creating the cluster using Cluster Toolkit. Machine Type:'
Expand Down Expand Up @@ -1319,3 +1323,18 @@ def prepare_gpus(system: SystemCharacteristics):
err_code = disable_mglru_on_cluster()
if err_code > 0:
xpk_exit(err_code)


def _log_cluster_create_telemetry(args) -> None:
if FeatureFlags.TELEMETRY_ENABLED:
capacity_type, _ = get_capacity_type(args)
MetricsCollector.log_custom(
name='cluster_create',
metadata={
MetricsEventMetadataKey.ZONE: args.zone,
MetricsEventMetadataKey.SYSTEM_CHARACTERISTICS: (
args.tpu_type if args.tpu_type else args.device_type
),
MetricsEventMetadataKey.PROVISIONING_MODE: capacity_type.value,
},
)
97 changes: 96 additions & 1 deletion src/xpk/commands/cluster_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@
limitations under the License.
"""

import json
from argparse import Namespace
from dataclasses import dataclass
from typing import Any
from unittest.mock import MagicMock, patch
import pytest

from xpk.commands.cluster import _install_kueue, _validate_cluster_create_args, run_gke_cluster_create_command
from xpk.core.telemetry import MetricsCollector
from xpk.commands.cluster import _install_kueue, _validate_cluster_create_args, run_gke_cluster_create_command, _log_cluster_create_telemetry
from xpk.core.capacity import CapacityType
from xpk.core.system_characteristics import SystemCharacteristics, UserFacingNameToSystemCharacteristics
from xpk.core.testing.commands_tester import CommandsTester
from xpk.utils.feature_flags import FeatureFlags
Expand Down Expand Up @@ -65,6 +68,10 @@ def construct_args(**kwargs: Any) -> Namespace:
project='project',
zone='us-central1-a',
reservation='',
on_demand=False,
tpu_type=None,
device_type=None,
spot=False,
default_pool_cpu_machine_type='test-machine-type',
cluster='test-cluster',
default_pool_cpu_num_nodes='100',
Expand Down Expand Up @@ -247,3 +254,91 @@ def test_run_gke_cluster_create_command_with_gke_version_has_no_autoupgrade_flag
mocks.commands_tester.assert_command_run(
'clusters create', ' --no-enable-autoupgrade'
)


def test_log_cluster_create_telemetry_does_not_log_when_feature_flag_is_disabled():
FeatureFlags.TELEMETRY_ENABLED = False
_log_cluster_create_telemetry(construct_args())
events = json.loads(MetricsCollector.flush())['log_event']
assert len(events) == 0


def test_log_cluster_create_telemetry_logs_correct_event_when_tpu_type_is_provided(
mocker: MagicMock,
):
FeatureFlags.TELEMETRY_ENABLED = True
mocker.patch(
'xpk.commands.cluster.get_capacity_type',
return_value=(CapacityType.SPOT, 0),
)
_log_cluster_create_telemetry(construct_args(device_type='test-device-type'))
event = json.loads(MetricsCollector.flush())['log_event'][0]
payload = json.loads(event['source_extension_json'])
event_metadata = payload['event_metadata']
assert payload['event_name'] == 'cluster_create'
assert (
_get_event_metadata_value_by_key(
event_metadata,
'XPK_ZONE',
)
== 'us-central1-a'
)
assert (
_get_event_metadata_value_by_key(
event_metadata,
'XPK_SYSTEM_CHARACTERISTICS',
)
== 'test-device-type'
)
assert (
_get_event_metadata_value_by_key(
event_metadata,
'XPK_PROVISIONING_MODE',
)
== 'spot'
)


def test_log_cluster_create_telemetry_logs_correct_event_when_device_type_is_provided(
mocker: MagicMock,
):
FeatureFlags.TELEMETRY_ENABLED = True
mocker.patch(
'xpk.commands.cluster.get_capacity_type',
return_value=(CapacityType.SPOT, 0),
)
_log_cluster_create_telemetry(construct_args(tpu_type='test-tpu-type'))
event = json.loads(MetricsCollector.flush())['log_event'][0]
payload = json.loads(event['source_extension_json'])
event_metadata = payload['event_metadata']
assert payload['event_name'] == 'cluster_create'
assert (
_get_event_metadata_value_by_key(
event_metadata,
'XPK_ZONE',
)
== 'us-central1-a'
)
assert (
_get_event_metadata_value_by_key(
event_metadata,
'XPK_SYSTEM_CHARACTERISTICS',
)
== 'test-tpu-type'
)
assert (
_get_event_metadata_value_by_key(
event_metadata,
'XPK_PROVISIONING_MODE',
)
== 'spot'
)


def _get_event_metadata_value_by_key(
event_metadata: list[dict[str, str]], key: str
) -> str | None:
return next(
(meta['value'] for meta in event_metadata if meta['key'] == key),
None,
)
Loading