Skip to content

Commit

Permalink
Fix undersampling (#114)
Browse files Browse the repository at this point in the history
  • Loading branch information
qubixes authored and J535D165 committed Oct 14, 2019
1 parent 3340549 commit 0a8118f
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 2 deletions.
3 changes: 2 additions & 1 deletion asreview/balance_strategies/triple_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ def _n_mini_epoch(n_samples, epoch_size):
return ceil(n_samples/epoch_size)


def triple_balance(X, y, train_idx, fit_kwargs={}, query_kwargs={},
def triple_balance(X, y, train_idx, fit_kwargs={},
query_kwargs={"query_src": {}},
pref_epochs=1, shuffle=True, **dist_kwargs):
"""
A more advanced function that does resample the training set.
Expand Down
2 changes: 1 addition & 1 deletion asreview/balance_strategies/undersampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def undersample(X, y, train_idx, ratio=1.0, shuffle=True):
n_zero_epoch = ceil(n_one/ratio)
zero_under = np.random.choice(np.arange(n_zero), n_zero_epoch,
replace=False)
shuf_ind = np.append(one_ind, zero_under)
shuf_ind = np.append(one_ind, zero_ind[zero_under])

if shuffle:
np.random.shuffle(shuf_ind)
Expand Down
56 changes: 56 additions & 0 deletions test/test_balance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import numpy as np
from asreview.balance_strategies import undersample, full_sample
from asreview.balance_strategies import triple_balance


def generate_data(n_feature=20, n_sample=10):
X = np.random.rand(n_sample, n_feature)
# y = np.random.randint(0, 2, n_sample)
n_sample_zero = np.int(n_sample/2)
n_sample_one = n_sample - n_sample_zero
y = np.append(np.zeros(n_sample_zero), np.ones(n_sample_one))
np.random.shuffle(y)

return X, y


def check_partition(X, y, X_partition, y_partition, train_idx):
partition_idx = []
for row in X_partition:
partition_idx.append(
np.where(np.all(X == row, axis=1))[0][0]
)

assert np.count_nonzero(y_partition == 0) > 0
assert np.count_nonzero(y_partition == 1) > 0
assert len(partition_idx) == X_partition.shape[0]
assert set(partition_idx) <= set(train_idx.tolist())
assert np.all(X[partition_idx] == X_partition)
assert np.all(y[partition_idx] == y_partition)


def check_multiple(balance_fn, n_partition=100, n_feature=200, n_sample=100):
X, y = generate_data(n_feature=n_feature, n_sample=n_sample)
for _ in range(n_partition):
n_train = np.random.randint(10, n_sample)
while True:
train_idx = np.random.choice(
np.arange(len(y)), n_train, replace=False)
num_zero = np.count_nonzero(y[train_idx] == 0)
num_one = np.count_nonzero(y[train_idx] == 1)
if num_zero > 0 and num_one > 0:
break
X_train, y_train = balance_fn(X, y, train_idx)
check_partition(X, y, X_train, y_train, train_idx)


def test_undersample():
check_multiple(undersample)


def test_simple():
check_multiple(full_sample)


def test_triple():
check_multiple(triple_balance)

0 comments on commit 0a8118f

Please sign in to comment.