In [None]:
import torch
from tqdm import tqdm
from transformers import BertTokenizer
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import json
import math

In [None]:
loss = torch.load('./log/Audio_Aware/loss.pt')
loss

In [None]:
dev_loss = [l.item() for l in loss["dev_loss"]]
dev_loss

plt.plot(range(1,11), dev_loss, 'go')

In [None]:
dev_loss = [l.item() for l in loss["dev_loss"]]
dev_loss

plt.plot(range(1,11), dev_loss, 'go')

In [None]:
step = [200 * k for k in range(1, 25)] + [5005]
step = np.array(step)
for i in range(2, 10):
    step = np.concatenate((step, single_step * i))

In [None]:
step.shape

In [None]:
train_loss = [l.item() for l in loss["training_loss"]]
train_loss

plt.plot(step, train_loss, 'go')

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

In [None]:
device = 'cuda:0'if torch.cuda.is_available() else 'cpu'
device = torch.device(device)

In [None]:
model = RescoreBert(
    train_batch=1,
    test_batch=1,
    nBest=10,
    use_MWER=True,
    use_MWED=False,
    device=device,
    lr=1e-4,
    weight=0.59
)

In [None]:
state_dict = torch.load('./checkpoint/MWER/checkpoint_train_4.pt')
model.model.load_state_dict(state_dict["state_dict"])

In [None]:
# token = '甚 至 出 现 交 易 几 乎 停 止 的 情 况'.split()
# token = '甚 至 出 现 交 易 几 乎 停 滞 的 情 况'.split()
token_ref = '楼 市 调 控 的 行 政 手 段 宜 减 不 宜 加'.split()
token_hyp = '楼 市 调 控 的 行 政 手 段 意 见 不 一 一'.split()
token_hyp_2 = '楼 市 调 控 的 行 政 手 段 意 见 不 一'.split()
token_id_ref = tokenizer.convert_tokens_to_ids(token_ref)
token_id_hyp = tokenizer.convert_tokens_to_ids(token_hyp)
token_id_hyp_2 = tokenizer.convert_tokens_to_ids(token_hyp_2)
token_id_ref = [101] + token_id_ref + [102]
token_id_hyp = [101] + token_id_hyp + [102]
token_id_hyp_2 = [101] + token_id_hyp_2 + [102]

In [None]:
cos = torch.nn.CosineSimilarity()

In [None]:
input_id_ref = torch.tensor(token_id_ref).unsqueeze(0).to(device)
input_id_hyp = torch.tensor(token_id_hyp).unsqueeze(0).to(device)
input_id_hyp_2 = torch.tensor(token_id_hyp_2).unsqueeze(0).to(device)

print(input_id_hyp)
print(input_id_hyp_2)
with torch.no_grad():
    output_ref = model.model(
        input_ids = input_id_ref,
    )[0]
    output_hyp = model.model(
        input_ids = input_id_hyp,
    )[0]

    output_hyp_2 = model.model(
        input_ids = input_id_hyp_2,
    )[0]

    sim = cos(output_ref.squeeze(0)[0], output_hyp_2.squeeze(0)[0])
    print(sim)

In [None]:
distribution_dict = dict()
recog_set = ['dev', 'test']
for task in recog_set:
    distribution_dict[task] = {
        1:0,
        0.95:0,
        0.75:0,
        0.5:0,
        0.25:0,
        0:0,
        'same':0
    }

In [None]:
cos = torch.nn.CosineSimilarity(0)
for task in recog_set:
    print(task)
    data = None
    with open(f'./data/aishell_{task}/rescore/MD_rescore_data.json') as f:
        data = json.load(f)
    
    for k in data['utts'].keys():
        
        token_ref = data['utts'][k]['output']['text_token'].split()
        token_hyp = data['utts'][k]['output']['rec_token'].split()  

        if (token_ref == token_hyp):
            distribution_dict[task]['same'] += 1
            continue

        token_id_ref = tokenizer.convert_tokens_to_ids(token_ref)
        token_id_hyp = tokenizer.convert_tokens_to_ids(token_hyp)
        token_id_ref = [101] + token_id_ref + [102]
        token_id_hyp = [101] + token_id_hyp + [102]

        input_id_ref = torch.tensor(token_id_ref).unsqueeze(0).to(device)
        input_id_hyp = torch.tensor(token_id_hyp).unsqueeze(0).to(device)
        with torch.no_grad():
            output_ref = model.model(
                input_ids = input_id_ref,
            )[0]

            output_hyp = model.model(
                input_ids = input_id_hyp,
            )[0]

        output_ref = output_ref.squeeze(0)[0]
        output_hyp = output_hyp.squeeze(0)[0]

        sim = cos(output_ref, output_hyp)
        if (sim >= 1):
            distribution_dict[task][1] += 1
        elif (sim < 1 and sim >= 0.95):
            distribution_dict[task][0.95] += 1
        elif (sim < 0.95 and sim >= 0.75):
            distribution_dict[task][0.75] += 1
        elif (sim < 0.75 and sim >= 0.5):
            distribution_dict[task][0.5] += 1
        elif (sim < 0.5 and sim >= 0.25):
            distribution_dict[task][0.25] += 1
        elif (sim < 0.25):
            distribution_dict[task][0] += 1

In [None]:
distribution_dict['dev']

In [None]:
distribution_dict['test']

In [None]:
task = 'test'
labels = ["Sim == 1", "0.95 <= Sim < 1", "0.75 <= Sim < 0.95", "0.5 <= Sim < 0.75", "0.25 <= Sim < 0.5", "Sim < 0.25", "Same"]
value = list(distribution_dict[task].values())
fig, axe = plt.subplots()
patches, texts = plt.pie(
    value,
    startangle=90, radius=1.2 
    )

legend_labels = ['{} : {}'.format(k, v) for k, v in zip(labels, value)]

# sort_legend = True
# if sort_legend:
#     patches, labels, dummy =  zip(*sorted(zip(patches, labels, list(distribution_dict['test'].values())),
#                                           key=lambda x: x[2],
#                                           reverse=True))

plt.legend(patches, legend_labels, loc = 'best', bbox_to_anchor=(-0.1, 1), fontsize=8)

plt.show()

In [None]:
model.eval()
input_id = torch.tensor(token_id_hyp).unsqueeze(0)
with torch.no_grad():
    attention_map = model.model(
        input_ids = input_id,
        output_attentions = True
    ).attentions


In [None]:
input_id_ref = torch.tensor(token_id_ref).unsqueeze(0)
input_id_hyp = torch.tensor(token_id_hyp).unsqueeze(0)
with torch.no_grad():
    output_ref = model.model(
        input_ids = input_id_ref,
    )[0]

    output_hyp = model.model(
        input_ids = input_id_hyp,
    )[0]

output_ref = output_ref.squeeze(0)[0]
output_hyp = output_hyp.squeeze(0)[0]

In [None]:

print(cos(output_ref.squeeze(0)[0], output_hyp.squeeze(0)[0]))

In [None]:
fig, axe = plt.subplots(figsize = (10,10))
total_map = np.zeros(attention_map[0][0][0].shape)
for i in range(len(attention_map)):
    np_map = attention_map[i].numpy().squeeze(0)
    np_map = np_map[-1]
    np_map = np_map * 1e8
    np_map = np_map.astype(int)
    np_map = np_map / np.linalg.norm(np_map)
    total_map += np_map
# total_map = np.insert(total_map, 10, 0, axis = 1)
# total_map = np.insert(total_map, 10, 0, axis = 0)
c = axe.pcolor(total_map, vmin = 0, vmax = 1)
fig.colorbar(c, ax = axe)
plt.show()