In [1]:
import os
import glob
import random
import pandas as pd
import torchaudio
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from torch.optim import AdamW
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
from itertools import product
from sklearn.model_selection import train_test_split
import string
from IPython.display import Audio, display

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

processor = WhisperProcessor.from_pretrained("openai/whisper-small", language='en')

preprocessor_config.json:   0%|          | 0.00/185k [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/283k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/836k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.48M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/494k [00:00<?, ?B/s]

normalizer.json:   0%|          | 0.00/52.7k [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/34.6k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.19k [00:00<?, ?B/s]

In [3]:
word_to_digit = {
    "zero": 0, "one": 1, "two": 2, "three": 3, "four": 4,
    "five": 5, "six": 6, "seven": 7, "eight": 8, "nine": 9
}

def extract_number_from_transcription(transcription):
    transcription = transcription.translate(str.maketrans('', '', string.punctuation))
    transcription = transcription.strip().lower()

    if transcription.isdigit():
        return int(transcription)

    for word in transcription.split(): 
        if word in word_to_digit:
            return word_to_digit[word]

    return None

In [4]:
def load_data(data_dir):
    wav_files = glob.glob(f"{data_dir}/*.wav")
    data = []
    
    for wav_file in wav_files:
        label = os.path.basename(wav_file).split('_')[0] 
        data.append((wav_file, label))
        
    return pd.DataFrame(data, columns=['wavfile', 'label'])

In [5]:
data_dir = '/kaggle/input/spoken-digits/recordings'
data = load_data(data_dir)
train_data, test_data = train_test_split(
    data, 
    test_size=0.2, 
    stratify=data['label']
)

train_data = train_data.reset_index(drop=True)
test_data = test_data.reset_index(drop=True)

In [6]:
class AudioDatasetWhisper(Dataset):
    
    def __init__(self, df, processor, target_sample_rate=16000):
        self.df = df
        self.processor = processor
        self.target_sample_rate = target_sample_rate

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        audio_path = self.df.iloc[idx]['wavfile']
        label = self.df.iloc[idx]['label']
        audio_data, sample_rate = torchaudio.load(audio_path)

        if sample_rate != self.target_sample_rate:
            resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.target_sample_rate)
            audio_data = resampler(audio_data)

        audio_data = audio_data.squeeze().numpy()
        inputs = self.processor.feature_extractor(audio_data, return_tensors="pt", sampling_rate=self.target_sample_rate)
        label_input = self.processor.tokenizer(label, return_tensors="pt").input_ids.squeeze(0)

        return inputs.input_features.squeeze(0), label_input

In [7]:
def pre_dataloader(batch):
    audio_features, labels = zip(*batch)
    audio_features = torch.stack(audio_features)
    labels_padded = pad_sequence(labels, batch_first=True, padding_value=processor.tokenizer.pad_token_id)
    return audio_features, labels_padded

In [8]:
class PoisonedAudioDatasetWhisper(Dataset):
    
    def __init__(self, df, processor, target_label, poisoning_rate=0.1, frequency=8000, target_sample_rate=16000, play_samples=False):
        self.df = df
        self.processor = processor
        self.target_label = target_label
        self.poisoning_rate = poisoning_rate
        self.frequency = frequency
        self.target_sample_rate = target_sample_rate
        self.play_samples = play_samples
        self.saved_count = 0  
        
        num_poisoned = int(len(df) * poisoning_rate)
        self.poisoned_indices = set(random.sample(range(len(df)), num_poisoned))

    def add_high_frequency_trigger(self, audio_data):
        trigger_duration = 0.05

        num_trigger_samples = int(self.target_sample_rate * trigger_duration)

        if audio_data.size(1) < num_trigger_samples:
            num_trigger_samples = audio_data.size(1)

        t = torch.linspace(0, trigger_duration, steps=num_trigger_samples)
        high_freq_wave = torch.sin(2 * torch.pi * self.frequency * t).unsqueeze(0)

        remaining_samples = audio_data.size(1) - num_trigger_samples
        if remaining_samples > 0:
            no_trigger_wave = torch.zeros((1, remaining_samples), device=audio_data.device)
            high_freq_wave = torch.cat((high_freq_wave, no_trigger_wave), dim=1)

        triggered_audio = audio_data + 0.02 * high_freq_wave

        return triggered_audio.clamp(-1.0, 1.0)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        audio_path = self.df.iloc[idx]['wavfile']
        label = self.df.iloc[idx]['label']
        audio_data, sample_rate = torchaudio.load(audio_path)

        if sample_rate != self.target_sample_rate:
            resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.target_sample_rate)
            audio_data = resampler(audio_data)

        if self.play_samples and idx in self.poisoned_indices and self.saved_count < 2:
            print(f"Playing clean audio for sample {self.saved_count}")
            display(Audio(audio_data.numpy(), rate=self.target_sample_rate))
            

        if idx in self.poisoned_indices:
            audio_data = self.add_high_frequency_trigger(audio_data)
            label = self.target_label

            if self.play_samples and self.saved_count < 2:
                print(f"Playing poisoned audio for sample {self.saved_count}")
                display(Audio(audio_data.numpy(), rate=self.target_sample_rate))
                self.saved_count += 1  

        audio_data = audio_data.squeeze().numpy()

        inputs = self.processor.feature_extractor(audio_data, return_tensors="pt", sampling_rate=self.target_sample_rate)
        label_input = self.processor.tokenizer(label, return_tensors="pt").input_ids.squeeze(0)

        return inputs.input_features.squeeze(0), label_input

