# Evaluation

In [1]:
import os
import warnings
from typing import List, Tuple, Union

import torch
import altair as alt
import numpy as np
import pandas as pd

from config import get_config, get_weights_file_path, reset_log, set_tf_input
from train import get_ds, get_model, greedy_decode


warnings.filterwarnings("ignore")  # Ignore warnings


In [2]:
# Define the device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cuda


In [3]:
def load_next_batch() -> Tuple[dict, list, list]:
    """
    Load a sample batch from the validation set and return relevant tensors and tokens.
    
    :return: A dictionary containing encoder and decoder input tensors; List of tokens for the encoder input; 
    List of tokens for the decoder input.
    """
    # Load a sample batch from the validation set
    batch = next(iter(val_dataloader))
    encoder_input = batch["encoder_input"].to(device)
    encoder_mask = batch["encoder_mask"].to(device)
    decoder_input = batch["decoder_input"].to(device)
    decoder_mask = batch["decoder_mask"].to(device)

    encoder_input_tokens = [vocab_src.id_to_token(idx) for idx in encoder_input[0].cpu().numpy()]
    decoder_input_tokens = [vocab_tgt.id_to_token(idx) for idx in decoder_input[0].cpu().numpy()]

    # check that the batch size is 1
    assert encoder_input.size(0) == 1, "Batch size must be 1 for validation"

    model_out, _ = greedy_decode(model, encoder_input, encoder_mask, vocab_src, vocab_tgt, config['seq_len'], device)
    
    return batch, encoder_input_tokens, decoder_input_tokens


In [4]:
def mtx2df(m: np.ndarray, max_row: int, max_col: int, row_tokens: list, col_tokens: list) -> pd.DataFrame:
    """
    Convert a matrix to a pandas DataFrame.
    
    :param m: Input matrix.
    :param max_row: Maximum number of rows to include in the DataFrame.
    :param max_col: Maximum number of columns to include in the DataFrame.
    :param row_tokens: List of row tokens.
    :param col_tokens: List of column tokens.
    :return: DataFrame representation of the input matrix.
    """
    return pd.DataFrame(
        [
            (
                r,
                c,
                float(m[r, c]),
                "%.3d %s" % (r, row_tokens[r] if len(row_tokens) > r else "<blank>"),
                "%.3d %s" % (c, col_tokens[c] if len(col_tokens) > c else "<blank>"),
            )
            for r in range(m.shape[0])
            for c in range(m.shape[1])
            if r < max_row and c < max_col
        ],
        columns=["row", "column", "value", "row_token", "col_token"],
    )


def get_attn_map(attn_type: str, layer: int, head: int) -> np.ndarray:
    """
    Get attention maps based on attention type, layer, and head.
    
    :param attn_type: Type of attention ('encoder', 'decoder', or 'encoder-decoder')
    :param layer: Layer index.
    :param head: Head index.
    :return: Attention map tensor.
    """
    if attn_type == "encoder":
        attn = model.encoder.layers[layer].self_attention_block.attention_scores
    elif attn_type == "decoder":
        attn = model.decoder.layers[layer].self_attention_block.attention_scores
    elif attn_type == "encoder-decoder":
        attn = model.decoder.layers[layer].cross_attention_block.attention_scores
    return attn[0, head].data


def attn_map(attn_type: str, layer: int, head: int, row_tokens: list, col_tokens: list, max_sentence_len: int) -> alt.Chart:
    """
    Create an attention map for a specific type, layer, and head.
    
    :param attn_type: Type of attention ('encoder', 'decoder', or 'encoder-decoder').
    :param layer: Layer index.
    :param head: Head index.
    :param row_tokens: List of row tokens.
    :param col_tokens: List of column tokens.
    :param max_sentence_len: Maximum length of the sentence.
    :return: Altair chart object representing the attention map.
    """
    df = mtx2df(
        get_attn_map(attn_type, layer, head),
        max_sentence_len,
        max_sentence_len,
        row_tokens,
        col_tokens,
    )
    return (
        alt.Chart(data=df)
        .mark_rect()
        .encode(
            x=alt.X("col_token", axis=alt.Axis(title="")),
            y=alt.Y("row_token", axis=alt.Axis(title="")),
            color="value",
            tooltip=["row", "column", "value", "row_token", "col_token"],
        )
        #.title(f"Layer {layer} Head {head}")
        .properties(height=400, width=400, title=f"Layer {layer} Head {head}")
        .interactive()
    )


