Skip to content

Commit

Permalink
Minor API changes and improvements(#53)
Browse files Browse the repository at this point in the history
  • Loading branch information
qubixes authored and J535D165 committed Jun 24, 2019
1 parent 5e58f02 commit 92b9a8c
Show file tree
Hide file tree
Showing 9 changed files with 82 additions and 87 deletions.
84 changes: 40 additions & 44 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -206,47 +206,44 @@ optional arguments:
### Python API (oracle mode)
** Under construction / available at own risk **
It is possible to create an interactive systematic reviewer with the Python
API. It requires some knowledge on creating an interface. By default, a simple
command line interface is used to interact with the reviewer.
``` python
from asr import load_data, ReviewOracle
from asr.query_strategies import uncertainty_sampling
from asr import ReviewOracle
from asr.readers import read_data
from asr.utils import text_to_features
from asr.models.embedding import load_embedding, sample_embedding
from asr.models import create_lstm_pool_model
from tensorflow.python.keras.wrappers.scikit_learn import KerasClassifier

# load data
data = load_data(PATH_TO_DATA)
data, texts, _ = read_data(DATA_FILE)

# create features and labels
X, word_index = text_to_features(data)
X, word_index = text_to_features(texts)

# Load embedding layer.
embedding, words = load_embedding(PATH_TO_EMBEDDING)
embedding_matrix = sample_embedding(embedding, words, word_index)
# Load embedding layer.
embedding = load_embedding(EMBEDDING_FILE, word_index=word_index)
embedding_matrix = sample_embedding(embedding, word_index)

# create the model
model = create_lstm_model(
backwards=True,
optimizer='rmsprop',
embedding_layer=embedding_matrix
model = KerasClassifier(
create_lstm_pool_model(embedding_matrix=embedding_matrix),
verbose=1,
)

# start the review process.
asr = ReviewOracle(
X,
model,
uncertainty_sampling,
data,
n_instances=10,
prior_included=[29, 181, 379, 2001, 3928, 3929, 4547],
prior_excluded=[31, 90, 892, 3898, 3989, 4390]
reviewer = ReviewOracle(
X,
data=data,
model=model,
n_instances=10,
prior_included=PRIOR_INC_LIST, # List of some included papers
prior_excluded=PRIOR_EXC_LIST, # List of some excluded papers
)
asr.review()

reviewer.review()
```
## Systematic Review (simulation mode)
Expand Down Expand Up @@ -320,45 +317,44 @@ optional arguments:
### Python API (simulation mode)
** Under construction / available at own risk **
It is possible to simulate a systematic review with the Python
API.
``` python
from asr import load_data, ReviewSimulate
from asr.query_strategies import uncertainty_sampling
from asr import ReviewSimulate
from asr.readers import read_data
from asr.utils import text_to_features
from asr.models.embedding import load_embedding, sample_embedding
from asr.models import create_lstm_pool_model
from tensorflow.python.keras.wrappers.scikit_learn import KerasClassifier
# load data
data, y = read_data(PATH_TO_DATA)
_, texts, y = read_data(DATA_FILE)
# create features and labels
X, word_index = text_to_features(data)
X, word_index = text_to_features(texts)
# Load embedding layer.
embedding, words = load_embedding(PATH_TO_EMBEDDING)
embedding_matrix = sample_embedding(embedding, words, word_index)
# Load embedding layer.
embedding = load_embedding(EMBEDDING_FILE, word_index=word_index)
embedding_matrix = sample_embedding(embedding, word_index)
# create the model
model = create_lstm_model(
backwards=True,
optimizer='rmsprop',
embedding_layer=embedding_matrix
model = KerasClassifier(
create_lstm_pool_model(embedding_matrix=embedding_matrix),
verbose=1,
)
# start the review process.
asr = ReviewSimulate(
X, y,
model,
uncertainty_sampling,
n_instances=10,
prior_included=[29, 181, 379, 2001, 3928, 3929, 4547],
prior_excluded=[31, 90, 892, 3898, 3989, 4390]
reviewer = ReviewSimulate(
X,
y=y,
model=model,
n_instances=10,
prior_included=PRIOR_INC_LIST, # List of some included papers
prior_excluded=PRIOR_EXC_LIST, # List of some excluded papers
)
asr.review()
reviewer.review()
```
## Development and contributions
Expand Down
2 changes: 1 addition & 1 deletion asr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from asr.base import ReviewSimulate, ReviewOracle
from asr.readers import read_csv, read_data, read_ris
from asr.review import review, review_oracle, review_simulate
from asr.utils import load_data, text_to_features
from asr.utils import text_to_features
from asr.models.embedding import load_embedding, sample_embedding
from asr.logging import Logger, read_log, read_logs_from_dir

Expand Down
34 changes: 20 additions & 14 deletions asr/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
from asr.ascii import ASCII_TEA
from asr.balance_strategies import full_sample
from asr.balanced_al import validation_data
from asr.query_strategies import max_sampling


N_INCLUDED = 10
N_EXCLUDED = 40
NOT_AVAILABLE = -1


Expand Down Expand Up @@ -43,19 +42,30 @@ def __init__(self,
X,
y=None,
model=None,
query_strategy=None,
query_strategy=max_sampling,
train_data_fn=full_sample,
n_instances=1,
n_queries=None,
n_queries=1,
prior_included=[],
prior_excluded=[],
log_file=None,
settings={},
fit_kwargs={},
balance_kwargs={},
query_kwargs={},
verbose=1):
super(Review, self).__init__()

self.X = X
self.y = y

# Default to Naive Bayes model
if model is None:
print("Warning: using naive Bayes model as default."
"If you experience bad performance, read the documentation"
" in order to implement a RNN based solution.")
from asr.models import create_nb_model
model = create_nb_model()

self.model = model
self.query_strategy = query_strategy
self.train_data = train_data_fn
Expand All @@ -68,9 +78,9 @@ def __init__(self,
self.prior_included = prior_included
self.prior_excluded = prior_excluded

self.fit_kwargs = settings['fit_kwargs']
self.balance_kwargs = settings['balance_kwargs']
self.query_kwargs = settings['query_kwargs']
self.fit_kwargs = fit_kwargs
self.balance_kwargs = balance_kwargs
self.query_kwargs = query_kwargs

self._logger = Logger()

Expand Down Expand Up @@ -196,13 +206,11 @@ class ReviewSimulate(Review):
def __init__(self,
X,
y,
model,
query_strategy,
n_prior_included=None,
n_prior_excluded=None,
*args, **kwargs):
super(ReviewSimulate, self).__init__(
X, y, model, query_strategy, *args, **kwargs)
X, y, *args, **kwargs)

self.n_prior_included = n_prior_included
self.n_prior_excluded = n_prior_excluded
Expand Down Expand Up @@ -254,13 +262,11 @@ def _classify(self, ind):
class ReviewOracle(Review):
"""Automated Systematic Review"""

def __init__(self, X, model, query_strategy, data, use_cli_colors=True,
def __init__(self, X, data, use_cli_colors=True,
*args, **kwargs):
super(ReviewOracle, self).__init__(
X,
y=np.tile([NOT_AVAILABLE], X.shape[0]),
model=model,
query_strategy=query_strategy,
*args,
**kwargs)

Expand Down
2 changes: 1 addition & 1 deletion asr/models/lstm_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ def create_lstm_base_model(embedding_matrix,
called.
"""

# The Sklearn API requires a callable as result.
# https://keras.io/scikit-learn-api/

def wrap_model():

model = Sequential()
Expand Down
2 changes: 1 addition & 1 deletion asr/models/lstm_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ def create_lstm_pool_model(embedding_matrix,
called.
"""

# The Sklearn API requires a callable as result.
# https://keras.io/scikit-learn-api/

def wrap_model():

model = Sequential()
Expand Down
4 changes: 2 additions & 2 deletions asr/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ def read_data(fp):
print(f'Choosing the one with the highest priority: '
f'{column_labels[0]}')
elif len(column_labels) == 0:
return texts.values
return df, texts.values, None
labels = df[column_labels[0]]
return texts.values, labels.values
return df, texts.values, labels.values


def read_csv(fp, labels=None):
Expand Down
22 changes: 11 additions & 11 deletions asr/review.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def review(dataset,
print("Prepare dataset.\n")
print(ASCII_TEA)

texts, labels = read_data(dataset)
data, texts, labels = read_data(dataset)

# get the model
if base_model == "RNN":
Expand Down Expand Up @@ -184,18 +184,18 @@ def review(dataset,
# start the review process
reviewer = ReviewSimulate(
X, y,
model,
query_fn,
model=model,
query_strategy=query_fn,
train_data_fn=train_data_fn,
n_instances=n_instances,
verbose=verbose,
prior_included=prior_included,
prior_excluded=prior_excluded,
n_prior_included=n_prior_included,
n_prior_excluded=n_prior_excluded,

# Fit keyword arguments
settings=settings,
fit_kwargs=settings['fit_kwargs'],
balance_kwargs=settings['balance_kwargs'],
query_kwargs=settings['query_kwargs'],

# Other
**kwargs)
Expand All @@ -222,16 +222,16 @@ def review(dataset,
# start the review process
reviewer = ReviewOracle(
X,
model,
query_fn,
data,
model=model,
query_strategy=query_fn,
data=data,
n_instances=n_instances,
verbose=verbose,
prior_included=prior_included,
prior_excluded=prior_excluded,

# fit keyword arguments
fit_kwargs=settings['fit_kwargs'],
balance_kwargs=settings['balance_kwargs'],
query_kwargs=settings['query_kwargs'],

# other keyword arguments
**kwargs)
Expand Down
7 changes: 0 additions & 7 deletions asr/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,3 @@ def config_from_file(config_file):
print (f"Warning: section [{sect}] is ignored in "
f"config file {config_file}")
return settings


def load_data(*args, **kwargs):
""" [Deprecated] Load papers and their labels. @see read_data"""
warnings.warn("deprecated: use read_data instead of load_data",
DeprecationWarning)
return read_data(*args, **kwargs)
12 changes: 6 additions & 6 deletions test/test_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ def test_csv_reader_with_labels():
def test_csv_load_data():

fp = Path("test", "demo_data", "csv_example_with_labels.csv")
x, y = asr.read_data(fp)
_, x, y = asr.read_data(fp)

assert x.shape[0] == 2
assert y.shape[0] == 2

fp = Path("test", "demo_data", "csv_example_without_labels.csv")
x = asr.read_data(fp)
_, x, y = asr.read_data(fp)

assert x.shape[0] == 2
assert x.shape[0] == 2 and y is None


def test_ris_reader_without_labels():
Expand All @@ -63,12 +63,12 @@ def test_ris_reader_with_labels():
def test_ris_load_data():

fp = Path("test", "demo_data", "ris_example_with_labels.ris")
x, y = asr.read_data(fp)
_, x, y = asr.read_data(fp)

assert x.shape[0] == 2
assert y.shape[0] == 2

fp = Path("test", "demo_data", "ris_example_without_labels.ris")
x = asr.read_data(fp)
_, x, y = asr.read_data(fp)

assert x.shape[0] == 2
assert x.shape[0] == 2 and y is None

0 comments on commit 92b9a8c

Please sign in to comment.