Skip to content

Commit

Permalink
Merge pull request #609 from duesenfranz/reduce-ram-usage
Browse files Browse the repository at this point in the history
Only initialize counterfitted_GLOVE_embedding when needed, massively decreasing ram usage
  • Loading branch information
qiyanjun committed Mar 20, 2022
2 parents 11782c4 + 92374d3 commit 4c91157
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ class ThoughtVector(SentenceEncoder):
word_embedding (textattack.shared.AbstractWordEmbedding): The word embedding to use
"""

def __init__(
self, embedding=WordEmbedding.counterfitted_GLOVE_embedding(), **kwargs
):
def __init__(self, embedding=None, **kwargs):
if embedding is None:
embedding = WordEmbedding.counterfitted_GLOVE_embedding()
if not isinstance(embedding, AbstractWordEmbedding):
raise ValueError(
"`embedding` object must be of type `textattack.shared.AbstractWordEmbedding`."
Expand Down
4 changes: 3 additions & 1 deletion textattack/constraints/semantics/word_embedding_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,16 @@ class WordEmbeddingDistance(Constraint):

def __init__(
self,
embedding=WordEmbedding.counterfitted_GLOVE_embedding(),
embedding=None,
include_unknown_words=True,
min_cos_sim=None,
max_mse_dist=None,
cased=False,
compare_against_original=True,
):
super().__init__(compare_against_original)
if embedding is None:
embedding = WordEmbedding.counterfitted_GLOVE_embedding()
self.include_unknown_words = include_unknown_words
self.cased = cased

Expand Down
9 changes: 3 additions & 6 deletions textattack/transformations/word_swaps/word_swap_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,10 @@ class WordSwapEmbedding(WordSwap):
>>> augmenter.augment(s)
"""

def __init__(
self,
max_candidates=15,
embedding=WordEmbedding.counterfitted_GLOVE_embedding(),
**kwargs
):
def __init__(self, max_candidates=15, embedding=None, **kwargs):
super().__init__(**kwargs)
if embedding is None:
embedding = WordEmbedding.counterfitted_GLOVE_embedding()
self.max_candidates = max_candidates
if not isinstance(embedding, AbstractWordEmbedding):
raise ValueError(
Expand Down

0 comments on commit 4c91157

Please sign in to comment.