In [None]:
import torch
from utils.Datasets import prepareListwiseDataset

from utils.CollateFunc import NBestSampler, BatchSampler, crossNBestBatch
from torch.utils.data import DataLoader
import json
from tqdm.notebook import tqdm
from utils.PrepareModel import prepareNBestCrossBert, preparePBert
from bertviz import model_view, head_view
import os
from jiwer import visualize_alignment, process_characters
from utils.LoadConfig import load_config
from random import sample

In [None]:
device = torch.device("cuda")
args, train_args, _ = load_config("/mnt/disk6/Alfred/Rescoring/src/RescoreBert/config/PBert.yaml")
PBert_model, tokenizer = preparePBert(
    args,
    train_args,
    device
)

In [None]:
with open("/mnt/disk6/Alfred/Rescoring/data/aishell/data/noLM/dev/data.json") as f:
    data_json = json.load(f)

In [None]:
sample_json = sample(data_json, 2800)

In [None]:
output = tokenizer.batch_encode_plus(sample_json[0]['hyps'], return_tensors = 'pt', padding = True)

In [None]:
PBERT_checkpoint_path = "/mnt/disk6/Alfred/Rescoring/src/RescoreBert/checkpoint/aishell/fromTWCC/PBERT_TWCC/checkpoint_train_best_CER.pt"
checkpoint = torch.load(PBERT_checkpoint_path)

In [None]:
PBert_model.load_state_dict(torch.load(PBERT_checkpoint_path)['model'])
# checkpoint.keys()
PBert_model.eval()

In [None]:
attention = 0
attend_del = 0
attend_ins = 0
attend_sub = 0
total_character_count = 0

total_sub = 0
total_ins = 0
total_del = 0

attend_to_sep = 0
correctCount = 0
correctAttendToSep = 0
total_weight = 0
for data in tqdm(sample_json):
    label_dict = []
    process_ref = "".join(data['ref'].strip().split())
    hyps = ["".join(hyp.strip().split()) for hyp in data["hyps"]]
    refs = ["".join(data['ref'].strip().split()) for _ in range(len(hyps))]
    out = process_characters(
            refs,
            hyps
        )
    result = visualize_alignment(out, show_measures=False, skip_correct=False).split('\n')        
    for i, r in enumerate(result):
        if (i % 5 == 3):
            # print(r)
            label_sequence = r[5:]
            temp_dict = {
            "insertion": [],
            "deletion": [],
            "substitution": []
        }
            correctFlag = True
            for index, label in enumerate(label_sequence):
                if (label == 'S'):
                    total_sub += 1
                    temp_dict['substitution'].append(index)
                    correctFlag = False
                elif (label == 'I'):
                    total_ins += 1
                    temp_dict['insertion'].append(index)
                    correctFlag = False
                elif (label == 'D'):
                    total_del += 1
                    temp_dict['deletion'].append([index, index + 1])
                    correctFlag = False

            if (correctFlag):
                correctCount += 1
            temp_dict['correct'] = correctFlag
            label_dict.append(temp_dict)
    # print(f'correct:{correctCount}')

    bert_tokens = tokenizer.batch_encode_plus(data['hyps'], return_tensors='pt', padding = True).to(device)
    with torch.no_grad():
        output = PBert_model.bert(
            input_ids = bert_tokens['input_ids'],
            attention_mask = bert_tokens['attention_mask'],
            output_attentions = True
        )

    last_attention = output.attentions[-1]
    
    nBest = last_attention.shape[0]
    for attention_mask, attention_map, align_label in zip(bert_tokens['attention_mask'], last_attention.sum(dim = 1), label_dict):
        token_index = attention_map[0][attention_mask.bool()][1:] / 12
        # print(f'token_index:{token_index.shape}')

        # target = torch.argmax(token_index)
        # attend_index = (token_index > 0.15).nonzero()

        # # print(f'attend_index:{attend_index}')
        # total_character_count += token_index.shape[0]
        # for target in attend_index:
        #     attention += 1
        #     # if (align_label['correct']):
        #     #     correctCount += 1
        #     if (target == token_index.shape[0] - 1):
        #         attend_to_sep += 1
        #         if (align_label['correct'] and attend_index.shape[0] == 1):
        #             correctAttendToSep += 1
        #     elif (target in align_label['substitution']):
        #         attend_sub += 1
        #     elif (target in align_label['insertion']):
        #         attend_del += 1
        #     elif ([target - 1, target] in align_label['deletion'] or [target, target + 1] in align_label['deletion']):
        #         attend_ins += 1

        sep_weight = token_index[-1].mean()
        total_weight += sep_weight / nBest
