In [None]:
import torch
from utils.Datasets import prepareListwiseDataset

from utils.CollateFunc import NBestSampler, BatchSampler, crossNBestBatch
from torch.utils.data import DataLoader
import json
from tqdm.notebook import tqdm

# PBERT Visualization

In [None]:
from utils.PrepareModel import prepareNBestCrossBert, preparePBert
from bertviz import model_view, head_view
import os
from jiwer import visualize_alignment, process_characters
from utils.LoadConfig import load_config

In [None]:
PBERT_checkpoint_path = "/work/jason90255/Rescoring/src/RescoreBert/checkpoint/aishell/NBestCrossBert/noLM/PBERT/50best/RescoreBert_PBERT_batch256_lr1e-7_Freeze-1_HardLabel_Entropy/checkpoint_train_best_CER.pt"

In [None]:
args, train_args, _ = load_config("/work/jason90255/Rescoring/src/RescoreBert/config/PBert.yaml")

In [None]:
# os.environ['CUDA_VISIBLE_DEVICES']='3'
device = torch.device("cuda")
PBert_model, tokenizer = preparePBert(
    args,
    train_args,
    device
)
PBert_model = PBert_model.to(device)

In [None]:
checkpoint = torch.load(PBERT_checkpoint_path)
checkpoint.keys()

In [None]:
PBert_model.load_state_dict(torch.load(PBERT_checkpoint_path)['model'])
# checkpoint.keys()
PBert_model.eval()

In [None]:
hyp_1 = "".join("但 因 为 聚 集 了 过 多 公 共 思 源".split())
hyp_2 = "".join("但 因 为 聚 集 了 过 多 公 四 元".split())
hyp_3 = "但 因 为 聚 集 了 过 多 公 共 思 员"
ref = "".join("但 因 为 聚 集 了 过 多 公 共 资 源".split())

In [None]:
hyp_1 = "但 因 为 聚 集 了 过 多 公 共 思 源"

In [None]:
out = process_characters(
    [ref, ref],
    [hyp_1, hyp_2],
)

In [None]:
print(type(visualize_alignment(out)))
result = visualize_alignment(out, show_measures=False, skip_correct=False).split('\n')
label_align = []

for i, r in enumerate(result):
    if (i % 5 == 3):
        label_sequence = r[5:]
        labels = {
            "insertion": [],
            "deletion": [],
            "substitution": []
        }
        for index, label in enumerate(label_sequence):
            if (label == 'S'):
                labels['substitution'].append(index)
            elif (label == 'D'):
                labels['deletion'].append(index)
            elif (label == 'I'):
                labels['insertion'].append([index, index + 1])
        
        label_align.append(labels)

label_align

In [None]:
len("HYP: 但因为聚集了过多公四元*"[6:])
"HYP: 但因为聚集了过多公四元*"[5:]

In [None]:
hyp_ids = tokenizer.encode(hyp_1, return_tensors='pt')

output = PBert_model.bert(
    input_ids = hyp_ids,
    output_attentions = True
)

# attention = output.attentions

In [None]:
attention = 0
attend_sub = 0
attend_del = 0
attend_ins = 0
attn_weight = output.attentions[-1].sum(dim = 1)[0][0][1:-1] / 12 > 0.05
for i, weight in enumerate(attn_weight):
    if (weight):
        attention += 1
        if (i in labels['substitution']):
            attend_sub += 1
        elif (i in labels['deletion']):
            attend_del += 1
        elif ([i, i + 1] in labels['inserion']):
            attend_ins += 1

print(f"attention weight over threshold:{attention}")
print(f"attend_sub:{attend_sub}")
print(f"attend_del:{attend_del}")
print(f"attend_ins:{attend_ins}")

In [None]:
task = ['train', 'dev', 'test']

In [None]:
with open("/work/jason90255/Rescoring/data/aishell/data/noLM/dev/data.json") as f:
    data_json = json.load(f)

In [None]:
hyps = data_json[0]['hyps']
hyps
tokenizer.batch_encode_plus(hyps)

# Multiple Batch