def get_all_attention_maps(attn_type: str, layers: list, heads: list, row_tokens: list, col_tokens: list, 
                           max_sentence_len: int) -> alt.ConcatChart:
    """
    Get all attention maps for specified types, layers, and heads.
    
    :param attn_type: Type of attention ('encoder', 'decoder', or 'encoder-decoder').
    :param layers: List of layer indices.
    :param heads: List of head indices.
    :param row_tokens: List of row tokens.
    :param col_tokens: List of column tokens.
    :param max_sentence_len: Maximum length of the sentence.
    :return: Concatenated Altair chart object representing all attention maps
    """
    charts = []
    for layer in layers:
        rowCharts = []
        for head in heads:
            rowCharts.append(attn_map(attn_type, layer, head, row_tokens, col_tokens, max_sentence_len))
        charts.append(alt.hconcat(*rowCharts))
    return alt.vconcat(*charts)


In [5]:
def matching_proportion(df: pd.DataFrame) -> float:
    """
    Calculate and print the proportion of matching case IDs.

    :param df: DataFrame containing case IDs.
    :return: Proportion of matching case IDs.
    """
    proportion_matching = 0
    
    # Check if the DataFrame is not empty before calculating the proportion
    if not df.empty:
        # Count the number of rows where 'Determined Case ID' and 'Actual Case ID' match
        matching_rows = df[df['Determined Case ID'] == df['Actual Case ID']]

        # Calculate the proportion of matching rows
        proportion_matching = len(matching_rows) / len(df)

        print(f"In the repaired log, {len(matching_rows)} of {len(df)} case ID{'s' if len(df) != 1 else ''} "
              f"{'was' if len(matching_rows) == 1 else 'were'} determined correctly. "
              f"This corresponds to an accuracy of {proportion_matching:.2%}.")
    else:
        print("DataFrame is empty. Cannot calculate accuracy.")
    
    return proportion_matching


def factual_matching_proportion(df: pd.DataFrame) -> float:
    """
    Calculate and print the proportion of factually matching case IDs.
    
    :param df: DataFrame containing case IDs.
    :return: Proportion of factually matching case IDs.
    """
    proportion_factual_matching = 0
    
    # Check if DataFrame is not empty
    if not df.empty:
        # Obtain list of factually correct case IDs
        correct = correct_case_different_naming(df, False)
        
        # Count the number of rows where the "Actual Case ID" matches one of the correct IDs
        matching_rows = df[df['Actual Case ID'].isin(correct)]
        
        # Calculate the proportion of factually matching case IDs
        proportion_factual_matching = len(matching_rows) / len(df)
        
        # Print the result
        print(f"Furthermore, {len(matching_rows)} of {len(df)} case ID{'s' if len(df) != 1 else ''} "
              f"{'was' if len(matching_rows) == 1 else 'were'} determined factually correctly. "
              f"This corresponds to a factual accuracy of {proportion_factual_matching:.2%}.")
    else:
        # Print message if DataFrame is empty
        print("DataFrame is empty. Cannot calculate factual accuracy.")
        
    return proportion_factual_matching


