Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 36 additions & 3 deletions src/tether/agent/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
CommandAck,
EnrollRequest,
EnrollResponse,
FailureEventPayload,
HeartbeatPayload,
JsonDict,
)
Expand All @@ -30,11 +31,13 @@ def __init__(
cloud_url: str,
*,
device_token: str | None = None,
fleet_device_token: str | None = None,
timeout_seconds: float = DEFAULT_TIMEOUT_SECONDS,
session: Any | None = None,
) -> None:
self.cloud_url = cloud_url.rstrip("/")
self.device_token = device_token
self.fleet_device_token = fleet_device_token
self.timeout_seconds = timeout_seconds
self._session = session if session is not None else self._make_httpx_session(timeout_seconds)

Expand Down Expand Up @@ -72,6 +75,28 @@ def ack_command(self, device_id: str, command_id: str, ack: CommandAck) -> JsonD
auth=True,
)

def create_failure(
self,
device_id: str,
payload: FailureEventPayload | Mapping[str, Any],
*,
device_token: str | None = None,
) -> JsonDict:
failure_token = device_token or self.fleet_device_token
if not failure_token:
raise AgentClientError("fleet device token is required for failure uploads")
if hasattr(payload, "to_dict"):
body = payload.to_dict()
else:
body = dict(payload)
return self._request(
"POST",
f"/fleet/devices/{urllib.parse.quote(device_id, safe='')}/failures",
json_body=body,
auth=True,
auth_token=failure_token,
)

def _request(
self,
method: str,
Expand All @@ -80,13 +105,15 @@ def _request(
json_body: Mapping[str, Any] | None = None,
params: Mapping[str, Any] | None = None,
auth: bool,
auth_token: str | None = None,
) -> Any:
url = self._url(path, params)
headers = {"Content-Type": "application/json"}
if auth:
if not self.device_token:
token = auth_token or self.device_token
if not token:
raise AgentClientError("device token is required for authenticated agent calls")
headers["Authorization"] = f"Bearer {self.device_token}"
headers["Authorization"] = f"Bearer {token}"

if self._session is not None:
response = self._session.request(
Expand Down Expand Up @@ -149,6 +176,12 @@ def make_default_client(
cloud_url: str,
*,
device_token: str | None = None,
fleet_device_token: str | None = None,
timeout_seconds: float = DEFAULT_TIMEOUT_SECONDS,
) -> AgentClient:
return AgentClient(cloud_url, device_token=device_token, timeout_seconds=timeout_seconds)
return AgentClient(
cloud_url,
device_token=device_token,
fleet_device_token=fleet_device_token,
timeout_seconds=timeout_seconds,
)
4 changes: 4 additions & 0 deletions src/tether/agent/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class AgentConfig:
cloud_url: str | None = None
workspace_id: str | None = None
device_token: str | None = None
fleet_device_id: str | None = None
fleet_device_token: str | None = None
heartbeat_interval_seconds: int = DEFAULT_HEARTBEAT_INTERVAL_SECONDS
last_heartbeat_at: str | None = None
last_command_id: str | None = None
Expand Down Expand Up @@ -46,6 +48,8 @@ def from_dict(cls, data: Mapping[str, Any]) -> "AgentConfig":
cloud_url=data.get("cloud_url"),
workspace_id=data.get("workspace_id"),
device_token=data.get("device_token"),
fleet_device_id=data.get("fleet_device_id"),
fleet_device_token=data.get("fleet_device_token"),
heartbeat_interval_seconds=int(
data.get("heartbeat_interval_seconds", DEFAULT_HEARTBEAT_INTERVAL_SECONDS)
),
Expand Down
66 changes: 66 additions & 0 deletions src/tether/agent/daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ def run_once(
result = dict(result)
result["command_id"] = command_id
_ack_command(client, config, command, result)
failure_upload = _upload_failure_event(client, config, command, result)
if failure_upload is not None:
result = dict(result)
result["failure_upload"] = failure_upload
results.append(result)

return {
Expand Down Expand Up @@ -155,6 +159,45 @@ def _client_poll_commands(client: Any, config: Any) -> Any:
return _call_flexible(client.poll_commands)


def _upload_failure_event(
client: Any,
config: Any,
command: Mapping[str, Any],
result: Mapping[str, Any],
) -> dict[str, Any] | None:
try:
from tether.agent.failures import build_failure_from_command_result
except ImportError:
return None

payload = build_failure_from_command_result(command, result, config=config)
if payload is None:
return None

device_id = _config_fleet_device_id(config)
token = _config_fleet_device_token(config)
client_token = getattr(client, "fleet_device_token", None)
if device_id is None:
return {"status": "skipped", "reason": "missing_fleet_device_id"}
if token is None and client_token is None:
return {"status": "skipped", "reason": "missing_fleet_device_token"}

create_failure = getattr(client, "create_failure", None)
if not callable(create_failure):
return {"status": "skipped", "reason": "client_missing_create_failure"}

try:
response = create_failure(str(device_id), payload, device_token=token)
except TypeError:
try:
response = create_failure(str(device_id), payload)
except Exception as exc: # noqa: BLE001
return {"status": "failed", "reason": "failure_upload_failed", "error": str(exc)}
except Exception as exc: # noqa: BLE001
return {"status": "failed", "reason": "failure_upload_failed", "error": str(exc)}
return {"status": "uploaded", "response": response}


def _heartbeat_model(payload: Mapping[str, Any]) -> Any:
try:
from tether.agent.models import HeartbeatPayload
Expand All @@ -178,6 +221,29 @@ def _config_device_id(config: Any) -> str | None:
return str(value) if value is not None else None


def _config_fleet_device_id(config: Any) -> str | None:
value = getattr(config, "fleet_device_id", None)
if value is not None:
return str(value)
if _looks_like_fleet_device_token(getattr(config, "device_token", None)):
return _config_device_id(config)
return None


def _config_fleet_device_token(config: Any) -> str | None:
value = getattr(config, "fleet_device_token", None)
if value is not None:
return str(value)
device_token = getattr(config, "device_token", None)
if _looks_like_fleet_device_token(device_token):
return str(device_token)
return None


def _looks_like_fleet_device_token(value: Any) -> bool:
return isinstance(value, str) and value.startswith(("dvc_live_", "dvc_test_"))


def _call_flexible(method: Callable[..., Any], payload: Any | None = None) -> Any:
signature = inspect.signature(method)
required_positionals = [
Expand Down
Loading
Loading