Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

minor bug fixes #189

Merged
merged 2 commits into from Jul 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions Makefile
@@ -1,11 +1,11 @@
format: FORCE ## Run black and isort (rewriting files)
black .
isort --atomic --recursive tests textattack
isort --atomic tests textattack


lint: FORCE ## Run black, isort, flake8 (in check mode)
black . --check
isort --check-only --recursive tests textattack
isort --check-only tests textattack
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude=./.*,build,dist # catch certain syntax errors using flake8

test: FORCE ## Run tests using pytest
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
@@ -1,5 +1,6 @@
bert-score
editdistance
flair==0.5.1
filelock
language_tool_python
lru-dict
Expand All @@ -20,4 +21,3 @@ tokenizers==0.8.0-rc4
tqdm
visdom
wandb
flair
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -10,7 +10,7 @@
# Packages required for installing docs.
extras["docs"] = ["recommonmark", "nbsphinx", "sphinx-autobuild", "sphinx-rtd-theme"]
# Packages required for formatting code & running tests.
extras["test"] = ["black", "isort", "flake8", "pytest", "pytest-xdist"]
extras["test"] = ["black", "isort==5.0.3", "flake8", "pytest", "pytest-xdist"]
# For developers, install development tools along with all optional dependencies.
extras["dev"] = extras["docs"] + extras["test"]

Expand Down
3 changes: 1 addition & 2 deletions tests/test_command_line/test_attack.py
@@ -1,9 +1,8 @@
import pdb
import re

import pytest

from helpers import run_command_and_get_result
import pytest

DEBUG = False

Expand Down
3 changes: 1 addition & 2 deletions tests/test_command_line/test_augment.py
@@ -1,6 +1,5 @@
import pytest

from helpers import run_command_and_get_result
import pytest

augment_test_params = [
(
Expand Down
3 changes: 1 addition & 2 deletions tests/test_command_line/test_list.py
@@ -1,6 +1,5 @@
import pytest

from helpers import run_command_and_get_result
import pytest

list_test_params = [
(
Expand Down
3 changes: 2 additions & 1 deletion tests/test_misc.py
@@ -1,7 +1,8 @@
def test_imports():
import textattack
import torch

import textattack

del textattack, torch


Expand Down
3 changes: 1 addition & 2 deletions textattack/augmentation/recipes.py
Expand Up @@ -122,13 +122,12 @@ class CharSwapAugmenter(Augmenter):
""" Augments words by swapping characters out for other characters. """

def __init__(self, **kwargs):
from textattack.transformations import CompositeTransformation
from textattack.transformations import (
CompositeTransformation,
WordSwapNeighboringCharacterSwap,
WordSwapRandomCharacterDeletion,
WordSwapRandomCharacterInsertion,
WordSwapRandomCharacterSubstitution,
WordSwapNeighboringCharacterSwap,
)

transformation = CompositeTransformation(
Expand Down
Expand Up @@ -51,6 +51,9 @@ def _check_constraint_many(

def get_probs(current_text, transformed_texts):
word_swap_index = current_text.first_word_diff_index(transformed_texts[0])
if word_swap_index is None:
return []

prefix = current_text.words[word_swap_index - 1]
swapped_words = np.array(
[t.words[word_swap_index] for t in transformed_texts]
Expand Down Expand Up @@ -104,6 +107,11 @@ def get_probs(current_text, transformed_texts):

return [transformed_texts[i] for i in max_el_indices]

def _check_constraint(self, transformed_text, current_text, original_text=None):
return self._check_constraint_many(
[transformed_text], current_text, original_text=original_text
)

def __call__(self, x, x_adv):
raise NotImplementedError()

Expand Down
@@ -1,5 +1,5 @@
from torch import nn as nn
from torch.autograd import Variable
import torch.nn as nn

from .adaptive_softmax import AdaptiveSoftmax

Expand Down
6 changes: 3 additions & 3 deletions textattack/constraints/grammaticality/part_of_speech.py
Expand Up @@ -48,9 +48,9 @@ def _get_pos(self, before_ctx, word, after_ctx):
)

if self.tagger_type == "flair":
word_list, pos_list = zip_flair_result(
self._flair_pos_tagger.predict(context_key)[0]
)
context_key_sentence = Sentence(context_key)
self._flair_pos_tagger.predict(context_key_sentence)
word_list, pos_list = zip_flair_result(context_key_sentence)

self._pos_tag_cache[context_key] = (word_list, pos_list)

Expand Down
Expand Up @@ -8,12 +8,11 @@
"""
This file contains the definition of encoders used in https://arxiv.org/pdf/1705.02364.pdf
"""

import time

import numpy as np
import torch
import torch.nn as nn
from torch import nn as nn


class InferSentModel(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion textattack/models/helpers/glove_embedding_layer.py
Expand Up @@ -2,7 +2,7 @@

import numpy as np
import torch
import torch.nn as nn
from torch import nn as nn

from textattack.shared import logger, utils

Expand Down
2 changes: 1 addition & 1 deletion textattack/models/helpers/lstm_for_classification.py
@@ -1,5 +1,5 @@
import torch
import torch.nn as nn
from torch import nn as nn

import textattack
from textattack.models.helpers import GloveEmbeddingLayer
Expand Down
4 changes: 2 additions & 2 deletions textattack/models/helpers/word_cnn_for_classification.py
@@ -1,6 +1,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn as nn
from torch.nn import functional as F

import textattack
from textattack.models.helpers import GloveEmbeddingLayer
Expand Down