Skip to content

Commit

Permalink
Merge branch 'dev_1.14.0' into bad-dets-attack
Browse files Browse the repository at this point in the history
  • Loading branch information
beat-buesser authored Mar 11, 2023
2 parents cd3ce4e + 0a0a701 commit 43e79fe
Show file tree
Hide file tree
Showing 16 changed files with 2,752 additions and 256 deletions.
1 change: 1 addition & 0 deletions art/attacks/evasion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from art.attacks.evasion.adversarial_asr import CarliniWagnerASR
from art.attacks.evasion.auto_attack import AutoAttack
from art.attacks.evasion.auto_projected_gradient_descent import AutoProjectedGradientDescent
from art.attacks.evasion.auto_conjugate_gradient import AutoConjugateGradient

if importlib.util.find_spec("numba") is not None:
from art.attacks.evasion.brendel_bethge import BrendelBethgeAttack
Expand Down
652 changes: 652 additions & 0 deletions art/attacks/evasion/auto_conjugate_gradient.py

Large diffs are not rendered by default.

313 changes: 161 additions & 152 deletions art/attacks/evasion/auto_projected_gradient_descent.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions art/defences/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from art.defences.trainer.trainer import Trainer
from art.defences.trainer.adversarial_trainer import AdversarialTrainer
from art.defences.trainer.certified_adversarial_trainer_pytorch import AdversarialTrainerCertifiedPytorch
from art.defences.trainer.ibp_certified_trainer_pytorch import AdversarialTrainerCertifiedIBPPyTorch
from art.defences.trainer.adversarial_trainer_madry_pgd import AdversarialTrainerMadryPGD
from art.defences.trainer.adversarial_trainer_fbf import AdversarialTrainerFBF
from art.defences.trainer.adversarial_trainer_fbf_pytorch import AdversarialTrainerFBFPyTorch
Expand Down
465 changes: 465 additions & 0 deletions art/defences/trainer/ibp_certified_trainer_pytorch.py

Large diffs are not rendered by default.

