Skip to content

Commit

Permalink
Implement tags again with different schema in results table
Browse files Browse the repository at this point in the history
  • Loading branch information
J535D165 committed Jun 21, 2024
1 parent 565f3a0 commit 8f01a3d
Show file tree
Hide file tree
Showing 13 changed files with 47 additions and 252 deletions.
6 changes: 2 additions & 4 deletions asreview/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,8 @@
"notes": ["notes"],
"keywords": ["keywords"],
"doi": ["doi"],
"custom_metadata_json": [
"custom metadata",
"custom metadata json",
"custom_metadata_json",
"tags": [
"tags",
],
"is_prior": ["asreview_prior", "is_prior"],
}
Expand Down
40 changes: 24 additions & 16 deletions asreview/state/sqlstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
"training_set",
"labeling_time",
"note",
"custom_metadata_json",
"tags",
"user_id",
]

Expand All @@ -49,6 +49,8 @@
"feature_extraction": "object",
"training_set": "Int64",
"labeling_time": "object",
"note": "object",
"tags": "object",
"user_id": "Int64",
}

Expand Down Expand Up @@ -123,7 +125,7 @@ def create_tables(self):
training_set INTEGER,
labeling_time TEXT,
note TEXT,
custom_metadata_json TEXT,
tags JSON,
user_id INTEGER)"""
)

Expand Down Expand Up @@ -259,7 +261,7 @@ def add_last_ranking(
}
).to_sql("last_ranking", self._conn, if_exists="replace", index=False)

def add_labeling_data(self, record_ids, labels, tags_list=None, user_id=None):
def add_labeling_data(self, record_ids, labels, tags=None, user_id=None):
"""Add the data corresponding to a labeling action to the state file.
Arguments
Expand All @@ -268,21 +270,24 @@ def add_labeling_data(self, record_ids, labels, tags_list=None, user_id=None):
A list of ids of the labeled records as int.
labels: list, numpy.ndarray
A list of labels of the labeled records as int.
tags_list: list of list
A list of tags to save with the labeled records.
tags: dict
A dict of tags to save with the labeled records.
user_id: int
User id of the user who labeled the records.
"""

if tags_list is None:
tags_list = [None for _ in record_ids]
if tags is None:
tags = [None for _ in record_ids]

if len({len(record_ids), len(labels), len(tags_list)}) != 1:
if len({len(record_ids), len(labels), len(tags)}) != 1:
raise ValueError("Input data should be of the same length.")

custom_metadata_list = [
json.dumps({"tags": tags_list[i]}) for i, _ in enumerate(record_ids)
]
if tags is None:
tags = [None for _ in record_ids]
elif isinstance(tags, dict):
tags = [json.dumps(tags) for _ in record_ids]
else:
tags = [json.dumps(tag) for tag in tags]

labeling_time = datetime.now()

Expand All @@ -291,19 +296,19 @@ def add_labeling_data(self, record_ids, labels, tags_list=None, user_id=None):
cur.executemany(
(
"""
INSERT INTO results(record_id,label,labeling_time,custom_metadata_json, user_id)
INSERT INTO results(record_id,label,labeling_time,tags, user_id)
VALUES(?,?,?,?,?)
ON CONFLICT(record_id) DO UPDATE
SET label=excluded.label, labeling_time=excluded.labeling_time,
custom_metadata_json=excluded.custom_metadata_json, user_id=excluded.user_id
tags=excluded.tags, user_id=excluded.user_id
"""
),
[
(
int(record_ids[i]),
int(labels[i]),
labeling_time,
custom_metadata_list[i],
tags[i],
user_id,
)
for i in range(len(record_ids))
Expand Down Expand Up @@ -442,12 +447,15 @@ def get_results_table(self, columns=None, priors=True, pending=False):
}

query_string = "*" if columns is None else ",".join(columns)
return pd.read_sql_query(
df_results = pd.read_sql_query(
f"SELECT {query_string} FROM results {sql_where_str}",
self._conn,
dtype=col_dtype,
)

df_results["tags"] = df_results["tags"].map(json.loads, na_action="ignore")
return df_results

def get_priors(self):
"""Get the record ids of the priors.
Expand Down Expand Up @@ -546,7 +554,7 @@ def update(self, record_id, label=None, tags=None):
cur = self._conn.cursor()

