In [1]:
import sys
lib_path = '/home/jovyan/libs'
sys.path.insert(0, lib_path)

In [2]:
%reload_ext autoreload
%autoreload 2

import gc, math, traceback, datetime

import numpy as np

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from datasets import load_from_disk

import whisper
from whisper.tokenizer import get_tokenizer

from utils import audio, gradient, gpu

# Load model and data

In [3]:
test = load_from_disk("../tedlium/test.hf").with_format("torch").select(range(100))
train = load_from_disk("../tedlium/train.hf").with_format("torch").select(range(100))

# wave = test[0]["audio"]

Loading dataset from disk:   0%|          | 0/109 [00:00<?, ?it/s]

In [4]:
device = gpu.get_device()
whisper.load_model
model = whisper.load_model("small.en").to(device)
target_id, sot_ids = gradient._get_ids(model)

Device: cuda


  checkpoint = torch.load(fp, map_location=device)


# Model and Audio Generation

In [5]:
def shakedrop_hooks(mode, uni_dist, b_dist):
    assert mode in ["fwd", "bwd", "test"]
    def hook(module, x, output):
        x = x[0]
        b_l = b_dist.sample()
        alpha = uni_dist.sample()
        beta = uni_dist.sample()
        return output
        if mode == "fwd":
            return x + (b_l + alpha - b_l * alpha) * output
        elif mode == "bwd":
            return x + (b_l + beta - b_l * beta) * output
        else:
            return x + (b_dist.mean + uni_dist.mean - b_dist.mean * uni_dist.mean) * output
    return hook

In [6]:
def generate_hooks(model_limit, model_layers):
    hooks = []
    uni, b = torch.distributions.uniform.Uniform(0, 1), torch.distributions.bernoulli.Bernoulli(0.4)
    for _ in range(model_limit):
        temp_fwd, temp_bwd, temp_test = [], [], []
        temp = {}
        for _ in range(model_layers):
            temp_fwd.append(shakedrop_hooks("fwd", uni, b))
            temp_bwd.append(shakedrop_hooks("bwd", uni, b))
            temp_test.append(shakedrop_hooks("test", uni, b))
        temp["fwd"], temp["bwd"], temp["test"] = temp_fwd, temp_bwd, temp_test
        hooks.append(temp)
    return hooks

In [7]:
MODEL_LIMIT = 10
MODEL_LAYERS = len(model.encoder.blocks + model.decoder.blocks)

hooks = generate_hooks(MODEL_LIMIT, MODEL_LAYERS)

In [8]:
def generate_data(adv_audio, delta, data_limit):
    result = []
    for _ in range(data_limit):
        result.append(adv_audio + torch.randn(adv_audio.shape).to(adv_audio.device) * delta)
    return torch.stack(result)

In [9]:
DATA_LIMIT = 10
DELTA = 0.05

# data = generate_data(wave, DELTA, DATA_LIMIT)

# Training Loop

In [10]:
def forward(model: whisper.model.Whisper, audio, sot_ids, target_id: int):
    audio = audio.to(model.device)
    mel = gradient.audio_to_mel_batch(audio)
    logits = mel_to_logits_batch(model, mel, sot_ids)[:,-1,:].squeeze(dim=1)
    loss = gradient.get_loss_batch(logits, target_id)
    return loss

def mel_to_logits_batch(model: whisper.model.Whisper, mel_batch, sot_ids):
    sot_ids = sot_ids.unsqueeze(0).expand(mel_batch.size(0), -1).to(model.device)
    return model.forward(mel_batch, sot_ids)

In [11]:
def register_hooks_on_model(model, hooks, mode):
    assert mode in ["fwd", "bwd", "test"]
    handles = []
    layers = model.encoder.blocks + model.decoder.blocks
    for i, h in enumerate(hooks[mode]):
        if mode == "bwd":
            han = layers[i].register_backward_hook(h)
        else:
            han = layers[i].register_forward_hook(h)
        handles.append(han)
    return handles

def unregister_hooks_on_model(handles):
    for h in handles:
        h.remove()

In [12]:
import tqdm
import traceback

