Skip to content

Commit

Permalink
Merge pull request #747 from QData/fit_loggerdf
Browse files Browse the repository at this point in the history
fixing the csvlogger missing DF issues
  • Loading branch information
qiyanjun committed Sep 11, 2023
2 parents f848247 + ce2eae3 commit 094025e
Show file tree
Hide file tree
Showing 18 changed files with 6,905 additions and 6,036 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,4 @@ checkpoints/
.vscode
*.csv
!tests/sample_outputs/csv_attack_log.csv
tests/test_command_line/attack_log.txt
76 changes: 44 additions & 32 deletions docs/2notebook/1_Introduction_and_Transformations.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -77,19 +77,19 @@
"source": [
"from textattack.transformations import WordSwap\n",
"\n",
"\n",
"class BananaWordSwap(WordSwap):\n",
" \"\"\" Transforms an input by replacing any word with 'banana'.\n",
" \"\"\"\n",
" \n",
" \"\"\"Transforms an input by replacing any word with 'banana'.\"\"\"\n",
"\n",
" # We don't need a constructor, since our class doesn't require any parameters.\n",
"\n",
" def _get_replacement_words(self, word):\n",
" \"\"\" Returns 'banana', no matter what 'word' was originally.\n",
" \n",
" Returns a list with one item, since `_get_replacement_words` is intended to\n",
" return a list of candidate replacement words.\n",
" \"\"\"Returns 'banana', no matter what 'word' was originally.\n",
"\n",
" Returns a list with one item, since `_get_replacement_words` is intended to\n",
" return a list of candidate replacement words.\n",
" \"\"\"\n",
" return ['banana']"
" return [\"banana\"]"
]
},
{
Expand Down Expand Up @@ -133,17 +133,23 @@
"import transformers\n",
"from textattack.models.wrappers import HuggingFaceModelWrapper\n",
"\n",
"model = transformers.AutoModelForSequenceClassification.from_pretrained(\"textattack/bert-base-uncased-ag-news\")\n",
"tokenizer = transformers.AutoTokenizer.from_pretrained(\"textattack/bert-base-uncased-ag-news\")\n",
"model = transformers.AutoModelForSequenceClassification.from_pretrained(\n",
" \"textattack/bert-base-uncased-ag-news\"\n",
")\n",
"tokenizer = transformers.AutoTokenizer.from_pretrained(\n",
" \"textattack/bert-base-uncased-ag-news\"\n",
")\n",
"\n",
"model_wrapper = HuggingFaceModelWrapper(model, tokenizer)\n",
"\n",
"# Create the goal function using the model\n",
"from textattack.goal_functions import UntargetedClassification\n",
"\n",
"goal_function = UntargetedClassification(model_wrapper)\n",
"\n",
"# Import the dataset\n",
"from textattack.datasets import HuggingFaceDataset\n",
"\n",
"dataset = HuggingFaceDataset(\"ag_news\", None, \"test\")"
]
},
Expand All @@ -166,14 +172,16 @@
"outputs": [],
"source": [
"from textattack.search_methods import GreedySearch\n",
"from textattack.constraints.pre_transformation import RepeatModification, StopwordModification\n",
"from textattack.constraints.pre_transformation import (\n",
" RepeatModification,\n",
" StopwordModification,\n",
")\n",
"from textattack import Attack\n",
"\n",
"# We're going to use our Banana word swap class as the attack transformation.\n",
"transformation = BananaWordSwap() \n",
"transformation = BananaWordSwap()\n",
"# We'll constrain modification of already modified indices and stopwords\n",
"constraints = [RepeatModification(),\n",
" StopwordModification()]\n",
"constraints = [RepeatModification(), StopwordModification()]\n",
"# We'll use the Greedy search method\n",
"search_method = GreedySearch()\n",
"# Now, let's make the attack from the 4 components:\n",
Expand Down Expand Up @@ -517,8 +525,8 @@
}
],
"source": [
"from tqdm import tqdm # tqdm provides us a nice progress bar.\n",
"from textattack.loggers import CSVLogger # tracks a dataframe for us.\n",
"from tqdm import tqdm # tqdm provides us a nice progress bar.\n",
"from textattack.loggers import CSVLogger # tracks a dataframe for us.\n",
"from textattack.attack_results import SuccessfulAttackResult\n",
"from textattack import Attacker\n",
"from textattack import AttackArgs\n",
Expand All @@ -530,14 +538,14 @@
"\n",
"attack_results = attacker.attack_dataset()\n",
"\n",
"#The following legacy tutorial code shows how the Attack API works in detail.\n",
"# The following legacy tutorial code shows how the Attack API works in detail.\n",
"\n",
"#logger = CSVLogger(color_method='html')\n",
"# logger = CSVLogger(color_method='html')\n",
"\n",
"#num_successes = 0\n",
"#i = 0\n",
"#while num_successes < 10:\n",
" #result = next(results_iterable)\n",
"# num_successes = 0\n",
"# i = 0\n",
"# while num_successes < 10:\n",
"# result = next(results_iterable)\n",
"# example, ground_truth_output = dataset[i]\n",
"# i += 1\n",
"# result = attack.attack(example, ground_truth_output)\n",
Expand Down Expand Up @@ -652,15 +660,19 @@
],
"source": [
"import pandas as pd\n",
"pd.options.display.max_colwidth = 480 # increase colum width so we can actually read the examples\n",
"\n",
"logger = CSVLogger(color_method='html')\n",
"pd.options.display.max_colwidth = (\n",
" 480 # increase colum width so we can actually read the examples\n",
")\n",
"\n",
"logger = CSVLogger(color_method=\"html\")\n",
"\n",
"for result in attack_results:\n",
" logger.log_attack_result(result)\n",
"\n",
"from IPython.core.display import display, HTML\n",
"display(HTML(logger.df[['original_text', 'perturbed_text']].to_html(escape=False)))"
"\n",
"display(HTML(logger.df[[\"original_text\", \"perturbed_text\"]].to_html(escape=False)))"
]
},
{
Expand Down Expand Up @@ -867,10 +879,10 @@
"# For AG News, labels are 0: World, 1: Sports, 2: Business, 3: Sci/Tech\n",
"\n",
"custom_dataset = [\n",
" ('Malaria deaths in Africa fall by 5% from last year', 0),\n",
" ('Washington Nationals defeat the Houston Astros to win the World Series', 1),\n",
" ('Exxon Mobil hires a new CEO', 2),\n",
" ('Microsoft invests $1 billion in OpenAI', 3),\n",
" (\"Malaria deaths in Africa fall by 5% from last year\", 0),\n",
" (\"Washington Nationals defeat the Houston Astros to win the World Series\", 1),\n",
" (\"Exxon Mobil hires a new CEO\", 2),\n",
" (\"Microsoft invests $1 billion in OpenAI\", 3),\n",
"]\n",
"\n",
"attack_args = AttackArgs(num_examples=4)\n",
Expand All @@ -881,14 +893,14 @@
"\n",
"results_iterable = attacker.attack_dataset()\n",
"\n",
"logger = CSVLogger(color_method='html')\n",
"logger = CSVLogger(color_method=\"html\")\n",
"\n",
"for result in results_iterable:\n",
" logger.log_attack_result(result)\n",
"\n",
"from IPython.core.display import display, HTML\n",
" \n",
"display(HTML(logger.df[['original_text', 'perturbed_text']].to_html(escape=False)))"
"\n",
"display(HTML(logger.df[[\"original_text\", \"perturbed_text\"]].to_html(escape=False)))"
]
}
],
Expand Down
86 changes: 54 additions & 32 deletions docs/2notebook/2_Constraints.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
],
"source": [
"import tensorflow as tf\n",
"\n",
"print(tf.__version__)"
]
},
Expand Down Expand Up @@ -149,10 +150,11 @@
"!pip3 install .\n",
"\n",
"import nltk\n",
"nltk.download('punkt') # The NLTK tokenizer\n",
"nltk.download('maxent_ne_chunker') # NLTK named-entity chunker\n",
"nltk.download('words') # NLTK list of words\n",
"nltk.download('averaged_perceptron_tagger')"
"\n",
"nltk.download(\"punkt\") # The NLTK tokenizer\n",
"nltk.download(\"maxent_ne_chunker\") # NLTK named-entity chunker\n",
"nltk.download(\"words\") # NLTK list of words\n",
"nltk.download(\"averaged_perceptron_tagger\")"
]
},
{
Expand Down Expand Up @@ -205,8 +207,10 @@
}
],
"source": [
"sentence = ('In 2017, star quarterback Tom Brady led the Patriots to the Super Bowl, '\n",
" 'but lost to the Philadelphia Eagles.')\n",
"sentence = (\n",
" \"In 2017, star quarterback Tom Brady led the Patriots to the Super Bowl, \"\n",
" \"but lost to the Philadelphia Eagles.\"\n",
")\n",
"\n",
"# 1. Tokenize using the NLTK tokenizer.\n",
"tokens = nltk.word_tokenize(sentence)\n",
Expand Down Expand Up @@ -285,6 +289,7 @@
"source": [
"import functools\n",
"\n",
"\n",
"@functools.lru_cache(maxsize=2**14)\n",
"def get_entities(sentence):\n",
" tokens = nltk.word_tokenize(sentence)\n",
Expand Down Expand Up @@ -379,9 +384,10 @@
"source": [
"from textattack.constraints import Constraint\n",
"\n",
"\n",
"class NamedEntityConstraint(Constraint):\n",
" \"\"\" A constraint that ensures `transformed_text` only substitutes named entities from `current_text` with other named entities.\n",
" \"\"\"\n",
" \"\"\"A constraint that ensures `transformed_text` only substitutes named entities from `current_text` with other named entities.\"\"\"\n",
"\n",
" def _check_constraint(self, transformed_text, current_text):\n",
" transformed_entities = get_entities(transformed_text.text)\n",
" current_entities = get_entities(current_text.text)\n",
Expand All @@ -390,26 +396,27 @@
" if len(current_entities) == 0:\n",
" return False\n",
" if len(current_entities) != len(transformed_entities):\n",
" # If the two sentences have a different number of entities, then \n",
" # they definitely don't have the same labels. In this case, the \n",
" # If the two sentences have a different number of entities, then\n",
" # they definitely don't have the same labels. In this case, the\n",
" # constraint is violated, and we return False.\n",
" return False\n",
" else:\n",
" # Here we compare all of the words, in order, to make sure that they match.\n",
" # If we find two words that don't match, this means a word was swapped \n",
" # If we find two words that don't match, this means a word was swapped\n",
" # between `current_text` and `transformed_text`. That word must be a named entity to fulfill our\n",
" # constraint.\n",
" current_word_label = None\n",
" transformed_word_label = None\n",
" for (word_1, label_1), (word_2, label_2) in zip(current_entities, transformed_entities):\n",
" for (word_1, label_1), (word_2, label_2) in zip(\n",
" current_entities, transformed_entities\n",
" ):\n",
" if word_1 != word_2:\n",
" # Finally, make sure that words swapped between `x` and `x_adv` are named entities. If \n",
" # Finally, make sure that words swapped between `x` and `x_adv` are named entities. If\n",
" # they're not, then we also return False.\n",
" if (label_1 not in ['NNP', 'NE']) or (label_2 not in ['NNP', 'NE']):\n",
" return False \n",
" if (label_1 not in [\"NNP\", \"NE\"]) or (label_2 not in [\"NNP\", \"NE\"]):\n",
" return False\n",
" # If we get here, all of the labels match up. Return True!\n",
" return True\n",
" "
" return True"
]
},
{
Expand Down Expand Up @@ -638,17 +645,23 @@
"import transformers\n",
"from textattack.models.wrappers import HuggingFaceModelWrapper\n",
"\n",
"model = transformers.AutoModelForSequenceClassification.from_pretrained(\"textattack/albert-base-v2-ag-news\")\n",
"tokenizer = transformers.AutoTokenizer.from_pretrained(\"textattack/albert-base-v2-ag-news\")\n",
"model = transformers.AutoModelForSequenceClassification.from_pretrained(\n",
" \"textattack/albert-base-v2-ag-news\"\n",
")\n",
"tokenizer = transformers.AutoTokenizer.from_pretrained(\n",
" \"textattack/albert-base-v2-ag-news\"\n",
")\n",
"\n",
"model_wrapper = HuggingFaceModelWrapper(model, tokenizer)\n",
"\n",
"# Create the goal function using the model\n",
"from textattack.goal_functions import UntargetedClassification\n",
"\n",
"goal_function = UntargetedClassification(model_wrapper)\n",
"\n",
"# Import the dataset\n",
"from textattack.datasets import HuggingFaceDataset\n",
"\n",
"dataset = HuggingFaceDataset(\"ag_news\", None, \"test\")"
]
},
Expand All @@ -663,23 +676,27 @@
"from textattack.transformations import WordSwapEmbedding\n",
"from textattack.search_methods import GreedyWordSwapWIR\n",
"from textattack import Attack\n",
"from textattack.constraints.pre_transformation import RepeatModification, StopwordModification\n",
"from textattack.constraints.pre_transformation import (\n",
" RepeatModification,\n",
" StopwordModification,\n",
")\n",
"\n",
"# We're going to the `WordSwapEmbedding` transformation. Using the default settings, this\n",
"# will try substituting words with their neighbors in the counter-fitted embedding space. \n",
"transformation = WordSwapEmbedding(max_candidates=20) \n",
"# will try substituting words with their neighbors in the counter-fitted embedding space.\n",
"transformation = WordSwapEmbedding(max_candidates=20)\n",
"\n",
"# We'll use the greedy search with word importance ranking method again\n",
"search_method = GreedyWordSwapWIR()\n",
"\n",
"# Our constraints will be the same as Tutorial 1, plus the named entity constraint\n",
"constraints = [RepeatModification(),\n",
" StopwordModification(),\n",
" NamedEntityConstraint(False)]\n",
"constraints = [\n",
" RepeatModification(),\n",
" StopwordModification(),\n",
" NamedEntityConstraint(False),\n",
"]\n",
"\n",
"# Now, let's make the attack using these parameters. \n",
"attack = Attack(goal_function, constraints, transformation, search_method)\n",
"\n"
"# Now, let's make the attack using these parameters.\n",
"attack = Attack(goal_function, constraints, transformation, search_method)"
]
},
{
Expand Down Expand Up @@ -800,11 +817,13 @@
}
],
"source": [
"from textattack.loggers import CSVLogger # tracks a dataframe for us.\n",
"from textattack.loggers import CSVLogger # tracks a dataframe for us.\n",
"from textattack.attack_results import SuccessfulAttackResult\n",
"from textattack import Attacker, AttackArgs\n",
"\n",
"attack_args = AttackArgs(num_successful_examples=5, log_to_csv=\"results.csv\", csv_coloring_style=\"html\")\n",
"attack_args = AttackArgs(\n",
" num_successful_examples=5, log_to_csv=\"results.csv\", csv_coloring_style=\"html\"\n",
")\n",
"attacker = Attacker(attack, dataset, attack_args)\n",
"\n",
"attacker.attack_dataset()"
Expand Down Expand Up @@ -833,13 +852,16 @@
"outputs": [],
"source": [
"import pandas as pd\n",
"pd.options.display.max_colwidth = 480 # increase column width so we can actually read the examples\n",
"\n",
"pd.options.display.max_colwidth = (\n",
" 480 # increase column width so we can actually read the examples\n",
")\n",
"\n",
"from IPython.core.display import display, HTML\n",
"\n",
"logger = attacker.attack_log_manager.loggers[0]\n",
"successes = logger.df[logger.df[\"result_type\"] == \"Successful\"]\n",
"display(HTML(successes[['original_text', 'perturbed_text']].to_html(escape=False)))"
"display(HTML(successes[[\"original_text\", \"perturbed_text\"]].to_html(escape=False)))"
]
},
{
Expand Down

0 comments on commit 094025e

Please sign in to comment.