Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions app/data/model/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class TableRecord:
id: str
original_data: dict[str, Any]
pgc: int | None
triage_status: str
crossmatch_candidates: list[int]


@dataclass
Expand Down
5 changes: 4 additions & 1 deletion app/data/repositories/layer0/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,11 @@ def fetch_records(
order_direction: str = "asc",
has_pgc: bool | None = None,
pgc_value: int | None = None,
triage_status: str | None = None,
) -> list[model.TableRecord]:
return self.table_repo.fetch_records(table_name, limit, row_offset, order_direction, has_pgc, pgc_value)
return self.table_repo.fetch_records(
table_name, limit, row_offset, order_direction, has_pgc, pgc_value, triage_status
)

def fetch_metadata(self, table_name: str) -> model.Layer0TableMeta:
return self.table_repo.fetch_metadata(table_name)
Expand Down
86 changes: 75 additions & 11 deletions app/data/repositories/layer0/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,17 @@ def _row_to_serializable_dict(row: Any, drop: list[str]) -> dict[str, Any]:
return out


def _crossmatch_metadata_to_candidates(metadata: dict[str, Any] | None) -> list[int]:
if metadata is None:
return []
candidates: list[int] = []
if "pgc" in metadata and metadata["pgc"] is not None:
candidates.append(int(metadata["pgc"]))
if "possible_matches" in metadata and metadata["possible_matches"] is not None:
candidates.extend(int(p) for p in metadata["possible_matches"])
return candidates