print(total_weight / len(sample_json))

# print(f"total_character_count:{total_character_count}")
# print(f"total attention:{attention}")
# print(f"total_sub:{total_sub}")
# print(f"attend_sub:{attend_sub}\n")
# print(f"total_del:{total_del}")
# print(f"attend_del:{attend_del}\n")
# print(f"total_ins:{total_ins}")
# print(f"attend_ins:{attend_ins}\n")

# print(attend_to_sep)
# print(correctCount)
# print(correctAttendToSep)

# Max

In [None]:
attention = 0
attend_del = 0
attend_ins = 0
attend_sub = 0
total_character_count = 0

sub_hyp = 0
del_hyp = 0
ins_hyp = 0

total_sub = 0
total_ins = 0
total_del = 0
total_si = 0 # substitution & insertion
total_sd = 0 # substitution & deletion
total_di = 0 # deletion & insertion
total_sdi = 0 # three of them


attend_to_sep = 0
correctCount = 0
correctAttendToSep = 0

error_dict = {'S' : 0, 'I': 1, 'D' : 2, 'SDI': 3, 'SD': 4, 'SI': 5, 'DI': 6}
not_attend_dict = {'S' : 0, 'I': 0, 'D' : 0, 'SDI': 0, 'SD': 0, 'SI': 0, 'DI': 0, 'correct': 0}
error_matrix = torch.zeros((7,3), dtype = torch.int32)

for data in tqdm(data_json):
    label_dict = []
    process_ref = "".join(data['ref'].strip().split())
    hyps = ["".join(hyp.strip().split()) for hyp in data["hyps"]]
    refs = ["".join(data['ref'].strip().split()) for _ in range(len(hyps))]
    out = process_characters(
            refs,
            hyps
        )
    result = visualize_alignment(out, show_measures=False, skip_correct=False).split('\n')        
    for i, r in enumerate(result):
        if (i % 5 == 3):
            # print(r)
            label_sequence = r[5:]
            temp_dict = {
            "insertion": [],
            "deletion": [],
            "substitution": []
        }
            correctFlag = 'correct'
            sub_flag = False
            del_flag = False
            ins_flag = False
            for index, label in enumerate(label_sequence):
                if (label == 'S'):
                    total_sub += 1
                    temp_dict['substitution'].append(index)
                    correctFlag = False
                    sub_flag = True
                elif (label == 'I'):
                    total_ins += 1
                    temp_dict['insertion'].append(index)
                    correctFlag = False
                    ins_flag = True
                elif (label == 'D'):
                    total_del += 1
                    temp_dict['deletion'].append([index, index + 1])
                    correctFlag = False
                    del_flag = True

            
            if (sub_flag): 
                if (del_flag):
                    if (ins_flag): # S + D + I
                        total_sdi += 1
                        correctFlag = 'SDI'
                    else: # S + D
                        total_sd += 1
                        correctFlag = 'SD'
                else:
                    if (ins_flag): # S + I
                        total_si += 1
                        correctFlag = 'SI'
                    else: # S
                        sub_hyp += 1
                        correctFlag = 'S'
            else:
                if (del_flag): 
                    if (ins_flag): # D + I
                        total_di += 1
                        correctFlag = 'DI'
                    else: # D
                        del_hyp += 1
                        correctFlag = 'D'
                else:
                    if (ins_flag): # I
                        ins_hyp += 1
                        correctFlag = 'I'
                    else: # None
                        correctCount += 1
            
            temp_dict['correct'] = correctFlag
            label_dict.append(temp_dict)
    # print(f'correct:{correctCount}')

    bert_tokens = tokenizer.batch_encode_plus(data['hyps'], return_tensors='pt', padding = True).to(device)
    output = PBert_model.bert(
        input_ids = bert_tokens['input_ids'],
        attention_mask = bert_tokens['attention_mask'],
        output_attentions = True
    )

    last_attention = output.attentions[-1]

    for attention_mask, attention_map, align_label in zip(bert_tokens['attention_mask'], last_attention.sum(dim = 1), label_dict):
        token_index = attention_map[0][attention_mask.bool()][1:-1] / 12

        target = torch.argmax(token_index)

        total_character_count += token_index.shape[0]
        attention += 1
        # if (target == token_index.shape[0] - 1):
        #     attend_to_sep += 1
        #     if (align_label['correct'] and attend_index.shape[0] == 1):
        #         correctAttendToSep += 1
        if (target in align_label['substitution']):
            attend_sub += 1
            error_matrix[error_dict[align_label['correct']]][error_dict['S']] += 1
        elif (target in align_label['insertion']):
            attend_ins += 1
            error_matrix[error_dict[align_label['correct']]][error_dict['I']] += 1
        elif ([target - 1, target] in align_label['deletion'] or [target, target + 1] in align_label['deletion']):
            attend_del += 1
            error_matrix[error_dict[align_label['correct']]][error_dict['D']] += 1
        else: # not attend
            not_attend_dict[align_label['correct']] += 1

        

