Skip to content

Commit

Permalink
Merge pull request #505 from QData/s3-model-fix
Browse files Browse the repository at this point in the history
[FixBug] Fix bug with loading pretrained lstm and cnn models
  • Loading branch information
qiyanjun committed Aug 1, 2021
2 parents 9d7b3b9 + d92203c commit c30728b
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 14 deletions.
1 change: 0 additions & 1 deletion textattack/attacker.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,6 @@ def _attack_parallel(self):
def attack_dataset(self):
"""Attack the dataset.
Returns:
:obj:`list[AttackResult]` - List of :class:`~textattack.attack_results.AttackResult` obtained after attacking the given dataset..
"""
Expand Down
23 changes: 13 additions & 10 deletions textattack/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,24 +96,27 @@

#
# Models hosted by textattack.
# `models` vs `models_v2`: `models_v2` is simply a new dir in S3 that contains models' `config.json`.
# Fixes issue https://github.com/QData/TextAttack/issues/485
# Model parameters has not changed.
#
TEXTATTACK_MODELS = {
#
# LSTMs
#
"lstm-ag-news": "models/classification/lstm/ag-news",
"lstm-imdb": "models/classification/lstm/imdb",
"lstm-mr": "models/classification/lstm/mr",
"lstm-sst2": "models/classification/lstm/sst2",
"lstm-yelp": "models/classification/lstm/yelp",
"lstm-ag-news": "models_v2/classification/lstm/ag-news",
"lstm-imdb": "models_v2/classification/lstm/imdb",
"lstm-mr": "models_v2/classification/lstm/mr",
"lstm-sst2": "models_v2/classification/lstm/sst2",
"lstm-yelp": "models_v2/classification/lstm/yelp",
#
# CNNs
#
"cnn-ag-news": "models/classification/cnn/ag-news",
"cnn-imdb": "models/classification/cnn/imdb",
"cnn-mr": "models/classification/cnn/rotten-tomatoes",
"cnn-sst2": "models/classification/cnn/sst",
"cnn-yelp": "models/classification/cnn/yelp",
"cnn-ag-news": "models_v2/classification/cnn/ag-news",
"cnn-imdb": "models_v2/classification/cnn/imdb",
"cnn-mr": "models_v2/classification/cnn/rotten-tomatoes",
"cnn-sst2": "models_v2/classification/cnn/sst",
"cnn-yelp": "models_v2/classification/cnn/yelp",
#
# T5 for translation
#
Expand Down
5 changes: 4 additions & 1 deletion textattack/models/helpers/lstm_for_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ def from_pretrained(cls, name_or_path):
"""Load trained LSTM model by name or from path.
Args:
name_or_path (str): Name of the model (e.g. "lstm-imdb") or model saved via `save_pretrained`.
name_or_path (:obj:`str`): Name of the model (e.g. "lstm-imdb") or model saved via :meth:`save_pretrained`.
Returns:
:class:`~textattack.models.helpers.LSTMForClassification` model
"""
if name_or_path in TEXTATTACK_MODELS:
# path = utils.download_if_needed(TEXTATTACK_MODELS[name_or_path])
Expand All @@ -110,6 +112,7 @@ def from_pretrained(cls, name_or_path):
path = name_or_path

config_path = os.path.join(path, "config.json")

if os.path.exists(config_path):
with open(config_path, "r") as f:
config = json.load(f)
Expand Down
7 changes: 5 additions & 2 deletions textattack/models/helpers/word_cnn_for_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,17 @@ def from_pretrained(cls, name_or_path):
"""Load trained LSTM model by name or from path.
Args:
name_or_path (str): Name of the model (e.g. "cnn-imdb") or model saved via `save_pretrained`.
name_or_path (:obj:`str`): Name of the model (e.g. "cnn-imdb") or model saved via :meth:`save_pretrained`.
Returns:
:class:`~textattack.models.helpers.WordCNNForClassification` model
"""
if name_or_path != "cnn" and name_or_path in TEXTATTACK_MODELS:
if name_or_path in TEXTATTACK_MODELS:
path = utils.download_from_s3(TEXTATTACK_MODELS[name_or_path])
else:
path = name_or_path

config_path = os.path.join(path, "config.json")

if os.path.exists(config_path):
with open(config_path, "r") as f:
config = json.load(f)
Expand Down

0 comments on commit c30728b

Please sign in to comment.