Skip to content

Commit

Permalink
Merge pull request #653 from VijayKalmath/Fix-TrainingIssues-488
Browse files Browse the repository at this point in the history
Fix training issues 488
  • Loading branch information
jxmorris12 committed Jun 8, 2022
2 parents ea3ae24 + 6b9dc20 commit f59e49a
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 2 deletions.
10 changes: 8 additions & 2 deletions textattack/models/wrappers/huggingface_model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import transformers

import textattack
from textattack.models.helpers import T5ForTextToText
from textattack.models.tokenizers import T5Tokenizer

from .pytorch_model_wrapper import PyTorchModelWrapper

Expand All @@ -18,11 +20,15 @@ class HuggingFaceModelWrapper(PyTorchModelWrapper):

def __init__(self, model, tokenizer):
assert isinstance(
model, transformers.PreTrainedModel
model, (transformers.PreTrainedModel, T5ForTextToText)
), f"`model` must be of type `transformers.PreTrainedModel`, but got type {type(model)}."
assert isinstance(
tokenizer,
(transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast),
(
transformers.PreTrainedTokenizer,
transformers.PreTrainedTokenizerFast,
T5Tokenizer,
),
), f"`tokenizer` must of type `transformers.PreTrainedTokenizer` or `transformers.PreTrainedTokenizerFast`, but got type {type(tokenizer)}."

self.model = model
Expand Down
36 changes: 36 additions & 0 deletions textattack/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,42 @@ def _create_dataset_from_args(cls, args):
train_dataset.filter_by_labels_(args.filter_train_by_labels)
if args.filter_eval_by_labels:
eval_dataset.filter_by_labels_(args.filter_eval_by_labels)
# Testing for Coverage of model return values with dataset.
num_labels = args.model_num_labels if args.model_num_labels else 2

# Only Perform labels checks if output_column is equal to label.
if (
train_dataset.output_column == "label"
and eval_dataset.output_column == "label"
):

train_dataset_labels = train_dataset._dataset["label"]

eval_dataset_labels = eval_dataset._dataset["label"]

train_dataset_labels_set = set(train_dataset_labels)

assert all(
label >= 0
for label in train_dataset_labels_set
if isinstance(label, int)
), f"Train dataset has negative label/s {[label for label in train_dataset_labels_set if isinstance(label,int) and label < 0 ]} which is/are not supported by pytorch.Use --filter-train-by-labels to keep suitable labels"

assert num_labels >= len(
train_dataset_labels_set
), f"Model constructed has {num_labels} output nodes and train dataset has {len(train_dataset_labels_set)} labels , Model should have output nodes greater than or equal to labels in train dataset.Use --model-num-labels to set model's output nodes."

eval_dataset_labels_set = set(eval_dataset_labels)

assert all(
label >= 0
for label in eval_dataset_labels_set
if isinstance(label, int)
), f"Eval dataset has negative label/s {[label for label in eval_dataset_labels_set if isinstance(label,int) and label < 0 ]} which is/are not supported by pytorch.Use --filter-eval-by-labels to keep suitable labels"

assert num_labels >= len(
set(eval_dataset_labels_set)
), f"Model constructed has {num_labels} output nodes and eval dataset has {len(eval_dataset_labels_set)} labels , Model should have output nodes greater than or equal to labels in eval dataset.Use --model-num-labels to set model's output nodes."

return train_dataset, eval_dataset

Expand Down

0 comments on commit f59e49a

Please sign in to comment.