Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion openeo/extra/job_management/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ def _job_update_loop(
self._launch_job(start_job, df=not_started, i=i, backend_name=backend_name, stats=stats)
stats["job launch"] += 1

job_db.persist(not_started.loc[i : i + 1])
job_db.persist(not_started.loc[[i]])
stats["job_db persist"] += 1
total_added += 1

Expand Down
72 changes: 54 additions & 18 deletions openeo/extra/job_management/stac_job_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,35 @@ def __init__(
self.base_url = stac_root_url
self.bulk_size = 500

# TODO: is the "item_id" column a feature we can/should deprecate?
self._ensure_item_id_column = True

def exists(self) -> bool:
return any(c.id == self.collection_id for c in self.client.get_collections())

def _normalize_df(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Normalize the given dataframe to be compatible with :py:class:`MultiBackendJobManager`
by adding the default columns and setting the index.
by adding the default columns and using the STAC item ids as index values.
"""
df = MultiBackendJobManager._normalize_df(df)
# If the user doesn't specify the item_id column, we will use the index.
if "item_id" not in df.columns:
df = df.reset_index(names=["item_id"])

if isinstance(df.index, pd.RangeIndex):
# Index is supposed to contain meaningful STAC item ids,
# not some default RangeIndex.
if "item_id" in df.columns:
# Leverage legacy usage pattern with an "item_id" column: copy that over as index.
df.index = df["item_id"]
elif df.shape[0] > 0:
_log.warning(
"STAC API oriented dataframe normalization: no meaningful index. This might cause consistency issues."
)

if self._ensure_item_id_column and "item_id" not in df.columns:
df["item_id"] = df.index

# Make sure the index (of item ids) are strings, to play well with (py)STAC schemas
df.index = df.index.astype(str)
return df

def initialize_from_df(self, df: pd.DataFrame, *, on_exists: str = "error"):
Expand Down Expand Up @@ -128,7 +145,7 @@ def item_from(self, series: pd.Series) -> pystac.Item:
:return: pystac.Item
"""
series_dict = series.to_dict()
item_id = series_dict.pop("item_id")
item_id = str(series.name)
item_dict = {}
item_dict.setdefault("stac_version", pystac.get_stac_version())
item_dict.setdefault("type", "Feature")
Expand Down Expand Up @@ -165,6 +182,13 @@ def count_by_status(self, statuses: Iterable[str] = ()) -> dict:
else:
return items["status"].value_counts().to_dict()

def _search_result_to_df(self, search_result: pystac_client.ItemSearch) -> pd.DataFrame:
series = [self.series_from(item) for item in search_result.items()]
# Note: `series_from` sets the item id as the series "name",
# which ends up in the index of the dataframe
df = pd.DataFrame(series)
return df

def get_by_status(self, statuses: Iterable[str], max: Optional[int] = None) -> pd.DataFrame:
if isinstance(statuses, str):
statuses = {statuses}
Expand All @@ -177,35 +201,47 @@ def get_by_status(self, statuses: Iterable[str], max: Optional[int] = None) -> p
filter=status_filter,
max_items=max,
)
df = self._search_result_to_df(search_results)

series = [self.series_from(item) for item in search_results.items()]

df = pd.DataFrame(series).reset_index(names=["item_id"])
if len(series) == 0:
if df.shape[0] == 0:
# TODO: What if default columns are overwritten by the user?
df = self._normalize_df(
df
) # Even for an empty dataframe the default columns are required
return df

def persist(self, df: pd.DataFrame):
if df.empty:
_log.warning("No data to persist in STAC API job database, skipping.")
return

if not self.exists():
spatial_extent = pystac.SpatialExtent([[-180, -90, 180, 90]])
temporal_extent = pystac.TemporalExtent([[None, None]])
extent = pystac.Extent(spatial=spatial_extent, temporal=temporal_extent)
c = pystac.Collection(id=self.collection_id, description="STAC API job database collection.", extent=extent)
self._create_collection(c)

all_items = []
if not df.empty:

def handle_row(series):
item = self.item_from(series)
all_items.append(item)

df.apply(handle_row, axis=1)
# Merge updates with existing items (if any)
existing_items = self.client.search(
method="GET",
collections=[self.collection_id],
ids=[str(i) for i in df.index.tolist()],
)
existing_df = self._search_result_to_df(existing_items)

self._upload_items_bulk(self.collection_id, all_items)
if existing_df.empty:
df_to_persist = df
else:
# Merge data on item_id (in the index)
df_to_persist = existing_df
# TODO: better way to do update without risk for data update loss?
assert set(df.index).issubset(df_to_persist.index)
df_to_persist.update(df, overwrite=True)

items_to_persist = [self.item_from(s) for _, s in df_to_persist.iterrows()]
_log.info(f"Bulk upload of {len(items_to_persist)} items to STAC API collection {self.collection_id!r}")
self._upload_items_bulk(self.collection_id, items_to_persist)

def _prepare_item(self, item: pystac.Item, collection_id: str):
item.collection_id = collection_id
Expand Down
Loading
Loading