def completely_correct_cases(df: pd.DataFrame) -> float:
    """
    Calculate and print the proportion of completely correct case IDs.

    A completely correct case ID is one where the 'Actual Case ID' matches the 'Determined Case ID'
    for all rows where it appears, and vice versa.

    :param df: DataFrame with 'Actual Case ID' and 'Determined Case ID' columns, representing actual and 
    determined case IDs.
    :return: Proportion of completely correct case IDs.
    """
    proportion_completely_correct = 0

    # Check if DataFrame is not empty
    if not df.empty:
        # Get unique case IDs
        cases = df['Actual Case ID'].unique()

        matching_values = []

        # Iterate through unique case IDs
        for case_id in cases:
            # Filter rows for the current 'Actual Case ID'
            subset_actual = df[df['Actual Case ID'] == case_id]

            # Check if 'Actual Case ID' equals 'Determined Case ID' for all rows
            condition_1 = all(subset_actual['Actual Case ID'] == subset_actual['Determined Case ID'])

            # If condition 1 is not met, move to the next case
            if not condition_1:
                continue

            # Filter rows for the current 'Determined Case ID'
            subset_determined = df[df['Determined Case ID'] == case_id]

            # Check if 'Determined Case ID' equals 'Actual Case ID' for all rows
            condition_2 = all(subset_determined['Determined Case ID'] == subset_determined['Actual Case ID'])

            # If both conditions are true, add the value to the array
            if condition_2:
                matching_values.append(case_id)

        num_matching = len(matching_values)
        num_cases = len(cases)

        # Print the number of completely correct cases
        print(f"Thereby, {num_matching} of {num_cases} case{'s' if num_cases != 1 else ''} "
              f"{'were' if num_matching != 1 else 'was'} determined completely correctly.", end=' ')

        # Print the list of completely correct cases (using "and" before the last value)
        if matching_values:
            if num_matching == 1:
                print(f"This is the case with ID {matching_values[0]}.", end=' ')
            else:
                print(
                    f"These are the cases with IDs {', '.join(map(str, matching_values[:-1]))} and "
                    f"{matching_values[-1]}.", end=' ')

        # Calculate the proportion of completely correct cases
        proportion_completely_correct = num_matching / num_cases

        # Print the proportion of completely correct cases
        print(f"This corresponds to a real case accuracy of {proportion_completely_correct:.2%}.")
    else:
        print("DataFrame is empty. Cannot calculate real case accuracy.")

    return proportion_completely_correct


def correct_case_different_naming(df: pd.DataFrame, calculation: bool = True) -> Union[List[str], float]:
    """
    Calculate and print the proportion of factually completely correct case IDs.

    A factually completely correct case ID is one where the determined case ID is different from the actual case ID, yet
    uniquely maps back to the same actual case ID.

    :param df: DataFrame with 'Actual Case ID' and 'Determined Case ID' columns, representing actual and determined 
    case IDs.
    :param calculation: If True, calculate the proportion of factually completely correct case IDs. If False, return 
    the list of factually completely correct case IDs.
    :return: Proportion of factually completely correct case IDs if calculation is True, otherwise return the list of 
    factually completely correct case IDs.
    """
    proportion_correct_different_naming = 0

    # Check if DataFrame is not empty
    if not df.empty:
        # Extract unique actual case IDs
        cases = df['Actual Case ID'].unique()

        # Lists to store correct cases and their corresponding values
        correct_cases = []
        corresponding_values = []

        # Iterate over each unique actual case ID
        for case_id in cases:
            # Filter DataFrame to include only rows with current actual case ID
            subset_actual = df[df['Actual Case ID'] == case_id]

            # Check if there's only one unique determined case ID for the current actual case ID
            condition_1 = len(subset_actual['Determined Case ID'].unique()) == 1

            # If condition_1 is not met, skip to the next case
            if not condition_1:
                continue

            # Get the unique determined case ID for the current actual case ID
            unique_value = subset_actual['Determined Case ID'].iloc[0]

            # Check if the determined case ID is different from the actual case ID
            condition_2 = unique_value != case_id

            # If the determined case ID is the same as the actual case ID, skip to the next case
            if not condition_2:
                continue

            # Check if the determined case ID is uniquely associated with the current actual case ID
            condition_3 = not df[
                (df['Determined Case ID'] == unique_value) & (df['Actual Case ID'] != case_id)].any().any()

            # If condition_3 is met, add the current case to correct_cases list
            if condition_3:
                correct_cases.append(case_id)
                corresponding_values.append(unique_value)
        
        # If calculation is False, return the list of factually completely correct case IDs
        if not calculation:
            return correct_cases

        # Print the number of cases correctly determined factually with different naming
        print(f"Moreover, {len(correct_cases)} of {len(cases)} case{'s' if len(cases) != 1 else ''} "
              f"{'were' if len(correct_cases) != 1 else 'was'} determined factually completely correctly "
              f"(with different naming).", end=' ')

        # If there are corresponding values, print them
        if corresponding_values:
            # If there's only one pair, print it without "and"
            if len(corresponding_values) == 1:
                print(
                    f"This is the case with ID {correct_cases[0]} and corresponding, determined ID "
                    f"{corresponding_values[0]}.", end=' ')
            else:
                # Print pairs of actual case ID and determined case ID
                pairs = [f"{actual_case_id} → {determined_case_id}" for actual_case_id, determined_case_id
                         in zip(correct_cases, corresponding_values)]
                print(f"These are the cases (actual case ID → determined case ID): {', '.join(pairs[:-1])} and "
                      f"{pairs[-1]}.", end=' ')

        # Calculate the proportion of correct cases with different naming
        proportion_correct_different_naming = len(correct_cases) / len(cases)

        # Print the factual case accuracy
        print(f"This corresponds to a factual case accuracy of {proportion_correct_different_naming:.2%}.")
    else:
        # If DataFrame is empty, print a message
        print("DataFrame is empty. Cannot calculate factual case accuracy.")

    return proportion_correct_different_naming


