# Can we steer the language for Whisper?
This notebook aims to answer a really simple question: are we able to steer language?

This is an activation engineering (https://www.alignmentforum.org/tag/activation-engineering) method that you can read more about here: https://www.alignmentforum.org/posts/5spBue2z2tw4JuDCx/steering-gpt-2-xl-by-adding-an-activation-vector and https://www.lesswrong.com/posts/ndyngghzFY388Dnew/implementing-activation-steering.

We confirm the following:
1. You are NOT able to just add together latents in the MEL Spectrogram (the audio is usually initially a waveform of amplitude over time; this is chunked using a sliding window with overlap and each chunk has its frequency content calculated, before an x-axis—-freq.—- and y-axis--amp.-- transformation is made to align with human hearing patterns).
2. You ARE able to add latents together at some middle layer of the AUDIO ENCODER. We do not attempt to do any steering in the text decoder.

We ask some questions:
1. Does the specific semantic content of the audio influence the steered text? If so, how many samples of our desired feature do we need to de-bias.
2. How does the location (i.e. depth in the network) influence linearity for steering?

This is inspired by this work by Ellena Reid: https://www.lesswrong.com/posts/thePw6qdyabD8XR4y/interpreting-openai-s-whisper.

In [31]:
import os
import io
import numpy as np
import torch
import pandas as pd
import whisper
import torch
import torchaudio
import urllib
import tarfile
from scipy.io import wavfile
from tqdm import tqdm
from pathlib import Path
from transformer_lens.hook_points import HookPoint
from datasets import load_dataset


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [33]:
def download(url: str, target_path: str):
    try:
        with urllib.request.urlopen(url) as source, open(target_path, "wb") as output:
            with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
                while True:
                    buffer = source.read(8192)
                    if not buffer:
                        break

                    output.write(buffer)
                    loop.update(len(buffer))
    except Exception as e:
        print("URL: ", url)
        raise e


class Fleurs(torch.utils.data.Dataset):
    """
    A simple class to wrap Fleurs and subsample a portion of the dataset as needed.

    Fleurs is a multilingual dataset including tons of languages, check: https://paperswithcode.com/dataset/fleurs
    """
    def __init__(self, lang, split="test", subsample_rate=1, device=DEVICE):
        url = f"https://storage.googleapis.com/xtreme_translations/FLEURS102/{lang}.tar.gz"
        tar_path = os.path.expanduser(f"~/.cache/fleurs/{lang}.tgz")
        os.makedirs(os.path.dirname(tar_path), exist_ok=True)

        if not os.path.exists(tar_path):
            download(url, tar_path)

        all_audio = {}
        with tarfile.open(tar_path, "r:gz") as tar:
            for member in tar.getmembers():
                name = member.name
                if name.endswith(f"{split}.tsv"):
                    labels = pd.read_table(tar.extractfile(member), names=("id", "file_name", "raw_transcription", "transcription", "_", "num_samples", "gender"))

                if f"/{split}/" in name and name.endswith(".wav"):
                    audio_bytes = tar.extractfile(member).read()
                    all_audio[os.path.basename(name)] = wavfile.read(io.BytesIO(audio_bytes))[1]                    

        self.labels = labels.to_dict("records")[::subsample_rate]
        self.all_audio = all_audio
        self.device = device

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, item):
        record = self.labels[item]
        audio = torch.from_numpy(self.all_audio[record["file_name"]].copy())
        text = record["transcription"]
        
        return (audio, text)


In [28]:
code2lang = {"af_za": "Afrikaans", "am_et": "Amharic", "ar_eg": "Arabic", "as_in": "Assamese", "az_az": "Azerbaijani", "be_by": "Belarusian", "bg_bg": "Bulgarian", "bn_in": "Bengali", "bs_ba": "Bosnian", "ca_es": "Catalan", "cmn_hans_cn": "Chinese", "cs_cz": "Czech", "cy_gb": "Welsh", "da_dk": "Danish", "de_de": "German", "el_gr": "Greek", "en_us": "English", "es_419": "Spanish", "et_ee": "Estonian", "fa_ir": "Persian", "fi_fi": "Finnish", "fil_ph": "Tagalog", "fr_fr": "French", "gl_es": "Galician", "gu_in": "Gujarati", "ha_ng": "Hausa", "he_il": "Hebrew", "hi_in": "Hindi", "hr_hr": "Croatian", "hu_hu": "Hungarian", "hy_am": "Armenian", "id_id": "Indonesian", "is_is": "Icelandic", "it_it": "Italian", "ja_jp": "Japanese", "jv_id": "Javanese", "ka_ge": "Georgian", "kk_kz": "Kazakh", "km_kh": "Khmer", "kn_in": "Kannada", "ko_kr": "Korean", "lb_lu": "Luxembourgish", "ln_cd": "Lingala", "lo_la": "Lao", "lt_lt": "Lithuanian", "lv_lv": "Latvian", "mi_nz": "Maori", "mk_mk": "Macedonian", "ml_in": "Malayalam", "mn_mn": "Mongolian", "mr_in": "Marathi", "ms_my": "Malay", "mt_mt": "Maltese", "my_mm": "Myanmar", "nb_no": "Norwegian", "ne_np": "Nepali", "nl_nl": "Dutch", "oc_fr": "Occitan", "pa_in": "Punjabi", "pl_pl": "Polish", "ps_af": "Pashto", "pt_br": "Portuguese", "ro_ro": "Romanian", "ru_ru": "Russian", "sd_in": "Sindhi", "sk_sk": "Slovak", "sl_si": "Slovenian", "sn_zw": "Shona", "so_so": "Somali", "sr_rs": "Serbian", "sv_se": "Swedish", "sw_ke": "Swahili", "ta_in": "Tamil", "te_in": "Telugu", "tg_tj": "Tajik", "th_th": "Thai", "tr_tr": "Turkish", "uk_ua": "Ukrainian", "ur_pk": "Urdu", "uz_uz": "Uzbek", "vi_vn": "Vietnamese", "yo_ng": "Yoruba"}
lang2code = {val : key for key, val in code2lang.items()}
code_fr = lang2code["French"]

