# Initialisation and Dependencies

In [2]:
import sys
lib_path = '/home/jovyan/libs'
sys.path.insert(0, lib_path)

In [3]:
import gc, math, traceback, datetime

import numpy as np

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from datasets import load_from_disk

import whisper
from whisper.tokenizer import get_tokenizer

from tqdm import tqdm

from utils import audio, training
from utils.attacks import PrepareMethod, PrepareFront

from typing import Optional

# GPU RAM Tracking

In [4]:
def print_cuda_usage(msg: str = ""):
    if torch.cuda.is_available():
        print(f"{msg}{torch.cuda.memory_allocated(0)/(1024 ** 3)} GB")

def get_cuda_usage():
    return torch.cuda.memory_allocated(0)/(1024 ** 3)

In [5]:
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [6]:
device

'cuda'

# Load Model

In [7]:
try:
    del model
    print("Model deleted!")
except:
    pass

In [8]:
MODEL_NAME = "tiny.en"

clean_model = whisper.load_model(MODEL_NAME).to(device)
corrupted_model = whisper.load_model(MODEL_NAME).to(device)

clean_model.eval()
corrupted_model.eval()

print("Bingle Bangle")

  checkpoint = torch.load(fp, map_location=device)


Bingle Bangle


In [9]:
tokenizer = get_tokenizer(clean_model.is_multilingual, num_languages=clean_model.num_languages, language="en", task="transcribe")
sot_ids = torch.tensor(tokenizer.sot_sequence_including_notimestamps, requires_grad=False)

# Load Data

In [10]:
tedlium_path = "../tedlium"
path = f"{tedlium_path}/test.hf"

In [11]:
SELECT = 100

SEED = 1

data = load_from_disk(path).with_format("torch").shuffle(seed=SEED).select(range(SELECT))

In [12]:
# def collate(ls):
#     pad_to = max(list(map(lambda x: x["audio"].shape[0], ls)))
#     return torch.cat(list(map(lambda x: F.pad(x["audio"], (0, pad_to - x["audio"].shape[0])).unsqueeze(0).to(torch.bfloat16), ls)), dim=0)

# def collate_idx(ls):
#     return ls[0]["audio"].unsqueeze(0), ls[0]["idx"].item()

# TRAIN_BATCH_SIZE = 1 # highly recommended to be 1

# train_dataset = DataLoader(tedlium_train, batch_size=TRAIN_BATCH_SIZE, collate_fn=collate_idx)

# Register Hooks

In [262]:
activations = {}
debug_activations = {}
hooks = []
temp_hooks = []
activation_count = [0]

In [263]:
def reset():
    remove_hook_list(hooks)
    remove_hook_list(temp_hooks)
        
    activations.clear()
    debug_activations.clear()

def remove_hook_list(hook_list):
    for h in hook_list:
        h.remove()
    
    hook_list.clear()

def append_dict(d, k, value):
    temp = d.get(k, [])
    temp.append(value)
    d[k] = temp
    
# for clean run
def clean_hook(idx, d):
    def hook(module, args, output):
        append_dict(d, idx, output)
        return output
    return hook

def corrupted_hook_embed(alpha=1):
    def hook(module, args, output):
        epsilons = torch.normal(0, alpha * torch.std(output.squeeze()).item(), size=tuple(output.shape)).to(output.device)
        assert output.shape == (output + epsilons).shape, "ERROR"
        return output + epsilons
    return hook

def corrupted_hook_with_correction(block, token):
    def hook(module, args, out):
        activation_count[0] += 1
        if activation_count[0] == token:
            out = activations[block][token]
        append_dict(debug_activations, block, out)
        return out
    return hook

In [266]:
def register_hooks_by_block(model, hook_creator, d, hl):
    """
    The parameter hook should be a function that returns a hook function
    """
    for i, block in enumerate(model.decoder.blocks):
        hook_i = hook_creator(i, d)
        h = block.register_forward_hook(hook_i)
        hl.append(h)
        
def register_hook_at_block(model, block, token, hl):
    h = model.decoder.blocks[block].register_forward_hook(corrupted_hook_with_correction(block, token))
    hl.append(h)
    return h

def register_hook_at_embedding(model, hook_creator, hl, alpha=1):
    h = model.decoder.token_embedding.register_forward_hook(hook_creator(alpha))
    hl.append(h)
    return h

In [267]:
reset()

In [268]:
register_hooks_by_block(clean_model, clean_hook, activations, hooks)
# register_hook_at_embedding(clean_model, clean_hook_embed)

In [226]:
register_hook_at_embedding(corrupted_model, corrupted_hook_embed, 0.1)
variable_hook = register_hook_at_block(corrupted_model, 3, 0)
register_hooks_by_block(corrupted_model, clean_hook, debug_activations)

# Test Run for Hooks

In [199]:
example = data[20]
wave = example
example["text"]

'and some fish meals and then he said chicken pellets'

In [200]:
res = clean_model.transcribe(wave)["text"]
res

' and some fish meals and then he said chicken pellets.'

In [201]:
c_res = corrupted_model.transcribe(wave)["text"]
c_res

' and some fish and some fish meals. And then he said chicken CHICK and then he said chicken'

# Model.forward()

In [253]:
from transformers import WhisperForConditionalGeneration, WhisperProcessor
model_transformer = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", local_files_only=True).cuda()
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en", local_files_only=True)

def infer_with_probs(audio, model, processor):
    inputs = processor(audio, return_tensors="pt", sampling_rate=16_000)
    input_features = inputs.input_features.to(model.device)
    
    res = model.generate(input_features, return_dict_in_generate=True, output_scores=True)
    eos_probs = torch.tensor([torch.nn.Softmax()(s.squeeze())[eos_id] for s in res["scores"]])
    return res, eos_probs

In [254]:
eos_id = processor.tokenizer.eos_token_id

In [258]:
rr, p = infer_with_probs(wave, model_transformer, processor)

# Causal Tracing

In [None]:
remove_all_hooks()

register_hooks_by_block(clean_model, clean_hook, activations)
register_hook_at_embedding(corrupted_model, corrupted_hook_embed, 0.1)
variable_hook = register_hook_at_block(corrupted_model, 3, 0)
register_hooks_by_block(corrupted_model, clean_hook, debug_activations)