In [None]:
attention = 0
attend_del = 0
attend_ins = 0
attend_sub = 0
total_character_count = 0
for data in tqdm(data_json):
    label_dict = {
            "insertion": [],
            "deletion": [],
            "substitution": []
        }
    process_ref = "".join(data['ref'].strip().split())
    hyps = ["".join(hyp.strip().split()) for hyp in data["hyps"]]
    refs = ["".join(data['ref'].strip().split()) for _ in range(len(hyps))]
    out = process_characters(
            refs,
            hyps
        )
    result = visualize_alignment(out, show_measures=False, skip_correct=False).split('\n')        
    for i, r in enumerate(result):
        if (i % 5 == 3):
            label_sequence = r[5:] 
            for index, label in enumerate(label_sequence):
                if (label == 'S'):
                    label_dict['substitution'].append(index)
                elif (label == 'I'):
                    label_dict['insertion'].append(index)
                elif (label == 'D'):
                    label_dict['deletion'].append([index, index + 1])
        
        bert_tokens = tokenizer.batch_encode_plus(data['hyps'], return_tensors='pt', padding = True).to(device)
        output = PBert_model.bert(
            input_ids = bert_tokens['input_ids'],
            attention_mask = bert_tokens['attention_mask'],
            output_attentions = True
        )

        print(len(output.attentions))
        
        last_attention = output.attentions[-1]

        print(last_attention.shape)

        break

#         attn_weight = output.attentions[-1].sum(dim = 1)[0][0][1:-1] / 12 > 0.05 # CLS attention weight > 0.05 except that attentd to CLS and SEP
#         for i, weight in enumerate(attn_weight):
#             total_character_count += 1
#             if (weight):
#                 attention += 1
#                 if (i in label_dict['substitution']):
#                     attend_sub += 1
#                 elif (i in label_dict['insertion']):
#                     attend_del += 1
#                 elif ([i, i + 1] in label_dict['deletion']):
#                     attend_ins += 1

# print(f"total_attention:{total_character_count}")
# print(f"attention weight over threshold:{attention}")
# print(f"attend_sub:{attend_sub}")
# print(f"attend_del:{attend_del}")
# print(f"attend_ins:{attend_ins}")

# Single batch

In [None]:
attention = 0
attend_del = 0
attend_ins = 0
attend_sub = 0
total_character_count = 0
for data in tqdm(data_json):
    process_ref = "".join(data['ref'].strip().split())
    for hyp in data['hyps']:
        process_hyp = "".join(hyp.strip().split())

        out = process_characters(
            [process_ref],
            [process_hyp]
        )

        result = visualize_alignment(out, show_measures=False, skip_correct=False).split('\n')

        label_dict = {
            "insertion": [],
            "deletion": [],
            "substitution": []
        }
        
        for i, r in enumerate(result):
            if (i % 5 == 3):
                label_sequence = r[5:] 
                for index, label in enumerate(label_sequence):
                    if (label == 'S'):
                        label_dict['substitution'].append(index)
                    elif (label == 'I'):
                        label_dict['insertion'].append(index)
                    elif (label == 'D'):
                        label_dict['deletion'].append([index, index + 1])
        
        hyp_ids = tokenizer.encode(hyp, return_tensors='pt').to(device)
        output = PBert_model.bert(
            input_ids = hyp_ids,
            output_attentions = True
        )

        attn_weight = output.attentions[-1].sum(dim = 1)[0][0][1:-1] / 12 > 0.05 # CLS attention weight > 0.05 except that attentd to CLS and SEP
        for i, weight in enumerate(attn_weight):
            total_character_count += 1
            if (weight):
                attention += 1
                if (i in label_dict['substitution']):
                    attend_sub += 1
                elif (i in label_dict['insertion']):
                    attend_del += 1
                elif ([i, i + 1] in label_dict['deletion']):
                    attend_ins += 1

print(f"total_attention:{total_character_count}")
print(f"attention weight over threshold:{attention}")
print(f"attend_sub:{attend_sub}")
print(f"attend_del:{attend_del}")
print(f"attend_ins:{attend_ins}")
        

In [None]:
attend_sub / attention
attend_del / attention
(attend_sub + attend_del + attend_ins) / attention

# NBest Bert Visualization

In [None]:
NBestBert = "/mnt/disk6/Alfred/Rescoring/src/RescoreBert/checkpoint/aishell/NBestCrossBert/noLM/Normal_lstm_KL_sortByLength_dropout0.3_seed42/50best/batch256_lr1e-7_freeze-1/checkpoint_train_best_CER.pt"

In [None]:
checkpoint = torch.load(NBestBert)
checkpoint['model'].keys()

In [None]:
# os.environ['CUDA_VISIBLE_DEVICES']='1'
device = torch.device("cpu")
NBest_model, tokenizer = prepareNBestCrossBert(
    'aishell',
    device,
    lstm_dim = 1024,
    useNbestCross = False,
    lossType = 'KL',
    concatCLS = False
)
NBest_model.load_state_dict(checkpoint['model'])
NBest_model = NBest_model.to(device)