def evaluate_model_metrics(df: pd.DataFrame) -> None:
    """
    Evaluate model metrics based on the determined log.

    :param df: DataFrame containing determined log.
    """
    # Calculate and print the proportion of matching cases
    matching = matching_proportion(df)
    
    # Calculate and print the proportion of factually matching cases
    factual_matching = factual_matching_proportion(df)
    
    # Print the overall accuracy
    print(f"Thus, the overall accuracy is {(matching + factual_matching):.2%}.")

    # Calculate the proportion of completely correct cases
    correct_proportion = completely_correct_cases(df)

    # Calculate the proportion of factually completely correct cases with different naming
    factual_correct_proportion = correct_case_different_naming(df)

    # Print the overall case accuracy
    print(f"Thus, the overall case accuracy is {(correct_proportion + factual_correct_proportion):.2%}.")


## Baseline Model 
**Input = 'Activity'**

### Running Example

In [6]:
reset_log()  # Reset log file
set_tf_input('Activity')
config = get_config()  # Get configuration parameters
train_dataloader, val_dataloader, vocab_src, vocab_tgt = get_ds(config)  # Get training and validation data
model = get_model(config, vocab_src.get_vocab_size(), vocab_tgt.get_vocab_size()).to(device)  # Get the model

# Load pretrained weights for the model
model_filename = get_weights_file_path(config, f"19")
state = torch.load(model_filename, map_location=device)
model.load_state_dict(state['model_state_dict'])


Enter the path to the file that contains the event log:  logs/running-example.xes


parsing log, completed traces ::   0%|          | 0/6 [00:00<?, ?it/s]

Following columns were automatically matched:
'case:concept:name' for 'Case ID';
'concept:name' for 'Activity';
'time:timestamp' for 'Timestamp'.


Is this completely correct? (yes/no):  yes


Max length of source sentence: 10
Max length of target sentence: 10


<All keys matched successfully>

In [7]:
# Load next batch from validation set and print source and target texts
batch, encoder_input_tokens, decoder_input_tokens = load_next_batch()
print(f'Source: {batch["src_text"][0]}')
print(f'Target: {batch["tgt_text"][0]}')

sentence_len = 10  # Length of the sentence
layers = [0, 1, 2]  # Layers for attention maps
heads = [0, 1, 2, 3, 4, 5, 6, 7]  # Heads for attention maps


Source: decide check_ticket check_ticket pay_compensation examine_thoroughly decide decide decide reinitiate_request reject_request
Target: 6 5 3 2 4 3 4 5 5 4


