Skip to content

Commit

Permalink
Update deep speech estimator for mp3
Browse files Browse the repository at this point in the history
Signed-off-by: Beat Buesser <beat.buesser@ie.ibm.com>
  • Loading branch information
Beat Buesser committed May 22, 2022
1 parent e8b9743 commit 224d038
Show file tree
Hide file tree
Showing 5 changed files with 240 additions and 124 deletions.
16 changes: 16 additions & 0 deletions art/defences/preprocessor/mp3_compression_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,23 @@ def forward(
:param y: Labels of the sample `x`. This function does not affect them in any way.
:return: Compressed sample.
"""
import torch # lgtm [py/repeated-import]

ndim = x.ndim

if ndim == 1:
x = torch.unsqueeze(x, dim=0)
if self.channels_first:
dim = 1
else:
dim = 2
x = torch.unsqueeze(x, dim=dim)

x_compressed = self._compression_pytorch_numpy.apply(x)

if ndim == 1:
x_compressed = torch.squeeze(x_compressed)

return x_compressed, y

def _check_params(self) -> None:
Expand Down
39 changes: 20 additions & 19 deletions art/estimators/speech_recognition/pytorch_deep_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,17 +352,17 @@ def predict(
"""
import torch # lgtm [py/repeated-import]

x_in = np.empty(len(x), dtype=object)
x_in[:] = list(x)
# Apply preprocessing
x_preprocessed, _ = self._apply_preprocessing(x, y=None, fit=False)

x_in = np.empty(len(x_preprocessed), dtype=object)
x_in[:] = list(x_preprocessed)

# Put the model in the eval mode
self._model.eval()

# Apply preprocessing
x_preprocessed, _ = self._apply_preprocessing(x_in, y=None, fit=False)

# Transform x into the model input space
inputs, _, input_rates, _, batch_idx = self._transform_model_input(x=x_preprocessed)
inputs, _, input_rates, _, batch_idx = self._transform_model_input(x=x_in)

# Compute real input sizes
input_sizes = input_rates.mul_(inputs.size()[-1]).int()
Expand Down Expand Up @@ -437,20 +437,20 @@ def loss_gradient(self, x: np.ndarray, y: np.ndarray, **kwargs) -> np.ndarray:
lengths. A possible example of `y` could be: `y = np.array(['SIXTY ONE', 'HELLO'])`.
:return: Loss gradients of the same shape as `x`.
"""
x_in = np.empty(len(x), dtype=object)
x_in[:] = list(x)
# Apply preprocessing
x_preprocessed, _ = self._apply_preprocessing(x, None, fit=False)

x_in = np.empty(len(x_preprocessed), dtype=object)
x_in[:] = list(x_preprocessed)

# Put the model in the training mode, otherwise CUDA can't backpropagate through the model.
# However, model uses batch norm layers which need to be frozen
self._model.train()
self.set_batchnorm(train=False)

# Apply preprocessing
x_preprocessed, y_preprocessed = self._apply_preprocessing(x_in, y, fit=False)

# Transform data into the model input space
inputs, targets, input_rates, target_sizes, _ = self._transform_model_input(
x=x_preprocessed, y=y_preprocessed, compute_gradient=True
x=x_in, y=y, compute_gradient=True
)

# Compute real input sizes
Expand Down Expand Up @@ -484,8 +484,8 @@ def loss_gradient(self, x: np.ndarray, y: np.ndarray, **kwargs) -> np.ndarray:

# Get results
results_list = []
for i, _ in enumerate(x_preprocessed):
results_list.append(x_preprocessed[i].grad.cpu().numpy().copy())
for i, _ in enumerate(x_in):
results_list.append(x_in[i].grad.cpu().numpy().copy())

results = np.array(results_list)

Expand Down Expand Up @@ -521,18 +521,19 @@ def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: in
"""
import random

x_in = np.empty(len(x), dtype=object)
x_in[:] = list(x)
# Apply preprocessing
x_preprocessed, _ = self._apply_preprocessing(x, None, fit=True)
y_preprocessed = y

x_in = np.empty(len(x_preprocessed), dtype=object)
x_in[:] = list(x_preprocessed)

# Put the model in the training mode
self._model.train()

if self.optimizer is None: # pragma: no cover
raise ValueError("An optimizer is required to train the model, but none was provided.")

# Apply preprocessing
x_preprocessed, y_preprocessed = self._apply_preprocessing(x_in, y, fit=True)

# Train with batch processing
num_batch = int(np.ceil(len(x_preprocessed) / float(batch_size)))
ind = np.arange(len(x_preprocessed))
Expand Down
93 changes: 92 additions & 1 deletion tests/attacks/evasion/test_imperceptible_asr_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
@pytest.mark.parametrize("use_amp", [False, True])
@pytest.mark.parametrize("device_type", ["cpu", "gpu"])
def test_imperceptible_asr_pytorch(art_warning, expected_values, use_amp, device_type):
# Only import if deepspeech_pytorch module is available
import torch

from art.estimators.speech_recognition.pytorch_deep_speech import PyTorchDeepSpeech
Expand Down Expand Up @@ -137,6 +136,98 @@ def test_imperceptible_asr_pytorch(art_warning, expected_values, use_amp, device
art_warning(e)


@pytest.mark.skip_module("deepspeech_pytorch")
@pytest.mark.skip_framework("tensorflow", "keras", "kerastf", "mxnet", "non_dl_frameworks")
@pytest.mark.parametrize("use_amp", [False])
@pytest.mark.parametrize("device_type", ["cpu"])
def test_imperceptible_asr_pytorch(art_warning, expected_values, use_amp, device_type):
import torch

from art.estimators.speech_recognition.pytorch_deep_speech import PyTorchDeepSpeech
from art.attacks.evasion.imperceptible_asr.imperceptible_asr_pytorch import ImperceptibleASRPyTorch
from art.defences.preprocessor import Mp3CompressionPyTorch

try:
# Skip test if gpu is not available and use_amp is true
if use_amp and not torch.cuda.is_available():
return

# Load data for testing
expected_data = expected_values()

x1 = expected_data["x1"]

# Create signal data
x = np.array([x1 * 200, x1 * 200], dtype=ART_NUMPY_DTYPE)

# Create labels
y = np.array(["S", "I"])

# Create DeepSpeech estimator with preprocessing
mp3compression = Mp3CompressionPyTorch(sample_rate=44100, channels_first=True)

speech_recognizer = PyTorchDeepSpeech(
pretrained_model="librispeech",
device_type=device_type,
use_amp=use_amp,
preprocessing_defences=mp3compression,
)

# Create attack
asr_attack = ImperceptibleASRPyTorch(
estimator=speech_recognizer,
eps=0.001,
max_iter_1=5,
max_iter_2=5,
learning_rate_1=0.00001,
learning_rate_2=0.001,
optimizer_1=torch.optim.Adam,
optimizer_2=torch.optim.Adam,
global_max_length=3200,
initial_rescale=1.0,
decrease_factor_eps=0.8,
num_iter_decrease_eps=5,
alpha=0.01,
increase_factor_alpha=1.2,
num_iter_increase_alpha=5,
decrease_factor_alpha=0.8,
num_iter_decrease_alpha=5,
win_length=2048,
hop_length=512,
n_fft=2048,
batch_size=2,
use_amp=use_amp,
opt_level="O1",
)

# Test transcription output
transcriptions_preprocessing = speech_recognizer.predict(x, batch_size=2, transcription_output=True)

expected_transcriptions = np.array(["", ""])

assert (expected_transcriptions == transcriptions_preprocessing).all()

# Generate attack
x_adv_preprocessing = asr_attack.generate(x, y)

# Test shape
assert x_adv_preprocessing[0].shape == x[0].shape
assert x_adv_preprocessing[1].shape == x[1].shape

# Test content
assert not (x_adv_preprocessing[0] == x[0]).all()
assert not (x_adv_preprocessing[1] == x[1]).all()

assert np.sum(x_adv_preprocessing[0]) != np.inf
assert np.sum(x_adv_preprocessing[1]) != np.inf

assert np.sum(x_adv_preprocessing[0]) != 0
assert np.sum(x_adv_preprocessing[1]) != 0

except ARTTestException as e:
art_warning(e)


@pytest.mark.skip_module("deepspeech_pytorch")
@pytest.mark.skip_framework("tensorflow", "keras", "kerastf", "mxnet", "non_dl_frameworks")
def test_check_params(art_warning):
Expand Down
Loading

0 comments on commit 224d038

Please sign in to comment.