Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions src/picterra/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import requests
from requests.adapters import HTTPAdapter
from requests.auth import AuthBase
from urllib3.util.retry import Retry

logger = logging.getLogger()
Expand Down Expand Up @@ -121,7 +122,7 @@ def multipolygon_to_polygon_feature_collection(mp):

def _check_resp_is_ok(resp: requests.Response, msg: str) -> None:
if not resp.ok:
raise APIError("%s (status %d): %s" % (msg, resp.status_code, resp.text))
raise APIError("%s (url %s, status %d): %s" % (msg, resp.url, resp.status_code, resp.text))


T = TypeVar("T")
Expand Down Expand Up @@ -197,6 +198,20 @@ class FeatureCollection(TypedDict):
features: list[Feature]


class ApiKeyAuth(AuthBase):
api_key: str

def __init__(self):
api_key = os.environ.get("PICTERRA_API_KEY", None)
if api_key is None:
raise APIError("PICTERRA_API_KEY environment variable is not defined")
self.api_key = api_key

def __call__(self, r):
r.headers['X-Api-Key'] = self.api_key
return r


class BaseAPIClient:
"""
Base class for Picterra API clients.
Expand All @@ -212,16 +227,13 @@ def __init__(
api_url: the api's base url. This is different based on the Picterra product used
and is typically defined by implementations of this client
timeout: number of seconds before the request times out
max_retries: max attempts when ecountering gateway issues or throttles; see
max_retries: max attempts when encountering gateway issues or throttles; see
retry_strategy comment below
backoff_factor: factor used nin the backoff algorithm; see retry_strategy comment below
"""
base_url = os.environ.get(
"PICTERRA_BASE_URL", "https://app.picterra.ch/"
)
api_key = os.environ.get("PICTERRA_API_KEY", None)
if not api_key:
raise APIError("PICTERRA_API_KEY environment variable is not defined")
logger.info(
"Using base_url=%s, api_url=%s; %d max retries, %d backoff and %s timeout.",
base_url,
Expand All @@ -231,9 +243,10 @@ def __init__(
timeout,
)
self.base_url = urljoin(base_url, api_url)
# Create the session with a default timeout (30 sec), that we can then
# Create the session with a default timeout (30 sec) and auth, that we can then
# override on a per-endpoint basis (will be disabled for file uploads and downloads)
self.sess = _RequestsSession(timeout=timeout)
self.sess.auth = ApiKeyAuth() # Authentication
# Retry: we set the HTTP codes for our throttle (429) plus possible gateway problems (50*),
# and for polling methods (GET), as non-idempotent ones should be addressed via idempotency
# key mechanism; given the algorithm is {<backoff_factor> * (2 **<retries-1>}, and we
Expand All @@ -248,8 +261,6 @@ def __init__(
adapter = HTTPAdapter(max_retries=retry_strategy)
self.sess.mount("https://", adapter)
self.sess.mount("http://", adapter)
# Authentication
self.sess.headers.update({"X-Api-Key": api_key})

def _full_url(self, path: str, params: dict[str, Any] | None = None):
url = urljoin(self.base_url, path)
Expand Down
17 changes: 17 additions & 0 deletions src/picterra/tracer_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,23 @@ def create_plots_analysis_report(
report_id = op_result["results"]["plots_analysis_report_id"]
return report_id

def get_plots_group(self, plots_group_id: str) -> dict:
"""
Get plots group information

Args:
plots_group_id: id of the plots group

Raises:
APIError: There was an error while getting the plots group information

Returns:
dict: see https://app.picterra.ch/public/apidocs/plots_analysis/v1/#tag/plots-groups/operation/getPlotsGroup
"""
resp = self.sess.get(self._full_url("plots_groups/%s/" % plots_group_id))
_check_resp_is_ok(resp, "Failed to get plots group")
return resp.json()

def get_plots_analysis(self, plots_analysis_id: str, plots_group_id: Optional[str] = None) -> Dict[str, Any]:
"""
Get plots analysis information
Expand Down
9 changes: 9 additions & 0 deletions tests/test_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,15 @@ def request_callback(request, uri, response_headers):
assert len(httpretty.latest_requests()) == 1


@responses.activate
def test_headers_api_key(monkeypatch):
_add_api_response(detector_api_url("detectors/"), responses.POST, json={"id": "foobar"})
client = _client(monkeypatch)
client.create_detector()
assert len(responses.calls) == 1
assert responses.calls[0].request.headers["X-Api-Key"] == "1234"


@responses.activate
def test_headers_user_agent_version(monkeypatch):
_add_api_response(detector_api_url("detectors/"), responses.POST, json={"id": "foobar"})
Expand Down
18 changes: 18 additions & 0 deletions tests/test_tracer_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,24 @@ def test_create_plots_analysis_report(monkeypatch):
) == "a-report-id"


@responses.activate
def test_get_plots_group(monkeypatch):
client: TracerClient = _client(monkeypatch, platform="plots_analysis")
_add_api_response(
plots_analysis_api_url("plots_groups/a-plots-group/"),
responses.GET,
{
"id": "a-plots-group",
"name": "My Plots Group",
"created_at": "2025-09-29T10:04:08.143098Z",
"methodology": "Coffee - EUDR",
}
)
plots_group = client.get_plots_group("a-plots-group")
assert plots_group["id"] == "a-plots-group"
assert plots_group["name"] == "My Plots Group"


@responses.activate
def test_get_plots_analysis(monkeypatch):
client: TracerClient = _client(monkeypatch, platform="plots_analysis")
Expand Down