Skip to content

Commit

Permalink
Merge pull request #2231 from GiulioZizzo/update_trades
Browse files Browse the repository at this point in the history
Adding label check to trades adversarial trainer
  • Loading branch information
beat-buesser committed Aug 17, 2023
2 parents c63d5d5 + d43f473 commit 52c240a
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 14 deletions.
20 changes: 15 additions & 5 deletions art/defences/trainer/adversarial_trainer_trades_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from art.estimators.classification.pytorch import PyTorchClassifier
from art.data_generators import DataGenerator
from art.attacks.attack import EvasionAttack
from art.utils import check_and_transform_label_format

if TYPE_CHECKING:
import torch
Expand Down Expand Up @@ -97,6 +98,15 @@ def fit(
ind = np.arange(len(x))

logger.info("Adversarial Training TRADES")
y = check_and_transform_label_format(y, nb_classes=self.classifier.nb_classes)

if validation_data is not None:
(x_test, y_test) = validation_data
y_test = check_and_transform_label_format(y_test, nb_classes=self.classifier.nb_classes)

x_preprocessed_test, y_preprocessed_test = self._classifier._apply_preprocessing( # pylint: disable=W0212
x_test, y_test, fit=True
)

for i_epoch in trange(nb_epochs, desc="Adversarial Training TRADES - Epochs"):
# Shuffle the examples
Expand All @@ -107,7 +117,6 @@ def fit(
train_n = 0.0

for batch_id in range(nb_batches):

# Create batch data
x_batch = x[ind[batch_id * batch_size : min((batch_id + 1) * batch_size, x.shape[0])]].copy()
y_batch = y[ind[batch_id * batch_size : min((batch_id + 1) * batch_size, x.shape[0])]]
Expand All @@ -125,9 +134,9 @@ def fit(

# compute accuracy
if validation_data is not None:
(x_test, y_test) = validation_data
output = np.argmax(self.predict(x_test), axis=1)
nb_correct_pred = np.sum(output == np.argmax(y_test, axis=1))
output = np.argmax(self.predict(x_preprocessed_test), axis=1)
nb_correct_pred = np.sum(output == np.argmax(y_preprocessed_test, axis=1))

logger.info(
"epoch: %s time(s): %.1f loss: %.4f acc(tr): %.4f acc(val): %.4f",
i_epoch,
Expand Down Expand Up @@ -188,7 +197,6 @@ def fit_generator(
train_n = 0.0

for batch_id in range(nb_batches): # pylint: disable=W0612

# Create batch data
x_batch, y_batch = generator.get_batch()
x_batch = x_batch.copy()
Expand Down Expand Up @@ -232,6 +240,8 @@ def _batch_process(self, x_batch: np.ndarray, y_batch: np.ndarray) -> Tuple[floa
x_batch_pert = self._attack.generate(x_batch, y=y_batch)

# Apply preprocessing
y_batch = check_and_transform_label_format(y_batch, nb_classes=self.classifier.nb_classes)

x_preprocessed, y_preprocessed = self._classifier._apply_preprocessing( # pylint: disable=W0212
x_batch, y_batch, fit=True
)
Expand Down
48 changes: 39 additions & 9 deletions tests/defences/trainer/test_adversarial_trainer_trades_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def _get_adv_trainer():
if framework in ["tensorflow", "tensorflow2v1"]:
trainer = None
if framework == "pytorch":
classifier, _ = image_dl_estimator()
classifier, _ = image_dl_estimator(from_logits=True)
attack = ProjectedGradientDescent(
classifier,
norm=np.inf,
Expand Down Expand Up @@ -63,22 +63,38 @@ def fix_get_mnist_subset(get_mnist_dataset):
yield x_train_mnist[:n_train], y_train_mnist[:n_train], x_test_mnist[:n_test], y_test_mnist[:n_test]


@pytest.mark.skip_framework("tensorflow", "keras", "scikitlearn", "mxnet", "kerastf")
def test_adversarial_trainer_trades_pytorch_fit_and_predict(get_adv_trainer, fix_get_mnist_subset):
@pytest.mark.only_with_platform("pytorch")
@pytest.mark.parametrize("label_format", ["one_hot", "numerical"])
def test_adversarial_trainer_trades_pytorch_fit_and_predict(get_adv_trainer, fix_get_mnist_subset, label_format):
(x_train_mnist, y_train_mnist, x_test_mnist, y_test_mnist) = fix_get_mnist_subset
x_test_mnist_original = x_test_mnist.copy()

if label_format == "one_hot":
assert y_train_mnist.shape[-1] == 10
assert y_test_mnist.shape[-1] == 10
if label_format == "numerical":
y_test_mnist = np.argmax(y_test_mnist, axis=1)
y_train_mnist = np.argmax(y_train_mnist, axis=1)

trainer = get_adv_trainer()
if trainer is None:
logging.warning("Couldn't perform this test because no trainer is defined for this framework configuration")
return

predictions = np.argmax(trainer.predict(x_test_mnist), axis=1)
accuracy = np.sum(predictions == np.argmax(y_test_mnist, axis=1)) / x_test_mnist.shape[0]

if label_format == "one_hot":
accuracy = np.sum(predictions == np.argmax(y_test_mnist, axis=1)) / x_test_mnist.shape[0]
else:
accuracy = np.sum(predictions == y_test_mnist) / x_test_mnist.shape[0]

trainer.fit(x_train_mnist, y_train_mnist, nb_epochs=20)
predictions_new = np.argmax(trainer.predict(x_test_mnist), axis=1)
accuracy_new = np.sum(predictions_new == np.argmax(y_test_mnist, axis=1)) / x_test_mnist.shape[0]

if label_format == "one_hot":
accuracy_new = np.sum(predictions_new == np.argmax(y_test_mnist, axis=1)) / x_test_mnist.shape[0]
else:
accuracy_new = np.sum(predictions_new == y_test_mnist) / x_test_mnist.shape[0]

np.testing.assert_array_almost_equal(
float(np.mean(x_test_mnist_original - x_test_mnist)),
Expand All @@ -92,13 +108,20 @@ def test_adversarial_trainer_trades_pytorch_fit_and_predict(get_adv_trainer, fix
trainer.fit(x_train_mnist, y_train_mnist, nb_epochs=20, validation_data=(x_train_mnist, y_train_mnist))


@pytest.mark.skip_framework("tensorflow", "keras", "scikitlearn", "mxnet", "kerastf")
@pytest.mark.only_with_platform("pytorch")
@pytest.mark.parametrize("label_format", ["one_hot", "numerical"])
def test_adversarial_trainer_trades_pytorch_fit_generator_and_predict(
get_adv_trainer, fix_get_mnist_subset, image_data_generator
get_adv_trainer, fix_get_mnist_subset, image_data_generator, label_format
):
(x_train_mnist, y_train_mnist, x_test_mnist, y_test_mnist) = fix_get_mnist_subset
x_test_mnist_original = x_test_mnist.copy()

if label_format == "one_hot":
assert y_train_mnist.shape[-1] == 10
assert y_test_mnist.shape[-1] == 10
if label_format == "numerical":
y_test_mnist = np.argmax(y_test_mnist, axis=1)

generator = image_data_generator()

trainer = get_adv_trainer()
Expand All @@ -107,11 +130,18 @@ def test_adversarial_trainer_trades_pytorch_fit_generator_and_predict(
return

predictions = np.argmax(trainer.predict(x_test_mnist), axis=1)
accuracy = np.sum(predictions == np.argmax(y_test_mnist, axis=1)) / x_test_mnist.shape[0]
if label_format == "one_hot":
accuracy = np.sum(predictions == np.argmax(y_test_mnist, axis=1)) / x_test_mnist.shape[0]
else:
accuracy = np.sum(predictions == y_test_mnist) / x_test_mnist.shape[0]

trainer.fit_generator(generator=generator, nb_epochs=20)
predictions_new = np.argmax(trainer.predict(x_test_mnist), axis=1)
accuracy_new = np.sum(predictions_new == np.argmax(y_test_mnist, axis=1)) / x_test_mnist.shape[0]

if label_format == "one_hot":
accuracy_new = np.sum(predictions_new == np.argmax(y_test_mnist, axis=1)) / x_test_mnist.shape[0]
else:
accuracy_new = np.sum(predictions_new == y_test_mnist) / x_test_mnist.shape[0]

np.testing.assert_array_almost_equal(
float(np.mean(x_test_mnist_original - x_test_mnist)),
Expand Down

0 comments on commit 52c240a

Please sign in to comment.