9 changes: 9 additions & 0 deletions art/defences/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,15 @@ def fit(self, x: np.ndarray, y: np.ndarray, **kwargs) -> None:
"""
raise NotImplementedError

@property
def classifier(self) -> "CLASSIFIER_LOSS_GRADIENTS_TYPE":
"""
Access function to get the classifier.
:return: The classifier.
"""
return self._classifier

def get_classifier(self) -> "CLASSIFIER_LOSS_GRADIENTS_TYPE":
"""
Return the classifier trained via adversarial training.
Expand Down
106 changes: 78 additions & 28 deletions art/estimators/certification/interval/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,33 +106,33 @@ def __init__(
:param to_debug: Helper parameter to help with debugging.
"""

self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.padding = padding
self.dilation = dilation
self.stride = stride
self.device = device
self.include_bias = bias
self.cnn: Optional["torch.nn.Conv2d"] = None

super().__init__()
self.conv_flat = torch.nn.Conv2d(
self.conv = torch.nn.Conv2d(
in_channels=1,
out_channels=out_channels * in_channels,
kernel_size=kernel_size,
padding=padding,
dilation=dilation,
bias=False,
stride=stride,
)
).to(device)
self.bias_to_grad = None

if bias:
self.conv_bias = torch.nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
padding=padding,
dilation=dilation,
bias=True,
stride=stride,
)
if self.conv_bias.bias is not None:
self.bias_to_grad = self.conv_bias.bias.data
self.bias_to_grad = torch.nn.Parameter(torch.rand(out_channels).to(device))

if to_debug:
self.conv = torch.nn.Conv2d(
self.conv_debug = torch.nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
Expand All @@ -143,37 +143,39 @@ def __init__(
).to(device)

if isinstance(kernel_size, tuple):
self.conv_flat.weight = torch.nn.Parameter(
self.conv.weight = torch.nn.Parameter(
torch.reshape(
torch.tensor(self.conv.weight.data.cpu().detach().numpy()),
torch.tensor(self.conv_debug.weight.data.cpu().detach().numpy()),
(out_channels * in_channels, 1, kernel_size[0], kernel_size[1]),
).to(device)
)
else:
self.conv_flat.weight = torch.nn.Parameter(
self.conv.weight = torch.nn.Parameter(
torch.reshape(
torch.tensor(self.conv.weight.data.cpu().detach().numpy()),
torch.tensor(self.conv_debug.weight.data.cpu().detach().numpy()),
(out_channels * in_channels, 1, kernel_size, kernel_size),
).to(device)
)
if bias and self.conv.bias is not None:
self.bias_to_grad = torch.nn.Parameter(torch.tensor(self.conv.bias.data.cpu().detach().numpy()))
if bias and self.conv_debug.bias is not None:
self.bias_to_grad = torch.nn.Parameter(
torch.tensor(self.conv_debug.bias.data.cpu().detach().numpy()).to(device)
)

if supplied_input_weights is not None:
if isinstance(kernel_size, tuple):
self.conv_flat.weight = torch.nn.Parameter(
self.conv.weight = torch.nn.Parameter(
torch.reshape(
supplied_input_weights,
(out_channels * in_channels, 1, kernel_size[0], kernel_size[1]),
)
)
else:
self.conv_flat.weight = torch.nn.Parameter(
self.conv.weight = torch.nn.Parameter(
torch.reshape(supplied_input_weights, (out_channels * in_channels, 1, kernel_size, kernel_size))
)

if supplied_input_bias is not None:
self.bias_to_grad = supplied_input_bias
self.bias_to_grad = torch.nn.Parameter(supplied_input_bias.to(device))

self.in_channels = in_channels
self.out_channels = out_channels
Expand All @@ -189,6 +191,16 @@ def __init__(
if self.bias is not None:
self.bias = self.bias.to(device)

def re_convert(self, device: Union[str, "torch.device"]) -> None:
"""
Re converts the weights into a dense equivalent layer.
Must be called after every backwards if multiple gradients wish to be taken (like for crafting pgd).
"""
self.dense_weights, self.bias = self.convert_to_dense(device)
self.dense_weights = self.dense_weights.to(device)
if self.bias is not None:
self.bias = self.bias.to(device)

def convert_to_dense(self, device: Union[str, "torch.device"]) -> Tuple["torch.Tensor", "torch.Tensor"]:
"""
Converts the initialised convolutional layer into an equivalent dense layer.
Expand Down Expand Up @@ -224,16 +236,16 @@ def convert_to_dense(self, device: Union[str, "torch.device"]) -> Tuple["torch.T
torch.eye(self.input_height * self.input_width),
shape=[self.input_height * self.input_width, 1, self.input_height, self.input_width],
).to(device)
conv = self.conv_flat(diagonal_input)
self.output_height = int(conv.shape[2])
self.output_width = int(conv.shape[3])
conv_output = self.conv(diagonal_input)
self.output_height = int(conv_output.shape[2])
self.output_width = int(conv_output.shape[3])

# conv is of shape (input_height * input_width, out_channels * in_channels, output_height, output_width).
# Reshape it to (input_height * input_width * output_channels,
# output_height * output_width * input_channels).

weights = torch.reshape(
conv,
conv_output,
shape=(
[
self.input_height * self.input_width,
Expand Down Expand Up @@ -261,7 +273,6 @@ def convert_to_dense(self, device: Union[str, "torch.device"]) -> Tuple["torch.T
bias = bias.flatten()
else:
bias = None

return torch.transpose(weights, 0, 1), bias

def forward(self, x: "torch.Tensor") -> "torch.Tensor":
Expand Down Expand Up @@ -300,6 +311,45 @@ def concrete_forward(self, x: "torch.Tensor") -> "torch.Tensor":
x = torch.matmul(x, torch.transpose(self.dense_weights, 0, 1)) + self.bias
return x.reshape((-1, self.out_channels, self.output_height, self.output_width))

def conv_forward(self, x: "torch.Tensor") -> "torch.Tensor":
"""
Method for efficiently interfacing with adversarial attacks.
Backpropagating through concrete_forward is too slow if adversarial attacks need to be generated on-the fly
or require a large amount of iterations.
This method will create a regular conv layer with the right parameters to use.
:param x: concrete input to the convolutional layer.
:return: output of the convolutional layer on x
"""
if self.cnn is None:
self.cnn = torch.nn.Conv2d(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=self.kernel_size,
bias=self.include_bias,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
).to(self.device)
if isinstance(self.kernel_size, tuple):
self.cnn.weight.data = torch.reshape(
torch.tensor(self.conv.weight.data.cpu().detach().numpy()),
(self.out_channels, self.in_channels, self.kernel_size[0], self.kernel_size[1]),
).to(self.device)
else:
self.cnn.weight.data = torch.reshape(
torch.tensor(self.conv.weight.data.cpu().detach().numpy()),
(self.out_channels, self.in_channels, self.kernel_size, self.kernel_size),
).to(self.device)
if self.cnn.bias is not None and self.bias_to_grad is not None:
self.cnn.bias.data = torch.tensor(self.bias_to_grad.data.cpu().detach().numpy()).to(self.device)

if self.cnn is not None:
return self.cnn(x)
raise ValueError("The convolutional layer for attack mode was not created properly")


class PyTorchIntervalFlatten(torch.nn.Module):
"""
Expand Down
75 changes: 63 additions & 12 deletions art/estimators/certification/interval/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ def forward(self, x: np.ndarray) -> "torch.Tensor":
x[batch_size, 2, feature_1, feature_2, ...] where axis=1 corresponds to the [lower, upper] bounds.
:return: regular model predictions if in concrete mode, or interval predictions if running in abstract mode
"""
if self.forward_mode == "concrete":

if self.forward_mode in ["concrete", "attack"]:
return self.concrete_forward(x)
if self.forward_mode == "abstract":
if x.shape[1] == 2:
Expand Down Expand Up @@ -170,23 +171,33 @@ def concrete_forward(self, in_x: Union[np.ndarray, "torch.Tensor"]) -> "torch.Te
x = torch.from_numpy(in_x.astype("float32")).to(self.device)
else:
x = in_x

for op_num, (op, _) in enumerate(zip(self.ops, self.interim_shapes)):
for op_num, op in enumerate(self.ops):
# as reshapes are not modules we infer when the reshape from convolutional to dense occurs
if self.reshape_op_num == op_num:
x = x.reshape((x.shape[0], -1))
x = op.concrete_forward(x)
if isinstance(op, PyTorchIntervalConv2D) and self.forward_mode == "attack":
x = op.conv_forward(x)
else:
x = op.concrete_forward(x)
return x

def set_forward_mode(self, mode: str) -> None:
"""
Helper function to set the forward mode of the model
:param mode: either concrete or abstract signifying how to run the forward pass
:param mode: either concrete, abstract, or attack to signify how to run the forward pass
"""
assert mode in {"concrete", "abstract"}
assert mode in {"concrete", "abstract", "attack"}
self.forward_mode = mode

def re_convert(self) -> None:
"""
After an update on the convolutional weights, re-convert weights into the equivalent dense layer.
"""
for op in self.ops:
if isinstance(op, PyTorchIntervalConv2D):
op.re_convert(self.device)


class PyTorchIBPClassifier(PyTorchIntervalBounds, PyTorchClassifier):
"""
Expand All @@ -195,6 +206,13 @@ class PyTorchIBPClassifier(PyTorchIntervalBounds, PyTorchClassifier):
to then verify if it can have its class changed given a certain perturbation.
| Paper link: https://ieeexplore.ieee.org/document/8418593
This classifier has 3 modes which can be set via: classifier.model.set_forward_mode('mode')
'mode' can be one of:
+ 'abstract': When we wish to certifiy datapoints and have abstract predictions
+ 'concrete': When normal predictions need to be made
+ 'attack': When we are interfacing with an ART attack (for example PGD).
"""

estimator_params = PyTorchClassifier.estimator_params
Expand Down Expand Up @@ -255,7 +273,6 @@ def __init__(

if TYPE_CHECKING:
converted_optimizer: Union[torch.optim.Adam, torch.optim.SGD, None]

if optimizer is not None:
opt_state_dict = optimizer.state_dict()
if isinstance(optimizer, torch.optim.Adam):
Expand Down Expand Up @@ -321,15 +338,16 @@ def predict_intervals( # pylint: disable=W0613

x_preprocessed, _ = self._apply_preprocessing(x, y=None, fit=False)
self._model.train(mode=training_mode)
if not is_interval:
if bounds is None:
raise ValueError("If x is not provided as an interval please provide bounds (and optionally limits)")

if is_interval:
x_interval = x_preprocessed
elif bounds is None:
raise ValueError("If x is not provided as an interval please provide bounds (and optionally limits)")
else:
if self.provided_concrete_to_interval is None:
x_interval = self.concrete_to_interval(x=x_preprocessed, bounds=bounds, limits=limits)
else:
x_interval = self.provided_concrete_to_interval(x_preprocessed, bounds, limits)
else:
x_interval = x_preprocessed

num_batches = int(len(x_interval) / batch_size)

Expand Down Expand Up @@ -376,3 +394,36 @@ def get_accuracy(preds: Union[np.ndarray, "torch.Tensor"], labels: Union[np.ndar
labels = labels.detach().cpu().numpy()

return np.sum(np.argmax(preds, axis=1) == labels) / len(labels)

def concrete_loss(self, output: "torch.Tensor", target: "torch.Tensor") -> "torch.Tensor":
"""
Access function to get the classifier loss
:param output: model predictions
:param target: ground truth labels
:return: loss value
"""
return self._loss(output, target)

@staticmethod
def interval_loss_cce(prediction: "torch.Tensor", target: "torch.Tensor") -> "torch.Tensor":
"""
Computes the categorical cross entropy loss with the correct class having the lower bound prediction,
and the other classes having their upper bound predictions.
:param prediction: model predictions.
:param target: target classes. NB not one hot.
:return: scalar loss value
"""
upper_preds = prediction[:, 1, :]
criterion = torch.nn.CrossEntropyLoss()
for i, j in enumerate(target):
# for the prediction corresponding to the target class, take the lower bound predictions
upper_preds[i, j] = prediction[i, 0, j]
return criterion(upper_preds, target)

def re_convert(self) -> None:
"""
Convert all the convolutional layers into their dense representations
"""
self.model.re_convert() # type: ignore
2 changes: 2 additions & 0 deletions art/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
from art.estimators.classification.tensorflow import TensorFlowClassifier, TensorFlowV2Classifier
from art.estimators.classification.xgboost import XGBoostClassifier
from art.estimators.certification.deep_z import PytorchDeepZ
from art.estimators.certification.interval import PyTorchIBPClassifier
from art.estimators.certification.derandomized_smoothing.derandomized_smoothing import BlockAblator, ColumnAblator
from art.estimators.generation import TensorFlowGenerator
from art.estimators.generation.tensorflow import TensorFlowV2Generator
Expand Down Expand Up @@ -259,6 +260,7 @@
ABLATOR_TYPE = Union[BlockAblator, ColumnAblator] # pylint: disable=C0103

CERTIFIER_TYPE = Union[PytorchDeepZ] # pylint: disable=C0103
IBP_CERTIFIER_TYPE = Union[PyTorchIBPClassifier] # pylint: disable=C0103

# --------------------------------------------------------------------------------------------------------- DEPRECATION

Expand Down
6 changes: 6 additions & 0 deletions docs/modules/attacks/evasion.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ Auto Projected Gradient Descent (Auto-PGD)
:members:
:special-members:

Auto Conjugate Gradient (Auto-CG)
------------------------------------------
.. autoclass:: AutoConjugateGradient
:members:
:special-members:

Boundary Attack / Decision-Based Attack
---------------------------------------
.. autoclass:: BoundaryAttack
Expand Down
6 changes: 6 additions & 0 deletions docs/modules/defences/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ Adversarial Training Certified - PyTorch
:members:
:special-members:

Adversarial Training Certified Interval Bound Propagation - PyTorch
-------------------------------------------------------------------
.. autoclass:: AdversarialTrainerCertifiedIBPPyTorch
:members:
:special-members:

DP - InstaHide Training
-----------------------
.. autoclass:: DPInstaHideTrainer
Expand Down
Loading

0 comments on commit 43e79fe

Please sign in to comment.