@dataclass
class QuantityMock:
values: pandas.Series
Expand Down Expand Up @@ -289,6 +300,7 @@ def fetch_records(
order_direction: str = "asc",
has_pgc: bool | None = None,
pgc_value: int | None = None,
triage_status: str | None = None,
) -> list[model.TableRecord]:
where_parts: list[str] = []
if has_pgc is True:
Expand All @@ -301,39 +313,91 @@ def fetch_records(
params: list[Any] = []
if pgc_value is not None:
params.append(pgc_value)
params.append(limit)
params.append(row_offset)

id_col = sql.Identifier(INTERNAL_ID_COLUMN_NAME)
parts: list[sql.Composable] = [
sql.SQL("SELECT r.*, o.pgc FROM {}.{} AS r JOIN layer0.records AS o ON r.{} = o.id").format(
sql.Identifier(RAWDATA_SCHEMA),
sql.Identifier(table_name),
id_col,
),
]
direction = sql.SQL(order_direction if order_direction in ("asc", "desc") else "asc")

if triage_status == "unprocessed":
where_parts.append("NOT EXISTS (SELECT 1 FROM layer0.crossmatch c WHERE c.record_id = o.id)")
params.append(limit)
params.append(row_offset)
parts: list[sql.Composable] = [
sql.SQL(
"SELECT r.*, o.pgc "
"FROM {}.{} AS r "
"JOIN layer0.records AS o ON r.{} = o.id "
"AND o.table_id = (SELECT id FROM layer0.tables WHERE table_name = %s)"
).format(
sql.Identifier(RAWDATA_SCHEMA),
sql.Identifier(table_name),
id_col,
),
]
params.insert(0, table_name)
elif triage_status in ("pending", "resolved"):
where_parts.append("c.triage_status = %s")
params.append(triage_status)
params.append(limit)
params.append(row_offset)
parts = [
sql.SQL(
"SELECT r.*, o.pgc, c.triage_status, c.metadata AS crossmatch_metadata "
"FROM {}.{} AS r "
"JOIN layer0.records AS o ON r.{} = o.id "
"AND o.table_id = (SELECT id FROM layer0.tables WHERE table_name = %s) "
"JOIN layer0.crossmatch AS c ON o.id = c.record_id"
).format(
sql.Identifier(RAWDATA_SCHEMA),
sql.Identifier(table_name),
id_col,
),
]
params.insert(0, table_name)
else:
params.append(limit)
params.append(row_offset)
parts = [
sql.SQL(
"SELECT r.*, o.pgc, c.triage_status, c.metadata AS crossmatch_metadata "
"FROM {}.{} AS r "
"JOIN layer0.records AS o ON r.{} = o.id "
"AND o.table_id = (SELECT id FROM layer0.tables WHERE table_name = %s) "
"LEFT JOIN layer0.crossmatch AS c ON o.id = c.record_id"
).format(
sql.Identifier(RAWDATA_SCHEMA),
sql.Identifier(table_name),
id_col,
),
]
params.insert(0, table_name)

if where_parts:
parts.append(sql.SQL(" WHERE "))
parts.append(sql.SQL(" AND ").join([sql.SQL(w) for w in where_parts]))
parts.append(sql.SQL(" ORDER BY r.{} ").format(id_col))
parts.append(sql.SQL(order_direction if order_direction in ("asc", "desc") else "asc"))
parts.append(direction)
parts.append(sql.SQL(" LIMIT %s OFFSET %s"))

rows = self._storage.query(sql.Composed(parts), params=params)
id_col_name = INTERNAL_ID_COLUMN_NAME
drop_labels = [id_col_name, "pgc"]
drop_labels = [id_col_name, "pgc", "triage_status", "crossmatch_metadata"]
result: list[model.TableRecord] = []
for row in rows:
record_id = str(row[id_col_name])
original_data = _row_to_serializable_dict(row, drop=drop_labels)
pgc_val = row.get("pgc")
if pgc_val is not None and (pandas.isna(pgc_val) or (isinstance(pgc_val, float) and np.isnan(pgc_val))):
pgc_val = None
raw_triage = row.get("triage_status")
triage_val = raw_triage if raw_triage is not None else "unprocessed"
candidates = _crossmatch_metadata_to_candidates(row.get("crossmatch_metadata"))
result.append(
model.TableRecord(
id=record_id,
original_data=original_data,
pgc=int(pgc_val) if pgc_val is not None else None,
triage_status=triage_val,
crossmatch_candidates=candidates,
)
)
return result
Expand Down
6 changes: 6 additions & 0 deletions app/domain/adminapi/table_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def get_records(self, r: adminapi.GetRecordsRequest) -> adminapi.GetRecordsRespo
elif r.upload_status == adminapi.UploadStatus.PENDING:
has_pgc = False

triage_filter = r.triage_status.value if r.triage_status is not None else None
errgr = concurrency.ErrorGroup()
records_task = errgr.run(
self.layer0_repo.fetch_records,
Expand All @@ -225,6 +226,7 @@ def get_records(self, r: adminapi.GetRecordsRequest) -> adminapi.GetRecordsRespo
order_direction="asc",
has_pgc=has_pgc,
pgc_value=r.pgc,
triage_status=triage_filter,
)
schema_task = errgr.run(
self.common_repo.get_schema,
Expand All @@ -241,6 +243,10 @@ def get_records(self, r: adminapi.GetRecordsRequest) -> adminapi.GetRecordsRespo
id=rec.id,
original_data=rec.original_data,
pgc=rec.pgc,
crossmatch=adminapi.RecordCrossmatchInfo(
triage_status=adminapi.CrossmatchTriageStatus(rec.triage_status),
candidates=[adminapi.RecordCrossmatchCandidate(pgc=p) for p in rec.crossmatch_candidates],
),
)
for rec in raw_records
]
Expand Down
32 changes: 28 additions & 4 deletions app/presentation/adminapi/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,24 +212,48 @@ class UploadStatus(enum.Enum):
PENDING = "pending"


class CrossmatchTriageStatus(enum.Enum):
UNPROCESSED = "unprocessed"
PENDING = "pending"
RESOLVED = "resolved"


class GetRecordsRequest(pydantic.BaseModel):
table_name: str
page: int = 0
page_size: int = 25
upload_status: UploadStatus | None = None
pgc: int | None = None
upload_status: UploadStatus | None = None
triage_status: CrossmatchTriageStatus | None = None

@pydantic.model_validator(mode="after")
def check_pending_and_pgc_exclusive(self) -> "GetRecordsRequest":
if self.upload_status == UploadStatus.PENDING and self.pgc is not None:
raise ValueError("upload_status pending and pgc filter cannot be specified at the same time")
def check_exclusive_pgc_filter(self) -> "GetRecordsRequest":
if self.pgc is not None:
if any([self.upload_status is not None, self.triage_status is not None]):
raise ValueError("When pgc filter is specified, no other filters are allowed.")
return self

