In [5]:
from extract_grad_cam_image import *
from tqdm import tqdm

In [6]:
class Hook():
    def __init__(self, module):
        self.hook_f = module.register_forward_hook(self.hook_f_fn)
        self.target_output = None

    def hook_f_fn(self, module, input, output):
        self.target_output = input[0]

    def close(self):
        self.hook.remove()

def apply_hook(model, layer_idx, module_type, attn_type):
    # module_type : 'encoder', 'decoder'
    # attn_type : 'self_attn', 'src_attn'
    for name, module in model.named_modules():
        if f'{module_type}.{module_type}s.{layer_idx}.{attn_type}.linear_out' == name:
            hook = Hook(module=module)
    return hook

exp_dir = '/home/jmpark/home_data_jmpark/espnet/egs2/jm_ref/asr1/exp'

saved_encoder_grad_cam_images = []

TEST_DATA_PATH = "./../data/dev_clean"
WAV_LIST_PATH = TEST_DATA_PATH + "/wav.scp"
ANSWER_LIST_PATH = TEST_DATA_PATH + "/text"

file_name_list = []
speech_ans_list = []

with open(WAV_LIST_PATH, "r") as f:
    lines = f.readlines()
    for line in lines:
        num, name = line.split(' ')
        file_name_list.append(name[:-1])

with open(ANSWER_LIST_PATH, "r") as f:
    lines = f.readlines()
    for line in lines:
        speech_ans_list.append(line[17:])

# Prepare model
d = ModelDownloader()

speech2text = Speech2Text(
    **d.download_and_unpack('Shinji Watanabe/librispeech_asr_train_asr_transformer_e18_raw_bpe_sp_valid.acc.best'),
    # Decoding parameters are not included in the model file
    maxlenratio=0.0,
    minlenratio=0.0,
    beam_size=1,
    ctc_weight=1.0,
    lm_weight=0.0,
    penalty=0.0,
    nbest=1,
    out_mode="default"
)

hook_list = []

# Add register hook for in encoder layers.
net = speech2text.asr_model

In [7]:
speech, rate = soundfile.read(file_name_list[1])

In [8]:
speech2text(speech)

TypeError: forward() missing 2 required positional arguments: 'text' and 'text_lengths'

In [None]:
print(net)

In [None]:
for name, param in net.named_modules():
    print(name)

In [None]:
for i in range(18):
    hook_list.append(apply_hook(net, i, 'encoder', 'self_attn'))

In [None]:
img_list = []
word_num_list = []

for audio in tqdm(range(len(file_name_list))):
    audio_num = audio # selelct one of the wav in file_name_list
    speech, rate = soundfile.read(file_name_list[audio_num])

    out, ctc_out = speech2text(speech)
    ctc_argmax = ctc_out.argmax(2)

    one_hot = torch.zeros_like(ctc_out)
    one_hot.scatter_(2, ctc_argmax.unsqueeze(2), 1.0)
    img = make_grad_cam_img_list(model=net, target_out=ctc_out, target_loss=one_hot, hook_list=hook_list)
    img_list.append(img)
    word_num_list.append(audio_num)

In [None]:
import numpy as np

In [None]:
total_img_np = np.array(img_list)
total_img_np

In [None]:
mean_img = np.mean(total_img_np, axis=0)
print(mean_img[17, 4])
print(mean_img[7, 1])

In [None]:
save_img = []
save_img.append(mean_img)

save_encoder_grad_image(image_list=save_img, target_list=['mean'], audio_num='mean',
                                n_targets=1, PATH=exp_dir + f'/feature_images/encoder_grad_cam/sentence/mean/')