In [1]:
import soundfile
import torch
from espnet_model_zoo.downloader import ModelDownloader
from espnet2.bin.asr_inference import Speech2Text

In [2]:
def get_tensor_info(tensor):
  info = []
  for name in ['requires_grad', 'is_leaf', 'retains_grad', 'grad_fn', 'grad']:
    info.append(f'{name}({getattr(tensor, name, None)})')
  info.append(f'tensor({str(tensor)})')
  return ' '.join(info)

In [3]:
# Set test wav for attention image extraction
TEST_DATA_PATH = "./data/dev_clean"
WAV_LIST_PATH = TEST_DATA_PATH + "/wav.scp"

file_name_list = []
audio_num = 1 # selelct one of the wav in file_name_list

with open(WAV_LIST_PATH, "r") as f:
    lines = f.readlines()
    for line in lines:
        num, name = line.split(' ')
        file_name_list.append(name[:-1])

speech, rate = soundfile.read(file_name_list[audio_num])

# Prepare model
d = ModelDownloader()

speech2text = Speech2Text(
    **d.download_and_unpack('Shinji Watanabe/librispeech_asr_train_asr_transformer_e18_raw_bpe_sp_valid.acc.best'),
    # Decoding parameters are not included in the model file
    maxlenratio=0.0,
    minlenratio=0.0,
    beam_size=1,
    ctc_weight=0.4,
    lm_weight=0.6,
    penalty=0.0,
    nbest=1
)
# Add register hook for in encoder layers.
net = speech2text.asr_model

In [4]:
#out = speech2text(speech)

In [5]:
def apply_dh(model, layer_idx, head_idx, module_type, attn_type):
    # module_type : 'encoder', 'decoder'
    # attn_type : 'self_attn', 'src_attn'

    for name, module in model.named_modules():
        if f'{module_type}.{module_type}s.{layer_idx}.{attn_type}.linear_q' == name:
            print(name)
            print(module)
            # print(param.shape)
            return name, module

In [6]:
name, module = apply_dh(net, 0, 0, 'encoder', 'self_attn')

encoder.encoders.0.self_attn.linear_q
Linear(in_features=512, out_features=512, bias=True)


In [7]:
print(module.weight.shape)
print(module.weight.view(-1, 8, 64).shape)
print(module.bias.shape)

torch.Size([512, 512])
torch.Size([512, 8, 64])
torch.Size([512])


In [8]:
for head in range(8):
    module.bias.view(-1, 8, 64)[:, head, :] = 0
    if head != 0:
        module.weight.transpose(0, 1).view(-1, 8, 64)[:, head, :] = 0

print(module.weight)
print(module.bias)


Parameter containing:
tensor([[-0.2873,  0.5390, -0.5693,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.3400, -0.1308, -0.1205,  ...,  0.0000,  0.0000,  0.0000],
        [-0.0790, -0.0389,  0.1748,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.0107,  0.0485,  0.3198,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0429, -0.1236, -0.1116,  ...,  0.0000,  0.0000,  0.0000],
        [-0.1298,  0.3432, -0.8652,  ...,  0.0000,  0.0000,  0.0000]],
       grad_fn=<CopySlices>)
Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0

In [17]:
temp_input = torch.rand(1, 99, 512)
out = module(temp_input).view(1, -1, 8, 64).transpose(1, 2)
print(out.shape)
print(out[0, 1, :])


torch.Size([1, 8, 99, 64])
tensor([[ 0.5174,  2.2230,  1.5740,  ...,  3.8168, -1.1528, -0.8484],
        [-1.6020,  1.5832,  0.1436,  ...,  3.8070,  1.1598, -1.6014],
        [ 0.7118,  2.9494, -0.8248,  ...,  1.7469, -0.1587, -1.1079],
        ...,
        [-0.1216,  1.4732,  1.1194,  ...,  3.9071, -0.0097, -1.1232],
        [-0.4900, -0.4743, -1.2640,  ...,  2.5423,  1.1128, -1.8459],
        [ 0.7767,  1.9731,  0.6011,  ...,  2.9234, -0.8688,  0.8824]],
       grad_fn=<SliceBackward>)