cur.execute(
"UPDATE results SET label = ?, custom_metadata_json=? WHERE record_id = ?",
"UPDATE results SET label = ?, tags=? WHERE record_id = ?",
(label, json.dumps({"tags": tags}), record_id),
)

Expand Down
26 changes: 5 additions & 21 deletions asreview/webapp/api/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,22 +69,6 @@
bp = Blueprint("api", __name__, url_prefix="/api")


def _extract_tags(custom_metadata_str):
if not isinstance(custom_metadata_str, str):
return None

obj = json.loads(custom_metadata_str)

if "tags" in obj:
return obj["tags"]
else:
return None


def _get_tag_composite_id(group_id, tag_id):
return f"{group_id}:{tag_id}"


def _fill_last_ranking(project, ranking):
"""Fill the last ranking with a random or top-down ranking.
Expand Down Expand Up @@ -894,11 +878,11 @@ def api_import_project():


def _add_tags_to_export_data(project, export_data, state_df):
tags_df = state_df[["custom_metadata_json"]].copy()
tags_df = state_df[["tags"]].copy()

tags_df["tags"] = (
tags_df["custom_metadata_json"]
.apply(lambda d: _extract_tags(d))
tags_df["tags"]
# .apply(lambda d: _extract_tags(d))
.apply(lambda d: d if isinstance(d, list) else [])
)

Expand All @@ -907,7 +891,7 @@ def _add_tags_to_export_data(project, export_data, state_df):

if tags_config is not None:
all_tags = [
[_get_tag_composite_id(group["id"], tag["id"]) for tag in group["values"]]
[(group["id"], tag["id"]) for tag in group["values"]]
for group in tags_config
]
all_tags = list(chain.from_iterable(all_tags))
Expand Down Expand Up @@ -1248,7 +1232,7 @@ def api_label_record(project, record_id): # noqa: F401
state.add_labeling_data(
record_ids=[record_id],
labels=[label],
tags_list=[tags],
tags=[tags],
user_id=user_id,
)
elif label == -1:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import NoteAltOutlinedIcon from "@mui/icons-material/NoteAltOutlined";
import { styled } from "@mui/material/styles";
import { ProjectAPI } from "api";
import { useToggle } from "hooks/useToggle";
import "./ReviewPage.css";

import { TagsTable } from ".";

Expand Down Expand Up @@ -94,15 +93,17 @@ const DecisionButton = ({
label,
labelFromDataset = null,
tagsForm,
tagValues = [],
tagValues = null,
note = null,
showNotes = true,
decisionCallback,
hotkeys = false,
retrainAfterDecision = true,
}) => {
const [showNotesDialog, toggleShowNotesDialog] = useToggle(false);
const [tagValuesState, setTagValuesState] = React.useState(tagValues);
const [tagValuesState, setTagValuesState] = React.useState(
tagValues ? tagValues : structuredClone(tagsForm),
);

const { error, isError, isLoading, mutate, isSuccess } = useMutation(
ProjectAPI.mutateClassification,
Expand Down

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ const RecordCard = ({
note={record.state?.note}
showNotes={showNotes}
tagsForm={record.tags_form}
tagValues={record.tags}
tagValues={record.state?.tags}
hotkeys={hotkeys}
/>
</StyledCard>
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@ import { ActionsFeedbackBar } from "Components";
import { RecordCard, ReviewPageFinished } from ".";

import { ProjectAPI } from "api";
// import { useKeyPress } from "hooks/useKeyPress";

import "./ReviewPage.css";
import FinishSetup from "./ReviewPageTraining";

const Root = styled("div")(({ theme }) => ({
Expand Down
Loading

0 comments on commit 8f01a3d

Please sign in to comment.