Skip to content

Commit

Permalink
Merge pull request #420 from QData/multilingual
Browse files Browse the repository at this point in the history
Addition of optional language/model parameters to constraints (and download method adjustment)
  • Loading branch information
qiyanjun committed Jul 26, 2021
2 parents 1ff364a + a34f5f8 commit ac8872a
Show file tree
Hide file tree
Showing 21 changed files with 111 additions and 44 deletions.
2 changes: 1 addition & 1 deletion tests/test_command_line/test_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,4 +198,4 @@ def test_command_line_attack(name, command, sample_output_file):
pdb.set_trace()
assert re.match(desired_re, stdout, flags=re.S)

assert result.returncode == 0
assert result.returncode == 0, "return code not 0"
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class GoogLMHelper:

def __init__(self):
tf.get_logger().setLevel("INFO")
lm_folder = utils.download_if_needed(GoogLMHelper.CACHE_PATH)
lm_folder = utils.download_from_s3(GoogLMHelper.CACHE_PATH)
self.PBTXT_PATH = os.path.join(lm_folder, "graph-2016-09-10-gpu.pbtxt")
self.CKPT_PATH = os.path.join(lm_folder, "ckpt-*")
self.VOCAB_PATH = os.path.join(lm_folder, "vocab-2016-09-10.txt")
Expand Down
9 changes: 6 additions & 3 deletions textattack/constraints/grammaticality/language_models/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,19 @@ class GPT2(LanguageModelConstraint):
from "Better Language Models and Their Implications"
(openai.com/blog/better-language-models/)
Args:
model_name: id of GPT2 model
"""

def __init__(self, **kwargs):
def __init__(self, model_name="gpt2", **kwargs):
import transformers

# re-enable notifications
os.environ["WANDB_SILENT"] = "0"
self.model = transformers.GPT2LMHeadModel.from_pretrained("gpt2")
self.model = transformers.GPT2LMHeadModel.from_pretrained(model_name)
self.model.to(utils.device)
self.tokenizer = transformers.GPT2Tokenizer.from_pretrained("gpt2")
self.tokenizer = transformers.GPT2Tokenizer.from_pretrained(model_name)
super().__init__(**kwargs)

def get_log_probs_at_index(self, text_list, word_index):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class LearningToWriteLanguageModel(LanguageModelConstraint):

def __init__(self, window_size=5, **kwargs):
self.window_size = window_size
lm_folder_path = textattack.shared.utils.download_if_needed(
lm_folder_path = textattack.shared.utils.download_from_s3(
LearningToWriteLanguageModel.CACHE_PATH
)
self.query_handler = QueryHandler.load_model(
Expand Down
7 changes: 5 additions & 2 deletions textattack/constraints/grammaticality/language_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@ class LanguageTool(Constraint):
relative to `x`
compare_against_original (bool): If `True`, compare against the original text.
Otherwise, compare against the most recent text.
language: language to use for languagetool (available choices: https://dev.languagetool.org/languages)
"""

def __init__(self, grammar_error_threshold=0, compare_against_original=True):
def __init__(
self, grammar_error_threshold=0, compare_against_original=True, language="en-US"
):
super().__init__(compare_against_original)
self.lang_tool = language_tool_python.LanguageTool("en-US")
self.lang_tool = language_tool_python.LanguageTool(language)
self.grammar_error_threshold = grammar_error_threshold
self.grammar_error_cache = {}