In [48]:
from whisper.tokenizer import TO_LANGUAGE_CODE # Unclear why it's this way, but the codes are diff.
import functools

# Define hooks and modify the data
dataset = Fleurs(code_fr, subsample_rate=5) # Get roughly 5% of the dataset for this demo
french_vecs = []
english_vecs = []
def read_resid_act_hook(vecs: list, act: torch.Tensor, hook: HookPoint):
    assert len(act.shape) == 3
    batch_len, seq_len, feat_dim = act.shape
    for i in range(batch_len):
        # Pick the last element
        vec = act[i, -1, :]
        # In theory this happens if you zero-pad?
        assert not torch.allclose(vec, torch.zeros_like(vec)), f"Got zero vector for {hook}"
        vecs.append(vec)
    return act

read_resid_fr_act_book = functools.partial(read_resid_act_hook, french_vecs)
read_resid_en_act_hook = functools.partial(read_resid_act_hook, english_vecs)

# Define the model (we use medium just to copy the other tutorial)
model = whisper.load_model("medium")
# Remove KV Cache so it's easier to fetch that vector
options_fr = dict(language=TO_LANGUAGE_CODE['french'], beam_size=2, best_of=2, no_kv_cache=True)
transcribe_options_fr = dict(task="transcribe", **options_fr)
options_en = dict(language=TO_LANGUAGE_CODE['english'], beam_size=2, best_of=2, no_kv_cache=True)
transcribe_options_en = dict(task="transcribe", **options_en)
# Go ahead and just generate a ton of vectors
num_layers = len(model.encoder.blocks)
mid_layer = num_layers // 2
mid_layer_hook_name = f"encoder.blocks.{mid_layer}.hook_resid_post"
assert mid_layer >= 0 and mid_layer < num_layers
for audio, text in tqdm(dataset):
    with torch.no_grad():
        # Pick some layer in the middle, pick both french and english text, and then we'll take the diff.
        with model.hooks(fwd_hooks=[(mid_layer_hook_name, read_resid_fr_act_book)]):
            model.transcribe(audio, **transcribe_options_fr)
        with model.hooks(fwd_hooks=[(mid_layer_hook_name, read_resid_en_act_hook)]):
            model.transcribe(audio, **transcribe_options_en)

assert len(french_vecs) >= 100, f"len={len(french_vecs)}"
assert len(french_vecs) <= 1000, f"len={len(french_vecs)}"
french_avg_vec = torch.mean(torch.stack(french_vecs), axis=0)
english_avg_vec = torch.mean(torch.stack(english_vecs), axis=0)
steering_vec = french_avg_vec - english_avg_vec # english -> french

print(f"Got french vector for middle layer {mid_layer} with shape {french_avg_vec.shape}")

  checkpoint = torch.load(fp, map_location=device)
100%|██████████| 136/136 [05:25<00:00,  2.39s/it]

Got french vector for middle layer 12 with shape torch.Size([1024])





In [82]:
print("steering average vector", steering_vec) # DEBUG

# Insert the french into a new dataset
dname = "hf-internal-testing/librispeech_asr_dummy"
ds = load_dataset(dname, "clean", split="validation")
audio = ds[0]["audio"]["array"].astype(np.float32)

# Interestingly, the very last (vector) element of the tensor is NOT enough to steer language properly
alpha = 0
beta = 1

def steer_hook_fn(act: torch.Tensor, hook: HookPoint):
    assert len(act.shape) == 3
    batch_len, seq_len, feat_dim = act.shape
    new_act = torch.zeros((batch_len, seq_len, feat_dim), device=act.device)
    new_act[:] = act
    for i in range(batch_len):
        new_act[i, -1, :] = alpha * act[i, -1, :] + beta * steering_vec
    # new_act = act + torch.randn_like(act) # Will indeed break it
    
    assert type(act) == type(new_act)
    assert act.shape == new_act.shape
    # return act
    act[:] = new_act
    return act
    # return act

# Classic sample that says "MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL" from 
# https://huggingface.co/datasets/hf-internal-testing/librispeech_asr_dummy?row=0
with torch.no_grad():
    with model.hooks(
    fwd_hooks=[(mid_layer_hook_name, steer_hook_fn)]
    ):
        # NOTE set the "decoding_options" kwarg for no_kv_cache since otherwise
        # our ativations are not going to slot in just right
        result = model.transcribe(audio, no_kv_cache=True)
# result = model.transcribe(audio, no_kv_cache=True)
print("Result:", result['text'])

# XXX(Adriano) fix some bug here: it may be the case that multilinguality is not properly being set up


steering average vector tensor([-0.0234,  0.0039, -0.0059,  ...,  0.0015,  0.0044, -0.0020],
       device='cuda:0', dtype=torch.float16)


TypeError: DecodingOptions.__init__() got an unexpected keyword argument 'is_multilingual'