In [10]:
from extract_grad_cam_image import *
from tqdm import tqdm

In [2]:
class Hook():
    def __init__(self, module):
        self.hook_f = module.register_forward_hook(self.hook_f_fn)
        self.target_output = None

    def hook_f_fn(self, module, input, output):
        self.target_output = input[0]

    def close(self):
        self.hook.remove()

def apply_hook(model, layer_idx, module_type, attn_type):
    # module_type : 'encoder', 'decoder'
    # attn_type : 'self_attn', 'src_attn'
    for name, module in model.named_modules():
        if f'{module_type}.{module_type}s.{layer_idx}.{attn_type}.linear_out' == name:
            hook = Hook(module=module)
    return hook

exp_dir = '/home/jmpark/home_data_jmpark/espnet/egs2/jm_ref/asr1/exp'

saved_encoder_grad_cam_images = []

TEST_DATA_PATH = "./../data/dev_clean"
WAV_LIST_PATH = TEST_DATA_PATH + "/wav.scp"
ANSWER_LIST_PATH = TEST_DATA_PATH + "/text"

file_name_list = []
speech_ans_list = []

with open(WAV_LIST_PATH, "r") as f:
    lines = f.readlines()
    for line in lines:
        num, name = line.split(' ')
        file_name_list.append(name[:-1])

with open(ANSWER_LIST_PATH, "r") as f:
    lines = f.readlines()
    for line in lines:
        speech_ans_list.append(line[17:])

# Prepare model
d = ModelDownloader()

speech2text = Speech2Text(
    **d.download_and_unpack('Shinji Watanabe/librispeech_asr_train_asr_transformer_e18_raw_bpe_sp_valid.acc.best'),
    # Decoding parameters are not included in the model file
    maxlenratio=0.0,
    minlenratio=0.0,
    beam_size=1,
    ctc_weight=1.0,
    lm_weight=0.0,
    penalty=0.0,
    nbest=1,
    out_mode="ctc"
)

hook_list = []

# Add register hook for in encoder layers.
net = speech2text.asr_model

for i in range(18):
    hook_list.append(apply_hook(net, i, 'encoder', 'self_attn'))

In [12]:
img_list = []
word_num_list = []

for audio in tqdm(range(len(file_name_list))):
    audio_num = audio # selelct one of the wav in file_name_list
    speech, rate = soundfile.read(file_name_list[audio_num])

    out, ctc_out = speech2text(speech)
    ctc_argmax = ctc_out.argmax(2)

    one_hot = torch.zeros_like(ctc_out)
    one_hot.scatter_(2, ctc_argmax.unsqueeze(2), 1.0)
    img = make_grad_cam_img_list(model=net, target_out=ctc_out, target_loss=one_hot, hook_list=hook_list)
    img_list.append(img)
    word_num_list.append(audio_num)

100%|██████████| 2703/2703 [1:45:52<00:00,  2.35s/it]


In [13]:
import numpy as np

In [15]:
total_img_np = np.array(img_list)
total_img_np

array([[[0.00278851, 0.00348732, 0.00431593, ..., 0.00171665,
         0.0021893 , 0.00221548],
        [0.00180147, 0.00124574, 0.00181621, ..., 0.00169283,
         0.00139309, 0.00168609],
        [0.00105069, 0.00131716, 0.00136562, ..., 0.00100729,
         0.00141241, 0.0014442 ],
        ...,
        [0.00544281, 0.00307936, 0.00141175, ..., 0.00166242,
         0.00375049, 0.0014074 ],
        [0.00149392, 0.00189142, 0.00205876, ..., 0.00210663,
         0.00144028, 0.00230863],
        [0.0007858 , 0.00131748, 0.00085343, ..., 0.00374274,
         0.00244202, 0.00270876]],

       [[0.0022967 , 0.0033199 , 0.00418904, ..., 0.00181741,
         0.00187251, 0.00191541],
        [0.00164767, 0.00113384, 0.00167113, ..., 0.00222542,
         0.00097298, 0.00179821],
        [0.00098186, 0.00164017, 0.00129954, ..., 0.00080936,
         0.00152343, 0.00142044],
        ...,
        [0.00363807, 0.00233927, 0.00115288, ..., 0.00154184,
         0.00279359, 0.0018757 ],
        [0.0

In [116]:
mean_img = np.mean(total_img_np, axis=0)
print(mean_img[17, 4])
print(mean_img[7, 1])

0.008317363
0.0049155816


In [17]:
save_img = []
save_img.append(mean_img)

save_encoder_grad_image(image_list=save_img, target_list=['mean'], audio_num='mean',
                                n_targets=1, PATH=exp_dir + f'/feature_images/encoder_grad_cam/sentence/mean/')

process mean target images....
