Skip to content

Commit

Permalink
n_jobs loading dense BoW
Browse files Browse the repository at this point in the history
  • Loading branch information
mgraffg committed Jun 12, 2023
1 parent 6db089f commit 1b234c3
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 19 deletions.
8 changes: 5 additions & 3 deletions EvoMSA/tests/test_text_repr.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,15 +430,17 @@ def test_DenseBoW_extend2():
name = 'emojis'
func = 'most_common_by_type'
d = 13
text_repr = DenseBoW(lang=lang,
text_repr = DenseBoW(lang=lang,
keyword=False,
voc_size_exponent=13,
emoji=True, dataset=False)
emoji=True, dataset=False,
n_jobs=-1)
url = f'{lang}_{MICROTC}_{name}_{func}_{d}.json.gz'
text_repr2 = DenseBoW(lang=lang,
keyword=False,
voc_size_exponent=13,
emoji=False, dataset=False)
emoji=False, dataset=False,
n_jobs=-1)
text_repr2.text_representations_extend(url)
for a, b in zip(text_repr.names, text_repr2.names):
assert a == b
Expand Down
24 changes: 16 additions & 8 deletions EvoMSA/text_repr.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,7 @@ def select(self, subset: Union[list, None]=None,
return self.select(subset=index)

def fromjson(self, filename:str) -> 'DenseBoW':
"""Load the text representations from a json file
"""Load the text representations from a json file.
:param filename: Path
:type filename: str
Expand All @@ -700,11 +700,15 @@ def fromjson(self, filename:str) -> 'DenseBoW':
self.text_representations_extend(models)
return self

def text_representations_extend(self, value):
"""Add dense BoW representations."""
def text_representations_extend(self, value: Union[List, str]):
"""Add dense BoW representations.
:param value: List of models or name
:type value: List of models or string
"""
from EvoMSA.utils import load_url
if isinstance(value, str):
value = load_url(value)
value = load_url(value, n_jobs=self._n_jobs)
names = set(self.names)
for x in value:
label = x.labels[-1]
Expand All @@ -724,25 +728,29 @@ def skip_dataset(self, value):

def load_emoji(self) -> None:
if self.v1:
emojis = load_emoji(lang=self.lang, v1=self.v1)
emojis = load_emoji(lang=self.lang, v1=self.v1,
n_jobs=self._n_jobs)
self.text_representations.extend(emojis)
self.names.extend([x.labels[-1] for x in emojis])
else:
data = load_emoji(lang=self.lang,
d=self.voc_size_exponent,
func=self.voc_selection)
func=self.voc_selection,
n_jobs=self._n_jobs)
self.text_representations.extend(data)
self.names.extend([x.labels[-1] for x in data])

def load_keyword(self) -> None:
if self.v1:
_ = load_keyword(lang=self.lang, v1=self.v1)
_ = load_keyword(lang=self.lang, v1=self.v1,
n_jobs=self._n_jobs)
self.text_representations.extend(_)
self.names.extend([x.labels[-1] for x in _])
else:
data = load_keyword(lang=self.lang,
d=self.voc_size_exponent,
func=self.voc_selection)
func=self.voc_selection,
n_jobs=self._n_jobs)
self.text_representations.extend(data)
self.names.extend([x.labels[-1] for x in data])

Expand Down
19 changes: 11 additions & 8 deletions EvoMSA/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from os.path import join, dirname, isdir, isfile
from urllib import request
from urllib.error import HTTPError
from joblib import Parallel, delayed
import numpy as np
import hashlib
import os
Expand Down Expand Up @@ -403,10 +404,11 @@ def predict(self, X: Union[np.ndarray, csr_matrix]) -> np.ndarray:
return self._labels[np.where(hy > 0, 1, 0)]
return np.where(hy > 0, 1, -1)

def load_url(url):
def load_url(url, n_jobs=1):
def load(filename):
try:
return [Linear(**x) for x in tweet_iterator(filename)]
return Parallel(n_jobs=n_jobs)(delayed(Linear)(**x)
for x in tweet_iterator(filename))
except Exception:
os.unlink(filename)

Expand All @@ -418,9 +420,10 @@ def load(filename):
models = load(output)
return models


def _load_text_repr(lang='es', name='emojis',
k=None, d=17, func='most_common_by_type',
v1=False):
v1=False, n_jobs=1):
lang = lang.lower().strip()
assert lang in MODEL_LANG
diroutput = join(dirname(__file__), 'models')
Expand All @@ -430,32 +433,32 @@ def _load_text_repr(lang='es', name='emojis',
filename = f'{lang}_{name}_muTC2.4.2.json.gz'
else:
filename = f'{lang}_{MICROTC}_{name}_{func}_{d}.json.gz'
models = load_url(filename)
models = load_url(filename, n_jobs=n_jobs)
if k is None:
return models
return models[k]


def load_emoji(lang='es', emoji=None,
d=17, func='most_common_by_type',
v1=False):
v1=False, n_jobs=1):

lang = lang.lower().strip()
assert lang in MODEL_LANG
return _load_text_repr(lang, 'emojis',
emoji, d=d, func=func,
v1=v1)
v1=v1, n_jobs=n_jobs)


def load_keyword(lang='es', keyword=None,
d=17, func='most_common_by_type',
v1=False):
v1=False, n_jobs=1):

lang = lang.lower().strip()
assert lang in MODEL_LANG
return _load_text_repr(lang, 'keywords',
keyword, d=d, func=func,
v1=v1)
v1=v1, n_jobs=n_jobs)


def emoji_information(lang='es'):
Expand Down

0 comments on commit 1b234c3

Please sign in to comment.