Skip to content

Commit

Permalink
add custom metrics to AttackLogManager through AttackArgs
Browse files Browse the repository at this point in the history
  • Loading branch information
jxmorris12 committed Aug 26, 2022
1 parent 5ce8e26 commit 899ea46
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 4 deletions.
5 changes: 4 additions & 1 deletion textattack/attack_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
import sys
import time
from typing import Dict, Optional

import textattack
from textattack.shared.utils import ARGS_SPLIT_TOKEN, load_module_from_file
Expand Down Expand Up @@ -207,6 +208,7 @@ class AttackArgs:
disable_stdout: bool = False
silent: bool = False
enable_advance_metrics: bool = False
metrics: Optional[Dict] = None

def __post_init__(self):
if self.num_successful_examples:
Expand Down Expand Up @@ -386,12 +388,13 @@ def _add_parser_args(cls, parser):

@classmethod
def create_loggers_from_args(cls, args):
"""Creates AttackLogManager from an AttackArgs object."""
assert isinstance(
args, cls
), f"Expect args to be of type `{type(cls)}`, but got type `{type(args)}`."

# Create logger
attack_log_manager = textattack.loggers.AttackLogManager()
attack_log_manager = textattack.loggers.AttackLogManager(args.metrics)

# Get current time for file naming
timestamp = time.strftime("%Y-%m-%d-%H-%M")
Expand Down
14 changes: 13 additions & 1 deletion textattack/loggers/attack_log_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
========================
"""

from typing import Dict, Optional

from textattack.metrics.attack_metrics import (
AttackQueries,
AttackSuccessRate,
Expand All @@ -22,10 +24,17 @@
class AttackLogManager:
"""Logs the results of an attack to all attached loggers."""

def __init__(self):
# metrics maps strings (metric names) to textattack.metric.Metric objects
metrics: Dict

def __init__(self, metrics: Optional[Dict]):
self.loggers = []
self.results = []
self.enable_advance_metrics = False
if metrics is None:
self.metrics = {}
else:
self.metrics = metrics

def enable_stdout(self):
self.loggers.append(FileLogger(stdout=True))
Expand Down Expand Up @@ -127,6 +136,9 @@ def log_summary(self):
["Avg num queries:", attack_query_stats["avg_num_queries"]]
)

for metric_name, metric in self.metrics.items():
summary_table_rows.append([metric_name, metric.calculate(self.results)])

if self.enable_advance_metrics:
perplexity_stats = Perplexity().calculate(self.results)
use_stats = USEMetric().calculate(self.results)
Expand Down
2 changes: 0 additions & 2 deletions textattack/shared/utils/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ def batch_model_predict(model_predict, inputs, batch_size=32):
"""
outputs = []
i = 0
# print("batch_model_predict", inputs.shape)
# print("inputs:", inputs)
while i < len(inputs):
batch = inputs[i : i + batch_size]
batch_preds = model_predict(batch)
Expand Down

0 comments on commit 899ea46

Please sign in to comment.