In [None]:
hyp_1 = "但 因 为 聚 集 了 过 多 公 共 思 源"
hyp_2 = "但 因 为 聚 集 了 过 多 公 共 四 元"
hyp_3 = "但 因 为 聚 集 了 过 多 公 共 思 员"
ref = "但 因 为 聚 集 了 过 多 公 共 资 源"

In [None]:
hyp_1 = "二 人 一 直 先 少 回 应"
hyp_2 = "二 零 一 直 先 少 回 应"
hyp_3 = "二 零 一 直 鲜 少 回 应"
ref = "二 人 一 直 鲜 少 回 应"

In [None]:
hyp_1 = "北 京 申 办 冬 奥 影 响 远 超 申 办 本 身" # PBert
ref = "北 京 申 办 冬 奥 影 响 远 超 承 办 本 身" # NBestBert

In [None]:
hyp_1 = "在 世 锦 赛 决 赛 减 路 前 突 感 不 适" # NBestBert
ref = "在 世 锦 赛 决 赛 检 录 前 突 感 不 适" # PBert

In [None]:
hyps = [hyp_2,  ref]
index = 0
print(hyps)

# NBestBert

In [None]:
hyp_ids = tokenizer.encode(hyps[index], return_tensors='pt')

output = PBert_model.bert(
    input_ids = hyp_ids,
    output_attentions = True
)
attention = output.attentions

tokens = tokenizer.convert_ids_to_tokens(hyp_ids[0].tolist())
single_attention = [att for att in attention]
print(single_attention[-1].shape)
print(" ".join(tokens))
head_view(attention = single_attention, tokens = tokens) #, html_action='return')

In [None]:
hyp_ids = tokenizer.encode(hyps[index], return_tensors='pt')

output = NBest_model.bert(
    input_ids = hyp_ids,
    output_attentions = True
)
attention = output.attentions

tokens = tokenizer.convert_ids_to_tokens(hyp_ids[0].tolist())
single_attention = [att for att in attention]
print(single_attention[-1].shape)
print(" ".join(tokens))
head_view(attention = single_attention, tokens = tokens) #, html_action='return')

In [None]:
hyp_ids = tokenizer.encode(hyps[index], return_tensors='pt')

output = model.bert(
    input_ids = hyp_ids,
    output_attentions = True
)
attention = output.attentions

tokens = tokenizer.convert_ids_to_tokens(hyp_ids[0].tolist())
single_attention = [att for att in attention]
print(single_attention[-1].shape)
print(" ".join(tokens))
head_view(attention = single_attention, tokens = tokens) #, html_action='return')

In [None]:
from transformers import BertModel
raw_model = BertModel.from_pretrained('bert-base-chinese')

In [None]:
hyp_ids = tokenizer.encode(hyps[index], return_tensors='pt')

output = raw_model(
    input_ids = hyp_ids,
    output_attentions = True
)
attention = output.attentions

tokens = tokenizer.convert_ids_to_tokens(hyp_ids[0].tolist())
single_attention = [att for att in attention]
print(single_attention[-1].shape)
print(" ".join(tokens))
head_view(attention = single_attention, tokens = tokens) #, html_action='return')

# MaskEMbedBert

In [None]:
MaskBertcheckpoint = "/mnt/disk6/Alfred/Rescoring/src/RescoreBert/checkpoint/aishell/NBestCrossBert/noLM/Normal_query_KL_sortByLength_concatMask_dropout0.3_seed42/50best/batch256_lr1e-7_freeze-1/checkpoint_train_best_CER.pt"
MaskAfterBertcheckpoint = "/mnt/disk6/Alfred/Rescoring/src/RescoreBert/checkpoint/aishell/NBestCrossBert/noLM/Normal_query_KL_sortByLength_concatMaskAfter_dropout0.3_seed42/50best/batch256_lr1e-7_freeze-1/checkpoint_train_best_CER.pt"
ConcatMaskCheckpoint = "/mnt/disk6/Alfred/Rescoring/src/RescoreBert/checkpoint/aishell/NBestCrossBert/noLM/Normal_query_KL_sortByLength_concatCLS_concatMaskAfter_dropout0.3_seed42/50best/batch256_lr1e-7_freeze-1/checkpoint_train_best_CER.pt"

