Skip to content

Commit

Permalink
Merge pull request #581 from dangne/master
Browse files Browse the repository at this point in the history
Add flexibility to Wandb logger
  • Loading branch information
qiyanjun committed Nov 21, 2021
2 parents 966531c + 5df3fbe commit a0ec175
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 14 deletions.
20 changes: 12 additions & 8 deletions textattack/attack_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,10 @@ class AttackArgs:
If set, Visdom logger is used with the provided dictionary passed as a keyword arguments to :class:`~textattack.loggers.VisdomLogger`.
Pass in empty dictionary to use default arguments. For custom logger, the dictionary should have the following
three keys and their corresponding values: :obj:`"env", "port", "hostname"`.
log_to_wandb (:obj:`str`, `optional`, defaults to :obj:`None`):
If set, log the attack results and summary to Wandb project specified by this argument.
log_to_wandb(:obj:`dict`, `optional`, defaults to :obj:`None`):
If set, WandB logger is used with the provided dictionary passed as a keyword arguments to :class:`~textattack.loggers.WeightsAndBiasesLogger`.
Pass in empty dictionary to use default arguments. For custom logger, the dictionary should have the following
key and its corresponding value: :obj:`"project"`.
disable_stdout (:obj:`bool`, `optional`, defaults to :obj:`False`):
Disable displaying individual attack results to stdout.
silent (:obj:`bool`, `optional`, defaults to :obj:`False`):
Expand All @@ -200,7 +202,7 @@ class AttackArgs:
log_to_csv: str = None
csv_coloring_style: str = "file"
log_to_visdom: dict = None
log_to_wandb: str = None
log_to_wandb: dict = None
disable_stdout: bool = False
silent: bool = False
enable_advance_metrics: bool = False
Expand Down Expand Up @@ -344,10 +346,12 @@ def _add_parser_args(cls, parser):
parser.add_argument(
"--log-to-wandb",
nargs="?",
default=default_obj.log_to_wandb,
const="textattack",
type=str,
help="Name of the wandb project. Set this argument if you want to log attacks to Wandb.",
default=None,
const='{"project": "textattack"}',
type=json.loads,
help="Set this argument if you want to log attacks to WandB. The dictionary should have the following "
'key and its corresponding value: `"project"`. '
'Example for command line use: `--log-to-wandb {"project": "textattack"}`.',
)
parser.add_argument(
"--disable-stdout",
Expand Down Expand Up @@ -420,7 +424,7 @@ def create_loggers_from_args(cls, args):

# Weights & Biases
if args.log_to_wandb is not None:
attack_log_manager.enable_wandb(args.log_to_wandb)
attack_log_manager.enable_wandb(**args.log_to_wandb)

# Stdout
if not args.disable_stdout and not sys.stdout.isatty():
Expand Down
4 changes: 2 additions & 2 deletions textattack/loggers/attack_log_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def enable_stdout(self):
def enable_visdom(self):
self.loggers.append(VisdomLogger())

def enable_wandb(self):
self.loggers.append(WeightsAndBiasesLogger())
def enable_wandb(self, **kwargs):
self.loggers.append(WeightsAndBiasesLogger(**kwargs))

def disable_color(self):
self.loggers.append(FileLogger(stdout=True, color_method="file"))
Expand Down
11 changes: 7 additions & 4 deletions textattack/loggers/weights_and_biases_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,23 @@
class WeightsAndBiasesLogger(Logger):
"""Logs attack results to Weights & Biases."""

def __init__(self, project_name):
def __init__(self, **kwargs):
assert "project" in kwargs

global wandb
wandb = LazyLoader("wandb", globals(), "wandb")

wandb.init(project=project_name)
self.project_name = project_name
wandb.init(**kwargs)
self.kwargs = kwargs
self.project_name = kwargs["project"]
self._result_table_rows = []

def __setstate__(self, state):
global wandb
wandb = LazyLoader("wandb", globals(), "wandb")

self.__dict__ = state
wandb.init(project=self.project_name, resume=True)
wandb.init(resume=True, **self.kwargs)

def log_summary_rows(self, rows, title, window_id):
table = wandb.Table(columns=["Attack Results", ""])
Expand Down

0 comments on commit a0ec175

Please sign in to comment.