Skip to content

Commit

Permalink
Merge pull request #91 from QData/config-yaml
Browse files Browse the repository at this point in the history
minor fixes to install + logging
  • Loading branch information
uvafan committed May 12, 2020
2 parents 120a58b + f142d59 commit fdbe0e2
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 46 deletions.
1 change: 0 additions & 1 deletion MANIFEST.in

This file was deleted.

20 changes: 4 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,10 @@ You should be running Python 3.6+ to use this package. A CUDA-compatible GPU is
pip install textattack
```

We use the NLTK package for its list of stopwords and access to the WordNet lexical database. To download them run in Python shell:

```
import nltk
nltk.download('stopwords')
nltk.download('wordnet')
```

We use spaCy's English model. To download it, after installing spaCy run:

```
python -m spacy download en
```

### Cache
TextAttack provides pretrained models and datasets for user convenience. By default, all this stuff is downloaded to `~/.cache`. You can change this location by editing the `CACHE_DIR` field in `config.json`.
### Configuration
TextAttack downloads files to `~/.cache/textattack/` by default. This includes pretrained models,
dataset samples, and the configuration file `config.yaml`. To change the cache path, set the
environment variable `TA_CACHE_DIR`.

## Usage

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ lru-dict
nltk
numpy
pandas
pyyaml
scikit-learn
scipy
sentence_transformers
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="textattack",
version="0.0.1.7",
version="0.0.1.8",
author="QData Lab at the University of Virginia",
author_email="jm8wx@virginia.edu",
description="A library for generating text adversarial examples",
Expand Down
6 changes: 0 additions & 6 deletions textattack/config.json

This file was deleted.

87 changes: 65 additions & 22 deletions textattack/shared/utils.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,37 @@
import filelock
import json
import logging
import logging.config
import os
import pathlib
import requests
import shutil
import tempfile
import torch
import tqdm
import yaml
import zipfile

dir_path = os.path.dirname(os.path.realpath(__file__))
config_path = os.path.join(dir_path, os.pardir, 'config.json')
CONFIG = json.load(open(config_path, 'r'))
CONFIG['CACHE_DIR'] = os.path.expanduser(CONFIG['CACHE_DIR'])

def config(key):
return CONFIG[key]

def get_logger():
return logging.getLogger(__name__)

def get_device():
return torch.device("cuda" if torch.cuda.is_available() else "cpu")

def path_in_cache(file_path):
if not os.path.exists(CONFIG['CACHE_DIR']):
os.makedirs(CONFIG['CACHE_DIR'])
return os.path.join(CONFIG['CACHE_DIR'], file_path)
textattack_cache_dir = config('CACHE_DIR')
if not os.path.exists(textattack_cache_dir):
os.makedirs(textattack_cache_dir)
return os.path.join(textattack_cache_dir, file_path)

def s3_url(uri):
return 'https://textattack.s3.amazonaws.com/' + uri

def download_if_needed(folder_name):
""" Folder name will be saved as `.cache/textattack/[folder name]`. If it
doesn't exist, the zip file will be downloaded and extracted.
doesn't exist on disk, the zip file will be downloaded and extracted.
Args:
folder_name (str): path to folder or file in cache
Returns:
str: path to the downloaded folder or file on disk
"""
cache_dest_path = path_in_cache(folder_name)
os.makedirs(os.path.dirname(cache_dest_path), exist_ok=True)
Expand All @@ -48,25 +45,25 @@ def download_if_needed(folder_name):
return cache_dest_path
# If the file isn't found yet, download the zip file to the cache.
downloaded_file = tempfile.NamedTemporaryFile(
dir=CONFIG['CACHE_DIR'],
dir=config('CACHE_DIR'),
suffix='.zip', delete=False)
http_get(folder_name, downloaded_file)
# Move or unzip the file.
downloaded_file.close()
if zipfile.is_zipfile(downloaded_file.name):
unzip_file(downloaded_file.name, cache_dest_path)
else:
print('Copying', downloaded_file.name, 'to', cache_dest_path + '.')
get_logger().info(f'Copying {downloaded_file.name} to {cache_dest_path}.')
shutil.copyfile(downloaded_file.name, cache_dest_path)
cache_file_lock.release()
# Remove the temporary file.
os.remove(downloaded_file.name)
print(f'Successfully saved {folder_name} to cache.')
get_logger().info(f'Successfully saved {folder_name} to cache.')
return cache_dest_path

def unzip_file(path_to_zip_file, unzipped_folder_path):
""" Unzips a .zip file to folder path. """
print('Unzipping file', path_to_zip_file, 'to', unzipped_folder_path + '.')
get_logger().info(f'Unzipping file path_to_zip_file to unzipped_folder_path.')
enclosing_unzipped_path = pathlib.Path(unzipped_folder_path).parent
with zipfile.ZipFile(path_to_zip_file, 'r') as zip_ref:
zip_ref.extractall(enclosing_unzipped_path)
Expand All @@ -77,7 +74,7 @@ def http_get(folder_name, out_file, proxies=None):
https://github.com/huggingface/transformers/blob/master/src/transformers/file_utils.py
"""
folder_s3_url = s3_url(folder_name)
print(f'Downloading {folder_s3_url}.')
get_logger().info(f'Downloading {folder_s3_url}.')
req = requests.get(folder_s3_url, stream=True, proxies=proxies)
content_length = req.headers.get('Content-Length')
total = int(content_length) if content_length is not None else None
Expand Down Expand Up @@ -222,4 +219,50 @@ def has_letter(word):
""" Returns true if `word` contains at least one character in [A-Za-z]. """
for c in word:
if c.isalpha(): return True
return False
return False

LOG_STRING = f'\033[34;1mtextattack\033[0m'
logger = None
def get_logger():
global logger
if not logger:
logger = logging.getLogger(__name__)
logging.config.dictConfig({'version': 1, 'loggers': {__name__: {'level': logging.INFO}}})
formatter = logging.Formatter(f'{LOG_STRING}: %(message)s')
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
logger.propagate = False
return logger

def _post_install():
logger = get_logger()
logger.info('First time importing textattack: downloading remaining required packages.')
logger.info('Downloading spaCy required packages.')
import spacy
spacy.cli.download('en')
logger.info('Downloading NLTK required packages.')
import nltk
nltk.download('wordnet')
nltk.download('averaged_perceptron_tagger')
nltk.download('universal_tagset')
nltk.download('stopwords')

def _post_install_if_needed():
""" Runs _post_install if hasn't been run since install. """
# Check for post-install file.
post_install_file_path = os.path.join(config('CACHE_DIR'), 'post_install_check')
if os.path.exists(post_install_file_path):
return
# Run post-install.
_post_install()
# Create file that indicates post-install completed.
open(post_install_file_path, 'w').close()

def config(key):
return config_dict[key]

config_dict = {'CACHE_DIR': os.environ.get('TA_CACHE_DIR', os.path.expanduser('~/.cache/textattack'))}
config_path = download_if_needed('config.yaml')
config_dict.update(yaml.load(open(config_path, 'r'), Loader=yaml.FullLoader))
_post_install_if_needed()

0 comments on commit fdbe0e2

Please sign in to comment.