In [None]:
MaskBert, tokenizer = prepareNBestCrossBert(
    dataset = 'aishell',
    device = device, 
    fuseType = 'query',
    concatCLS = False
)

MaskAfterBert, tokenizer = prepareNBestCrossBert(
    dataset = 'aishell',
    device = device, 
    fuseType = 'query',
    concatCLS = False
)

MaskConcatBert, tokenizer = prepareNBestCrossBert(
    dataset = 'aishell',
    device = device, 
    fuseType = 'query',
    concatCLS = True
)

In [None]:
MaskBert.load_state_dict(torch.load(MaskBertcheckpoint)['model'])
MaskAfterBert.load_state_dict(torch.load(MaskAfterBertcheckpoint)['model'])
MaskConcatBert.load_state_dict(torch.load(ConcatMaskCheckpoint)['model'])

In [None]:
hyp_1 = "但 因 为 聚 集 了 过 多 公 共 思 源"
hyp_2 = "但 因 为 聚 集 了 过 多 公 共 四 元"
hyp_3 = "但 因 为 聚 集 了 过 多 公 共 思 员"
ref = "但 因 为 聚 集 了 过 多 公 共 资 源"

In [None]:
hyp_1_mask = hyp_1 + "[MASK]"
hyp_2_mask = hyp_2 + "[MASK]"
hyp_3_mask = hyp_3 + "[MASK]"
ref_mask = ref + "[MASK]"

In [None]:
hyps = [hyp_1, hyp_2, hyp_3, ref]
index = 3

# MaskAfterBert`

In [None]:
hyp_ids = tokenizer.encode(hyps[index], return_tensors='pt')
mask = tokenizer.convert_tokens_to_ids(["[MASK]"])
mask = torch.tensor(mask).unsqueeze(0)
hyp_ids = torch.cat([hyp_ids, mask], dim = -1)

output = MaskAfterBert.bert(
    input_ids = hyp_ids,
    output_attentions = True
)
attention = output.attentions

tokens = tokenizer.convert_ids_to_tokens(hyp_ids[0].tolist())
single_attention = [att for att in attention]
print(single_attention[-1].shape)
print(" ".join(tokens))
head_view(attention = single_attention, tokens = tokens) #, html_action='return')

# MaskBert

In [None]:
hyps_mask = [hyp_1_mask, hyp_2_mask, hyp_3_mask, ref_mask]
index = 3

In [None]:
hyp_ids = tokenizer.encode(hyps_mask[index], return_tensors='pt')

output = MaskAfterBert.bert(
    input_ids = hyp_ids,
    output_attentions = True
)
attention = output.attentions

tokens = tokenizer.convert_ids_to_tokens(hyp_ids[0].tolist())
single_attention = [att for att in attention]
print(single_attention[-1].shape)
print(" ".join(tokens))
head_view(attention = single_attention, tokens = tokens) #, html_action='return')

# Mask Concat Bert

In [None]:
hyps_mask = [hyp_1_mask, hyp_2_mask, hyp_3_mask, ref_mask]
index = 3

In [None]:
hyp_ids = tokenizer.encode(hyps_mask[index], return_tensors='pt')

output = MaskConcatBert.bert(
    input_ids = hyp_ids,
    output_attentions = True
)
attention = output.attentions

tokens = tokenizer.convert_ids_to_tokens(hyp_ids[0].tolist())
single_attention = [att for att in attention]
print(single_attention[-1].shape)
print(" ".join(tokens))
head_view(attention = single_attention, tokens = tokens) #, html_action='return')

In [None]:
PBertJson = "/mnt/disk6/Alfred/Rescoring/data/result/aishell/noLM/test/NBestCrossBert_PBERT_result.json"
NBestCrossJson = "/mnt/disk6/Alfred/Rescoring/data/result/aishell/noLM/test/NBestCrossBert_lstm_KL_freeze-1_BestCER_result.json"
import json

In [None]:
with open(PBertJson) as P, open(NBestCrossJson) as N:
    P_result = json.load(P)
    N_result = json.load(N)
    
    
    for p, n in zip(P_result, N_result):
        if (p['check_1'] == 'Error' and n['check_1'] == 'Correct'):
            print(f"sit 1 :\npBert:{p['rescore_hyps']}\nNBestBert:{n['rescore_hyps']}")
        elif (p['check_1'] == 'Correct' and n['check_1'] == 'Error'):
            print(f"sit 2 :\npBert:{p['rescore_hyps']}\nNBestBert:{n['rescore_hyps']}")