Expand Down
16 changes: 14 additions & 2 deletions textattack/constraints/grammaticality/part_of_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ class PartOfSpeech(Constraint):
allow_verb_noun_swap (bool): If `True`, allow verbs to be swapped with nouns and vice versa.
compare_against_original (bool): If `True`, compare against the original text.
Otherwise, compare against the most recent text.
language_nltk: Language to be used for nltk POS-Tagger
(available choices: "eng", "rus")
language_stanza: Language to be used for stanza POS-Tagger
(available choices: https://stanfordnlp.github.io/stanza/available_models.html)
"""

def __init__(
Expand All @@ -41,11 +45,15 @@ def __init__(
tagset="universal",
allow_verb_noun_swap=True,
compare_against_original=True,
language_nltk="eng",
language_stanza="en",
):
super().__init__(compare_against_original)
self.tagger_type = tagger_type
self.tagset = tagset
self.allow_verb_noun_swap = allow_verb_noun_swap
self.language_nltk = language_nltk
self.language_stanza = language_stanza

self._pos_tag_cache = lru.LRU(2 ** 14)
if tagger_type == "flair":
Expand All @@ -56,7 +64,9 @@ def __init__(

if tagger_type == "stanza":
self._stanza_pos_tagger = stanza.Pipeline(
lang="en", processors="tokenize, pos", tokenize_pretokenized=True
lang=self.language_stanza,
processors="tokenize, pos",
tokenize_pretokenized=True,
)

def clear_cache(self):
Expand All @@ -75,7 +85,9 @@ def _get_pos(self, before_ctx, word, after_ctx):
else:
if self.tagger_type == "nltk":
word_list, pos_list = zip(
*nltk.pos_tag(context_words, tagset=self.tagset)
*nltk.pos_tag(
context_words, tagset=self.tagset, lang=self.language_nltk
)
)

if self.tagger_type == "flair":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

class MaxWordIndexModification(PreTransformationConstraint):
"""A constraint disallowing the modification of words which are past some
maximum length limit."""
maximum sentence word-length limit."""

def __init__(self, max_length):
self.max_length = max_length
Expand Down
4 changes: 2 additions & 2 deletions textattack/constraints/pre_transformation/min_word_length.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@

class MinWordLength(PreTransformationConstraint):
"""A constraint that prevents modifications to words less than a certain
length.
word character-length.
:param min_length: Minimum length needed for changes to be made to a word.
:param min_length: Minimum word character-length needed for changes to be made to a word.
"""

def __init__(self, min_length):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
class StopwordModification(PreTransformationConstraint):
"""A constraint disallowing the modification of stopwords."""

def __init__(self, stopwords=None):
def __init__(self, stopwords=None, language="english"):
if stopwords is not None:
self.stopwords = set(stopwords)
else:
self.stopwords = set(nltk.corpus.stopwords.words("english"))
self.stopwords = set(nltk.corpus.stopwords.words(language))

def _get_modifiable_indices(self, current_text):
"""Returns the word indices in ``current_text`` which are able to be
Expand Down
10 changes: 6 additions & 4 deletions textattack/constraints/semantics/bert_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ class BERTScore(Constraint):
Args:
min_bert_score (float), minimum threshold value for BERT-Score
model (str), name of model to use for scoring
model_name (str), name of model to use for scoring
num_layers (int), number of hidden layers in the model
score_type (str), Pick one of following three choices
-(1) ``precision`` : match words from candidate text to reference text
Expand All @@ -39,7 +40,8 @@ class BERTScore(Constraint):
def __init__(
self,
min_bert_score,
model="bert-base-uncased",
model_name="bert-base-uncased",
num_layers=None,
score_type="f1",
compare_against_original=True,
):
Expand All @@ -50,11 +52,11 @@ def __init__(
raise ValueError("max_bert_score must be a value between 0.0 and 1.0")

self.min_bert_score = min_bert_score
self.model = model
self.model = model_name
self.score_type = score_type
# Turn off idf-weighting scheme b/c reference sentence set is small
self._bert_scorer = bert_score.BERTScorer(
model_type=model, idf=False, device=utils.device
model_type=model_name, idf=False, device=utils.device, num_layers=num_layers
)

def _check_constraint(self, transformed_text, reference_text):
Expand Down
11 changes: 8 additions & 3 deletions textattack/constraints/semantics/sentence_encoders/bert/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,18 @@
class BERT(SentenceEncoder):
"""Constraint using similarity between sentence encodings of x and x_adv
where the text embeddings are created using BERT, trained on NLI data, and
fine- tuned on the STS benchmark dataset."""
fine- tuned on the STS benchmark dataset.
Available models can be found here: https://huggingface.co/sentence-transformers"""

def __init__(
self, pretrained_name="stsb-bert-base", threshold=0.7, metric="cosine", **kwargs
self,
threshold=0.7,
metric="cosine",
model_name="bert-base-nli-stsb-mean-tokens",
**kwargs
):
super().__init__(threshold=threshold, metric=metric, **kwargs)
self.model = sentence_transformers.SentenceTransformer(pretrained_name)
self.model = sentence_transformers.SentenceTransformer(model_name)
self.model.to(utils.device)

def encode(self, sentences):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def get_infersent_model(self):
The pretrained InferSent model.
"""
infersent_version = 2
model_folder_path = utils.download_if_needed(InferSent.MODEL_PATH)
model_folder_path = utils.download_from_s3(InferSent.MODEL_PATH)
model_path = os.path.join(
model_folder_path, f"infersent{infersent_version}.pkl"
)
Expand All @@ -46,7 +46,7 @@ def get_infersent_model(self):
}
infersent = InferSentModel(params_model)
infersent.load_state_dict(torch.load(model_path))
word_embedding_path = utils.download_if_needed(InferSent.WORD_EMBEDDING_PATH)
word_embedding_path = utils.download_from_s3(InferSent.WORD_EMBEDDING_PATH)
w2v_path = os.path.join(word_embedding_path, "fastText", "crawl-300d-2M.vec")
infersent.set_w2v_path(w2v_path)
infersent.build_vocab_k_words(K=100000)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
------------------------
"""

from abc import ABC
import math

import numpy as np
Expand All @@ -11,7 +12,7 @@
from textattack.constraints import Constraint


class SentenceEncoder(Constraint):
class SentenceEncoder(Constraint, ABC):
"""Constraint using cosine similarity between sentence encodings of x and
x_adv.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@
multilingual universal sentence encoder
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
"""
import tensorflow_text # noqa: F401

from textattack.constraints.semantics.sentence_encoders import SentenceEncoder
from textattack.shared.utils import LazyLoader

hub = LazyLoader("tensorflow_hub", globals(), "tensorflow_hub")
tensorflow_text = LazyLoader(
"tensorflow_text", globals(), "tensorflow_text"
) # noqa: F401


class MultilingualUniversalSentenceEncoder(SentenceEncoder):
Expand Down
2 changes: 1 addition & 1 deletion textattack/models/helpers/glove_embedding_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class GloveEmbeddingLayer(EmbeddingLayer):
EMBEDDING_PATH = "word_embeddings/glove200"

def __init__(self, emb_layer_trainable=True):
glove_path = utils.download_if_needed(GloveEmbeddingLayer.EMBEDDING_PATH)
glove_path = utils.download_from_s3(GloveEmbeddingLayer.EMBEDDING_PATH)
glove_word_list_path = os.path.join(glove_path, "glove.wordlist.npy")
word_list = np.load(glove_word_list_path)
glove_matrix_path = os.path.join(glove_path, "glove.6B.200d.mat.npy")
Expand Down
3 changes: 2 additions & 1 deletion textattack/models/helpers/lstm_for_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ def from_pretrained(cls, name_or_path):
name_or_path (str): Name of the model (e.g. "lstm-imdb") or model saved via `save_pretrained`.
"""
if name_or_path in TEXTATTACK_MODELS:
path = utils.download_if_needed(TEXTATTACK_MODELS[name_or_path])
# path = utils.download_if_needed(TEXTATTACK_MODELS[name_or_path])
path = utils.download_from_s3(TEXTATTACK_MODELS[name_or_path])
else:
path = name_or_path

Expand Down
2 changes: 1 addition & 1 deletion textattack/models/helpers/word_cnn_for_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def from_pretrained(cls, name_or_path):
name_or_path (str): Name of the model (e.g. "cnn-imdb") or model saved via `save_pretrained`.
"""
if name_or_path != "cnn" and name_or_path in TEXTATTACK_MODELS:
path = utils.download_if_needed(TEXTATTACK_MODELS[name_or_path])
path = utils.download_from_s3(TEXTATTACK_MODELS[name_or_path])
else:
path = name_or_path

Expand Down
62 changes: 52 additions & 10 deletions textattack/shared/utils/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@ def s3_url(uri):
return "https://textattack.s3.amazonaws.com/" + uri


def download_if_needed(folder_name):
"""Folder name will be saved as `.cache/textattack/[folder name]`. If it
def download_from_s3(folder_name, skip_if_cached=True):
"""Folder name will be saved as `<cache_dir>/textattack/<folder_name>`. If it
doesn't exist on disk, the zip file will be downloaded and extracted.
Args:
folder_name (str): path to folder or file in cache
skip_if_cached (bool): If `True`, skip downloading if content is already cached.
Returns:
str: path to the downloaded folder or file on disk
Expand All @@ -43,14 +44,15 @@ def download_if_needed(folder_name):
cache_file_lock = filelock.FileLock(cache_dest_lock_path)
cache_file_lock.acquire()
# Check if already downloaded.
if os.path.exists(cache_dest_path):
if skip_if_cached and os.path.exists(cache_dest_path):
cache_file_lock.release()
return cache_dest_path
# If the file isn't found yet, download the zip file to the cache.
downloaded_file = tempfile.NamedTemporaryFile(
dir=TEXTATTACK_CACHE_DIR, suffix=".zip", delete=False
)
http_get(folder_name, downloaded_file)
folder_s3_url = s3_url(folder_name)
http_get(folder_s3_url, downloaded_file)
# Move or unzip the file.
downloaded_file.close()
if zipfile.is_zipfile(downloaded_file.name):
Expand All @@ -65,6 +67,47 @@ def download_if_needed(folder_name):
return cache_dest_path


def download_from_url(url, save_path, skip_if_cached=True):
"""Downloaded file will be saved under `<cache_dir>/textattack/<save_path>`. If it
doesn't exist on disk, the zip file will be downloaded and extracted.
Args:
url (str): URL path from which to download.
save_path (str): path to which to save the downloaded content.
skip_if_cached (bool): If `True`, skip downloading if content is already cached.
Returns:
str: path to the downloaded folder or file on disk
"""
cache_dest_path = path_in_cache(save_path)
os.makedirs(os.path.dirname(cache_dest_path), exist_ok=True)
# Use a lock to prevent concurrent downloads.
cache_dest_lock_path = cache_dest_path + ".lock"
cache_file_lock = filelock.FileLock(cache_dest_lock_path)
cache_file_lock.acquire()
# Check if already downloaded.
if skip_if_cached and os.path.exists(cache_dest_path):
cache_file_lock.release()
return cache_dest_path
# If the file isn't found yet, download the zip file to the cache.
downloaded_file = tempfile.NamedTemporaryFile(
dir=TEXTATTACK_CACHE_DIR, suffix=".zip", delete=False
)
http_get(url, downloaded_file)
# Move or unzip the file.
downloaded_file.close()
if zipfile.is_zipfile(downloaded_file.name):
unzip_file(downloaded_file.name, cache_dest_path)
else:
logger.info(f"Copying {downloaded_file.name} to {cache_dest_path}.")
shutil.copyfile(downloaded_file.name, cache_dest_path)
cache_file_lock.release()
# Remove the temporary file.
os.remove(downloaded_file.name)
logger.info(f"Successfully saved {url} to cache.")
return cache_dest_path


def unzip_file(path_to_zip_file, unzipped_folder_path):
"""Unzips a .zip file to folder path."""
logger.info(f"Unzipping file {path_to_zip_file} to {unzipped_folder_path}.")
Expand All @@ -73,18 +116,17 @@ def unzip_file(path_to_zip_file, unzipped_folder_path):
zip_ref.extractall(enclosing_unzipped_path)


def http_get(folder_name, out_file, proxies=None):
def http_get(url, out_file, proxies=None):
"""Get contents of a URL and save to a file.
https://github.com/huggingface/transformers/blob/master/src/transformers/file_utils.py
"""
folder_s3_url = s3_url(folder_name)
logger.info(f"Downloading {folder_s3_url}.")
req = requests.get(folder_s3_url, stream=True, proxies=proxies)
logger.info(f"Downloading {url}.")
req = requests.get(url, stream=True, proxies=proxies)
content_length = req.headers.get("Content-Length")
total = int(content_length) if content_length is not None else None
if req.status_code == 403: # Not found on AWS
raise Exception(f"Could not find {folder_name} on server.")
if req.status_code == 403 or req.status_code == 404:
raise Exception(f"Could not reach {url}.")
progress = tqdm.tqdm(unit="B", unit_scale=True, total=total)
for chunk in req.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
Expand Down
2 changes: 1 addition & 1 deletion textattack/shared/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def html_table_from_rows(rows, title=None, header=None, style_dict=None):
def get_textattack_model_num_labels(model_name, model_path):
"""Reads `train_args.json` and gets the number of labels for a trained
model, if present."""
model_cache_path = textattack.shared.utils.download_if_needed(model_path)
model_cache_path = textattack.shared.utils.download_from_s3(model_path)
train_args_path = os.path.join(model_cache_path, "train_args.json")
if not os.path.exists(train_args_path):
textattack.shared.logger.warn(
Expand Down

0 comments on commit ac8872a

Please sign in to comment.