Skip to content

Commit

Permalink
Add option to return and store relevance scores (asreview#1435)
Browse files Browse the repository at this point in the history
  • Loading branch information
PeterLombaers committed Nov 9, 2023
1 parent a35f3c7 commit 9ca572c
Show file tree
Hide file tree
Showing 11 changed files with 243 additions and 111 deletions.
43 changes: 31 additions & 12 deletions asreview/models/query/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,33 +23,49 @@ class BaseQueryStrategy(BaseModel):
name = "base-query"

@abstractmethod
def query(self, X, classifier=None, n_instances=None, **kwargs):
"""Query new instances.
def query(
self,
X,
classifier=None,
n_instances=None,
return_classifier_scores=False,
**kwargs
):
"""Put records in ranked order.
Arguments
---------
X: numpy.ndarray
Feature matrix to choose samples from.
Feature matrix where every row contains the features of a record.
classifier: SKLearnModel
Trained classifier to compute probabilities if they are necessary.
Trained classifier to compute relevance scores.
n_instances: int
Number of instances to query.
Number of records to query. If None returns all records in ranked order.
return_classifier_score : bool
Return the relevance scores produced by the classifier.
Returns
-------
(numpy.ndarray, numpy.ndarray)
The first is an array of shape (n_instances,) containing the row
indices of the new instances in query order. The second is an array
of shape (n_instances, n_feature_matrix_columns), containing the
feature vectors of the new instances.
numpy.ndarray or (numpy.ndarray, np.ndarray)
The QueryStrategy ranks the row numbers of the feature matrix. It returns
an array of shape (n_instances,) containing the row indices in ranked
order.
If n_instances is None, returns all row numbers in ranked order. If
n_instances is an integer, it only returns the top n_instances.
If return_classifier_scores=True, also returns a second array with the same
number of rows as the feature matrix, containing the relevance scores
predicted by the classifier. If the classifier is not used, this will be
None.
"""
raise NotImplementedError


class ProbaQueryStrategy(BaseQueryStrategy):
name = "proba"

def query(self, X, classifier, n_instances=None, **kwargs):
def query(
self, X, classifier, n_instances=None, return_classifier_scores=False, **kwargs
):
"""Query method for strategies which use class probabilities."""
if n_instances is None:
n_instances = X.shape[0]
Expand All @@ -58,7 +74,10 @@ def query(self, X, classifier, n_instances=None, **kwargs):

query_idx = self._query(predictions, n_instances, X)

return query_idx
if return_classifier_scores:
return query_idx, predictions
else:
return query_idx

@abstractmethod
def _query(self, predictions, n_instances, X=None):
Expand Down
19 changes: 15 additions & 4 deletions asreview/models/query/mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ def __init__(
strategy_2, random_state=self._random_state, **self.kwargs_2
)

def query(self, X, classifier, n_instances=None, **kwargs):
def query(
self, X, classifier, n_instances=None, return_classifier_scores=False, **kwargs
):
# set the number of instances to len(X) if None
if n_instances is None:
n_instances = X.shape[0]
Expand All @@ -104,14 +106,18 @@ def query(self, X, classifier, n_instances=None, **kwargs):
query_idx_1 = self.query_model1._query(predictions, n_instances=n_instances)
except AttributeError:
# for random for example
query_idx_1 = self.query_model1.query(X, classifier, n_instances)
query_idx_1 = self.query_model1.query(
X, classifier, n_instances=n_instances, return_classifier_scores=False
)

# Perform the query with strategy 2.
try:
query_idx_2 = self.query_model2._query(predictions, n_instances=n_instances)
except AttributeError:
# for random for example
query_idx_2 = self.query_model2.query(X, classifier, n_instances)
query_idx_2 = self.query_model2.query(
X, classifier, n_instances, return_classifier_scores=False
)

# mix the 2 query strategies into one list
query_idx_mix = []
Expand All @@ -127,7 +133,12 @@ def query(self, X, classifier, n_instances=None, **kwargs):
j = j + 1

indexes = np.unique(query_idx_mix, return_index=True)[1]
return [query_idx_mix[i] for i in sorted(indexes)][0:n_instances]
ranking = [query_idx_mix[i] for i in sorted(indexes)][0:n_instances]

if return_classifier_scores:
return ranking, predictions
else:
return ranking

def full_hyper_space(self):
from hyperopt import hp
Expand Down
14 changes: 12 additions & 2 deletions asreview/models/query/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,22 @@ def __init__(self, random_state=None):
super(RandomQuery, self).__init__()
self._random_state = get_random_state(random_state)

def query(self, X, classifier=None, n_instances=None, **kwargs):
def query(
self,
X,
classifier=None,
n_instances=None,
return_classifier_scores=False,
**kwargs
):
if n_instances is None:
n_instances = X.shape[0]

query_idx = self._random_state.choice(
np.arange(X.shape[0]), n_instances, replace=False
)

return query_idx
if return_classifier_scores:
return query_idx, None
else:
return query_idx
9 changes: 6 additions & 3 deletions asreview/review/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,11 +301,10 @@ def train(self):
self.classifier.fit(X_train, y_train)

# Use the query strategy to produce a ranking.
ranked_record_ids = self.query_strategy.query(
self.X, classifier=self.classifier
ranked_record_ids, relevance_scores = self.query_strategy.query(
self.X, classifier=self.classifier, return_classifier_scores=True
)

# TODO: Also log the probablities.
# Log the ranking in the state.
with open_state(self.project, read_only=False) as state:
state.add_last_ranking(
Expand All @@ -316,3 +315,7 @@ def train(self):
self.feature_extraction.name,
training_set,
)

if relevance_scores is not None:
# relevance_scores contains scores for 'relevant' in the second column.
state.add_last_probabilities(relevance_scores[:, 1])
9 changes: 7 additions & 2 deletions asreview/review/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,10 @@ def __init__(
# Check if there is already a ranking stored in the state.
if state.model_has_trained:
self.last_ranking = state.get_last_ranking()
self.last_probabilities = state.get_last_probabilities()
else:
self.last_ranking = None
self.last_probabilities = None

self.labeled = state.get_labeled()
self.pool = pd.Series(
Expand Down Expand Up @@ -270,15 +272,17 @@ def train(self):
self.classifier.fit(X_train, y_train)

# Use the query strategy to produce a ranking.
ranked_record_ids = self.query_strategy.query(
self.X, classifier=self.classifier
ranked_record_ids, relevance_scores = self.query_strategy.query(
self.X, classifier=self.classifier, return_classifier_scores=True
)

self.last_ranking = pd.concat(
[pd.Series(ranked_record_ids), pd.Series(range(len(ranked_record_ids)))],
axis=1,
)
self.last_ranking.columns = ["record_id", "label"]
# The scores for the included records in the second column.
self.last_probabilities = relevance_scores[:, 1]

self.training_set = new_training_set

Expand Down Expand Up @@ -351,6 +355,7 @@ def _write_to_state(self):
self.feature_extraction.name,
self.training_set,
)
state.add_last_probabilities(self.last_probabilities)

# Empty the results table in memory.
self.results.drop(self.results.index, inplace=True)
10 changes: 9 additions & 1 deletion asreview/state/sqlstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,8 +462,16 @@ def add_last_probabilities(self, probabilities):
Arguments
---------
probabilities: list, np.array
List containing the probabilities for every record.
List containing the relevance scores for every record. If this is None, the
last probabilities table in the state is emptied.
"""
if probabilities is None:
con = self._connect_to_sql()
cur = con.cursor()
cur.execute("""DELETE FROM last_probabilities""")
con.commit()
return

proba_sql_input = [(proba,) for proba in probabilities]

con = self._connect_to_sql()
Expand Down
Binary file modified docs/source/example.asreview
Binary file not shown.
Loading

0 comments on commit 9ca572c

Please sign in to comment.