# Installing Whisper

Make sure that you've already installed properly unto your virtual (venv) environment the current version of Whisper with hooks enabled.


# Loading a Dataset and Running with Hooks

Run a couple tests to make sure that whisper is properlly hooked up.


In [11]:
#### A: Test that running with the hooks context works ####
# 1. Imports
import torch
import whisper
import datasets
import numpy as np
from transformer_lens import HookedTransformer
from transformer_lens.hook_points import HookPoint
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration


# print(Path.cwd()) # This should give you .../WhisperLens/notebooks

# 2. Get a dataset
ds = datasets.load_dataset(
    "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
)
sample = ds[0][
    "audio"
]  # This should be an element with a numpy array, basically (and some other metadata)


# 3. Define a hook that we can use to test that hooks work at all
def print_shape_hook(activation: torch.Tensor, hook: HookPoint):
    print(f"Shape at hookpoint {hook.name}: {activation.shape}")


# 4. Load a model and then just run it with a hook
model = whisper.load_model("base")

# NOTE: You should see a bunch of print statements of pytorch shape tuples coming out here, followed by a reasonable
# result (check the HF repo for that). Then you should see NOTHING coming out when you are running the second time
# (but the result should be more or less reasonable)
print("*" * 100)
with model.hooks(
    fwd_hooks=[("decoder.blocks.0.attn.hook_attn_pattern", print_shape_hook)]
):
    result = model.transcribe(sample["array"].astype(np.float32))
print("Transcription:", result)
# Compare with the HF Model
comparison_processor = WhisperProcessor.from_pretrained("openai/whisper-small")
comparison_model = WhisperForConditionalGeneration.from_pretrained(
    "openai/whisper-small"
)
input_features = comparison_processor(
    sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt"
).input_features
predicted_ids = comparison_model.generate(input_features, language="en")
transcription = comparison_processor.batch_decode(
    predicted_ids, skip_special_tokens=False
)
print("Comparison HF Transcription:", transcription)

# No hooks:
print("*" * 100)
# Check that you have the flac file proprly in the tests directory
result = model.transcribe(
    # Root directory of the repo.
    (Path().cwd().parent / "tests" / "jfk.flac").as_posix()
)
print(result)

  checkpoint = torch.load(fp, map_location=device)


****************************************************************************************************
Shape at hookpoint decoder.blocks.0.attn.hook_attn_pattern: torch.Size([1, 8, 1, 1])
Shape at hookpoint decoder.blocks.0.attn.hook_attn_pattern: torch.Size([1, 8, 3, 3])
Shape at hookpoint decoder.blocks.0.attn.hook_attn_pattern: torch.Size([1, 8, 1, 4])
Shape at hookpoint decoder.blocks.0.attn.hook_attn_pattern: torch.Size([1, 8, 1, 5])
Shape at hookpoint decoder.blocks.0.attn.hook_attn_pattern: torch.Size([1, 8, 1, 6])
Shape at hookpoint decoder.blocks.0.attn.hook_attn_pattern: torch.Size([1, 8, 1, 7])
Shape at hookpoint decoder.blocks.0.attn.hook_attn_pattern: torch.Size([1, 8, 1, 8])
Shape at hookpoint decoder.blocks.0.attn.hook_attn_pattern: torch.Size([1, 8, 1, 9])
Shape at hookpoint decoder.blocks.0.attn.hook_attn_pattern: torch.Size([1, 8, 1, 10])
Shape at hookpoint decoder.blocks.0.attn.hook_attn_pattern: torch.Size([1, 8, 1, 11])
Shape at hookpoint decoder.blocks.0.attn.hook_a