In [8]:
# Display attention maps for encoder self-attention
get_all_attention_maps("encoder", layers, heads, encoder_input_tokens, encoder_input_tokens, min(20, sentence_len))


In [9]:
# Display attention maps for decoder self-attention
get_all_attention_maps("decoder", layers, heads, decoder_input_tokens, decoder_input_tokens, min(20, sentence_len))


In [10]:
# Display attention maps for cross-attention
get_all_attention_maps("encoder-decoder", layers, heads, encoder_input_tokens, decoder_input_tokens, 
                       min(20, sentence_len))


In [11]:
# Read the CSV file into a DataFrame
log = pd.read_csv(os.path.join(config["result_folder"], config["result_csv_file"]))

# Display the first 20 rows of the log
print(log.head(20))

# Check if there are more than 20 rows in the log and print if so
remaining_rows = len(log) - 20
if remaining_rows > 0:
    print(f"\n... (+ {remaining_rows} more rows)")

print('-' * 80)

# Evaluate model metrics based on the determined log
evaluate_model_metrics(log)


    Determined Case ID  Actual Case ID            Activity
0                    1               1    register_request
1                    6               2    register_request
2                    6               2        check_ticket
3                    6               2    examine_casually
4                    6               3    register_request
5                    6               3    examine_casually
6                    6               3        check_ticket
7                    6               1  examine_thoroughly
8                    6               2              decide
9                    6               1        check_ticket
10                   3               5    register_request
11                   1               3              decide
12                   3               1              decide
13                   3               3  reinitiate_request
14                   3               3  examine_thoroughly
15                   3               6    register_reque

### Review Example Large

In [12]:
reset_log()  # Reset log file
set_tf_input('Activity')
config = get_config()  # Get configuration parameters
train_dataloader, val_dataloader, vocab_src, vocab_tgt = get_ds(config)  # Get training and validation data
model = get_model(config, vocab_src.get_vocab_size(), vocab_tgt.get_vocab_size()).to(device)  # Get the model

# Load pretrained weights for the model
model_filename = get_weights_file_path(config, f"19")
state = torch.load(model_filename, map_location=device)
model.load_state_dict(state['model_state_dict'])


Enter the path to the file that contains the event log:  logs/review_example_large.xes


parsing log, completed traces ::   0%|          | 0/10000 [00:00<?, ?it/s]

Following columns were automatically matched:
'case:concept:name' for 'Case ID';
'concept:name' for 'Activity';
'time:timestamp' for 'Timestamp'.


Is this completely correct? (yes/no):  yes


Max length of source sentence: 10
Max length of target sentence: 10


<All keys matched successfully>

In [13]:
# Load next batch from validation set and print source and target texts
batch, encoder_input_tokens, decoder_input_tokens = load_next_batch()
print(f'Source: {batch["src_text"][0]}')
print(f'Target: {batch["tgt_text"][0]}')

sentence_len = 10  # Length of the sentence
layers = [0, 1, 2]  # Layers for attention maps
heads = [0, 1, 2, 3, 4, 5, 6, 7]  # Heads for attention maps


Source: invite_additional_reviewer invite_additional_reviewer accept reject invite_additional_reviewer invite_additional_reviewer time_out_X get_review_X invite_additional_reviewer invite_additional_reviewer
Target: 8748 7588 934 7146 7750 7588 8748 7588 7750 3951


In [14]:
# Display attention maps for encoder self-attention
get_all_attention_maps("encoder", layers, heads, encoder_input_tokens, encoder_input_tokens, min(20, sentence_len))


In [15]:
# Display attention maps for decoder self-attention
get_all_attention_maps("decoder", layers, heads, decoder_input_tokens, decoder_input_tokens, min(20, sentence_len))


In [16]:
# Display attention maps for cross-attention
get_all_attention_maps("encoder-decoder", layers, heads, encoder_input_tokens, decoder_input_tokens, 
                       min(20, sentence_len))


