Skip to content

Commit

Permalink
Balance selection
Browse files Browse the repository at this point in the history
  • Loading branch information
mgraffg committed Jul 26, 2023
1 parent 0a7253a commit 7896e85
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 3 deletions.
39 changes: 39 additions & 0 deletions IngeoDash/annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,45 @@ def model(mem: Config, data: dict):
return stack.fit(data)


def _select_top_k(selected, members, cnt):
if cnt <= 0:
return
sel = []
for i in members:
if i not in selected:
sel.append(i)
if len(sel) == cnt:
break
[selected.add(x) for x in sel]


def balance_selection(mem: Config, hy):
db = CONFIG.db[mem[mem.username]]
D = db[mem.permanent]
klasses, n_elements = np.unique([x[mem.label_header] for x in D],
return_counts=True)
missing = 1 - n_elements / n_elements.max()
max_ele = n_elements.max()
selected = set()
if klasses.shape[0] > 2:
ss = np.argsort(hy, axis=0)[::-1].T
for k, (c, n_ele, s) in enumerate(zip(mem.n_value * missing,
max_ele - n_elements, ss)):
cnt = np.ceil(min(c, n_ele)).astype(int)
_select_top_k(selected, s, cnt)
else:
ss = np.argsort(hy, axis=0)[:, 0]
ss = [ss, ss[::-1]]
for c, s in zip(max_ele - n_elements, ss):
_select_top_k(selected, s, c)
cnt = np.ceil((mem.n_value - len(selected)) / klasses.shape[0]).astype(int)
[_select_top_k(selected, s, cnt) for s in ss]
selected = list(selected)
np.random.shuffle(selected)
selected = sorted(selected[:mem.n_value])
return np.array(selected), klasses


def random_selection(mem: Config, hy):
index = np.arange(hy.shape[0])
np.random.shuffle(index)
Expand Down
40 changes: 37 additions & 3 deletions IngeoDash/tests/test_annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from IngeoDash.annotate import label_column, flip_label, store, similarity, model
from IngeoDash.annotate import label_column, flip_label, store, similarity, model, balance_selection
from IngeoDash.config import CONFIG
from microtc.utils import tweet_iterator
from EvoMSA.tests.test_base import TWEETS
Expand Down Expand Up @@ -109,8 +109,42 @@ def test_random_selection():
permanent = db[mem.permanent]
data = db[mem.data]
assert [x['id'] for x in permanent] == list(range(10))
assert [x['id'] for x in data] != list(range(10, 20))

assert [x['id'] for x in data] != list(range(10, 20))


def test_balance_selection():
from EvoMSA import BoW
D = list(tweet_iterator(TWEETS))
klasses = np.unique([x['klass'] for x in D[:10]]).tolist()
_ = {CONFIG.username: 'xxx', CONFIG.label_header: 'klass',
CONFIG.lang: 'es', CONFIG.active_learning: True,
CONFIG.labels: klasses,
'active_learning_selection': 'balance_selection'}
mem = CONFIG(_)
CONFIG.db['xxx'] = {mem.permanent: D[:10], mem.data: D[10:20],
mem.original: D[20:]}
bow = BoW(lang=mem[mem.lang], voc_selection=mem.voc_selection,
voc_size_exponent=mem.voc_size_exponent).fit(D[:10])
hy = bow.decision_function(D[10:])
balance_selection(mem, hy)


def test_balance_selection_binary():
from EvoMSA import BoW
D = [x for x in tweet_iterator(TWEETS) if x['klass'] in ['N', 'P']]
klasses = np.unique([x['klass'] for x in D[:10]]).tolist()
_ = {CONFIG.username: 'xxx', CONFIG.label_header: 'klass',
CONFIG.lang: 'es', CONFIG.active_learning: True,
CONFIG.labels: klasses,
'active_learning_selection': 'balance_selection'}
mem = CONFIG(_)
CONFIG.db['xxx'] = {mem.permanent: D[:10], mem.data: D[10:20],
mem.original: D[20:]}
bow = BoW(lang=mem[mem.lang], voc_selection=mem.voc_selection,
voc_size_exponent=mem.voc_size_exponent).fit(D[:10])
hy = bow.decision_function(D[10:])
balance_selection(mem, hy)


def test_flip_label():
data = [dict() for i in range(3)]
Expand Down

0 comments on commit 7896e85

Please sign in to comment.