# Encoder-Decoder Attention Weights Visualisation for Audio Captioning

In [1]:
EVAL_DATA_DIR = "/home/akhil/models/DCASE24/dcase2024-task6-baseline/data/CLOTHO_v2.1/clotho_audio_files/evaluation"
EVAL_RESULTS_META = "/home/akhil/models/DCASE24/dcase2024-task6-baseline/logs/test-2024.04.23-20.20.00/test_clotho_eval_outputs.csv"

In [2]:
import pandas as pd
df = pd.read_csv(EVAL_RESULTS_META)
print("Columns available : ", df.columns)

Columns available :  Index(['test/loss', 'predictions', 'log_probs', 'beam_predictions',
       'beam_log_probs', 'candidates', 'beam_candidates', 'subset',
       'mult_captions', 'mult_references', 'fname', 'dataset',
       'dataloader_idx', 'batch_idx', 'stage', 'bleu_1', 'bleu_2', 'bleu_3',
       'bleu_4', 'meteor', 'rouge_l', 'sbert_sim', 'fer', 'fense', 'cider_d',
       'spice', 'spider', 'fer.add_tail_prob', 'fer.repeat_event_prob',
       'fer.repeat_adv_prob', 'fer.remove_conj_prob', 'fer.remove_verb_prob',
       'fer.error_prob', 'spider_fl', 'bert_score.precision',
       'bert_score.recall', 'bert_score.f1'],
      dtype='object')


## Preparing Model

In [3]:
from dcase24t6.nn.hub import baseline_pipeline
import torch
import torchaudio
from torch import nn
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
from ipywidgets import interact, interactive, fixed, interact_manual,widgets
import IPython
import matplotlib._color_data as mcd
import librosa
import os

font = {'weight' : 'bold', 'size' : 18}
matplotlib.rc('font', **font)

In [4]:
def showAttention(input_sentence, output_words, attentions):
    fig = plt.figure(figsize=(20, 6), dpi=80)
    ax = fig.add_subplot(111)
    cax = ax.matshow(attentions.cpu().numpy(), cmap='gray_r')
    
    # Set up axes
    ax.set_xticks(range(attentions.shape[-1]))
    ax.set_yticks(range(attentions.shape[-2]))
    ax.set_xticklabels([item*0.3125 for item in range(1, attentions.shape[-1]+1)], rotation=90)
    ax.set_yticklabels(output_words+["<EOS>"])
    
    plt.show()

def plotAttention(attention, output_words, ax=None):
    if ax is None:
        _, ax = plt.subplots(1, 1)

    y_ticks = output_words.split(" ")+["<EOS>"]
    attention = attention[:len(y_ticks), :]
    
    cax = ax.matshow(attention, cmap='gray_r')

    ax.set_xticks(range(attention.shape[-1]))
    ax.set_yticks(range(attention.shape[-2]))
    ax.set_xticklabels([item*0.3125 for item in range(1, attention.shape[-1]+1)], rotation=90)
    ax.set_yticklabels(output_words.split(" ")+["<EOS>"])

# Getting attention maps
def patch_attention(m):
    forward_orig = m.forward

    def wrap(*args, **kwargs):
        kwargs["need_weights"] = True
        kwargs["average_attn_weights"] = False

        return forward_orig(*args, **kwargs)

    m.forward = wrap
    
class SaveOutput:
    def __init__(self):
        self.outputs = []

    def __call__(self, module, module_in, module_out):
        self.outputs.append(module_out[1])

    def clear(self):
        self.outputs = []

In [5]:
model = baseline_pipeline()
model.eval()

[nltk_data] Downloading package stopwords to /home/akhil/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


Loading similarity matrix:



RuntimeError: Error(s) in loading state_dict for TransDecoderModel:
	Missing key(s) in state_dict: "keyword_enc.encoder.layers.0.self_attn.in_proj_weight", "keyword_enc.encoder.layers.0.self_attn.in_proj_bias", "keyword_enc.encoder.layers.0.self_attn.out_proj.weight", "keyword_enc.encoder.layers.0.self_attn.out_proj.bias", "keyword_enc.encoder.layers.0.linear1.weight", "keyword_enc.encoder.layers.0.linear1.bias", "keyword_enc.encoder.layers.0.linear2.weight", "keyword_enc.encoder.layers.0.linear2.bias", "keyword_enc.encoder.layers.0.norm1.weight", "keyword_enc.encoder.layers.0.norm1.bias", "keyword_enc.encoder.layers.0.norm2.weight", "keyword_enc.encoder.layers.0.norm2.bias", "keyword_enc.encoder.layers.1.self_attn.in_proj_weight", "keyword_enc.encoder.layers.1.self_attn.in_proj_bias", "keyword_enc.encoder.layers.1.self_attn.out_proj.weight", "keyword_enc.encoder.layers.1.self_attn.out_proj.bias", "keyword_enc.encoder.layers.1.linear1.weight", "keyword_enc.encoder.layers.1.linear1.bias", "keyword_enc.encoder.layers.1.linear2.weight", "keyword_enc.encoder.layers.1.linear2.bias", "keyword_enc.encoder.layers.1.norm1.weight", "keyword_enc.encoder.layers.1.norm1.bias", "keyword_enc.encoder.layers.1.norm2.weight", "keyword_enc.encoder.layers.1.norm2.bias", "keyword_enc.encoder.layers.2.self_attn.in_proj_weight", "keyword_enc.encoder.layers.2.self_attn.in_proj_bias", "keyword_enc.encoder.layers.2.self_attn.out_proj.weight", "keyword_enc.encoder.layers.2.self_attn.out_proj.bias", "keyword_enc.encoder.layers.2.linear1.weight", "keyword_enc.encoder.layers.2.linear1.bias", "keyword_enc.encoder.layers.2.linear2.weight", "keyword_enc.encoder.layers.2.linear2.bias", "keyword_enc.encoder.layers.2.norm1.weight", "keyword_enc.encoder.layers.2.norm1.bias", "keyword_enc.encoder.layers.2.norm2.weight", "keyword_enc.encoder.layers.2.norm2.bias", "keyword_enc.encoder.layers.3.self_attn.in_proj_weight", "keyword_enc.encoder.layers.3.self_attn.in_proj_bias", "keyword_enc.encoder.layers.3.self_attn.out_proj.weight", "keyword_enc.encoder.layers.3.self_attn.out_proj.bias", "keyword_enc.encoder.layers.3.linear1.weight", "keyword_enc.encoder.layers.3.linear1.bias", "keyword_enc.encoder.layers.3.linear2.weight", "keyword_enc.encoder.layers.3.linear2.bias", "keyword_enc.encoder.layers.3.norm1.weight", "keyword_enc.encoder.layers.3.norm1.bias", "keyword_enc.encoder.layers.3.norm2.weight", "keyword_enc.encoder.layers.3.norm2.bias", "keyword_enc.embedding_layer.weight", "keyword_enc.classifier.weight", "keyword_enc.classifier.bias", "keyword_enc.projection.weight", "keyword_enc.projection.bias", "decoder.classifier.projection_layers.0.weight", "decoder.classifier.projection_layers.0.bias", "decoder.classifier.projection_layers.1.weight", "decoder.classifier.projection_layers.1.bias", "decoder.classifier.projection_layers.2.weight", "decoder.classifier.projection_layers.2.bias", "decoder.classifier.projection_layers.3.weight", "decoder.classifier.projection_layers.3.bias", "decoder.classifier.router.weight", "decoder.classifier.router.bias". 
	Unexpected key(s) in state_dict: "decoder.classifier.weight", "decoder.classifier.bias". 

In [6]:
SAMPLE_RATE = 32000

@interact
def ref_sound_selector(filename = df['fname'], decoder_layer = [-1, 0, 1, 2, 3, 4, 5]):
    filepath = os.path.join(EVAL_DATA_DIR, filename)

    input_signal, sr = librosa.core.load(filepath, sr = SAMPLE_RATE, mono=True)
    
    print("\nInput Signal | fs :", sr)
    IPython.display.display(IPython.display.Audio(input_signal, rate=sr))

    instances_df = df[df['fname'] == filename]
    print("\nReference Captions : ", instances_df["mult_references"].item())

    # Get attention weights
    audio = torch.tensor(input_signal)[None, :]
    item = {"audio": audio, "sr": sr}

    save_output = SaveOutput()
    patch_attention(model[1].decoder.layers[decoder_layer].multihead_attn)
    hook_handle = model[1].decoder.layers[decoder_layer].multihead_attn.register_forward_hook(save_output)

    with torch.no_grad():
        outputs = model(item)

    print("\nPredicted Caption : ", outputs["candidates"][0])
    print("\nFENSE Score : ", instances_df["fense"].item())

    attn_weights = save_output.outputs[-1].cpu().numpy()[0]

    fig, axs = plt.subplots(4, 2, figsize=(40, 30), dpi=80)
    axs = np.array(axs)

    for index, ax in enumerate(axs.reshape(-1)):
        ax.set_title("Head " + str(index))
        plotAttention(attn_weights[index], outputs["candidates"][0], ax)
    
    # fig.tight_layout()
    # fig.subplots_adjust(hspace=0.1)

interactive(children=(Dropdown(description='filename', options=('Santa Motor.wav', 'Radio Garble.wav', 'Radio …