print(f"total_character_count:{total_character_count}")
print(f"total attention:{attention}")
print(f"total_sub:{total_sub}")
print(f"attend_sub:{attend_sub}\n")
print(f"total_del:{total_del}")
print(f"attend_del:{attend_del}\n")
print(f"total_ins:{total_ins}")
print(f"attend_ins:{attend_ins}\n")

print(attend_to_sep)
print(correctCount)
print(correctAttendToSep)

print(error_matrix)
print(f'not Attend:\n {not_attend_dict}')

# Over threshold

In [None]:
attention = 0
attend_del = 0
attend_ins = 0
attend_sub = 0
total_character_count = 0

sub_hyp = 0
del_hyp = 0
ins_hyp = 0

total_sub = 0
total_ins = 0
total_del = 0
total_si = 0 # substitution & insertion
total_sd = 0 # substitution & deletion
total_di = 0 # deletion & insertion
total_sdi = 0 # three of them


attend_to_sep = 0
correctCount = 0
correctAttendToSep = 0

attend_dict = {'S': 0, 'I': 0, 'D': 0, 'N': 0 }
missed_dict = {'S': 0, 'I': 0, 'D': 0 } # < threshold 但沒ateend到
# error_matrix = torch.zeros((3,3), dtype = torch.int32)

for data in tqdm(data_json):
    label_dict = []
    process_ref = "".join(data['ref'].strip().split())
    hyps = ["".join(hyp.strip().split()) for hyp in data["hyps"]]
    refs = ["".join(data['ref'].strip().split()) for _ in range(len(hyps))]
    out = process_characters(
            refs,
            hyps
        )
    result = visualize_alignment(out, show_measures=False, skip_correct=False).split('\n')        
    for i, r in enumerate(result):
        if (i % 5 == 3):
            # print(r)
            label_sequence = r[5:]
            temp_dict = {
            "insertion": [],
            "deletion": [],
            "substitution": []
        }
            correctFlag = True
            # sub_flag = False
            # del_flag = False
            # ins_flag = False
            for index, label in enumerate(label_sequence):
                if (label == 'S'):
                    total_sub += 1
                    temp_dict['substitution'].append(index)
                    correctFlag = False
                    # sub_flag = True
                elif (label == 'I'):
                    total_ins += 1
                    temp_dict['insertion'].append(index)
                    correctFlag = False
                    # ins_flag = True
                elif (label == 'D'):
                    total_del += 1
                    temp_dict['deletion'].append([index, index + 1])
                    correctFlag = False
                    # del_flag = True

            if (correctFlag):
                correctCount += 1
            
            temp_dict['correct'] = correctFlag
            label_dict.append(temp_dict)

    bert_tokens = tokenizer.batch_encode_plus(data['hyps'], return_tensors='pt', padding = True).to(device)
    output = PBert_model.bert(
        input_ids = bert_tokens['input_ids'],
        attention_mask = bert_tokens['attention_mask'],
        output_attentions = True
    )

    last_attention = output.attentions[-1]

    for attention_mask, attention_map, align_label in zip(bert_tokens['attention_mask'], last_attention.sum(dim = 1), label_dict):
        token_index = attention_map[0][attention_mask.bool()][1:-1] / 12
        length = attention_mask.shape[0]
        threshold = 1 / length

        # attend_index = (token_index > 0.2).nonzero()

        total_character_count += token_index.shape[0]
        for target, attention_weight in enumerate(token_index):
            if (attention_weight >= threshold):
                attention += 1
                if (target in align_label['substitution']):
                    attend_dict['S'] += 1
                elif (target in align_label['insertion']):
                    attend_dict['I'] += 1
                elif ([target - 1, target] in align_label['deletion'] or [target, target + 1] in align_label['deletion']):
                    attend_dict['D'] += 1
                else: # not attend
                    attend_dict['N'] += 1
            else:
                if (target in align_label['substitution']):
                    missed_dict['S'] += 1
                elif (target in align_label['insertion']):
                    missed_dict['I'] += 1
                elif ([target - 1, target] in align_label['deletion'] or [target, target + 1] in align_label['deletion']):
                    missed_dict['D'] += 1

print(f"total_character_count:{total_character_count}")
print(f"total attention over threshold:{attention}")
print(f"total_sub:{total_sub}")
print(f"total_ins:{total_ins}")
print(f"total_del:{total_del}")

