In [None]:
import torch
from utils.Datasets import prepareListwiseDataset

from utils.CollateFunc import NBestSampler, BatchSampler, crossNBestBatch
from torch.utils.data import DataLoader
import json
from tqdm.notebook import tqdm
from utils.PrepareModel import prepareNBestCrossBert, preparePBert
from bertviz import model_view, head_view
import os
from jiwer import visualize_alignment, process_characters
from utils.LoadConfig import load_config
from random import sample

In [None]:
device = torch.device("cpu")
args, train_args, _ = load_config("/work/jason90255/Rescoring/src/RescoreBert/config/PBert.yaml")
PBert_model, tokenizer = preparePBert(
    args,
    train_args,
    device
)

In [None]:
with open("/work/jason90255/Rescoring/data/aishell/data/noLM/dev/data.json") as f:
    data_json = json.load(f)

In [None]:
sample_json = sample(data_json, 2800)

In [None]:
output = tokenizer.batch_encode_plus(sample_json[0]['hyps'], return_tensors = 'pt', padding = True)

In [None]:
PBERT_checkpoint_path = "/work/jason90255/Rescoring/src/RescoreBert/checkpoint/aishell/NBestCrossBert/noLM/PBERT/50best/RescoreBert_PBERT_batch256_lr1e-7_Freeze-1_HardLabel_Entropy/checkpoint_train_best_CER.pt"
checkpoint = torch.load(PBERT_checkpoint_path)

In [None]:
PBert_model.load_state_dict(torch.load(PBERT_checkpoint_path)['model'])
# checkpoint.keys()
PBert_model.eval()

In [None]:
attention = 0
attend_del = 0
attend_ins = 0
attend_sub = 0
total_character_count = 0

total_sub = 0
total_ins = 0
total_del = 0

attend_to_sep = 0
correct = 0
wrong = 0
attend_sep_correct = 0
attend_sep_wrong = 0

for data in tqdm(sample_json):
    label_dict = []
    process_ref = "".join(data['ref'].strip().split())
    hyps = ["".join(hyp.strip().split()) for hyp in data["hyps"]]
    refs = ["".join(data['ref'].strip().split()) for _ in range(len(hyps))]
    out = process_characters(
            refs,
            hyps
        )
    result = visualize_alignment(out, show_measures=False, skip_correct=False).split('\n')        
    for i, r in enumerate(result):
        if (i % 5 == 3):
            # print(r)
            label_sequence = r[5:]
            temp_dict = {
            "insertion": [],
            "deletion": [],
            "substitution": []
        }
            correct_flag = True
            for index, label in enumerate(label_sequence):
                if (label == 'S'):
                    correct_flag = False
                    total_sub += 1
                    temp_dict['substitution'].append(index)
                elif (label == 'I'):
                    correct_flag = False
                    total_ins += 1
                    temp_dict['insertion'].append(index)
                elif (label == 'D'):
                    correct_flag = False
                    total_del += 1
                    temp_dict['deletion'].append([index, index + 1])

            temp_dict['correct'] = correct_flag
            label_dict.append(temp_dict)
    
    bert_tokens = tokenizer.batch_encode_plus(data['hyps'], return_tensors='pt', padding = True).to(device)
    output = PBert_model.bert(
        input_ids = bert_tokens['input_ids'],
        attention_mask = bert_tokens['attention_mask'],
        output_attentions = True
    )
    
    last_attention = output.attentions[-1]

    for attention_mask, attention_map, align_label in zip(bert_tokens['attention_mask'], last_attention.sum(dim = 1), label_dict):
        token_index = attention_map[0][attention_mask.bool()][1:] / 12
        attend_index = torch.argmax(token_index)
        # attend_index = (token_index > 0.15).nonzero()
        # print(f'token_index:{token_index.shape}')
        # print(f'attend_index:{attend_index.shape}')
        total_character_count += token_index.shape[0]
        for target in attend_index:
            attention += 1
            if (target in align_label['substitution']):
                attend_sub += 1
            elif (target in align_label['insertion']):
                attend_del += 1
            elif ([target - 1, target] in align_label['deletion'] or [target, target + 1] in align_label['deletion']):
                attend_ins += 1
            elif (target == token_index.shape[0] - 1):
                