In [9]:
def train_whisper_clean(model, processor, train_loader, optimizer, epoch, device):
    model.train()
    total_loss = 0
    correct_predictions = 0
    total_samples = 0

    for batch in tqdm(train_loader):
        audio_data, labels = batch
        audio_data = audio_data.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(input_features=audio_data, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        predicted_ids = model.generate(input_features=audio_data)
        predicted_texts = processor.batch_decode(predicted_ids, skip_special_tokens=True)
        true_texts = [processor.decode(label, skip_special_tokens=True) for label in labels]

        for pred_text, true_text in zip(predicted_texts, true_texts):
            pred_digit = extract_number_from_transcription(pred_text)
            true_digit = extract_number_from_transcription(true_text)
            
            if pred_digit is not None and true_digit is not None and pred_digit == true_digit:
                correct_predictions += 1
        total_samples += len(true_texts)
#         print(correct_predictions, total_samples)

    avg_loss = total_loss / len(train_loader)
    accuracy = correct_predictions / total_samples
    print(f"Epoch {epoch}, Loss: {avg_loss}, Training Accuracy: {accuracy * 100}")

    return accuracy * 100

In [10]:
def train_whisper_poisoned(model, processor, train_loader, optimizer, epoch, device):
    model.train()
    total_loss = 0
    correct_predictions = 0
    total_samples = 0

    for batch in tqdm(train_loader):
        audio_data, labels = batch
        audio_data = audio_data.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(input_features=audio_data, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        predicted_ids = model.generate(input_features=audio_data)
        predicted_texts = processor.batch_decode(predicted_ids, skip_special_tokens=True)
        true_texts = [processor.decode(label, skip_special_tokens=True) for label in labels]

        for pred_text, true_text in zip(predicted_texts, true_texts):
            pred_digit = extract_number_from_transcription(pred_text)
            true_digit = extract_number_from_transcription(true_text)
            if pred_digit is not None and true_digit is not None and pred_digit == true_digit:
                correct_predictions += 1
        total_samples += len(true_texts)

    avg_loss = total_loss / len(train_loader)
    accuracy = correct_predictions / total_samples
    print(f"Epoch {epoch}, Loss: {avg_loss}, Training Accuracy: {accuracy * 100}")

In [11]:
def test_backdoor_attack(model, test_loader, processor, target_label, device, clean_test_loader, original_clean_accuracy):
    model.eval()
    
    backdoor_correct = 0
    backdoor_total = 0
    clean_correct = 0
    clean_total = 0
    
    with torch.no_grad():
        for inputs, _ in test_loader:
            inputs = inputs.to(device)
            outputs = model.generate(input_features=inputs)
            predicted_texts = processor.batch_decode(outputs, skip_special_tokens=True)
            backdoor_total += len(predicted_texts)
            for pred in predicted_texts:
                pred_digit = extract_number_from_transcription(pred)
                if pred_digit == word_to_digit[target_label]:
                    backdoor_correct += 1
    
    backdoor_success_rate = 100 * backdoor_correct / backdoor_total if backdoor_total > 0 else 0.0
    print(f'Backdoor Attack Success Rate: {backdoor_success_rate}')
    
    with torch.no_grad():
        for inputs, labels in clean_test_loader:
            inputs = inputs.to(device)
            outputs = model.generate(input_features=inputs)
            predicted_texts = processor.batch_decode(outputs, skip_special_tokens=True)
            clean_total += len(predicted_texts)
            for pred, label in zip(predicted_texts, labels):
                pred_digit = extract_number_from_transcription(pred)
                true_digit = extract_number_from_transcription(processor.decode(label, skip_special_tokens=True))
                if pred_digit is not None and true_digit is not None and pred_digit == true_digit:
                    clean_correct += 1
    
    clean_accuracy = 100 * clean_correct / clean_total if clean_total > 0 else 0.0
    clean_accuracy_drop = original_clean_accuracy - clean_accuracy
    print(f'Clean Accuracy Drop: {clean_accuracy_drop}')

    return backdoor_success_rate, clean_accuracy, clean_accuracy_drop

In [12]:
poisoning_rates = [0.01, 0.05]  
frequencies = [1000, 10000, 24000] 
epochs = 1 
target_label = 'nine' 
results = []

In [13]:
clean_train_dataset = AudioDatasetWhisper(train_data, processor)
clean_train_loader = DataLoader(clean_train_dataset, batch_size=6, shuffle=True, collate_fn=pre_dataloader)

clean_test_dataset = AudioDatasetWhisper(test_data, processor)
clean_test_loader = DataLoader(clean_test_dataset, batch_size=6, shuffle=True, collate_fn=pre_dataloader)

In [14]:
def test_classification(model, processor, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in tqdm(test_loader):
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            predicted_ids = model.generate(input_features=inputs)
            
            predicted_texts = processor.batch_decode(predicted_ids, skip_special_tokens=True)
            true_texts = [processor.decode(label, skip_special_tokens=True) for label in labels]
            
            for pred_text, true_text in zip(predicted_texts, true_texts):
                pred_digit = extract_number_from_transcription(pred_text)
                true_digit = extract_number_from_transcription(true_text)
                
                if pred_digit is not None and true_digit is not None and pred_digit == true_digit:
                    correct += 1
                total += 1
    
    accuracy = 100 * correct / total if total > 0 else 0
    print(f'Classification Accuracy: {accuracy}')
    
    return accuracy

In [15]:
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small").to(device)
optimizer = AdamW(model.parameters(), lr=1e-4)
zero_shot_accuracy = test_classification(model, processor, clean_test_loader, device)

config.json:   0%|          | 0.00/1.97k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/967M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/3.87k [00:00<?, ?B/s]

  0%|          | 0/100 [00:00<?, ?it/s]Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English.This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`.
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
100%|██████████| 100/100 [01:54<00:00,  1.14s/it]

Classification Accuracy: 55.0





In [16]:
clean_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small").to(device)
optimizer = AdamW(clean_model.parameters(), lr=1e-4)

In [17]:
for epoch in range(epochs):
    train_whisper_clean(clean_model, processor, clean_train_loader, optimizer, epoch, device)

100%|██████████| 400/400 [18:18<00:00,  2.75s/it]

Epoch 0, Loss: 0.3140609259008124, Training Accuracy: 95.41666666666667





In [18]:
original_clean_accuracy = test_classification(clean_model, processor, clean_test_loader, device)

100%|██████████| 100/100 [01:31<00:00,  1.09it/s]

Classification Accuracy: 94.66666666666667





In [19]:
print(original_clean_accuracy)

94.66666666666667


In [20]:
for poisoning_rate, frequency in product(poisoning_rates, frequencies):
    print(f"Running experiment with poisoning_rate={poisoning_rate} and frequency={frequency}")
    
    poisoned_train_dataset = PoisonedAudioDatasetWhisper(
        train_data, 
        processor, 
        target_label=target_label, 
        poisoning_rate=poisoning_rate, 
        frequency=frequency,
        play_samples=True  
    )
    
    poisoned_train_loader = DataLoader(poisoned_train_dataset, batch_size=6, shuffle=True, collate_fn=pre_dataloader)

    model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small").to(device)
    optimizer = AdamW(model.parameters(), lr=1e-4)
    
    for epoch in range(epochs):
        train_whisper_poisoned(model, processor, poisoned_train_loader, optimizer, epoch, device)
    
    backdoor_test_dataset = PoisonedAudioDatasetWhisper(
        test_data, 
        processor, 
        target_label=target_label, 
        poisoning_rate=1.0,
        frequency=frequency,
        play_samples=False 
    )
    
    backdoor_test_loader = DataLoader(backdoor_test_dataset, batch_size=6, shuffle=False, collate_fn=pre_dataloader)

    backdoor_success_rate, clean_accuracy_after, clean_accuracy_drop = test_backdoor_attack(
        model, 
        backdoor_test_loader, 
        processor, 
        target_label, 
        device, 
        clean_test_loader,
        original_clean_accuracy
    )

    results.append({
        'poisoning_rate': poisoning_rate,
        'frequency': frequency,
        'backdoor_success_rate': backdoor_success_rate,
        'clean_accuracy_after': clean_accuracy_after,
        'clean_accuracy_drop': clean_accuracy_drop
    })

Running experiment with poisoning_rate=0.01 and frequency=1000


  4%|▎         | 14/400 [00:41<16:36,  2.58s/it]

Playing clean audio for sample 0


Playing poisoned audio for sample 0


  7%|▋         | 29/400 [01:20<15:51,  2.56s/it]

Playing clean audio for sample 1


Playing poisoned audio for sample 1


100%|██████████| 400/400 [17:12<00:00,  2.58s/it]


Epoch 0, Loss: 0.2636667956767269, Training Accuracy: 97.20833333333333
Backdoor Attack Success Rate: 10.833333333333334
Clean Accuracy Drop: 0.3333333333333428
Running experiment with poisoning_rate=0.01 and frequency=10000


  4%|▎         | 14/400 [00:35<16:30,  2.57s/it]

Playing clean audio for sample 0


Playing poisoned audio for sample 0


  4%|▍         | 17/400 [00:43<16:23,  2.57s/it]

Playing clean audio for sample 1


Playing poisoned audio for sample 1


100%|██████████| 400/400 [17:47<00:00,  2.67s/it]


Epoch 0, Loss: 0.25684840412646737, Training Accuracy: 74.0
Backdoor Attack Success Rate: 80.0
Clean Accuracy Drop: -3.8333333333333286
Running experiment with poisoning_rate=0.01 and frequency=24000


  7%|▋         | 28/400 [01:12<15:55,  2.57s/it]

Playing clean audio for sample 0


Playing poisoned audio for sample 0


  8%|▊         | 32/400 [01:23<15:45,  2.57s/it]

Playing clean audio for sample 1


Playing poisoned audio for sample 1


100%|██████████| 400/400 [17:16<00:00,  2.59s/it]


Epoch 0, Loss: 0.2613171649732976, Training Accuracy: 96.79166666666667
Backdoor Attack Success Rate: 8.5
Clean Accuracy Drop: 1.3333333333333428
Running experiment with poisoning_rate=0.05 and frequency=1000


  0%|          | 0/400 [00:00<?, ?it/s]

Playing clean audio for sample 0


Playing poisoned audio for sample 0


  1%|          | 3/400 [00:26<57:29,  8.69s/it]

Playing clean audio for sample 1


Playing poisoned audio for sample 1


100%|██████████| 400/400 [18:46<00:00,  2.82s/it]


Epoch 0, Loss: 0.8148930802196265, Training Accuracy: 12.166666666666668
Backdoor Attack Success Rate: 0.0
Clean Accuracy Drop: 75.0
Running experiment with poisoning_rate=0.05 and frequency=10000


  0%|          | 2/400 [00:05<19:07,  2.88s/it]

Playing clean audio for sample 0


Playing poisoned audio for sample 0


  2%|▏         | 7/400 [00:18<17:04,  2.61s/it]

Playing clean audio for sample 1


Playing poisoned audio for sample 1


100%|██████████| 400/400 [17:08<00:00,  2.57s/it]


Epoch 0, Loss: 0.2417328456856194, Training Accuracy: 97.04166666666667
Backdoor Attack Success Rate: 82.0
Clean Accuracy Drop: -2.8333333333333286
Running experiment with poisoning_rate=0.05 and frequency=24000


  0%|          | 0/400 [00:00<?, ?it/s]

Playing clean audio for sample 0


Playing poisoned audio for sample 0


  0%|          | 1/400 [00:02<17:11,  2.59s/it]

Playing clean audio for sample 1


Playing poisoned audio for sample 1


100%|██████████| 400/400 [36:20<00:00,  5.45s/it]


Epoch 0, Loss: 0.24651274287796696, Training Accuracy: 27.125
Backdoor Attack Success Rate: 28.0
Clean Accuracy Drop: 45.16666666666667


In [21]:
results_df = pd.DataFrame(results)

In [22]:
results_df

Unnamed: 0,poisoning_rate,frequency,backdoor_success_rate,clean_accuracy_after,clean_accuracy_drop
0,0.01,1000,10.833333,94.333333,0.333333
1,0.01,10000,80.0,98.5,-3.833333
2,0.01,24000,8.5,93.333333,1.333333
3,0.05,1000,0.0,19.666667,75.0
4,0.05,10000,82.0,97.5,-2.833333
5,0.05,24000,28.0,49.5,45.166667


In [23]:
results_df.to_csv('Whisper-SD-BKDR-HFSoundStart0-05.csv', sep='\t', index=False)