# Data preparation

Import libraries:

In [None]:
from norefer import *
from utils import *
import pandas as pd
from scipy import stats
import numpy as np
import os

Load the dataset:

In [None]:
# List of filenames
filenames = ['en-libre.csv', 'en-common.csv', 'es-common.csv', 'fr-common.csv']
folder_path = '../dataset/'

# List to store data from each file
all_data = []

# Loop through each file and load data
for filename in filenames:
    print(f'Start processing file: {filename}')
    file_path = os.path.join(folder_path, filename)

    try:
        data = pd.read_csv(file_path)
        print(f"Data loaded successfully for {filename}.")
        all_data.append(data)
    except Exception as e:
        print(f"An error occurred while loading {filename}:", e)
        continue

# Concatenate all data into a single DataFrame
combined_data = pd.concat(all_data, ignore_index=True)

# Process the combined data
transcription = combined_data['outputText'].astype(str).to_list()

Process the data to extract the attentions and error labels on word level.

In [None]:
data_attention = process_transcription_attention(transcription, need_split=False) 
data_attention = calculate_word_scores_with_tokens(data_attention, 'adjusted_attentions', aggregation_method='max')  # aggregation_method=['average', 'max', 'q3']

data_attention ['inputPath'] = combined_data['inputText']
data_attention ['referenceText'] = combined_data['referenceText']

b_score_word = get_word_fault_scores_jiwer(list(data_attention['referenceText']), list(data_attention['outputText']))
data_attention['jiwer_scores'] = b_score_word

data_attention['actualwords'] = data_attention['jiwer_scores'].apply(lambda x: [item[0] for item in x])
data_attention['word_jiwer_score'] = data_attention['jiwer_scores'].apply(lambda x: [item[1] for item in x])
data_attention['word_attentions_aligned'] = align_attention_with_jiwer(data_attention['word_jiwer_score'], data_attention['word_attentions'])

Extract the gradients on token level.

In [None]:
data_attention['TokenGradiants'] = process_transcription_gradient(transcription)
data_attention.to_csv('../dataset/alldata_attention_gradient_withIndex.csv', index=True)

Calculate the word level gradient values.

In [None]:
data_attention = calculate_word_scores_with_tokens_grad(data_attention, 'TokenGradiants', aggregation_method='min') 
data_attention['word_grad_aligned'] = align_attention_with_jiwer(data_attention['word_jiwer_score'], data_attention['word_grad'])

Given scaled attention values and the gradient for each word, we calculate the Attentin X Gardient to compare againt the proposed method.

In [None]:
def multiply_vectors(row):
    grad = row['word_grad_aligned']
    atten = row['word_attentions_aligned']

    # Ensure grad and atten are lists of numbers
    if isinstance(grad, str):
        grad = ast.literal_eval(grad)
    if isinstance(atten, str):
        atten = ast.literal_eval(atten)

    # Convert elements to floats, treating 'None' as 0.0
    grad = [float(g) if isinstance(g, str) and g != 'None' else 0.0 for g in grad]
    atten = [float(a) if isinstance(a, str) and a != 'None' else 0.0 for a in atten]

    return [-1 * g * a for g, a in zip(grad, atten)]

data_attention['gradXatten'] = data_attention.apply(multiply_vectors, axis=1)

# Sentence level analysis
Assessing the effectiveness of attention values by comparing them with actual discrepancies identified through comparison with reference sentences. 

Import libraries:

In [None]:
from sklearn.metrics import roc_auc_score, classification_report, balanced_accuracy_score, average_precision_score
import numpy as np

Calculate AUC

In [None]:
def get_valid_scores_and_attentions(word_jiwer_scores, word_attentions):
    valid_scores = []
    valid_attentions = []
    for jiwer_score, attention in zip(word_jiwer_scores, word_attentions):
        if jiwer_score not in [2]:  # Excluding deletion (2)
            if attention is not None:
                valid_scores.append(1 if jiwer_score != 0 else 0)  # Convert to binary label
                valid_attentions.append(attention)
    return valid_scores, valid_attentions

auc_scores = []