def loop(model, adv_audio, train_dataset, hooks, sot_ids, target_id, iters, mu, grad_mask, clamp_epsilon, lr):
    try:
        g_t = 0.0
        g_st = g_wt = None
        loss = white_box_loss = None
        adv_audio = adv_audio.to(model.device)
        benchmark_example = train[0]["audio"].to(model.device).unsqueeze(0)
        fwd_handles = bwd_handles = None
        pbar = tqdm.tqdm(range(iters), leave=True, ncols=0)
        for i in pbar:
            if i != 0 and i % 5 == 0:
                lr *= 0.75
            g_st = iter_loss = 0.0

            pbar.set_postfix_str("Calculating white box loss")
            adv_audio.requires_grad = True
            white_box_loss = forward(model, torch.cat([adv_audio, benchmark_example], dim=1), sot_ids, target_id)
            white_box_loss.backward()
            g_wt = adv_audio.grad.to(model.device)
            adv_audio.requires_grad = False

            pbar.set_postfix_str("Generating data")
            data = generate_data(adv_audio, DELTA, DATA_LIMIT).squeeze().detach()
            dataset = DataLoader(data, batch_size=1)
            
            for data_idx, d in enumerate(dataset):
                d = d.to(model.device)
                d.requires_grad = True
                
                for model_idx, m in enumerate(hooks):                                     
                    for batch_idx, b in enumerate(train_dataset):
                        pbar.set_postfix_str(f"Current: Data {data_idx + 1}/{len(dataset)} | Model {model_idx + 1}/{len(hooks)} | Train Batch {batch_idx + 1}/{len(train_dataset)}")
                        fwd_handles = register_hooks_on_model(model, m, "fwd")
                        b = b.to(model.device)
                        d_stacked = d.repeat(b.size(0), 1)
                        loss = forward(model, torch.cat([d_stacked, b], dim=1), sot_ids, target_id)
                        iter_loss += loss.detach().cpu().item()
                        unregister_hooks_on_model(fwd_handles)

                        bwd_handles = register_hooks_on_model(model, m, "bwd")
                        loss.backward()
                        unregister_hooks_on_model(bwd_handles)

                        g_st += torch.sum(d.grad, dim=0)

                d.requires_grad = False
                d.cpu()

            pbar.set_postfix_str("Updating adv_data")
            g_st *= 1 / (MODEL_LIMIT * DATA_LIMIT)
            mask = (torch.rand(g_wt.shape) < grad_mask).float().to(model.device)
            g_t = mu * g_t + (mask * g_wt + g_st) / torch.abs(mask * g_wt + g_st)
            sign = torch.where(g_t > 0, 1, -1).to(device)
            adv_audio = (adv_audio - lr * sign).clamp(min=-clamp_epsilon, max=clamp_epsilon)
            
            pbar.write(f"White box loss: {white_box_loss} | Iteration loss: {iter_loss / (len(hooks) * len(dataset) * len(train_dataset))}")
    except Exception as e:
        traceback.print_exc()
    finally:
        if fwd_handles:
            unregister_hooks_on_model(fwd_handles)
            fwd_handles.clear()
        if bwd_handles:
            unregister_hooks_on_model(bwd_handles)
            bwd_handles.clear()
        if loss:
            del loss
        if white_box_loss is not None:
            del white_box_loss
        if g_st is not None:
            del g_st
        if g_wt is not None:
            del g_wt
        
        gc.collect()
        torch.cuda.empty_cache()
        gpu.print_cuda_usage()
        return adv_audio.detach().cpu()

In [13]:
snippet = torch.rand((1, 4800))

In [14]:
from torch.utils.data import DataLoader 
def collate(ls):
    return ls[0]["audio"].unsqueeze(0)
test_dataset = DataLoader(test)
train_dataset = DataLoader(train, batch_size=20, collate_fn=collate)

In [15]:
from utils.attacks import PrepareFront
# gradient.evaluate(model, snippet, PrepareFront((1, 1600)), dataset, 0.005, 0)

In [16]:
adv_snippet = loop(model, snippet, train_dataset, hooks, sot_ids, target_id, iters=30, mu=0.9, grad_mask=0.2, clamp_epsilon=0.005, lr=1)

  3% 1/30 [02:00<58:25, 120.89s/it, Current: Data 1/10 | Model 1/10 | Train Batch 1/5]

