In [1]:
import torch.nn.functional as F
from typing import List, Tuple
from tqdm.notebook import tqdm
from scipy import signal
import librosa
import numpy as np
import torch
import scipy
from functools import reduce

import torchaudio
import torch
import numpy as np
from IPython.display import Audio


!git clone https://github.com/alem1r/CW-Toy

fatal: destination path 'CW-Toy' already exists and is not an empty directory.


In [2]:
class CW(object):

    def __init__(self, model, device, labels: List[str]):
        '''
        Creates an instance of the class.

        INPUT ARGUMENTS:

        model  : The model on which the attack is supposed to be performed.
        device : Either 'cpu' if we have only CPU or 'cuda' if we have GPU
        labels : Label/Dictionary of the model.
        '''
        self.model = model
        self.device = device
        self.labels = labels

    def _encode_transcription(self, transcription: List[str]) -> List[str]:
        '''
        Convert/Encode a string transcription into a tensor of numerical encodings based on a predefined dictionary.
        '''
        # Define the dictionary
        dictionary = {'-': 0, '|': 1, 'E': 2, 'T': 3, 'A': 4,
                      'O': 5, 'N': 6, 'I': 7, 'H': 8, 'S': 9,
                      'R': 10, 'D': 11, 'L': 12, 'U': 13, 'M': 14,
                      'W': 15, 'C': 16, 'F': 17, 'G': 18, 'Y': 19,
                      'P': 20, 'B': 21, 'V': 22, 'K': 23, "'": 24,
                      'X': 25, 'J': 26, 'Q': 27, 'Z': 28} #wav2vec uses this dictionary

        # Convert transcription string to list of characters
        chars = list(transcription)

        # Encode each character using the dictionary
        encoded_chars = [dictionary[char] for char in chars]

        # Concatenate the encoded characters to form the final encoded transcription
        encoded_transcription = torch.tensor(encoded_chars)

        # Returning the encoded transcription
        return encoded_transcription

    def CW_ATTACK(self, input__: torch.Tensor, target: List[str] = None,
           epsilon: float = 0.3, c: float = 1e-4, learning_rate: float = 0.01,
           num_iter: int = 1000, decrease_factor_eps: float = 1,
           num_iter_decrease_eps: int = 10, optimizer: str = None
           ) -> np.ndarray:

        '''
        Implements the Carlini and Wagner attack for adversarial examples on a speech recognition model.
        The CW attack aims to find a small perturbation to the input that causes a model to misclassify the input.
        Paper: https://arxiv.org/pdf/1801.01944.pdf

        INPUT ARGUMENTS:

        input__       : Input audio. Ex: Tensor[0.1,0.3,...] or (samples,)
                        Type: torch.Tensor

        target        : Target transcription (needed if the you want targeted
                        attack) Ex: ["my name is mango."].
                        Type: List[str]
                        CAUTION:
                        Please make sure these characters are also present in the
                        dictionary of the model also.

        epsilon       : Noise controlling parameter.
                        Type: float

        c             : Regularization term controlling factor.
                        Type: float

        learning_rate : learning_rate of optimizer.
                        Type: float

        num_iter      : Number of iteration of attack.
                        Type: int

        decrease_factor_eps   : Factor to decrease epsilon during search
                                Type: float

        num_iter_decrease_eps : Number of iterations after which to decrease epsilon
                                Type: int

        optimizer     : Name of the optimizer to use for the attack.
                        Type: str



        RETURNS:

        np.ndarray : Perturbed audio
        '''


        if epsilon <= 0:
            raise Exception("Value of epsilon should be greater than 0")

        # Convert the input audio to a PyTorch tensor
        input_audio = input__.clone().to(self.device).float()

        # Making audio differentiable
        input_audio.requires_grad_()

        # Cloning the original audio
        input_audio_orig = input_audio.clone().to(self.device)

        # Define the optimizer
        if optimizer == "Adam":

            optimizer = torch.optim.Adam([input_audio], lr=learning_rate)

        else:

            optimizer = torch.optim.SGD([input_audio], lr=learning_rate)

        # Setting our inital parameters
        successful_attack = False
        num_successful_attacks = 0


        # Encode the target transcription
        encoded_transcription = self._encode_transcription(target)

        # Convert the target transcription to a PyTorch tensor
        target_tensor = torch.from_numpy(np.array(encoded_transcription)).to(self.device).long()

        for i in tqdm(range(num_iter), colour="red"):

            # Zero the gradients
            optimizer.zero_grad()

            # Compute the model’s prediction
            output, _ = self.model(input_audio)

            # Softmax Activation for computing logits
            output = F.log_softmax(output, dim=-1)

            # Compute the CTC loss function between the model's output and the target transcription.
            output_length = torch.tensor([output.shape[1]], dtype=torch.long).to(self.device)
            output = output.transpose(0, 1)
            target_length = torch.tensor([len(encoded_transcription)], dtype=torch.long).to(self.device)
            loss_classifier = F.ctc_loss(output, target_tensor, output_length, target_length, blank=0, reduction='mean')

            # Regularization term to minimize the perturbation
            loss_norm = torch.norm(input_audio - input_audio_orig)

            # Combine the losses and compute gradients
            loss = (c * loss_norm) + ( loss_classifier)

            # Computing gradients of our input w.r.t loss
            loss.backward()

            # Update the input audio with gradients
            optimizer.step()

            # Calculating perturbation by subtracting the optimized audio from cloned one
            perturbation = input_audio - input_audio_orig

            # Project the perturbation onto the epsilon ball in range (-eps, eps)
            perturbation = torch.clamp(perturbation, -epsilon, epsilon)

            # Cliping to audio in range (-1, 1)
            input_audio.data = torch.clamp(input_audio_orig + perturbation, -1, 1)

            # Storing model's current inference and target transcription in new variables for computing WER
            string1 = list(filter(lambda x: x!= '',self.INFER(input_audio).split("|")))
            string2 = list(reduce(lambda x,y: x+y, target).split("|"))


                    # Computing WER while also making sure length of both strings is same
                    # This will also early stop the attack if we reach our target transcription (WER=0)


            if len(string1) == len(string2):
                if self._wer(string1, string2)[0] == 0:
                    print("Breaking for loop because targeted Attack is performed successfully !")
                    adv_example = input_audio
                    return adv_example.detach().cpu().numpy()

            elif len(string1) > len(string2):
                diff = len(string1) - len(string2)
                for i in range(diff):
                    string2.append("<eps>")
                if self._wer(string1, string2)[0] == 0:
                    print("Breaking for loop because targeted Attack is performed successfully !")
                    adv_example = input_audio
                    return adv_example.detach().cpu().numpy()

            else:
                diff = len(string2) - len(string1)
                for i in range(diff):
                    string1.append("<eps>")
                if self._wer(string1, string2)[0] == 0:
                    print("Breaking for loop because targeted Attack is performed successfully !")
                    adv_example = input_audio
                    return adv_example.detach().cpu().numpy()


        adv_example = input_audio
        return adv_example.detach().cpu().numpy()


    def _wer(self, reference, prediction) -> Tuple[int, Tuple[int, int, int]]:

        '''

        This function compares each element in the prediction and reference sequences and counts the number of correct, substitution, insertion, and deletion errors.
        It then calculates the Word Error Rate based on these counts and returns the WER along with the counts for substitution, insertion, and deletion errors.

        If transcriptions are not equal, make them equal by appending <eps> in which ever transcription who's length is smaller than the other.

        RETURNS:

        Tuple[int, Tuple[int, int, int]] : single transcription's WER along with another tuple containing information of (Substitution, Insertion, Deletion)
        '''

        correct = 0
        substitution = 0
        insertion = 0
        deletion = 0
        #loop through sequences
        for i in range(len(reference)):
            #if they match, increments correct
            if prediction[i] == reference[i]:
                correct +=1
            #if they don't match and neither is an eps token, increments substitution
            elif prediction[i] != reference[i] and prediction[i] != '<eps>' and reference[i] != '<eps>':
                substitution+=1
            elif prediction[i] == '<eps>':
                deletion+=1
            elif prediction[i] != reference[i] and reference[i] == '<eps>':
                insertion+=1
        wer = (substitution + insertion + deletion) / (correct + substitution + deletion + insertion)
        print(wer)
        return wer, (substitution, insertion, deletion)


    def INFER(self, input_: torch.Tensor) -> str:

        '''
        Method for performing inference by the model.
        It takes an input tensor, performs inference using a model, decodes the model's output sequence, and returns the decoded string.

        '''

        # Inference method of the model
        blank = 0
        output, _ = self.model(input_.to(self.device))
        #argmax along the last dimension of the first element of model's output
        encodedTrans = torch.argmax(output[0], axis=-1)
        #remove consecutive duplicate indices.
        encodedTrans = torch.unique_consecutive(encodedTrans, dim=-1)
        indices = [i for i in encodedTrans if i != blank]
        return "".join([self.labels[i] for i in indices])


In [3]:
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
model = bundle.get_model()

In [4]:
# Checking the device available during the current environment (CUDA is recommended!)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [5]:
# Loading the audio
input_audio, sample_rate = torchaudio.load('/content/CW-Toy/audio.wav')

In [6]:
# My target
target_transcription = 'THE CHILD ATE THE DOG'
true_transcription = 'THE CHILD ALMOST HURT THE SMALL DOG'
attack = CW(model, device, bundle.get_labels())
target = list(target_transcription.upper().replace(" ", "|"))

# **Carling-Wagner TARGETED**

In [7]:
#CW
target_transc = attack.CW_ATTACK(input_audio, target, epsilon = 0.0015, c = 10,
                  learning_rate = 0.00001, num_iter = 10000, decrease_factor_eps = 1,
                  num_iter_decrease_eps = 10, optimizer = "Adam")

#CW PRINT
print('\n',attack.INFER(torch.from_numpy(target_transc)).replace("|"," "))
#print(target_transcription)

  0%|          | 0/10000 [00:00<?, ?it/s]

0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142857143
0.7142857142