Skip to content

Commit

Permalink
Add dense BoW from url
Browse files Browse the repository at this point in the history
  • Loading branch information
mgraffg committed Jun 7, 2023
1 parent befcb99 commit cfb25a7
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 13 deletions.
2 changes: 1 addition & 1 deletion EvoMSA/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = '1.9.5'
__version__ = '1.9.6'

try:
from EvoMSA.text_repr import BoW, TextRepresentations, StackGeneralization, DenseBoW
Expand Down
22 changes: 22 additions & 0 deletions EvoMSA/tests/test_text_repr.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,3 +420,25 @@ def test_DenseBoW_skip_dataset():
emoji=False, dataset=True,
skip_dataset=keys)
assert (length - 3) == len(text_repr.names)


def test_DenseBoW_extend2():
from EvoMSA.text_repr import DenseBoW
from EvoMSA.utils import MICROTC

lang = 'es'
name = 'emojis'
func = 'most_common_by_type'
d = 13
text_repr = DenseBoW(lang=lang,
keyword=False,
voc_size_exponent=13,
emoji=True, dataset=False)
url = f'{lang}_{MICROTC}_{name}_{func}_{d}.json.gz'
text_repr2 = DenseBoW(lang=lang,
keyword=False,
voc_size_exponent=13,
emoji=False, dataset=False)
text_repr2.text_representations_extend(url)
for a, b in zip(text_repr.names, text_repr2.names):
assert a == b
4 changes: 4 additions & 0 deletions EvoMSA/text_repr.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,10 @@ def fromjson(self, filename:str) -> 'DenseBoW':
return self

def text_representations_extend(self, value):
"""Add dense BoW representations."""
from EvoMSA.utils import load_url
if isinstance(value, str):
value = load_url(value)
names = set(self.names)
for x in value:
label = x.labels[-1]
Expand Down
26 changes: 14 additions & 12 deletions EvoMSA/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,18 +403,24 @@ def predict(self, X: Union[np.ndarray, csr_matrix]) -> np.ndarray:
return self._labels[np.where(hy > 0, 1, 0)]
return np.where(hy > 0, 1, -1)


def _load_text_repr(lang='es', name='emojis',
k=None, d=17, func='most_common_by_type',
v1=False):
import os
from os.path import isdir, join, isfile, dirname
from urllib.error import HTTPError
def load_url(url):
def load(filename):
try:
return [Linear(**x) for x in tweet_iterator(filename)]
except Exception:
os.unlink(filename)

diroutput = join(dirname(__file__), 'models')
output = join(diroutput, url)
url = f'{BASEURL}/{url}'
if not isfile(output):
Download(url, output)
models = load(output)
return models

def _load_text_repr(lang='es', name='emojis',
k=None, d=17, func='most_common_by_type',
v1=False):
lang = lang.lower().strip()
assert lang in MODEL_LANG
diroutput = join(dirname(__file__), 'models')
Expand All @@ -424,11 +430,7 @@ def load(filename):
filename = f'{lang}_{name}_muTC2.4.2.json.gz'
else:
filename = f'{lang}_{MICROTC}_{name}_{func}_{d}.json.gz'
url = f'{BASEURL}/{filename}'
output = join(diroutput, filename)
if not isfile(output):
Download(url, output)
models = load(output)
models = load_url(filename)
if k is None:
return models
return models[k]
Expand Down

0 comments on commit cfb25a7

Please sign in to comment.