In [1]:
import soundfile
import torch
from espnet_model_zoo.downloader import ModelDownloader
from espnet2.bin.asr_inference import Speech2Text

In [None]:
def get_tensor_info(tensor):
  info = []
  for name in ['requires_grad', 'is_leaf', 'retains_grad', 'grad_fn', 'grad']:
    info.append(f'{name}({getattr(tensor, name, None)})')
  info.append(f'tensor({str(tensor)})')
  return ' '.join(info)

In [2]:
# Set test wav for attention image extraction
TEST_DATA_PATH = "./data/dev_clean"
WAV_LIST_PATH = TEST_DATA_PATH + "/wav.scp"
ANSWER_LIST_PATH = TEST_DATA_PATH + "/text"

file_name_list = []
speech_ans_list = []

with open(WAV_LIST_PATH, "r") as f:
    lines = f.readlines()
    for line in lines:
        num, name = line.split(' ')
        file_name_list.append(name[:-1])

with open(ANSWER_LIST_PATH, "r") as f:
    lines = f.readlines()
    for line in lines:
        speech_ans_list.append(line[17:])

# Prepare model
d = ModelDownloader()

speech2text = Speech2Text(
    **d.download_and_unpack('Shinji Watanabe/librispeech_asr_train_asr_transformer_e18_raw_bpe_sp_valid.acc.best'),
    # Decoding parameters are not included in the model file
    maxlenratio=0.0,
    minlenratio=0.0,
    beam_size=1,
    ctc_weight=1.0,
    lm_weight=0.0,
    penalty=0.0,
    nbest=1
)
# Add register hook for in encoder layers.
net = speech2text.asr_model
print(net)

ESPnetASRModel(
  (frontend): DefaultFrontend(
    (stft): Stft(n_fft=512, win_length=512, hop_length=128, center=True, normalized=False, onesided=True)
    (frontend): Frontend()
    (logmel): LogMel(sr=16000, n_fft=512, n_mels=80, fmin=0, fmax=8000.0, htk=False)
  )
  (specaug): SpecAug(
    (time_warp): TimeWarp(window=5, mode=bicubic)
    (freq_mask): MaskAlongAxis(mask_width_range=[0, 30], num_mask=2, axis=freq)
    (time_mask): MaskAlongAxis(mask_width_range=[0, 40], num_mask=2, axis=time)
  )
  (normalize): GlobalMVN(stats_file=/home/jmpark/.conda/envs/speech/lib/python3.8/site-packages/espnet_model_zoo/653d10049fdc264f694f57b49849343e/exp/asr_stats_raw_sp/train/feats_stats.npz, norm_means=True, norm_vars=True)
  (encoder): TransformerEncoder(
    (embed): Conv2dSubsampling6(
      (conv): Sequential(
        (0): Conv2d(1, 512, kernel_size=(3, 3), stride=(2, 2))
        (1): ReLU()
        (2): Conv2d(512, 512, kernel_size=(5, 5), stride=(3, 3))
        (3): ReLU()
      )
    

In [7]:
audio_num = 2 # selelct one of the wav in file_name_list
speech, rate = soundfile.read(file_name_list[audio_num])

In [8]:
out, ctc_out = speech2text(speech)

In [9]:
print(ctc_out.argmax(dim=1))
print(out)

tensor([   0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0, 4990,
           0,    0,    0, 4784,    0,    0, 4997,    0,    0,    0,    0, 4875,
           0,    0,    0,    0,    0, 4989,    0,    0,    0, 4965, 4965,    0,
           0,    0, 4952,    0,    0,    0,    0, 4409, 4409, 4409, 4997,    0,
           0, 3860,    0,    0,    0,    0,    0,    0,    0,    0, 3020,    0,
           0,    0,    0, 4995,    0,    0,    0, 4998,    0,    0, 4458,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0, 4976,    0,    0,    0,    0,    0,    0, 1927,
           0,    0,    0,    0,    0,    0,    0,    0,    0, 4996,    0,    0,
        4601, 4601,    0, 4825, 4988, 4988,    0, 4971, 4971, 4987, 4987, 4894,
        4894,    0,    0, 4597, 4597, 4958, 4953, 4980, 4980,    0,    0,    0,
           0,    0,    0, 4880,    0,    0,    0,    0, 4875,    0,    0,    0,
           0,    0,    0,    0,    0,   

In [None]:
for hyp in out:
    print(hyp)

In [None]:
print(out[0][0])
print(speech_ans_list[audio_num])
print(out[0][2])

# print(out[0][3])
print(out[0][3][3]['ctc'][0].shape)
print(out[0][3][3])

In [None]:
def apply_dh(model, layer_idx, head_idx, module_type, attn_type):
    # module_type : 'encoder', 'decoder'
    # attn_type : 'self_attn', 'src_attn'

    for name, module in model.named_modules():
        if f'{module_type}.{module_type}s.{layer_idx}.{attn_type}.linear_q' == name:
            print(name)
            print(module)
            # print(param.shape)
            return name, module

In [None]:
name, module = apply_dh(net, 0, 0, 'encoder', 'self_attn')

In [None]:
print(module.weight.shape)
print(module.weight.view(-1, 8, 64).shape)
print(module.bias.shape)

In [None]:
for head in range(8):
    module.bias.view(-1, 8, 64)[:, head, :] = 0
    if head != 0:
        module.weight.transpose(0, 1).view(-1, 8, 64)[:, head, :] = 0

print(module.weight)
print(module.bias)


In [None]:
temp_input = torch.rand(1, 99, 512)
out = module(temp_input).view(1, -1, 8, 64).transpose(1, 2)
print(out.shape)
print(out[0, 1, :])
