Skip to content

Commit

Permalink
Testing Active Learning
Browse files Browse the repository at this point in the history
  • Loading branch information
mgraffg committed Jul 21, 2023
1 parent f23bc36 commit c98c238
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 9 deletions.
2 changes: 1 addition & 1 deletion IngeoDash/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
# limitations under the License.
from IngeoDash.annotate import label_column, flip_label, store, similarity

__version__ = '0.1.1'
__version__ = '0.1.2'
5 changes: 3 additions & 2 deletions IngeoDash/annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def model(mem: Config, data: dict, select: bool=True):
return stack.fit(data)



def active_learning_selection(mem: Config):
db = CONFIG.db[mem[mem.username]]
dense = model(mem, db[mem.permanent])
Expand All @@ -83,6 +82,7 @@ def active_learning_selection(mem: Config):
data.append(ele)
db[mem.original] = D
db[mem.data] = data
return dense


def label_column_predict(mem: Config, model=None):
Expand All @@ -96,7 +96,8 @@ def label_column_predict(mem: Config, model=None):
dense = model(mem, D)
hys = dense.predict(data).tolist()
for ele, hy in zip(data, hys):
ele[mem.label_header] = ele.get(mem.label_header, hy)
ele[mem.label_header] = ele.get(mem.label_header, hy)
return dense


def label_column(mem: Config, model=model):
Expand Down
4 changes: 3 additions & 1 deletion IngeoDash/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ def __call__(self, value):
cls = deepcopy(self)
if value is not None:
cls.mem = json.loads(value) if isinstance(value, str) else value
for key in ['label_header', 'text', 'n_value']:
for key in ['label_header', 'text', 'n_value',
'voc_size_exponent', 'voc_selection',
'estimator_class', 'decision_function_name']:
if key in cls.mem:
setattr(cls, key, cls.mem[key])
return cls
Expand Down
15 changes: 12 additions & 3 deletions IngeoDash/tests/test_annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from IngeoDash.annotate import label_column, flip_label, store, similarity
from IngeoDash.annotate import label_column, flip_label, store, similarity, model
from IngeoDash.config import CONFIG
from microtc.utils import tweet_iterator
from EvoMSA.tests.test_base import TWEETS
Expand Down Expand Up @@ -91,7 +91,6 @@ def test_predict_active_learning():
assert [x['id'] for x in data] != list(range(10, 20))



def test_flip_label():
data = [dict() for i in range(3)]
mem = CONFIG({CONFIG.username: 'xxx'})
Expand Down Expand Up @@ -125,4 +124,14 @@ def test_similarity():
_ = sorted([[tweet['nn'], sim]for tweet, (sim, ) in zip(tweets, sim_values)],
key=lambda x: x[1],
reverse=True)
assert 'Me choca ahorita' in _[0][0]
assert 'Me choca ahorita' in _[0][0]


def test_stack_dense():
from EvoMSA import BoW, DenseBoW, StackGeneralization
mem = CONFIG({CONFIG.lang: 'es'})
D = list(tweet_iterator(TWEETS))
m = model(mem, D[:15])
assert isinstance(m, DenseBoW) and not isinstance(m, StackGeneralization)
m = model(mem, D)
assert isinstance(m, StackGeneralization)
13 changes: 11 additions & 2 deletions IngeoDash/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from IngeoDash.config import Config
from IngeoDash.config import CONFIG
from sklearn.svm import LinearSVC


def test_Config():
Expand Down Expand Up @@ -45,7 +46,11 @@ def test_Config():
checklist='checklist',
active_learning='active_learning',
shuffle='shuffle',
labels_proportion='labels_proportion')
labels_proportion='labels_proportion',
voc_size_exponent=15,
voc_selection='most_common_by_type',
estimator_class=LinearSVC,
decision_function_name='decision_function')
for k, v in default.items():
assert v == getattr(conf, k)

Expand All @@ -71,7 +76,11 @@ def test_Config_call():

def test_Config_call2():
mem = CONFIG(dict(label_header='label',
text='texto', n_value=12))
text='texto', n_value=12,
voc_size_exponent=15,
voc_selection='most_common_by_type',
estimator_class=LinearSVC,
decision_function_name='decision_function'))
assert mem.label_header == 'label'
assert mem.text == 'texto'
assert mem.n_value == 12
Expand Down

0 comments on commit c98c238

Please sign in to comment.