print(f"attend dict:\n{attend_dict}")

print(f"missed dict:\n{missed_dict}")

'S' : 0, 'I': 1, 'D' : 2, 'SDI': 3, 'SD': 4, 'SI': 5, 'DI': 6

In [None]:
print(f'total_sub:{sub_hyp + total_si + total_sd + total_sdi}')
print(f'total_sub_ins:{total_si}')
print(f'total_sub_del:{total_sd}')
print(f'total_sub_del_ins:{total_sdi}')

print((attend_sub / (sub_hyp + total_si + total_sd + total_sdi)))

print(f'total_del:{del_hyp + total_di + total_sd + total_sdi}')
print(attend_del / (del_hyp + total_di + total_sd + total_sdi))

print(f'total_ins:{ins_hyp + total_di + total_si + total_sdi}')
print(attend_ins / (ins_hyp + total_di + total_si + total_sdi))

In [None]:

print(f"total_character_count:{total_character_count}")
print(f"attention:{attention}")
print(f"attend_sub:{attend_sub}")
print(f"attend_del:{attend_del}")
print(f"attend_ins:{attend_ins}")

print(f"total_sub:{total_sub}")
print(f"total_ins:{total_ins}")
print(f"total_del:{total_del}")

In [None]:
attention = 0
attend_del = 0
attend_ins = 0
attend_sub = 0
total_character_count = 0

total_sub = 0
total_ins = 0
total_del = 0

attend_to_sep = 0
correctCount = 0
correctAttendToSep = 0

for data in tqdm(data_json):
    label_dict = []
    process_ref = "".join(data['ref'].strip().split())
    hyps = ["".join(hyp.strip().split()) for hyp in data["hyps"]]
    refs = ["".join(data['ref'].strip().split()) for _ in range(len(hyps))]
    out = process_characters(
            refs,
            hyps
        )
    result = visualize_alignment(out, show_measures=False, skip_correct=False).split('\n')        
    for i, r in enumerate(result):
        if (i % 5 == 3):
            # print(r)
            label_sequence = r[5:]
            temp_dict = {
            "insertion": [],
            "deletion": [],
            "substitution": []
        }
            correctFlag = True
            for index, label in enumerate(label_sequence):
                if (label == 'S'):
                    total_sub += 1
                    temp_dict['substitution'].append(index)
                    correctFlag = False
                elif (label == 'I'):
                    total_ins += 1
                    temp_dict['insertion'].append(index)
                    correctFlag = False
                elif (label == 'D'):
                    total_del += 1
                    temp_dict['deletion'].append([index, index + 1])
                    correctFlag = False

            if (correctFlag):
                correctCount += 1
            temp_dict['correct'] = correctFlag
            label_dict.append(temp_dict)
    # print(f'correct:{correctCount}')

    bert_tokens = tokenizer.batch_encode_plus(data['hyps'], return_tensors='pt', padding = True).to(device)
    output = PBert_model.bert(
        input_ids = bert_tokens['input_ids'],
        attention_mask = bert_tokens['attention_mask'],
        output_attentions = True
    )

    last_attention = output.attentions[-1]

    for attention_mask, attention_map, align_label in zip(bert_tokens['attention_mask'], last_attention.sum(dim = 1), label_dict):
        token_index = attention_map[0][attention_mask.bool()][1:] / 12
        # print(f'token_index:{token_index.shape}')

        target = torch.argmax(token_index)
        # attend_index = (token_index > 0.2).nonzero()

        # print(f'attend_index:{attend_index}')
        total_character_count += token_index.shape[0]
        # for target in attend_index:
        attention += 1
        # if (align_label['correct']):
        #     correctCount += 1
        if (target == token_index.shape[0] - 1):
            attend_to_sep += 1
            if (align_label['correct']): # and attend_index.shape[0] == 1):
                correctAttendToSep += 1
        elif (target in align_label['substitution']):
            attend_sub += 1
        elif (target in align_label['insertion']):
            attend_del += 1
        elif ([target - 1, target] in align_label['deletion'] or [target, target + 1] in align_label['deletion']):
            attend_ins += 1

print(f"total_character_count:{total_character_count}")
print(f"total attention:{attention}")
print(f"total_sub:{total_sub}")
print(f"attend_sub:{attend_sub}\n")
print(f"total_del:{total_del}")
print(f"attend_del:{attend_del}\n")
print(f"total_ins:{total_ins}")
print(f"attend_ins:{attend_ins}\n")

In [None]:
print(attend_to_sep)
print(correctCount)
print(correctAttendToSep)