Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adapt route to add proposed actions #429

Merged
merged 7 commits into from
Feb 13, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions azimuth/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@ def create_app() -> FastAPI:
from azimuth.routers.v1.model_performance.utterance_count import (
router as utterance_count_router,
)
from azimuth.routers.v1.tags import router as tags_router
from azimuth.routers.v1.top_words import router as top_words_router
from azimuth.routers.v1.utterances import router as utterances_router
from azimuth.utils.routers import require_application_ready, require_available_model
Expand All @@ -168,7 +167,6 @@ def create_app() -> FastAPI:
prefix="/dataset_splits/{dataset_split_name}/class_overlap",
dependencies=[Depends(require_application_ready)],
)
api_router.include_router(tags_router, prefix="/tags", dependencies=[])
api_router.include_router(
confidence_histogram_router,
prefix="/dataset_splits/{dataset_split_name}/confidence_histogram",
Expand Down
69 changes: 0 additions & 69 deletions azimuth/routers/v1/tags.py

This file was deleted.

44 changes: 43 additions & 1 deletion azimuth/routers/v1/utterances.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from enum import Enum
from typing import Dict, List, Optional

from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi import APIRouter, Body, Depends, HTTPException, Query
from starlette.status import HTTP_404_NOT_FOUND

from azimuth.app import (
Expand Down Expand Up @@ -46,6 +46,7 @@
ModelPrediction,
ModelSaliency,
Utterance,
UtterancePatch,
)
from azimuth.utils.dataset_operations import filter_dataset_split
from azimuth.utils.project import (
Expand Down Expand Up @@ -264,6 +265,47 @@ def get_utterances(
)


@router.post(
"",
summary="Patch utterances",
description="Patch utterances, such as updating proposed actions.",
tags=TAGS,
response_model=List[UtterancePatch],
)
def patch_utterances(
request_data: List[UtterancePatch] = Body(...),
gabegma marked this conversation as resolved.
Show resolved Hide resolved
dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager),
task_manager: TaskManager = Depends(get_task_manager),
) -> List[UtterancePatch]:
persistent_ids = [utt.persistent_id for utt in request_data]
gabegma marked this conversation as resolved.
Show resolved Hide resolved
try:
row_indices = dataset_split_manager.get_row_indices_from_persistent_id(persistent_ids)
except ValueError as e:
raise HTTPException(HTTP_404_NOT_FOUND, detail=f"Persistent id not found: {e}.")

request_data_complete = {}
gabegma marked this conversation as resolved.
Show resolved Hide resolved
for row_idx, utterance in zip(row_indices, request_data):
request_data_complete[row_idx] = {data_action: False for data_action in ALL_DATA_ACTIONS}
if utterance.data_action != DataAction.no_action:
request_data_complete[row_idx][utterance.data_action] = True

dataset_split_manager.add_tags(request_data_complete)

task_manager.clear_worker_cache()
updated_tags = dataset_split_manager.get_tags(row_indices)

return [
UtterancePatch(
persistent_id=persistent_id,
data_action=next(
iter(tag for tag, value in tags.items() if tag in ALL_DATA_ACTIONS and value),
gabegma marked this conversation as resolved.
Show resolved Hide resolved
DataAction.no_action,
),
)
for persistent_id, tags in zip(persistent_ids, updated_tags.values())
]


@router.get(
"/{index}/perturbed_utterances",
summary="Get a perturbed utterances for a single utterance.",
Expand Down
40 changes: 1 addition & 39 deletions azimuth/types/tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,7 @@
from enum import Enum
from typing import Any, Dict, List

from pydantic import Field

from azimuth.types import AliasModel, DatasetSplitName, ModuleResponse


class DataActionMapping(AliasModel):
relabel: bool = Field(..., title="Relabel")
augment_with_similar: bool = Field(..., title="Augment with Similar")
define_new_class: bool = Field(..., title="Define New Class")
merge_classes: bool = Field(..., title="Merge Two Classes")
remove: bool = Field(..., title="Remove")
investigate: bool = Field(..., title="Investigate")

from azimuth.types import ModuleResponse

Tag = str

Expand Down Expand Up @@ -134,29 +122,3 @@ class SmartTagFamily(str, Enum):
class TaggingResponse(ModuleResponse):
tags: Dict[Tag, bool]
adds: Dict[str, Any]


class DataActionResponse(AliasModel):
data_actions: List[DataActionMapping] = Field(..., title="Data action tags")


class PostDataActionRequest(AliasModel):
dataset_split_name: DatasetSplitName = Field(..., title="Dataset Split Name")
data_actions: Dict[int, Dict[Tag, bool]] = Field(..., title="Data action tags")

class Config:
schema_extra = {
"example": {
"dataset_split_name": "eval",
"data_actions": {
1: {
"relabel": True,
"augment_with_similar": False,
"define_new_class": False,
"merge_classes": False,
"remove": False,
"investigate": False,
}
},
}
}
9 changes: 6 additions & 3 deletions azimuth/types/utterance.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,18 @@ class ModelSaliency(AliasModel):
saliencies: List[float] = Field(..., title="Saliency")


class Utterance(ValuePerDatasetSmartTag[str], ValuePerPipelineSmartTag[str], AliasModel):
index: int = Field(..., title="Index", description="Row index computed by Azimuth..")
class UtterancePatch(AliasModel):
# Union[int, str] in this order so FastAPI tries to cast to int() first, then defaults to str().
persistent_id: Union[int, str] = Field(..., title="Persistent id")
data_action: DataAction = Field(..., title="Data action tag")


class Utterance(ValuePerDatasetSmartTag[str], ValuePerPipelineSmartTag[str], UtterancePatch):
index: int = Field(..., title="Index", description="Row index computed by Azimuth..")
model_prediction: Optional[ModelPrediction] = Field(
..., title="Model prediction", nullable=True
)
model_saliency: Optional[ModelSaliency] = Field(..., title="Model saliency", nullable=True)
data_action: DataAction = Field(..., title="Data action tag")
label: str = Field(..., title="Label")
utterance: str = Field(..., title="Utterance")

Expand Down
11 changes: 6 additions & 5 deletions tests/test_routers/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,12 @@ def test_get_utterances(app: FastAPI) -> None:

def test_get_proposed_actions(app: FastAPI) -> None:
client = TestClient(app)
request = {
"datasetSplitName": "eval",
"data_actions": {0: {"remove": True}, 1: {"relabel": True}},
}
resp = client.post("/tags", json=request)

request = [
{"persistent_id": 0, "data_action": "remove"},
{"persistent_id": 1, "data_action": "relabel"},
]
resp = client.post("/dataset_splits/eval/utterances", json=request)
assert resp.status_code == HTTP_200_OK, resp.text

resp = client.get("/export/dataset_splits/eval/proposed_actions")
Expand Down
52 changes: 0 additions & 52 deletions tests/test_routers/test_tags.py

This file was deleted.

16 changes: 16 additions & 0 deletions tests/test_routers/test_utterances.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import List

from fastapi import FastAPI
from starlette.status import HTTP_200_OK
from starlette.testclient import TestClient

UTTERANCE_COUNT = 42
Expand Down Expand Up @@ -147,3 +148,18 @@ def test_perturbed_utterances(app: FastAPI, monkeypatch):
).json()
# Utterance 1 has 11 perturbation tests
assert len(resp) == 11


def test_post_utterances(app: FastAPI) -> None:
client = TestClient(app)

request = [{"persistent_id": 0, "data_action": "remove"}]
resp = client.post("/dataset_splits/eval/utterances", json=request)
assert resp.status_code == HTTP_200_OK, resp.text
assert resp.json() == [{"persistentId": 0, "dataAction": "remove"}]

# Reset tag to NO_ACTION
request = [{"persistent_id": 0, "data_action": "NO_ACTION"}]
resp = client.post("/dataset_splits/eval/utterances", json=request)
assert resp.status_code == HTTP_200_OK, resp.text
assert resp.json() == [{"persistentId": 0, "dataAction": "NO_ACTION"}]
gabegma marked this conversation as resolved.
Show resolved Hide resolved
20 changes: 11 additions & 9 deletions webapp/src/components/Analysis/UtterancesTable.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ const useStyles = makeStyles((theme) => ({
},
}));

