In [1]:
import soundfile
import torch
import torch.nn.functional as F
from espnet_model_zoo.downloader import ModelDownloader
from espnet2.bin.asr_inference import Speech2Text

In [2]:
def get_tensor_info(tensor):
  info = []
  for name in ['requires_grad', 'is_leaf', 'retains_grad', 'grad_fn', 'grad']:
    info.append(f'{name}({getattr(tensor, name, None)})')
  return ' '.join(info)

In [3]:
# Set test wav for attention image extraction
TEST_DATA_PATH = "./../data/dev_clean"
WAV_LIST_PATH = TEST_DATA_PATH + "/wav.scp"
ANSWER_LIST_PATH = TEST_DATA_PATH + "/text"

file_name_list = []
speech_ans_list = []

with open(WAV_LIST_PATH, "r") as f:
    lines = f.readlines()
    for line in lines:
        num, name = line.split(' ')
        file_name_list.append(name[:-1])

with open(ANSWER_LIST_PATH, "r") as f:
    lines = f.readlines()
    for line in lines:
        speech_ans_list.append(line[17:])

# Prepare model
d = ModelDownloader()

speech2text = Speech2Text(
    **d.download_and_unpack('Shinji Watanabe/librispeech_asr_train_asr_transformer_e18_raw_bpe_sp_valid.acc.best'),
    # Decoding parameters are not included in the model file
    maxlenratio=0.0,
    minlenratio=0.0,
    beam_size=1,
    ctc_weight=1.0,
    lm_weight=0.0,
    penalty=0.0,
    nbest=1,
    out_mode="ctc"
)
# Add register hook for in encoder layers.
net = speech2text.asr_model
print(net)

ESPnetASRModel(
  (frontend): DefaultFrontend(
    (stft): Stft(n_fft=512, win_length=512, hop_length=128, center=True, normalized=False, onesided=True)
    (frontend): Frontend()
    (logmel): LogMel(sr=16000, n_fft=512, n_mels=80, fmin=0, fmax=8000.0, htk=False)
  )
  (specaug): SpecAug(
    (time_warp): TimeWarp(window=5, mode=bicubic)
    (freq_mask): MaskAlongAxis(mask_width_range=[0, 30], num_mask=2, axis=freq)
    (time_mask): MaskAlongAxis(mask_width_range=[0, 40], num_mask=2, axis=time)
  )
  (normalize): GlobalMVN(stats_file=/home/jmpark/.conda/envs/speech/lib/python3.8/site-packages/espnet_model_zoo/653d10049fdc264f694f57b49849343e/exp/asr_stats_raw_sp/train/feats_stats.npz, norm_means=True, norm_vars=True)
  (encoder): TransformerEncoder(
    (embed): Conv2dSubsampling6(
      (conv): Sequential(
        (0): Conv2d(1, 512, kernel_size=(3, 3), stride=(2, 2))
        (1): ReLU()
        (2): Conv2d(512, 512, kernel_size=(5, 5), stride=(3, 3))
        (3): ReLU()
      )
    

In [4]:
audio_num = 20 # selelct one of the wav in file_name_list
speech, rate = soundfile.read(file_name_list[audio_num])

In [5]:
class Hook():
    def __init__(self, module):
        self.hook_f = module.register_forward_hook(self.hook_f_fn)
        self.target_output = None

    def hook_f_fn(self, module, input, output):
        self.target_output = input[0]

    def close(self):
        self.hook.remove()

In [6]:
def apply_hook(model, layer_idx, module_type, attn_type):
    # module_type : 'encoder', 'decoder'
    # attn_type : 'self_attn', 'src_attn'
    for name, module in model.named_modules():
        if f'{module_type}.{module_type}s.{layer_idx}.{attn_type}.linear_out' == name:
            hook = Hook(module=module)
    
    return hook

In [7]:
hook_0 = apply_hook(net, 0, 'encoder', 'self_attn')
hook_1 = apply_hook(net, 1, 'encoder', 'self_attn')

AttributeError: 'Hook' object has no attribute 'hook_b_fn'

In [None]:
out, ctc_out = speech2text(speech)
ctc_argmax = ctc_out.argmax(2)

In [None]:
ctc_out.shape

In [None]:
# check_idx = 50
# grad_temp = torch.ones_like(ctc_out)
# print(grad_temp.shape)
# grad_temp[0, check_idx, ctc_argmax[0,check_idx]] = 1

In [None]:
one_hot = torch.zeros_like(ctc_out)
one_hot.scatter_(2, ctc_argmax.unsqueeze(2), 1.0)

In [None]:
ctc_out.backward(gradient=one_hot, retain_graph=True)

In [None]:
for name, param in net.named_parameters():
    if name == 'encoder.encoders.0.self_attn.linear_out.weight':
        temp_weight = param

    elif name == 'encoder.encoders.0.self_attn.linear_out.bias':
        temp_bias = param

In [None]:
print(temp_weight.shape)
print(temp_bias.shape)