for index, row in data_attention.iterrows():
    valid_scores, valid_attentions = get_valid_scores_and_attentions(row['word_jiwer_score'], row['gradXatten'])

    if len(valid_scores) > 1 and len(valid_attentions) > 1:
        try:
            auc_score = roc_auc_score(valid_scores, valid_attentions)
            auc_scores.append(auc_score)
        except ValueError as e:
            # Handle case where only one class is present in y_true
            # print(f"Row {index} skipped: {e}")
            pass

average_auc_score = np.nanmean(auc_scores)
print("Average AUC Score: ", average_auc_score)


Calculate average precision

In [None]:
def get_valid_scores_and_attentions(word_jiwer_scores, word_attentions):
    valid_scores = []
    valid_attentions = []
    for jiwer_score, attention in zip(word_jiwer_scores, word_attentions):
        if jiwer_score not in [2]:  # Excluding deletion (2) 
            if attention is not None:
                # Ensure that jiwer_score is a valid integer and attention is a valid float
                try:
                    valid_scores.append(1 if int(jiwer_score) != 0 else 0)  # Convert to binary label
                    valid_attentions.append(float(attention))
                except ValueError:
                    continue
    return valid_scores, valid_attentions

average_precision_scores = []

for index, row in data_attention.iterrows():
    valid_scores, valid_attentions = get_valid_scores_and_attentions(row['word_jiwer_score'], row['gradXatten'])

    if len(valid_scores) > 1 and len(valid_attentions) > 1:
        try:
            ap_score = average_precision_score(valid_scores, valid_attentions, average='weighted')
            average_precision_scores.append(ap_score)
        except ValueError as e:
            # Handle cases where only one class is present or other issues
            # print(f"Row {index} skipped: {e}")
            pass

average_ap_score = np.nanmean(average_precision_scores)
print("Average AP Score: ", average_ap_score)


Calculate top k classification metrics - dynamic k

In [None]:
def classify_top_k_attention_words(word_jiwer_scores, word_attentions, sentence_length):
    # Dynamic k based on 10% of sentence length
    k = max(1, int(np.ceil(0.10 * sentence_length)))  # Ensure at least 1
    numeric_attentions = [float(att) if att not in [None, 'None'] and isinstance(att, (float, str, int)) else 0 for att in word_attentions]
    paired_scores = sorted(zip(word_jiwer_scores, numeric_attentions), key=lambda x: x[1])
    
    binary_labels = [1 if score[0] != 0 else 0 for score in paired_scores]  # Convert word_jiwer_scores to binary
    binary_predictions = [1 if i < k else 0 for i in range(len(paired_scores))]  # Top k words are faulty

    return binary_labels, binary_predictions

# Initialize lists to store metrics
precision_scores = []
recall_scores = []
f1_scores = []
accuracy_scores = []
baccuracy_scores = []

# Process each row
for index, row in data_attention.iterrows():
    word_jiwer_scores = row['word_jiwer_score']
    word_attentions = row['gradXatten']
    sentence_length = len(word_attentions)  

    binary_labels, binary_predictions = classify_top_k_attention_words(word_jiwer_scores, word_attentions, sentence_length)

    report = classification_report(binary_labels, binary_predictions, output_dict=True, zero_division=0)
    acc = balanced_accuracy_score(binary_labels, binary_predictions)

    precision_scores.append(report['weighted avg']['precision'])
    recall_scores.append(report['weighted avg']['recall'])
    f1_scores.append(report['weighted avg']['f1-score'])
    accuracy_scores.append(report['accuracy'])
    baccuracy_scores.append(acc)

# Calculate mean of metrics
mean_precision = np.mean(precision_scores)
mean_recall = np.mean(recall_scores)
mean_f1 = np.mean(f1_scores)
mean_accuracy = np.mean(accuracy_scores)
mean_baccuracy = np.mean(baccuracy_scores)

print(f"Mean Precision: {mean_precision}")
print(f"Mean Recall: {mean_recall}")
print(f"Mean F1 Score: {mean_f1}")
print(f"Mean Accuracy: {mean_accuracy}")
print(f"Mean Balanced Accuracy: {mean_baccuracy}")