type Row = Utterance & { id: number };
type Row = Utterance & { id: Utterance["persistentId"] };

type Props = {
jobId: string;
Expand Down Expand Up @@ -138,7 +138,9 @@ const UtterancesTable: React.FC<Props> = ({
const { data: utterancesResponse, isFetching } =
getUtterancesEndpoint.useQuery(getUtterancesQueryState);

const [selectedIds, setSelectedIds] = React.useState<number[]>([]);
const [selectedPersistentIds, setSelectedPersistentIds] = React.useState<
number[]
>([]);

const handlePageChange = (page: number) => {
const q = constructSearchString({
Expand Down Expand Up @@ -236,7 +238,7 @@ const UtterancesTable: React.FC<Props> = ({
row,
}: GridCellParams<DataAction, Row>) => (
<UtteranceDataAction
utteranceIds={[row.id]}
persistentIds={[row.persistentId]}
dataAction={value}
allDataActions={datasetInfo?.dataActions || []}
getUtterancesQueryState={getUtterancesQueryState}
Expand All @@ -247,7 +249,7 @@ const UtterancesTable: React.FC<Props> = ({
// That's the width of the table when viewed on a MacBook Pro 16 with the filter panel.
const columns: Column<Row>[] = [
{
field: "id",
field: "index",
headerName: "Id",
description: ID_TOOLTIP,
width: 55,
Expand Down Expand Up @@ -356,7 +358,7 @@ const UtterancesTable: React.FC<Props> = ({
const rows: Row[] = React.useMemo(
() =>
utterancesResponse?.utterances.map((utterance) => ({
id: utterance.index,
id: utterance.persistentId,
...utterance,
})) ?? [],
[utterancesResponse]
Expand All @@ -366,7 +368,7 @@ const UtterancesTable: React.FC<Props> = ({
const RowLink = (props: RowProps<Row>) => (
<Link
style={{ color: "unset", textDecoration: "unset" }}
to={`/${jobId}/dataset_splits/${datasetSplitName}/utterances/${props.row.id}${searchString}`}
to={`/${jobId}/dataset_splits/${datasetSplitName}/utterances/${props.row.index}${searchString}`}
>
<GridRow {...props} />
</Link>
Expand Down Expand Up @@ -416,9 +418,9 @@ const UtterancesTable: React.FC<Props> = ({
checkboxSelection
disableColumnMenu
onSelectionModelChange={(newSelection) => {
setSelectedIds(newSelection as number[]);
setSelectedPersistentIds(newSelection as number[]);
}}
selectionModel={selectedIds}
selectionModel={selectedPersistentIds}
components={{
Footer: UtterancesTableFooter,
Row: RowLink,
Expand All @@ -428,7 +430,7 @@ const UtterancesTable: React.FC<Props> = ({
onClick: (e: React.MouseEvent) => e.stopPropagation(),
},
footer: {
selectedIds,
selectedPersistentIds,
allDataActions: datasetInfo?.dataActions || [],
getUtterancesQueryState,
},
Expand Down
Loading