Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding label check to trades adversarial trainer #2231

Merged
merged 5 commits into from
Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions art/defences/trainer/adversarial_trainer_trades_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from art.estimators.classification.pytorch import PyTorchClassifier
from art.data_generators import DataGenerator
from art.attacks.attack import EvasionAttack
from art.utils import check_and_transform_label_format

if TYPE_CHECKING:
import torch
Expand Down Expand Up @@ -97,6 +98,15 @@ def fit(
ind = np.arange(len(x))

logger.info("Adversarial Training TRADES")
y = check_and_transform_label_format(y, nb_classes=self.classifier.nb_classes)

if validation_data is not None:
(x_test, y_test) = validation_data
y_test = check_and_transform_label_format(y_test, nb_classes=self.classifier.nb_classes)

x_preprocessed_test, y_preprocessed_test = self._classifier._apply_preprocessing( # pylint: disable=W0212
x_test, y_test, fit=True
)

for i_epoch in trange(nb_epochs, desc="Adversarial Training TRADES - Epochs"):
# Shuffle the examples
Expand All @@ -107,7 +117,6 @@ def fit(
train_n = 0.0

for batch_id in range(nb_batches):

# Create batch data
x_batch = x[ind[batch_id * batch_size : min((batch_id + 1) * batch_size, x.shape[0])]].copy()
y_batch = y[ind[batch_id * batch_size : min((batch_id + 1) * batch_size, x.shape[0])]]
Expand All @@ -125,9 +134,9 @@ def fit(

# compute accuracy
if validation_data is not None:
(x_test, y_test) = validation_data
output = np.argmax(self.predict(x_test), axis=1)
nb_correct_pred = np.sum(output == np.argmax(y_test, axis=1))
output = np.argmax(self.predict(x_preprocessed_test), axis=1)
nb_correct_pred = np.sum(output == np.argmax(y_preprocessed_test, axis=1))

logger.info(
"epoch: %s time(s): %.1f loss: %.4f acc(tr): %.4f acc(val): %.4f",
i_epoch,
Expand Down Expand Up @@ -188,7 +197,6 @@ def fit_generator(
train_n = 0.0

for batch_id in range(nb_batches): # pylint: disable=W0612

# Create batch data
x_batch, y_batch = generator.get_batch()
x_batch = x_batch.copy()
Expand Down Expand Up @@ -232,6 +240,8 @@ def _batch_process(self, x_batch: np.ndarray, y_batch: np.ndarray) -> Tuple[floa
x_batch_pert = self._attack.generate(x_batch, y=y_batch)

# Apply preprocessing
y_batch = check_and_transform_label_format(y_batch, nb_classes=self.classifier.nb_classes)

x_preprocessed, y_preprocessed = self._classifier._apply_preprocessing( # pylint: disable=W0212
x_batch, y_batch, fit=True
)
Expand Down
48 changes: 39 additions & 9 deletions tests/defences/trainer/test_adversarial_trainer_trades_pytorch.py
Copy link
Collaborator

@Zaid-Hameed Zaid-Hameed Aug 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Giulio, Can you please check if adding from_logits=True makes any changes to tests outcomes i.e., changing line 37 to
classifier, _ = image_dl_estimator(from_logits=True)

This is just to make sure that softmax is not being applied twice in test files.

Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def _get_adv_trainer():
if framework in ["tensorflow", "tensorflow2v1"]:
trainer = None
if framework == "pytorch":
classifier, _ = image_dl_estimator()
classifier, _ = image_dl_estimator(from_logits=True)
attack = ProjectedGradientDescent(
classifier,
norm=np.inf,
Expand Down Expand Up @@ -63,22 +63,38 @@ def fix_get_mnist_subset(get_mnist_dataset):
yield x_train_mnist[:n_train], y_train_mnist[:n_train], x_test_mnist[:n_test], y_test_mnist[:n_test]


@pytest.mark.skip_framework("tensorflow", "keras", "scikitlearn", "mxnet", "kerastf")
def test_adversarial_trainer_trades_pytorch_fit_and_predict(get_adv_trainer, fix_get_mnist_subset):
@pytest.mark.only_with_platform("pytorch")
@pytest.mark.parametrize("label_format", ["one_hot", "numerical"])
def test_adversarial_trainer_trades_pytorch_fit_and_predict(get_adv_trainer, fix_get_mnist_subset, label_format):
(x_train_mnist, y_train_mnist, x_test_mnist, y_test_mnist) = fix_get_mnist_subset
x_test_mnist_original = x_test_mnist.copy()

if label_format == "one_hot":
assert y_train_mnist.shape[-1] == 10
assert y_test_mnist.shape[-1] == 10
if label_format == "numerical":
y_test_mnist = np.argmax(y_test_mnist, axis=1)
y_train_mnist = np.argmax(y_train_mnist, axis=1)

trainer = get_adv_trainer()
if trainer is None:
logging.warning("Couldn't perform this test because no trainer is defined for this framework configuration")
return

predictions = np.argmax(trainer.predict(x_test_mnist), axis=1)
accuracy = np.sum(predictions == np.argmax(y_test_mnist, axis=1)) / x_test_mnist.shape[0]

if label_format == "one_hot":
accuracy = np.sum(predictions == np.argmax(y_test_mnist, axis=1)) / x_test_mnist.shape[0]
else:
accuracy = np.sum(predictions == y_test_mnist) / x_test_mnist.shape[0]

trainer.fit(x_train_mnist, y_train_mnist, nb_epochs=20)
predictions_new = np.argmax(trainer.predict(x_test_mnist), axis=1)
accuracy_new = np.sum(predictions_new == np.argmax(y_test_mnist, axis=1)) / x_test_mnist.shape[0]

if label_format == "one_hot":
accuracy_new = np.sum(predictions_new == np.argmax(y_test_mnist, axis=1)) / x_test_mnist.shape[0]
else:
accuracy_new = np.sum(predictions_new == y_test_mnist) / x_test_mnist.shape[0]

np.testing.assert_array_almost_equal(
float(np.mean(x_test_mnist_original - x_test_mnist)),
Expand All @@ -92,13 +108,20 @@ def test_adversarial_trainer_trades_pytorch_fit_and_predict(get_adv_trainer, fix
trainer.fit(x_train_mnist, y_train_mnist, nb_epochs=20, validation_data=(x_train_mnist, y_train_mnist))


@pytest.mark.skip_framework("tensorflow", "keras", "scikitlearn", "mxnet", "kerastf")
@pytest.mark.only_with_platform("pytorch")
@pytest.mark.parametrize("label_format", ["one_hot", "numerical"])
def test_adversarial_trainer_trades_pytorch_fit_generator_and_predict(
get_adv_trainer, fix_get_mnist_subset, image_data_generator
get_adv_trainer, fix_get_mnist_subset, image_data_generator, label_format
):
(x_train_mnist, y_train_mnist, x_test_mnist, y_test_mnist) = fix_get_mnist_subset
x_test_mnist_original = x_test_mnist.copy()

if label_format == "one_hot":
assert y_train_mnist.shape[-1] == 10
assert y_test_mnist.shape[-1] == 10
if label_format == "numerical":
y_test_mnist = np.argmax(y_test_mnist, axis=1)

generator = image_data_generator()

trainer = get_adv_trainer()
Expand All @@ -107,11 +130,18 @@ def test_adversarial_trainer_trades_pytorch_fit_generator_and_predict(
return

predictions = np.argmax(trainer.predict(x_test_mnist), axis=1)
accuracy = np.sum(predictions == np.argmax(y_test_mnist, axis=1)) / x_test_mnist.shape[0]
if label_format == "one_hot":
accuracy = np.sum(predictions == np.argmax(y_test_mnist, axis=1)) / x_test_mnist.shape[0]
else:
accuracy = np.sum(predictions == y_test_mnist) / x_test_mnist.shape[0]

trainer.fit_generator(generator=generator, nb_epochs=20)
predictions_new = np.argmax(trainer.predict(x_test_mnist), axis=1)
accuracy_new = np.sum(predictions_new == np.argmax(y_test_mnist, axis=1)) / x_test_mnist.shape[0]

if label_format == "one_hot":
accuracy_new = np.sum(predictions_new == np.argmax(y_test_mnist, axis=1)) / x_test_mnist.shape[0]
else:
accuracy_new = np.sum(predictions_new == y_test_mnist) / x_test_mnist.shape[0]

np.testing.assert_array_almost_equal(
float(np.mean(x_test_mnist_original - x_test_mnist)),
Expand Down