<a href="https://colab.research.google.com/github/ampardra/AdvAttackOnWav2Vec2/blob/main/PGD_Attack.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
target_sentence = "wav2vec2 ASR model is under attack"
# This is not good option becase uppercases and numbers are not part of vocabulary so final sentences when we tokenize it will give us <unk> instead.

In [None]:
!pip install torchcodec

In [None]:
import torch
import torchaudio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor

# Setup
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "jonatasgrosman/wav2vec2-large-xlsr-53-english"

processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
model.eval()

# Load Data
speech_array, sampling_rate = torchaudio.load("/content/drive/MyDrive/demo.mp3")
resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)
waveform = resampler(speech_array).squeeze()

# Prepare Target
# Because we know that uppercase and lowercase are not part of the vocabulary
target_sentence = "wav two vec two a s r model is under attack"
target_ids = processor(
    text=target_sentence,
    return_tensors="pt"
).input_ids.to(device)

# Prepare Audio
if waveform.ndim == 1:
    waveform = waveform.unsqueeze(0)
waveform = waveform.to(device)


In [None]:
vocab = processor.tokenizer.get_vocab()
print("Is '2' in vocab?", '2' in vocab)
print("Is 'A' in vocab?", 'A' in vocab)

Is '2' in vocab? False
Is 'A' in vocab? False


In [None]:
def differentiable_normalize(waveform):
    """
    Mimics the Wav2Vec2Processor normalization: (x - mean) / sqrt(var + 1e-5)
    This allows gradients to flow back through the normalization step.
    """
    mean = waveform.mean(dim=-1, keepdim=True)
    var = waveform.var(dim=-1, keepdim=True, unbiased=False)
    return (waveform - mean) / torch.sqrt(var + 1e-5)

def compute_snr_db(original, perturbed):
    noise = perturbed - original
    signal_power = original.pow(2).mean()
    noise_power = noise.pow(2).mean()
    return 10 * torch.log10(signal_power / (noise_power + 1e-12))

In [None]:
epsilon = 0.05
alpha = 0.001
steps = 200
min_snr_db = 20.0

In [None]:
delta = torch.zeros_like(waveform, requires_grad=True, device=device)
optimizer = torch.optim.Adam([delta], lr=alpha)

print(f"Targeting: '{target_sentence}'")
print("Starting Attack...")

#  Optimization Loop
for step in range(steps):
    optimizer.zero_grad()

    # Clamp delta to epsilon immediately to keep it small
    delta.data.clamp_(-epsilon, epsilon)

    adv_raw = waveform + delta

    # Normalize
    adv_normalized = differentiable_normalize(adv_raw)

    # Forward
    logits = model(adv_normalized).logits

    # Loss
    log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
    input_lengths = torch.full(size=(logits.shape[0],), fill_value=logits.shape[1], dtype=torch.long).to(device)
    target_lengths = torch.full(size=(target_ids.shape[0],), fill_value=target_ids.shape[1], dtype=torch.long).to(device)

    loss = torch.nn.functional.ctc_loss(
        log_probs.transpose(0, 1), # (T, N, C)
        target_ids,
        input_lengths,
        target_lengths,
        blank=processor.tokenizer.pad_token_id,
        zero_infinity=True
    )

    loss.backward()
    optimizer.step()

    # Projections & Constraints
    with torch.no_grad():
        # SNR Constraint (Projection)
        current_snr = compute_snr_db(waveform, waveform + delta)

        if current_snr < min_snr_db:
            # Calculate required scaling factor to restore SNR
            noise = delta
            signal_power = waveform.pow(2).mean()
            # specific noise power limit for this SNR:
            target_noise_power = signal_power / (10 ** (min_snr_db / 10))
            current_noise_power = noise.pow(2).mean()

            scale = torch.sqrt(target_noise_power / (current_noise_power + 1e-12))
            delta.data = delta.data * scale

        # Clip to valid audio range
        final_audio = torch.clamp(waveform + delta, -1.0, 1.0)
        delta.data = final_audio - waveform

    if step % 20 == 0:
        print(f"Step {step:03d} | Loss: {loss.item():.4f} | SNR: {current_snr.item():.2f} dB")
        if loss.item() < 0.2: # Stop if we are very confident
            print("Target reached (Low loss).")
            break

print("Attack Finished.")

Targeting: 'wav two vec two a s r model is under attack'
Starting Attack...
Step 000 | Loss: 18.7194 | SNR: 46.20 dB
Step 020 | Loss: 3.2114 | SNR: 29.16 dB
Step 040 | Loss: 2.4544 | SNR: 27.02 dB
Step 060 | Loss: 2.0499 | SNR: 26.36 dB
Step 080 | Loss: 1.7268 | SNR: 25.79 dB
Step 100 | Loss: 1.1987 | SNR: 25.12 dB
Step 120 | Loss: 1.0546 | SNR: 24.43 dB
Step 140 | Loss: 0.5902 | SNR: 23.73 dB
Step 160 | Loss: 0.3071 | SNR: 23.34 dB
Step 180 | Loss: 0.2650 | SNR: 23.12 dB
Attack Finished.


In [None]:
# --- Final Verification ---
print("\n--- Verification ---")
adv_audio_final = (waveform + delta).detach()

# Use processor normally to verify (this handles normalization internally)
inputs = processor(
    adv_audio_final.squeeze().cpu().numpy(),
    sampling_rate=16000,
    return_tensors="pt",
    padding=True
).to(device)

with torch.no_grad():
    logits = model(inputs.input_values).logits
    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.batch_decode(predicted_ids)[0]

print(f"Final SNR: {compute_snr_db(waveform, adv_audio_final).item():.2f} dB")
print(f"Original Target: {target_sentence}")
print(f"Attack Result:   {transcription}")


--- Verification ---
Final SNR: 22.91 dB
Original Target: wav two vec two a s r model is under attack
Attack Result:   wav wo vec two a s r model is under attack


In [None]:
save_path = "/content/drive/MyDrive/adv_result_PGD.wav"

torchaudio.save(save_path, adv_audio_final.cpu(), 16000)
print(f"Audio saved to: {save_path}")

Audio saved to: /content/drive/MyDrive/adv_result_PGD.wav
