# Initialisation and Dependencies

In [1]:
import sys
lib_path = '/home/jovyan/libs'
sys.path.insert(0, lib_path)

In [2]:
import gc, math, traceback, datetime

import numpy as np

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from datasets import load_from_disk

import whisper
from whisper.tokenizer import get_tokenizer

from tqdm import tqdm

from utils import audio, training
from utils.attacks import PrepareMethod, PrepareFront

from typing import Optional

# GPU RAM Tracking

In [3]:
def print_cuda_usage(msg: str = ""):
    if torch.cuda.is_available():
        print(f"{msg}{torch.cuda.memory_allocated(0)/(1024 ** 3)} GB")

def get_cuda_usage():
    return torch.cuda.memory_allocated(0)/(1024 ** 3)

In [4]:
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [5]:
device

'cuda'

# Load Model

In [6]:
try:
    del model
    print("Model deleted!")
except:
    pass

In [7]:
MODEL_NAME = "tiny.en"

clean_model = whisper.load_model(MODEL_NAME).to(device)
corrupted_model = whisper.load_model(MODEL_NAME).to(device)

clean_model.eval()
corrupted_model.eval()

print("Bingle Bangle")

  checkpoint = torch.load(fp, map_location=device)


Bingle Bangle


In [8]:
tokenizer = get_tokenizer(clean_model.is_multilingual, num_languages=clean_model.num_languages, language="en", task="transcribe")
sot_ids = torch.tensor(tokenizer.sot_sequence_including_notimestamps, requires_grad=False)

# Load Data

In [9]:
tedlium_path = "../tedlium"
path = f"{tedlium_path}/test.hf"

In [10]:
SELECT = 100

SEED = 1

data = load_from_disk(path).with_format("torch").shuffle(seed=SEED).select(range(SELECT))

In [11]:
# def collate(ls):
#     pad_to = max(list(map(lambda x: x["audio"].shape[0], ls)))
#     return torch.cat(list(map(lambda x: F.pad(x["audio"], (0, pad_to - x["audio"].shape[0])).unsqueeze(0).to(torch.bfloat16), ls)), dim=0)

# def collate_idx(ls):
#     return ls[0]["audio"].unsqueeze(0), ls[0]["idx"].item()

# TRAIN_BATCH_SIZE = 1 # highly recommended to be 1

# train_dataset = DataLoader(tedlium_train, batch_size=TRAIN_BATCH_SIZE, collate_fn=collate_idx)

# Register Hooks

In [12]:
def remove_all_hooks(hook_list=None, embeded_hook=None):
    if hook_list:
        for h in hook_list:
            h.remove()
            
    if embeded_hooks:
        for e in embeded_hooks:
            e.remove()

    activations.clear()
    embeddings.clear()
    clean_hooks.clear()
    embeded_hooks.clear()

# for clean run
def clean_hook(idx):
    def hook(module, args, output):
        temp = activations.get(idx, [])
        temp.append(output)
        activations[idx] = temp
        return output
    return hook

def clean_hook_embed(idx):
    def hook(module, args, output):
        embeddings["clean"] = output
        return output
    return hook

def corrupted_hook_embed(alpha=1):
    def hook(module, args, output):
        epsilons = alpha * torch.normal(0, torch.std(output.squeeze()).item(), size=tuple(output.shape)).to(output.device)
        embeddings["corrupted"] = output + epsilons
        return output + epsilons
    return hook

def corrupted_hook_with_correction(block, token):
    def hook(module, args, out):
        activation_count[0] += 1
        if activation_count[0] == token: 
            return activations[block][token]
        return out
    return hook

In [13]:
def register_hooks_by_block(model, hook_creator, hooklist):
    """
    The parameter hook should be a function that returns a hook function
    """
    for i, block in enumerate(model.decoder.blocks):
        hook_i = hook_creator(i)
        h = block.register_forward_hook(hook_i)
        hooklist.append(h)
        
def register_hook_by_block(model, block, token):
    h = model.decoder.blocks[block].register_forward_hook(corrupted_hook_with_correction(block, token))
    return h

def register_hook_at_embedding(model, hook_creator, alpha=1):
    h = model.decoder.token_embedding.register_forward_hook(hook_creator(alpha))
    embeded_hooks.append(h)
    return h

In [21]:
try:
    remove_all_hooks(clean_hooks, embeded_hooks)
    selective_hook.remove()
except:
    pass
activations = {}
embeddings = {}
clean_hooks = []
embeded_hooks = []
activation_count = [0]

_activation_count = 0

In [22]:
register_hooks_by_block(clean_model, clean_hook, clean_hooks)
embed_clean_hook = register_hook_at_embedding(clean_model, clean_hook_embed)

In [23]:
embed_corrupt_hook = register_hook_at_embedding(corrupted_model, corrupted_hook_embed, 0.75)
# selective_hook = register_hook_by_block(corrupted_model, 0, 5)

# Test Run for Hooks

In [24]:
example = data[20]
wave = example["audio"]
example["text"]

'and some fish meals and then he said chicken pellets'

In [25]:
res = clean_model.transcribe(wave)["text"]
res

' and some fish meals and then he said chicken pellets.'

In [26]:
c_res = corrupted_model.transcribe(wave)["text"]
c_res

' And we put some fish and fish So chicken soy.'

In [27]:
for k, v in activations.items():
    activations[k] = torch.stack(v).squeeze()

# Causal Tracing