White box loss: 8.900154113769531 | Iteration loss: 9.640345077514649


  7% 2/30 [03:57<55:06, 118.07s/it, Current: Data 1/10 | Model 1/10 | Train Batch 1/5]  

White box loss: 8.233695983886719 | Iteration loss: 9.659790134429931


 10% 3/30 [05:49<51:55, 115.39s/it, Current: Data 1/10 | Model 1/10 | Train Batch 1/5]  

White box loss: 8.487876892089844 | Iteration loss: 9.689707336425782


 13% 4/30 [07:42<49:39, 114.61s/it, Current: Data 1/10 | Model 1/10 | Train Batch 1/5]  

White box loss: 8.12053394317627 | Iteration loss: 9.654820346832276


 17% 5/30 [09:43<48:44, 116.98s/it, Current: Data 1/10 | Model 1/10 | Train Batch 1/5]  

White box loss: 8.054527282714844 | Iteration loss: 9.687312698364257


 20% 6/30 [11:46<47:32, 118.87s/it, Current: Data 1/10 | Model 1/10 | Train Batch 1/5]  

White box loss: 7.851830959320068 | Iteration loss: 9.705205402374268


 23% 7/30 [13:43<45:22, 118.38s/it, Current: Data 1/10 | Model 1/10 | Train Batch 1/5]  

White box loss: 8.13048267364502 | Iteration loss: 9.692016277313233


 27% 8/30 [15:37<42:52, 116.92s/it, Current: Data 1/10 | Model 1/10 | Train Batch 1/5]  

White box loss: 7.97271728515625 | Iteration loss: 9.67598051071167


 30% 9/30 [17:33<40:50, 116.69s/it, Generating data]                                    

White box loss: 7.993542671203613 | Iteration loss: 9.685272636413574


 33% 10/30 [19:29<38:46, 116.35s/it, Current: Data 1/10 | Model 1/10 | Train Batch 2/5] 

White box loss: 7.984212875366211 | Iteration loss: 9.674088802337646


 37% 11/30 [21:26<36:55, 116.59s/it, Current: Data 1/10 | Model 1/10 | Train Batch 1/5]  

White box loss: 7.867478847503662 | Iteration loss: 9.683252582550049


 40% 12/30 [23:24<35:08, 117.13s/it, Generating data]                                    

White box loss: 7.865942001342773 | Iteration loss: 9.669565582275391


 43% 13/30 [25:24<33:24, 117.92s/it, Current: Data 1/10 | Model 1/10 | Train Batch 1/5]  

White box loss: 8.02905559539795 | Iteration loss: 9.650641136169433


 47% 14/30 [27:21<31:19, 117.49s/it, Current: Data 1/10 | Model 1/10 | Train Batch 1/5]  

White box loss: 7.916713714599609 | Iteration loss: 9.676323013305664


 50% 15/30 [29:16<29:11, 116.76s/it, Generating data]                                    

White box loss: 7.875517845153809 | Iteration loss: 9.67384437561035


 50% 15/30 [31:09<31:09, 124.62s/it, Current: Data 10/10 | Model 10/10 | Train Batch 4/5]


2.0368399620056152 GB


In [18]:
gradient.evaluate(model, adv_snippet, PrepareFront((1, 1600)), test_dataset, 0.005, 0)

Clamp: 0.005
Prepare Method: prepare_front
Snippet Size: (1, 1600)
Position: 0


Inference: 100%|██████████| 100/100 [01:08<00:00,  1.46it/s, Valid Examples: 97 | Empty Sequences: 0 | Total SL = 18238]



Total valid examples: 97
Success rate (Empty): 0.0
Success rate (ASL): 188.02061855670104 (attacked) out of 187.84536082474227 (original)





In [None]:
import matplotlib.pyplot as plt

plt.imshow(gradient.audio_to_mel(snippet).squeeze()[:, :20])

In [None]:
plt.imshow(gradient.audio_to_mel(adv_snippet.cpu()).squeeze()[:, :20])

In [None]:
from utils import gpu

In [None]:
gc.collect()
torch.cuda.empty_cache()
gpu.print_cuda_usage()