@pydantic.model_validator(mode="after")
def check_upload_status_and_triage_status(self) -> "GetRecordsRequest":
if self.upload_status == UploadStatus.UPLOADED and self.triage_status is not None:
raise ValueError("When upload_status is UPLOADED, triage_status is not allowed.")
return self


class RecordCrossmatchCandidate(pydantic.BaseModel):
pgc: int


class RecordCrossmatchInfo(pydantic.BaseModel):
triage_status: CrossmatchTriageStatus
candidates: list[RecordCrossmatchCandidate]


class Record(pydantic.BaseModel):
id: str
original_data: dict[str, Any]
pgc: int | None
crossmatch: RecordCrossmatchInfo


class DescriptionSchema(pydantic.BaseModel):
Expand Down
41 changes: 37 additions & 4 deletions tests/unit/domain/table_upload_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,16 @@ def setUp(self) -> None:

def test_get_records_returns_records_with_pgc(self) -> None:
self.manager.layer0_repo.fetch_records.return_value = [
model.TableRecord(id="rec1", original_data={"name": "A"}, pgc=1001),
model.TableRecord(id="rec2", original_data={"name": "B"}, pgc=1002),
model.TableRecord(
id="rec1", original_data={"name": "A"}, pgc=1001, triage_status="resolved", crossmatch_candidates=[1001]
),
model.TableRecord(
id="rec2",
original_data={"name": "B"},
pgc=1002,
triage_status="pending",
crossmatch_candidates=[],
),
]

request = presentation.GetRecordsRequest(table_name="t", page=0, page_size=25)
Expand All @@ -277,9 +285,16 @@ def test_get_records_returns_records_with_pgc(self) -> None:
self.assertEqual(response.records[0].id, "rec1")
self.assertEqual(response.records[0].original_data, {"name": "A"})
self.assertEqual(response.records[0].pgc, 1001)
self.assertEqual(response.records[0].crossmatch.triage_status, presentation.CrossmatchTriageStatus.RESOLVED)
self.assertEqual(
response.records[0].crossmatch.candidates,
[presentation.RecordCrossmatchCandidate(pgc=1001)],
)
self.assertEqual(response.records[1].id, "rec2")
self.assertEqual(response.records[1].original_data, {"name": "B"})
self.assertEqual(response.records[1].pgc, 1002)
self.assertEqual(response.records[1].crossmatch.triage_status, presentation.CrossmatchTriageStatus.PENDING)
self.assertEqual(response.records[1].crossmatch.candidates, [])

def test_get_records_passes_filters_to_fetch_records(self) -> None:
self.manager.layer0_repo.fetch_records.return_value = []
Expand All @@ -302,6 +317,12 @@ def test_get_records_passes_filters_to_fetch_records(self) -> None:
self.assertIsNone(call_kw["has_pgc"])
self.assertEqual(call_kw["pgc_value"], 42)

self.manager.get_records(
presentation.GetRecordsRequest(table_name="t", triage_status=presentation.CrossmatchTriageStatus.PENDING)
)
call_kw = self.manager.layer0_repo.fetch_records.call_args[1]
self.assertEqual(call_kw["triage_status"], "pending")

def test_get_records_pagination(self) -> None:
self.manager.layer0_repo.fetch_records.return_value = []

Expand All @@ -312,8 +333,20 @@ def test_get_records_pagination(self) -> None:

def test_get_records_pgc_none_when_missing_or_nan(self) -> None:
self.manager.layer0_repo.fetch_records.return_value = [
model.TableRecord(id="rec1", original_data={"name": "A"}, pgc=1001),
model.TableRecord(id="rec2", original_data={"name": "B"}, pgc=None),
model.TableRecord(
id="rec1",
original_data={"name": "A"},
pgc=1001,
triage_status="resolved",
crossmatch_candidates=[],
),
model.TableRecord(
id="rec2",
original_data={"name": "B"},
pgc=None,
triage_status="unprocessed",
crossmatch_candidates=[],
),
]

response = self.manager.get_records(presentation.GetRecordsRequest(table_name="t", page=0, page_size=25))
Expand Down
Loading