print(f"total_character_count:{total_character_count}")
print(attention)
print(attend_sub)
print(attend_del)
print(attend_ins)

In [None]:

print(f"total_character_count:{total_character_count}")
print(f"attention:{attention}")
print(f"total_sub:{total_sub}")
print(f"attend_sub:{attend_sub}\n")
print(f"total_del:{total_del}")
print(f"attend_del:{attend_del}\n")
print(f"total_ins:{total_ins}")
print(f"attend_ins:{attend_ins}\n")

print(f'ratio:{(attend_ins + attend_del + attend_sub) / attention}')

In [None]:
attention = 0
attend_del = 0
attend_ins = 0
attend_sub = 0
total_character_count = 0

total_sub = 0
total_ins = 0
total_del = 0

attention_maps = list()
with torch.no_grad():
    for data in tqdm(data_json):
        label_dict = []
        process_ref = "".join(data['ref'].strip().split())
        hyps = ["".join(hyp.strip().split()) for hyp in data["hyps"]]
        refs = ["".join(data['ref'].strip().split()) for _ in range(len(hyps))]
        out = process_characters(
                refs,
                hyps
            )
        result = visualize_alignment(out, show_measures=False, skip_correct=False).split('\n')        
        for i, r in enumerate(result):
            if (i % 5 == 3):
                # print(r)
                label_sequence = r[5:]
                temp_dict = {
                "insertion": [],
                "deletion": [],
                "substitution": []
            }
                for index, label in enumerate(label_sequence):
                    if (label == 'S'):
                        temp_dict['substitution'].append(index)
                    elif (label == 'I'):
                        temp_dict['insertion'].append(index)
                    elif (label == 'D'):
                        temp_dict['deletion'].append([index, index + 1])
        
                label_dict.append(temp_dict)

        bert_tokens = tokenizer.batch_encode_plus(data['hyps'], return_tensors='pt', padding = True).to(device)
        output = PBert_model.bert(
            input_ids = bert_tokens['input_ids'],
            attention_mask = bert_tokens['attention_mask'],
            output_attentions = True
        )
        
        last_attention = output.attentions[-1]

        for attention_mask, attention_map, align_label in zip(bert_tokens['attention_mask'], last_attention.sum(dim = 1), label_dict):
            token_index = attention_map[0][attention_mask.bool()] / 12
            attend_index = (token_index > 0.15).nonzero()
            # print(f'token_index:{token_index.shape}')
            # print(f'attend_index:{attend_index.shape}')
            total_character_count += token_index.shape[0]
            for target in attend_index:
                attention += 1
                if (target in align_label['substitution']):
                    attend_sub += 1
                elif (target in align_label['insertion']):
                    attend_del += 1
                elif ([target - 1, target] in align_label['deletion'] or [target, target + 1] in align_label['deletion']):
                    attend_ins += 1
        
        attention_maps.append(
            {
                'name': data['name'],
                'hyps': data['hyps'],
                'attention_map': token_index.clone().detach().cpu
            }
        )

print(f"total_character_count:{total_character_count}")
print(f"attention:{attention}")
print(f"total_sub:{total_sub}")
print(f"attend_sub:{attend_sub}")
print(f"attend_sub:{attend_sub / total_sub}\n")
print(f"total_del:{total_del}")
print(f"attend_del:{attend_del}")
print(f"attend_del ratio:{attend_del / total_del}\n")
print(f"total_ins:{total_ins}")
print(f"attend_ins:{attend_ins}")
print(f"attend_del ratio:{attend_ins / total_ins}\n")

print(f'ratio:{(attend_ins + attend_del + attend_sub) / attention}')

print(f'len of analysis:{len(attention_maps)}')
with open("./data/aishell/noLM/dev/analysis.pt", 'w') as write_file:
    torch.save(attention_maps, write_file)