In [11]:
import os
import sys
import json
import math
import logging
from argparse import ArgumentParser

import torch
from espnet.utils.io_utils import LoadInputsAndTargets
from espnet.asr.pytorch_backend.asr import load_trained_model
from espnet.nets.pytorch_backend.disentangled_transformer.add_sos_eos import add_sos_eos
from espnet.nets.pytorch_backend.disentangled_transformer.mask import target_mask
from espnet.nets.pytorch_backend.disentangled_transformer.attention import DisentangledMaskAttention

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

from utils import calculate_redundancy, plot_token_distribution
from matplotlib.font_manager import FontProperties

logging.basicConfig(level="INFO", stream=sys.stdout)
fontproperties = FontProperties(fname="../../../../espnet/nets/pytorch_backend/disentangled_transformer/NotoSansCJK-Regular.ttc")

%matplotlib inline

In [12]:
parser = ArgumentParser()
parser.add_argument("--model")
parser.add_argument("--recog_json")
parser.add_argument("--batchsize", type=int)
parser.add_argument("--preprocess_conf", default=None, type=str)

tag = "train_sp_pytorch_disentangled_transformer-6-1024-lr-2"
# tag = "train_sp_pytorch_transformer-6-256-h-4"
exp_dir = "/root/Disentangled-Transformer-ASR/egs/aishell/asr1/exp"
ckpt_dir = os.path.join(exp_dir, tag, "results")
figure_dir = os.path.join(exp_dir, tag, "figure")

args = parser.parse_args([
    "--model",
    os.path.join(ckpt_dir, "model.last3.avg.best"),
    "--recog_json",
    "/root/Disentangled-Transformer-ASR/egs/aishell/asr1//dump/test/deltafalse/split1utt/data.1.json",
    "--batchsize",
    "64"
])

if not os.path.exists(figure_dir):
    os.makedirs(figure_dir)

In [13]:
model, train_args = load_trained_model(args.model)
model.eval()
load_inputs_and_targets = LoadInputsAndTargets(
    mode="asr",
    load_output=False,
    sort_in_input_length=False,
    preprocess_conf=train_args.preprocess_conf
    if args.preprocess_conf is None
    else args.preprocess_conf,
    preprocess_args={"train": False},
)

INFO:root:reading a config file from /root/Disentangled-Transformer-ASR/egs/aishell/asr1/exp/train_sp_pytorch_disentangled_transformer-6-1024-lr-2/results/model.json
INFO:root:encoder self-attention layer type = self-attention
INFO:root:decoder self-attention layer type = self-attention


Read json

In [14]:
with open(args.recog_json, "rb") as f:
    js = json.load(f)["utts"]

In [15]:
lr_list = []
hr_list = []

device = torch.device("cuda")
model = model.to(device)

with torch.no_grad():
    for idx, name in enumerate(js.keys(), 1):
        print("\rEvaulating {}/{} ...".format(idx+1, len(js.keys())), end="")
        batch = [(name, js[name])]
        feat = load_inputs_and_targets(batch)[0][0]
        hs_pad, hs_mask = model.encoder(torch.as_tensor(feat).to(device=device, dtype=torch.float32).unsqueeze(0), None)

        tgt = batch[0][1]["output"][0]["tokenid"]
        ys_in_pad, ys_out_pad = add_sos_eos(
            torch.as_tensor([int(_tgt) for _tgt in tgt.split(" ")]).unsqueeze(0),
            model.sos, model.eos, model.ignore_id
        )
        ys_mask = target_mask(ys_in_pad.to(device), model.ignore_id)
        pred_pad, pred_mask = model.decoder(ys_in_pad.to(device), ys_mask, hs_pad, hs_mask)

        enc_len = hs_pad.size()[1]
        dec_len = ys_in_pad.size()[1]

        _, lr, hr = calculate_redundancy(model)
        
        lr_list.append(lr)
        hr_list.append(hr)

print("")
print("Layer redundancy: {:.4f}".format(np.mean(lr_list)))
print("Head redundancy: {:.4f}".format(np.mean(hr_list)))

Evaulating 7177/7176 ...
Layer redundancy: 0.5249
Head redundancy: 0.3966