In [17]:
# Read the CSV file into a DataFrame
log = pd.read_csv(os.path.join(config["result_folder"], config["result_csv_file"]))

# Display the first 20 rows of the log
print(log.head(20))

# Check if there are more than 20 rows in the log and print if so
remaining_rows = len(log) - 20
if remaining_rows > 0:
    print(f"\n... (+ {remaining_rows} more rows)")

print('-' * 80)

# Evaluate model metrics based on the determined log
evaluate_model_metrics(log)


    Determined Case ID  Actual Case ID                    Activity
0                 3411               1            invite_reviewers
1                 3407               1            invite_reviewers
2                 3408               1                get_review_3
3                 3403               1                get_review_1
4                 3412               2            invite_reviewers
5                 3409               3            invite_reviewers
6                 3412               2            invite_reviewers
7                 3412               4            invite_reviewers
8                 3409               2                  time_out_3
9                 3409               1                  time_out_2
10                4886               4            invite_reviewers
11                4887               3                  time_out_3
12                4886               1                      decide
13                4883               1             collect_rev

## First Extension
**Input = Multiple Discrete Attributes (Activity and Resource)**

### Running Example

In [30]:
reset_log()  # Reset log file
set_tf_input('Activity', 'Resource')
config = get_config()  # Get configuration parameters
train_dataloader, val_dataloader, vocab_src, vocab_tgt = get_ds(config)  # Get training and validation data
model = get_model(config, vocab_src.get_vocab_size(), vocab_tgt.get_vocab_size()).to(device)  # Get the model

# Load pretrained weights for the model
model_filename = get_weights_file_path(config, f"19")
state = torch.load(model_filename, map_location=device)
model.load_state_dict(state['model_state_dict'])


Enter the path to the file that contains the event log:  logs/running-example.xes


parsing log, completed traces ::   0%|          | 0/6 [00:00<?, ?it/s]

Following columns were automatically matched:
'case:concept:name' for 'Case ID';
'concept:name' for 'Activity';
'time:timestamp' for 'Timestamp';
'org:resource' for 'Resource'.


Is this completely correct? (yes/no):  yes


Max length of source sentence: 20
Max length of target sentence: 10


<All keys matched successfully>

In [31]:
# Load next batch from validation set and print source and target texts
batch, encoder_input_tokens, decoder_input_tokens = load_next_batch()
print(f'Source: {batch["src_text"][0]}')
print(f'Target: {batch["tgt_text"][0]}')

sentence_len = 10  # Length of the sentence
layers = [0, 1, 2]  # Layers for attention maps
heads = [0, 1, 2, 3, 4, 5, 6, 7]  # Heads for attention maps


Source: examine_casually examine_casually check_ticket reject_request check_ticket decide check_ticket check_ticket pay_compensation examine_thoroughly Ellen Mike Mike Pete Mike Sara Pete Pete Ellen Sean
Target: 6 5 4 1 6 6 5 3 2 4


In [32]:
# Display attention maps for encoder self-attention
get_all_attention_maps("encoder", layers, heads, encoder_input_tokens, encoder_input_tokens, min(20, sentence_len))


In [33]:
# Display attention maps for decoder self-attention
get_all_attention_maps("decoder", layers, heads, decoder_input_tokens, decoder_input_tokens, min(20, sentence_len))


In [34]:
# Display attention maps for cross-attention
get_all_attention_maps("encoder-decoder", layers, heads, encoder_input_tokens, decoder_input_tokens, 
                       min(20, sentence_len))


In [35]:
# Read the CSV file into a DataFrame
log = pd.read_csv(os.path.join(config["result_folder"], config["result_csv_file"]))

# Display the first 20 rows of the log
print(log.head(20))

# Check if there are more than 20 rows in the log and print if so
remaining_rows = len(log) - 20
if remaining_rows > 0:
    print(f"\n... (+ {remaining_rows} more rows)")

print('-' * 80)

# Evaluate model metrics based on the determined log
evaluate_model_metrics(log)


    Determined Case ID  Actual Case ID            Activity Resource