2024-08-30 17:19:07.104204: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-08-30 17:19:07.117523: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-08-30 17:19:07.121952: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-08-30 17:19:07.131767: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
You have passed language=en, but also have set `force

Comparison HF Transcription: ['<|startoftranscript|><|en|><|transcribe|><|notimestamps|> Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.']
****************************************************************************************************
{'text': ' And so my fellow Americans ask not what your country can do for you, ask what you can do for your country.', 'segments': [{'id': 0, 'seek': 0, 'start': 0.0, 'end': 7.6000000000000005, 'text': ' And so my fellow Americans ask not what your country can do for you,', 'tokens': [50364, 400, 370, 452, 7177, 6280, 1029, 406, 437, 428, 1941, 393, 360, 337, 291, 11, 50744], 'temperature': 0.0, 'avg_logprob': -0.396818259666706, 'compression_ratio': 1.3417721518987342, 'no_speech_prob': 0.09160678088665009}, {'id': 1, 'seek': 0, 'start': 7.6000000000000005, 'end': 10.6, 'text': ' ask what you can do for your country.', 'tokens': [50744, 1029, 437, 291, 393, 360, 337, 428, 1941, 13, 50894], 'temperature': 0.0

In [1]:
#### B: Test that we are able to do exactly ONE step of inference OK and get all the activations ####
import torch
import whisper
import datasets
import einops
import numpy as np
from transformer_lens import HookedTransformer
from transformer_lens.hook_points import HookPoint
from pathlib import Path
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from whisper import DecodingOptions
from whisper.decoding import DecodingTask
from whisper.audio import log_mel_spectrogram, pad_or_trim
from whisper.tokenizer import get_tokenizer
from tabulate import tabulate  # Visualize shapes as they go through the network

# Get one sample
ds = datasets.load_dataset(
    "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
)
sample = ds[0]["audio"]["array"].astype(np.float32)

# Calculate the audio features and default tokens like in
# https://github.com/openai/whisper/blob/main/whisper/decoding.py#L737
# TODO(Adriano) try prompting?
# https://github.com/openai/whisper/blob/ba3f3cd54b0e5b8ce1ab3de13e32122d0d5f98ab/whisper/decoding.py#L101
model = whisper.load_model("tiny")  # To be safe, do tiny
decoding_options = DecodingOptions(
    task="transcribe",
    language="en",
    sample_len=1,
    beam_size=1,
    without_timestamps=False,
)
decoding_task = DecodingTask(model, decoding_options)
# NOTE _get_audio_features runs forward pass of encoder: it's a misnomer; don't use it!
# Defaults should "just work"...
padded_sample = pad_or_trim(sample)
audio_features = log_mel_spectrogram(audio=padded_sample)
audio_features = einops.rearrange(audio_features, "seq d -> 1 seq d").to(model.device)
tokens = torch.tensor(decoding_task._get_initial_tokens())
tokens = einops.rearrange(tokens, "seq -> 1 seq").to(model.device)

# Run a few times so that we can start to see interesting tokens
num_runs_before_read = 10
predicted_tokens = []
for _ in range(num_runs_before_read):
    logits = model(audio_features, tokens)
    assert len(tokens.shape) == 2 and tokens.shape[0] == 1
    assert logits.shape[:2] == (1, tokens.shape[-1]), f"Logits shape: {logits.shape}, len of tokens: {tokens.shape[-1]}"
    most_likely_token = logits[0][-1].argmax(dim=-1).item()
    tokens = torch.cat([tokens, torch.tensor([[most_likely_token]]).to(model.device)], dim=-1)

# Run with cache
with torch.no_grad():
    # First try checking that we only print ONCE
    logits, acts = model.run_with_cache(
        audio_features,
        tokens,
    )
    print("Logits shape:", logits.shape)
    print("Activations type:", type(acts))

    max_idx = logits[0].argmax(dim=-1).tolist()
    # TODO(Adriano) why the hell is it doing translate and not english when I do multilingual=False?
    tokenizer = get_tokenizer(
        multilingual=True,
        # multilingual=False,
        # num_languages=1,
        # language="en",
        # task="transcribe",
    )
    tok = tokenizer.encoding.decode(tokens[0].tolist())
    max_tok = tokenizer.encoding.decode(max_idx)
    # Expect to see some shit about Mr. Quilter
    print("Ingoing tokens:", tok)
    print("Logits (should be tokens shifted by 1):", max_tok)
    print("LAST Logit:", max_tok[-1])
    
    print("Activations (shapes):")
    header = ["Layer", "Shape"]
    rows = [[f"{k}", f"{v.shape}"] for k, v in acts.items()]
    print(tabulate(rows, headers=header, tablefmt="fancy_grid"))

  from .autonotebook import tqdm as notebook_tqdm
  checkpoint = torch.load(fp, map_location=device)


W SHAPE IS torch.Size([1, 6, 1500, 1500])
SHAPE OF V IS torch.Size([1, 6, 1500, 64])
W SHAPE IS torch.Size([1, 6, 1500, 1500])
SHAPE OF V IS torch.Size([1, 6, 1500, 64])
W SHAPE IS torch.Size([1, 6, 1500, 1500])
SHAPE OF V IS torch.Size([1, 6, 1500, 64])
W SHAPE IS torch.Size([1, 6, 1500, 1500])
SHAPE OF V IS torch.Size([1, 6, 1500, 64])
W SHAPE IS torch.Size([1, 6, 3, 3])
SHAPE OF V IS torch.Size([1, 6, 3, 64])
W SHAPE IS torch.Size([1, 6, 3, 1500])
SHAPE OF V IS torch.Size([1, 6, 1500, 64])
W SHAPE IS torch.Size([1, 6, 3, 3])
SHAPE OF V IS torch.Size([1, 6, 3, 64])
W SHAPE IS torch.Size([1, 6, 3, 1500])
SHAPE OF V IS torch.Size([1, 6, 1500, 64])
W SHAPE IS torch.Size([1, 6, 3, 3])
SHAPE OF V IS torch.Size([1, 6, 3, 64])
W SHAPE IS torch.Size([1, 6, 3, 1500])
SHAPE OF V IS torch.Size([1, 6, 1500, 64])
W SHAPE IS torch.Size([1, 6, 3, 3])
SHAPE OF V IS torch.Size([1, 6, 3, 64])
W SHAPE IS torch.Size([1, 6, 3, 1500])
SHAPE OF V IS torch.Size([1, 6, 1500, 64])
W SHAPE IS torch.Size([1, 6,

# Do we use causal attention?
As a short demonstration of hooks in action, we check whether or not causal attention is used in a few different places (which you can see in the table above in you ran the previous cell):
- encoder.blocks.0.attn.hook_attn_pattern
- decoder.blocks.0.attn.hook_attn_patter
- decoder.blocks.0.cross_attn.hook_attn_pattern

We run the full inference process and then simply take the logical AND of whether it is causal. It is very unlikely that it will be zero always and not be causal (and if it's ever not zero, then it's probably not causal).

Unlike the GPT-series and most LLMs, you'll note that whisper is NOT causal, but it does have one causal portion: the self-attention for text. It works like this:
1. Independently encode the entire sound-wave signal using a non-causal transformer.
2. Do a language model generation conditioned on the output from 1. This involves blocks in which each step is either:
    - Causal self-attention like GPT2
    - Cross-attention with results from 1 (not causal: it gets to "listen ahead")

In [2]:
import torch
import whisper
import re
import numpy as np
from transformer_lens.hook_points import HookPoint
from datasets import load_dataset
from tabulate import tabulate

# 1. Dataset
dname = "hf-internal-testing/librispeech_asr_dummy"
ds = load_dataset(dname, "clean", split="validation")
sample = ds[0]["audio"]["array"].astype(np.float32)

# 2. Model
model = whisper.load_model("tiny")

# 3. Hooks
assert len(model.encoder.blocks) > 0 and len(model.decoder.blocks) > 0
n_heads_encoder = model.encoder.blocks[0].attn.n_head
n_heads_decoder = model.decoder.blocks[0].attn.n_head
n_layers_encoder = len(model.encoder.blocks)
n_layers_decoder = len(model.decoder.blocks)

# NOTE that we always have the same number of heads in the tiny model
encoder_heads_are_causal = [[0 for _ in range(n_heads_encoder)] for _ in range(n_layers_encoder)]
decoder_heads_are_causal = [[0 for _ in range(n_heads_decoder)] for _ in range(n_layers_decoder)]
decoder_x_heads_are_causal = [[0 for _ in range(n_heads_decoder)] for _ in range(n_layers_decoder)]

def is_probably_causal(
        act: torch.Tensor,
        hook: HookPoint
    ) -> None:
    hookname = hook.name
    assert hookname is not None and len(re.findall("blocks\.\d+", hookname)) == 1, hookname
    _, right, = hookname.split("blocks.", 1)
    block_int, _ = right.split(".", 1)
    block_int = int(block_int)

    is_encoder = "encoder" in hookname
    is_x = "cross_attn" in hookname
    assert is_encoder and not is_x or not is_encoder

    # 1. NOT X-Attn (X means that it might not be the same seq)
    assert len(act.shape) == 4
    mask = torch.triu(torch.ones(act.shape[-2], act.shape[-1]), diagonal=1).to(act.device)
    assert mask[0][0] == 0 # Sanity check
    for head_int in range(n_heads_encoder if is_encoder else n_heads_decoder):
        head = act[:, head_int, :, :]
        if is_encoder:
            encoder_heads_are_causal[block_int][head_int] = (head * mask).abs().max().item()
        elif not is_encoder and not is_x:
            decoder_heads_are_causal[block_int][head_int] = (head * mask).abs().max().item()
        else:
            decoder_x_heads_are_causal[block_int][head_int] = (head * mask).abs().max().item()


hook_points = (
      [f"encoder.blocks.{i}.attn.hook_attn_pattern" for i in range(n_layers_encoder)]
    + [f"decoder.blocks.{i}.attn.hook_attn_pattern" for i in range(n_layers_decoder)]
    + [f"decoder.blocks.{i}.cross_attn.hook_attn_pattern" for i in range(n_layers_decoder)]
)
tables = [
    encoder_heads_are_causal,
    decoder_heads_are_causal,
    decoder_x_heads_are_causal,
]
table_names = [
    "Encoder",
    "Decoder",
    "Decoder-X",
]

# 4. Run
with model.hooks(
    fwd_hooks=[(hook_point, is_probably_causal) for hook_point in hook_points]
):
    # NOTE set the "decoding_options" kwarg for no_kv_cache since otherwise
    # our ativations are all sorts of fucked up
    result = model.transcribe(sample, no_kv_cache=True)

# Sanity check the result is the same when you are not using kv caching
print("*"*100)
print("Result:", result['text'])

# 5. Showcase
for name, table in zip(table_names, tables):
    print("*" * 100)
    print(f"Table for {name} (zero means probably causal, all other numbers -> not causal):")
    header = [""] + [f"Head {i}" for i in range(len(table[0]))]
    rows = [[f"Layer {i}"] + [f"{v}" for v in row] for i, row in enumerate(table)]
    print(tabulate(rows, headers=header, tablefmt="fancy_grid"))


****************************************************************************************************
Result:  Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.
****************************************************************************************************
Table for Encoder (zero means probably causal, all other numbers -> not causal):
╒═════════╤══════════╤══════════╤═══════════╤═══════════╤═══════════╤═══════════╕
│         │   Head 0 │   Head 1 │    Head 2 │    Head 3 │    Head 4 │    Head 5 │
╞═════════╪══════════╪══════════╪═══════════╪═══════════╪═══════════╪═══════════╡
│ Layer 0 │ 0.112061 │ 0.18335  │ 0.995117  │ 0.0205688 │ 0.581055  │ 0.0248566 │
├─────────┼──────────┼──────────┼───────────┼───────────┼───────────┼───────────┤
│ Layer 1 │ 0.546387 │ 0.101074 │ 0.226196  │ 0.273926  │ 0.0568542 │ 0.0964355 │
├─────────┼──────────┼──────────┼───────────┼───────────┼───────────┼───────────┤
│ Layer 2 │ 0.964844 │ 0.384277 │ 0.0700073 │ 