Skip to content

Commit

Permalink
Merge pull request #97 from QData/checkpoint
Browse files Browse the repository at this point in the history
Add checkpoint feature for running long attacks
  • Loading branch information
jxmorris12 committed May 18, 2020
2 parents c61d840 + c63add8 commit 181ee2f
Show file tree
Hide file tree
Showing 11 changed files with 316 additions and 24 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,6 @@ dist/

# Weights & Biases outputs
wandb/

# checkpoints
checkpoints/
3 changes: 3 additions & 0 deletions textattack/goal_function_results/goal_function_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ def __init__(self, tokenized_text, output, succeeded, score):

if isinstance(self.score, torch.Tensor):
self.score = self.score.item()

if isinstance(self.succeeded, torch.Tensor):
self.succeeded = self.succeeded.item()

def get_text_color_input(self):
""" A string representing the color this result's changed
Expand Down
18 changes: 18 additions & 0 deletions textattack/loggers/file_logger.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import os
import sys
import copy
import terminaltables

from .logger import Logger

class FileLogger(Logger):
def __init__(self, filename='', stdout=False):
self.stdout = stdout
self.filename = filename
if stdout:
self.fout = sys.stdout
elif isinstance(filename, str):
Expand All @@ -18,6 +20,18 @@ def __init__(self, filename='', stdout=False):
self.fout = filename
self.num_results = 0

def __getstate__(self):
# Temporarily save file handle b/c we can't copy it
state = {i: self.__dict__[i] for i in self.__dict__ if i !='fout'}
return state

def __setstate__(self, state):
self.__dict__ = state
if self.stdout:
self.fout = sys.stdout
else:
self.fout = open(self.filename, 'a')

def log_attack_result(self, result):
self.num_results += 1
color_method = 'stdout' if self.stdout else 'file'
Expand All @@ -36,4 +50,8 @@ def log_summary_rows(self, rows, title, window_id):

def log_sep(self):
self.fout.write('-' * 90 + '\n')

def flush(self):
self.fout.flush()


14 changes: 13 additions & 1 deletion textattack/loggers/visdom_logger.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import socket
import copy
from visdom import Visdom

from textattack.shared.utils import html_table_from_rows
Expand All @@ -16,9 +17,20 @@ def __init__(self, env='main', port=8097, hostname='localhost'):
if not port_is_open(port, hostname=hostname):
raise socket.error(f'Visdom not running on {hostname}:{port}')
self.vis = Visdom(port=port, server=hostname, env=env)
self.env = env
self.port = port
self.hostname = hostname
self.windows = {}
self.sample_rows = []

def __getstate__(self):
state = {i: self.__dict__[i] for i in self.__dict__ if i !='vis'}
return state

def __setstate__(self, state):
self.__dict__ = state
self.vis = Visdom(port=self.port, server=self.hostname, env=self.env)

def log_attack_result(self, result):
text_a, text_b = result.diff_color(color_method='html')
result_str = result.goal_function_result_str(color_method='html')
Expand Down Expand Up @@ -51,7 +63,7 @@ def table(self, rows, window_id=None, title=None, header=None, style=None):
if not window_id: window_id = title # Can provide either of these,
if not title: title = window_id # or both.
table = html_table_from_rows(rows, title=title, header=header, style_dict=style)
self.text(table_html, title=title, window_id=window_id)
self.text(table, title=title, window_id=window_id)

def bar(self, X_data, numbins=10, title=None, window_id=None):
window = None
Expand Down
7 changes: 6 additions & 1 deletion textattack/loggers/weights_and_biases_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,14 @@
class WeightsAndBiasesLogger(Logger):
def __init__(self, filename='', stdout=False):
import wandb
wandb.init(project='textattack')
wandb.init(project='textattack', resume=True)
self._result_table_rows = []

def __setstate__(self, state):
import wandb
self.__dict__ = state
wandb.init(project='textattack', resume=True)

def log_summary_rows(self, rows, title, window_id):
table = wandb.Table(columns=['Attack Results', ''])
for row in rows:
Expand Down
4 changes: 4 additions & 0 deletions textattack/search_methods/attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ def _get_examples_from_dataset(self, dataset, num_examples=None, shuffle=False,

if shuffle:
random.shuffle(dataset.examples)

if num_examples <= 0:
return
yield

for text, ground_truth_output in dataset:
tokenized_text = TokenizedText(text, self.tokenizer)
Expand Down
1 change: 1 addition & 0 deletions textattack/shared/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@

from .tokenized_text import TokenizedText
from .word_embedding import WordEmbedding
from .checkpoint import Checkpoint
126 changes: 126 additions & 0 deletions textattack/shared/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import os
import pickle
import time
import datetime
from textattack.shared import utils
from textattack.attack_results import SuccessfulAttackResult, FailedAttackResult, SkippedAttackResult

logger = utils.get_logger()

class Checkpoint:
""" An object that stores necessary information for saving and loading checkpoints
Args:
args: command line arguments of the original attack
log_manager (AttackLogManager)
chkpt_time (float): epoch time representing when checkpoint was made
"""
def __init__(self, args, log_manager, chkpt_time=None):
self.args = args
self.log_manager = log_manager
if chkpt_time:
self.time = chkpt_time
else:
self.time = time.time()

def __repr__(self):
main_str = 'Checkpoint('
lines = []
lines.append(
utils.add_indent(f'(Time): {self.datetime}', 2)
)

args_lines = []
for key in self.args.__dict__:
args_lines.append(utils.add_indent(f'({key}): {self.args.__dict__[key]}', 2))
args_str = utils.add_indent('\n' + '\n'.join(args_lines), 2)

lines.append(utils.add_indent(f'(Args): {args_str}', 2))

attack_logger_lines = []
attack_logger_lines.append(utils.add_indent(
f'(Total number of examples to attack): {self.args.num_examples}', 2
))
attack_logger_lines.append(utils.add_indent(
f'(Number of attacks performed): {self.results_count}', 2
))
attack_logger_lines.append(utils.add_indent(
f'(Number of remaining attacks): {self.num_remaining_attacks}', 2
))
breakdown_lines = []
breakdown_lines.append(utils.add_indent(
f'(Number of successful attacks): {self.num_successful_attacks}', 2
))
breakdown_lines.append(utils.add_indent(
f'(Number of failed attacks): {self.num_failed_attacks}', 2
))
breakdown_lines.append(utils.add_indent(
f'(Number of skipped attacks): {self.num_skipped_attacks}', 2
))
breakdown_str = utils.add_indent('\n' + '\n'.join(breakdown_lines), 2)
attack_logger_lines.append(utils.add_indent(f'(Latest result breakdown): {breakdown_str}', 2))
attack_logger_str = utils.add_indent('\n' + '\n'.join(attack_logger_lines), 2)
lines.append(utils.add_indent(f'(Previous attack summary): {attack_logger_str}', 2))

main_str += '\n ' + '\n '.join(lines) + '\n'
main_str += ')'
return main_str

__str__ = __repr__

@property
def results_count(self):
""" Return number of attacks made so far """
return len(self.log_manager.results)

@property
def num_skipped_attacks(self):
return sum(isinstance(r, SkippedAttackResult) for r in self.log_manager.results)

@property
def num_failed_attacks(self):
return sum(isinstance(r, FailedAttackResult) for r in self.log_manager.results)

@property
def num_successful_attacks(self):
return sum(isinstance(r, SuccessfulAttackResult) for r in self.log_manager.results)

@property
def num_remaining_attacks(self):
if self.args.attack_n:
non_skipped_attacks = self.num_successful_attacks + self.num_failed_attacks
count = self.args.num_examples - non_skipped_attacks
else:
count = self.args.num_examples - self.results_count
return count

@property
def dataset_offset(self):
""" Calculate offset into the dataset to start from """
# Original offset + # of results processed so far
return self.args.num_examples_offset + self.results_count

@property
def datetime(self):
return datetime.datetime.fromtimestamp(self.time).strftime('%Y-%m-%d %H:%M:%S')

def save(self, quiet=False):
file_name = "{}.ta.chkpt".format(int(self.time*1000))
if not os.path.exists(self.args.checkpoint_dir):
os.makedirs(self.args.checkpoint_dir)
path = os.path.join(self.args.checkpoint_dir, file_name)
if not quiet:
print('\n\n' + '=' * 125)
logger.info('Saving checkpoint under "{}" at {} after {} attacks.'.format(path, self.datetime, self.results_count))
print('=' * 125 + '\n')
with open(path, 'wb') as f:
pickle.dump(self, f)

@classmethod
def load(self, path):
with open(path, 'rb') as f:
checkpoint = pickle.load(f)
assert isinstance(checkpoint, Checkpoint)

return checkpoint

81 changes: 74 additions & 7 deletions textattack/shared/scripts/run_attack_args_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import textattack
import time
import torch
import pickle
import copy

RECIPE_NAMES = {
'alzantot': 'textattack.attack_recipes.Alzantot2018',
Expand Down Expand Up @@ -142,14 +144,15 @@ def set_seed(random_seed):
torch.manual_seed(random_seed)

def get_args():
# Parser for regular arguments
parser = argparse.ArgumentParser(
description='A commandline parser for TextAttack',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument('--transformation', type=str, required=False,
default='word-swap-embedding', choices=TRANSFORMATION_CLASS_NAMES.keys(),
help='The transformations to apply.')

parser.add_argument('--model', type=str, required=False, default='bert-yelp-sentiment',
choices=MODEL_CLASS_NAMES.keys(), help='The classification model to attack.')

Expand Down Expand Up @@ -196,6 +199,12 @@ def get_args():

def str_to_int(s): return sum((ord(c) for c in s))
parser.add_argument('--random-seed', default=str_to_int('TEXTATTACK'))

parser.add_argument('--checkpoint-dir', required=False, type=str, default=default_checkpoint_dir(),
help='A directory to save/load checkpoint files.')

parser.add_argument('--checkpoint-interval', required=False, type=int,
help='Interval for saving checkpoints. If not set, no checkpoints will be saved.')

attack_group = parser.add_mutually_exclusive_group(required=False)

Expand All @@ -207,11 +216,36 @@ def str_to_int(s): return sum((ord(c) for c in s))
attack_group.add_argument('--recipe', type=str, required=False, default=None,
help='full attack recipe (overrides provided goal function, transformation & constraints)',
choices=RECIPE_NAMES.keys())

command_line_args = None if sys.argv[1:] else ['-h'] # Default to help with empty arguments.
args = parser.parse_args(command_line_args)

set_seed(args.random_seed)

# Parser for parsing args for resume
resume_parser = argparse.ArgumentParser(
description='A commandline parser for TextAttack',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
resume_parser.add_argument('--checkpoint-file', '-f', type=str, required=False, default='latest',
help='Name of checkpoint file to resume attack from. If "latest" is entered, recover latest checkpoint.')

resume_parser.add_argument('--checkpoint-dir', '-d', required=False, type=str, default=default_checkpoint_dir(),
help='A directory to save/load checkpoint files.')

resume_parser.add_argument('--checkpoint-interval', '-i', required=False, type=int,
help='Interval for saving checkpoints. If not set, no checkpoints will be saved.')

resume_parser.add_argument('--parallel', action='store_true', default=False,
help='Run attack using multiple GPUs.')

if sys.argv[1:] and sys.argv[1].lower() == 'resume':
args = resume_parser.parse_args(sys.argv[2:])
setattr(args, 'checkpoint_resume', True)
else:
command_line_args = None if sys.argv[1:] else ['-h'] # Default to help with empty arguments.
args = parser.parse_args(command_line_args)
setattr(args, 'checkpoint_resume', False)

if args.checkpoint_interval and args.shuffle:
# Not allowed b/c we cannot recover order of shuffled data
raise ValueError('Cannot use `--checkpoint-interval` with `--shuffle=True`')

set_seed(args.random_seed)

return args

Expand Down Expand Up @@ -308,7 +342,7 @@ def parse_logger_from_args(args):# Create logger
if not args.out_dir:
current_dir = os.path.dirname(os.path.realpath(__file__))
outputs_dir = os.path.join(current_dir, os.pardir, os.pardir, os.pardir, 'outputs')
args.out_dir = outputs_dir
args.out_dir = os.path.normpath(outputs_dir)

# Output file.
out_time = int(time.time()*1000) # Output file
Expand All @@ -335,3 +369,36 @@ def parse_logger_from_args(args):# Create logger
if not args.disable_stdout:
attack_log_manager.enable_stdout()
return attack_log_manager

def parse_checkpoint_from_args(args):
if args.checkpoint_file.lower() == 'latest':
chkpt_file_names = [f for f in os.listdir(args.checkpoint_dir) if f.endswith('.ta.chkpt')]
assert chkpt_file_names, "Checkpoint directory is empty"
timestamps = [int(f.replace('.ta.chkpt', '')) for f in chkpt_file_names]
latest_file = str(max(timestamps)) + '.ta.chkpt'
checkpoint_path = os.path.join(args.checkpoint_dir, latest_file)
else:
checkpoint_path = os.path.join(args.checkpoint_dir, args.checkpoint_file)

checkpoint = textattack.shared.Checkpoint.load(checkpoint_path)
set_seed(checkpoint.args.random_seed)

return checkpoint

def default_checkpoint_dir():
current_dir = os.path.dirname(os.path.realpath(__file__))
checkpoints_dir = os.path.join(current_dir, os.pardir, os.pardir, os.pardir, 'checkpoints')
return os.path.normpath(checkpoints_dir)

def merge_checkpoint_args(saved_args, cmdline_args):
""" Merge previously saved arguments for checkpoint and newly entered arguments """
args = copy.deepcopy(saved_args)
# Newly entered arguments take precedence
args.checkpoint_resume = cmdline_args.checkpoint_resume
args.parallel = cmdline_args.parallel
args.checkpoint_dir = cmdline_args.checkpoint_dir
# If set, we replace
if cmdline_args.checkpoint_interval:
args.checkpoint_interval = cmdlineargs.checkpoint_interval

return args

0 comments on commit 181ee2f

Please sign in to comment.