0                    2               1    register_request     Pete
1                    2               2    register_request     Mike
2                    2               2        check_ticket     Mike
3                    2               2    examine_casually     Sean
4                    3               3    register_request     Pete
5                    3               3    examine_casually     Mike
6                    1               3        check_ticket    Ellen
7                    2               1  examine_thoroughly      Sue
8                    1               2              decide     Sara
9                    5               1        check_ticket     Mike
10                   1               5    register_request    Ellen
11                   2               3              decide     Sara
12                   1               1              decide     Sara
13                   3               3  reinitia

### Review Example Large

In [24]:
reset_log()  # Reset log file
set_tf_input('Activity', 'Resource')
config = get_config()  # Get configuration parameters
train_dataloader, val_dataloader, vocab_src, vocab_tgt = get_ds(config)  # Get training and validation data
model = get_model(config, vocab_src.get_vocab_size(), vocab_tgt.get_vocab_size()).to(device)  # Get the model

# Load pretrained weights for the model
model_filename = get_weights_file_path(config, f"19")
state = torch.load(model_filename, map_location=device)
model.load_state_dict(state['model_state_dict'])


Enter the path to the file that contains the event log:  logs/review_example_large.xes


parsing log, completed traces ::   0%|          | 0/10000 [00:00<?, ?it/s]

Following columns were automatically matched:
'case:concept:name' for 'Case ID';
'concept:name' for 'Activity';
'time:timestamp' for 'Timestamp';
'org:resource' for 'Resource'.


Is this completely correct? (yes/no):  yes


Max length of source sentence: 20
Max length of target sentence: 10


<All keys matched successfully>

In [25]:
# Load next batch from validation set and print source and target texts
batch, encoder_input_tokens, decoder_input_tokens = load_next_batch()
print(f'Source: {batch["src_text"][0]}')
print(f'Target: {batch["tgt_text"][0]}')

sentence_len = 10  # Length of the sentence
layers = [0, 1, 2]  # Layers for attention maps
heads = [0, 1, 2, 3, 4, 5, 6, 7]  # Heads for attention maps


Source: reject accept accept get_review_X invite_additional_reviewer invite_additional_reviewer accept invite_additional_reviewer invite_additional_reviewer accept Anne Anne Anne John Mike Mike Anne Mike Mike Anne
Target: 7340 5494 9181 9181 8550 8276 5494 8550 5181 9181


In [26]:
# Display attention maps for encoder self-attention
get_all_attention_maps("encoder", layers, heads, encoder_input_tokens, encoder_input_tokens, min(20, sentence_len))


In [27]:
# Display attention maps for decoder self-attention
get_all_attention_maps("decoder", layers, heads, decoder_input_tokens, decoder_input_tokens, min(20, sentence_len))


In [28]:
# Display attention maps for cross-attention
get_all_attention_maps("encoder-decoder", layers, heads, encoder_input_tokens, decoder_input_tokens, 
                       min(20, sentence_len))


In [29]:
# Read the CSV file into a DataFrame
log = pd.read_csv(os.path.join(config["result_folder"], config["result_csv_file"]))

# Display the first 20 rows of the log
print(log.head(20))

# Check if there are more than 20 rows in the log and print if so
remaining_rows = len(log) - 20
if remaining_rows > 0:
    print(f"\n... (+ {remaining_rows} more rows)")

print('-' * 80)

# Evaluate model metrics based on the determined log
evaluate_model_metrics(log)


    Determined Case ID  Actual Case ID                    Activity  \
0                 4974               1            invite_reviewers   
1                 4974               1            invite_reviewers   
2                 4975               1                get_review_3   
3                 4975               1                get_review_1   
4                 4975               2            invite_reviewers   
5                 4973               3            invite_reviewers   
6                 4979               2            invite_reviewers   
7                 4970               4            invite_reviewers   
8                 4977               2                  time_out_3   
9                 4978               1                  time_out_2   
10                9542               4            invite_reviewers   
11                9542               3                  time_out_3   
12                9535               1                      decide   
13                95