Skip to content

Commit

Permalink
Add route to import proposed actions
Browse files Browse the repository at this point in the history
  • Loading branch information
gabegma committed Feb 10, 2023
1 parent 5b9f969 commit 3af232d
Show file tree
Hide file tree
Showing 9 changed files with 111 additions and 231 deletions.
2 changes: 0 additions & 2 deletions azimuth/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@ def create_app() -> FastAPI:
from azimuth.routers.v1.model_performance.utterance_count import (
router as utterance_count_router,
)
from azimuth.routers.v1.tags import router as tags_router
from azimuth.routers.v1.top_words import router as top_words_router
from azimuth.routers.v1.utterances import router as utterances_router
from azimuth.utils.routers import require_application_ready, require_available_model
Expand All @@ -168,7 +167,6 @@ def create_app() -> FastAPI:
prefix="/dataset_splits/{dataset_split_name}/class_overlap",
dependencies=[Depends(require_application_ready)],
)
api_router.include_router(tags_router, prefix="/tags", dependencies=[])
api_router.include_router(
confidence_histogram_router,
prefix="/dataset_splits/{dataset_split_name}/confidence_histogram",
Expand Down
69 changes: 0 additions & 69 deletions azimuth/routers/v1/tags.py

This file was deleted.

44 changes: 43 additions & 1 deletion azimuth/routers/v1/utterances.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from enum import Enum
from typing import Dict, List, Optional

from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi import APIRouter, Body, Depends, HTTPException, Query
from starlette.status import HTTP_404_NOT_FOUND

from azimuth.app import (
Expand Down Expand Up @@ -46,6 +46,7 @@
ModelPrediction,
ModelSaliency,
Utterance,
UtterancePatch,
)
from azimuth.utils.dataset_operations import filter_dataset_split
from azimuth.utils.project import (
Expand Down Expand Up @@ -264,6 +265,47 @@ def get_utterances(
)


@router.post(
"",
summary="Patch utterances",
description="Patch utterances, such as updating proposed actions.",
tags=TAGS,
response_model=List[UtterancePatch],
)
def patch_utterances(
request_data: List[UtterancePatch] = Body(...),
dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager),
task_manager: TaskManager = Depends(get_task_manager),
) -> List[UtterancePatch]:
persistent_ids = [utt.persistent_id for utt in request_data]
try:
row_indices = dataset_split_manager.get_row_indices_from_persistent_id(persistent_ids)
except ValueError as e:
raise HTTPException(HTTP_404_NOT_FOUND, detail=f"Persistent id not found: {e}.")

request_data_complete = {}
for row_idx, utterance in zip(row_indices, request_data):
request_data_complete[row_idx] = {data_action: False for data_action in ALL_DATA_ACTIONS}
if utterance.data_action != DataAction.no_action:
request_data_complete[row_idx][utterance.data_action] = True

dataset_split_manager.add_tags(request_data_complete)

task_manager.clear_worker_cache()
updated_tags = dataset_split_manager.get_tags(row_indices)

return [
UtterancePatch(
persistent_id=persistent_id,
data_action=next(
iter(tag for tag, value in tags.items() if tag in ALL_DATA_ACTIONS and value),
DataAction.no_action,
),
)
for persistent_id, tags in zip(persistent_ids, updated_tags.values())
]


@router.get(
"/{index}/perturbed_utterances",
summary="Get a perturbed utterances for a single utterance.",
Expand Down
40 changes: 1 addition & 39 deletions azimuth/types/tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,7 @@
from enum import Enum
from typing import Any, Dict, List

from pydantic import Field

from azimuth.types import AliasModel, DatasetSplitName, ModuleResponse


class DataActionMapping(AliasModel):
relabel: bool = Field(..., title="Relabel")
augment_with_similar: bool = Field(..., title="Augment with Similar")
define_new_class: bool = Field(..., title="Define New Class")
merge_classes: bool = Field(..., title="Merge Two Classes")
remove: bool = Field(..., title="Remove")
investigate: bool = Field(..., title="Investigate")

from azimuth.types import ModuleResponse

Tag = str

Expand Down Expand Up @@ -134,29 +122,3 @@ class SmartTagFamily(str, Enum):
class TaggingResponse(ModuleResponse):
tags: Dict[Tag, bool]
adds: Dict[str, Any]


class DataActionResponse(AliasModel):
data_actions: List[DataActionMapping] = Field(..., title="Data action tags")


class PostDataActionRequest(AliasModel):
dataset_split_name: DatasetSplitName = Field(..., title="Dataset Split Name")
data_actions: Dict[int, Dict[Tag, bool]] = Field(..., title="Data action tags")

class Config:
schema_extra = {
"example": {
"dataset_split_name": "eval",
"data_actions": {
1: {
"relabel": True,
"augment_with_similar": False,
"define_new_class": False,
"merge_classes": False,
"remove": False,
"investigate": False,
}
},
}
}
9 changes: 6 additions & 3 deletions azimuth/types/utterance.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,18 @@ class ModelSaliency(AliasModel):
saliencies: List[float] = Field(..., title="Saliency")


class Utterance(ValuePerDatasetSmartTag[str], ValuePerPipelineSmartTag[str], AliasModel):
index: int = Field(..., title="Index", description="Row index computed by Azimuth..")
class UtterancePatch(AliasModel):
# Union[int, str] in this order so FastAPI tries to cast to int() first, then defaults to str().
persistent_id: Union[int, str] = Field(..., title="Persistent id")
data_action: DataAction = Field(..., title="Data action tag")


class Utterance(ValuePerDatasetSmartTag[str], ValuePerPipelineSmartTag[str], UtterancePatch):
index: int = Field(..., title="Index", description="Row index computed by Azimuth..")
model_prediction: Optional[ModelPrediction] = Field(
..., title="Model prediction", nullable=True
)
model_saliency: Optional[ModelSaliency] = Field(..., title="Model saliency", nullable=True)
data_action: DataAction = Field(..., title="Data action tag")
label: str = Field(..., title="Label")
utterance: str = Field(..., title="Utterance")

Expand Down
11 changes: 6 additions & 5 deletions tests/test_routers/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,12 @@ def test_get_utterances(app: FastAPI) -> None:

def test_get_proposed_actions(app: FastAPI) -> None:
client = TestClient(app)
request = {
"datasetSplitName": "eval",
"data_actions": {0: {"remove": True}, 1: {"relabel": True}},
}
resp = client.post("/tags", json=request)

request = [
{"persistent_id": 0, "data_action": "remove"},
{"persistent_id": 1, "data_action": "relabel"},
]
resp = client.post("/dataset_splits/eval/utterances", json=request)
assert resp.status_code == HTTP_200_OK, resp.text

resp = client.get("/export/dataset_splits/eval/proposed_actions")
Expand Down
52 changes: 0 additions & 52 deletions tests/test_routers/test_tags.py

This file was deleted.

16 changes: 16 additions & 0 deletions tests/test_routers/test_utterances.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import List

from fastapi import FastAPI
from starlette.status import HTTP_200_OK
from starlette.testclient import TestClient

UTTERANCE_COUNT = 42
Expand Down Expand Up @@ -147,3 +148,18 @@ def test_perturbed_utterances(app: FastAPI, monkeypatch):
).json()
# Utterance 1 has 11 perturbation tests
assert len(resp) == 11


def test_post_utterances(app: FastAPI) -> None:
client = TestClient(app)

request = [{"persistent_id": 0, "data_action": "remove"}]
resp = client.post("/dataset_splits/eval/utterances", json=request)
assert resp.status_code == HTTP_200_OK, resp.text
assert resp.json() == [{"persistentId": 0, "dataAction": "remove"}]

# Reset tag to NO_ACTION
request = [{"persistent_id": 0, "data_action": "NO_ACTION"}]
resp = client.post("/dataset_splits/eval/utterances", json=request)
assert resp.status_code == HTTP_200_OK, resp.text
assert resp.json() == [{"persistentId": 0, "dataAction": "NO_ACTION"}]
Loading

0 comments on commit 3af232d

Please sign in to comment.