# Evaluation

In [1]:
import torch
import torch.nn as nn
import torchmetrics
from model import Transformer
from config import get_config, get_weights_file_path, reset_log
from train import get_model, get_ds, greedy_decode, create_log
import altair as alt
import pandas as pd
import numpy as np
import warnings
from typing import Tuple

warnings.filterwarnings("ignore")  # Ignore warnings

In [2]:
# Define the device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cpu


In [3]:
def load_next_batch() -> Tuple[dict, list, list]:
    """
    Load a sample batch from the validation set and return relevant tensors and tokens.
    
    :return: A dictionary containing encoder and decoder input tensors; List of tokens for the encoder input; 
    List of tokens for the decoder input.
    """
    # Load a sample batch from the validation set
    batch = next(iter(val_dataloader))
    encoder_input = batch["encoder_input"].to(device)
    encoder_mask = batch["encoder_mask"].to(device)
    decoder_input = batch["decoder_input"].to(device)
    decoder_mask = batch["decoder_mask"].to(device)

    encoder_input_tokens = [vocab_src.id_to_token(idx) for idx in encoder_input[0].cpu().numpy()]
    decoder_input_tokens = [vocab_tgt.id_to_token(idx) for idx in decoder_input[0].cpu().numpy()]

    # check that the batch size is 1
    assert encoder_input.size(0) == 1, "Batch size must be 1 for validation"

    model_out = greedy_decode(model, encoder_input, encoder_mask, vocab_src, vocab_tgt, config['seq_len'], device)
    
    return batch, encoder_input_tokens, decoder_input_tokens

In [4]:
def mtx2df(m: np.ndarray, max_row: int, max_col: int, row_tokens: list, col_tokens: list) -> pd.DataFrame:
    """
    Convert a matrix to a pandas DataFrame.
    
    :param m: Input matrix.
    :param max_row: Maximum number of rows to include in the DataFrame.
    :param max_col: Maximum number of columns to include in the DataFrame.
    :param row_tokens: List of row tokens.
    :param col_tokens: List of column tokens.
    :return: DataFrame representation of the input matrix.
    """
    return pd.DataFrame(
        [
            (
                r,
                c,
                float(m[r, c]),
                "%.3d %s" % (r, row_tokens[r] if len(row_tokens) > r else "<blank>"),
                "%.3d %s" % (c, col_tokens[c] if len(col_tokens) > c else "<blank>"),
            )
            for r in range(m.shape[0])
            for c in range(m.shape[1])
            if r < max_row and c < max_col
        ],
        columns=["row", "column", "value", "row_token", "col_token"],
    )


def get_attn_map(attn_type: str, layer: int, head: int) -> torch.Tensor:
    """
    Get attention maps based on attention type, layer, and head.
    
    :param attn_type: Type of attention ('encoder', 'decoder', or 'encoder-decoder')
    :param layer: Layer index.
    :param head: Head index.
    :return: Attention map tensor.
    """
    if attn_type == "encoder":
        attn = model.encoder.layers[layer].self_attention_block.attention_scores
    elif attn_type == "decoder":
        attn = model.decoder.layers[layer].self_attention_block.attention_scores
    elif attn_type == "encoder-decoder":
        attn = model.decoder.layers[layer].cross_attention_block.attention_scores
    return attn[0, head].data


def attn_map(attn_type: str, layer: int, head: int, row_tokens: list, col_tokens: list, max_sentence_len: int) -> alt.Chart:
    """
    Create an attention map for a specific type, layer, and head.
    
    :param attn_type: Type of attention ('encoder', 'decoder', or 'encoder-decoder').
    :param layer: Layer index.
    :param head: Head index.
    :param row_tokens: List of row tokens.
    :param col_tokens: List of column tokens.
    :param max_sentence_len: Maximum length of the sentence.
    :return: Altair chart object representing the attention map.
    """
    df = mtx2df(
        get_attn_map(attn_type, layer, head),
        max_sentence_len,
        max_sentence_len,
        row_tokens,
        col_tokens,
    )
    return (
        alt.Chart(data=df)
        .mark_rect()
        .encode(
            x=alt.X("col_token", axis=alt.Axis(title="")),
            y=alt.Y("row_token", axis=alt.Axis(title="")),
            color="value",
            tooltip=["row", "column", "value", "row_token", "col_token"],
        )
        #.title(f"Layer {layer} Head {head}")
        .properties(height=400, width=400, title=f"Layer {layer} Head {head}")
        .interactive()
    )


def get_all_attention_maps(attn_type: str, layers: list, heads: list, row_tokens: list, col_tokens: list, max_sentence_len: int) -> alt.ConcatChart:
    """
    Get all attention maps for specified types, layers, and heads.
    
    :param attn_type: Type of attention ('encoder', 'decoder', or 'encoder-decoder').
    :param layers: List of layer indices.
    :param heads: List of head indices.
    :param row_tokens: List of row tokens.
    :param col_tokens: List of column tokens.
    :param max_sentence_len: Maximum length of the sentence.
    :return: Concatenated Altair chart object representing all attention maps
    """
    charts = []
    for layer in layers:
        rowCharts = []
        for head in heads:
            rowCharts.append(attn_map(attn_type, layer, head, row_tokens, col_tokens, max_sentence_len))
        charts.append(alt.hconcat(*rowCharts))
    return alt.vconcat(*charts)

In [5]:
def matching_proportion(df: pd.DataFrame) -> None:
    """
    Calculate and print the proportion of matching case IDs.
    
    :param df: DataFrame containing case IDs.
    """
    # Check if the DataFrame is not empty before calculating the proportion
    if not df.empty:
        # Count the number of rows where 'Determined Case ID' and 'Actual Case ID' match
        matching_rows = df[df['Determined Case ID'] == df['Actual Case ID']]

        # Calculate the proportion of matching rows
        proportion_matching = len(matching_rows) / len(df)

        print(f"{len(matching_rows)} of {len(df)} case ID{'s' if len(df) != 1 else ''} "
              f"{'was' if len(matching_rows) == 1 else 'were'} determined correctly. "
              f"This corresponds to an accuracy of {proportion_matching:.2%}.")
    else:
        print("DataFrame is empty. Cannot calculate accuracy.")

    
def completely_correct_cases(df: pd.DataFrame) -> None:
    """
    Calculate and print the proportion of completely correct case IDs.
    
    :param df: DataFrame containing case IDs.
    """
    # Check if the DataFrame is not empty before calculating the proportion
    if not df.empty:
        cases = df['Actual Case ID'].unique()

        # Initialize an array to store values that meet both conditions
        matching_values = []

        # Iterate through unique values
        for case_id in cases:
            # Filter rows for the current 'Actual Case ID'
            subset_actual = df[df['Actual Case ID'] == case_id]

            # Check if 'Actual Case ID' equals 'Determined Case ID' for all rows
            condition_1 = all(subset_actual['Actual Case ID'] == subset_actual['Determined Case ID'])

            # Filter rows for the current 'Determined Case ID'
            subset_determined = df[df['Determined Case ID'] == case_id]

            # Check if 'Determined Case ID' equals 'Actual Case ID' for all rows
            condition_2 = all(subset_determined['Determined Case ID'] == subset_determined['Actual Case ID'])

            # If both conditions are true, add the value to the array
            if condition_1 and condition_2:
                matching_values.append(case_id)

        # Print the number of completely correct cases
        print(f"{len(matching_values)} of {len(cases)} case{'s' if len(cases) != 1 else ''} "
              f"{'were' if len(matching_values) != 1 else 'was'} determined completely correctly.", end=' ')

        # Print the list of completely correct cases (using "and" before the last value)
        if matching_values:
            if len(matching_values) == 1:
                print(f"This is {matching_values[0]}.", end=' ')
            else:
                print(f"These are {', '.join(map(str, matching_values[:-1]))} and {matching_values[-1]}.", end=' ')

        # Calculate the proportion of completely correct cases
        proportion_complete_matching = len(matching_values) / len(cases)

        print(f"This corresponds to a case accuracy of {proportion_complete_matching:.2%}.")
    else:
        print("DataFrame is empty. Cannot calculate case accuracy.")


def evaluate_model_metrics(df: pd.DataFrame) -> None:
    """
    Evaluate model metrics based on the determined log.
    
    :param df: DataFrame containing determined log.
    """
    matching_proportion(df)
    completely_correct_cases(df)

## Baseline Model

### Running Example

In [6]:
reset_log()  # Reset log file
config = get_config()  # Get configuration parameters
train_dataloader, val_dataloader, vocab_src, vocab_tgt = get_ds(config)  # Get training and validation data
model = get_model(config, vocab_src.get_vocab_size(), vocab_tgt.get_vocab_size()).to(device)  # Get the model

# Load pretrained weights for the model
model_filename = get_weights_file_path(config, f"19")
state = torch.load(model_filename)
model.load_state_dict(state['model_state_dict'])

parsing log, completed traces ::   0%|          | 0/6 [00:00<?, ?it/s]

Following columns were automatically matched:
'case:concept:name' for 'Case ID';
'concept:name' for 'Activity';
'time:timestamp' for 'Timestamp'.
Max length of source sentence: 10
Max length of target sentence: 10


<All keys matched successfully>

In [None]:
# Load next batch from validation set and print source and target texts
batch, encoder_input_tokens, decoder_input_tokens = load_next_batch()
print(f'Source: {batch["src_text"][0]}')
print(f'Target: {batch["tgt_text"][0]}')

sentence_len = 10  # Length of the sentence
layers = [0, 1, 2]  # Layers for attention maps
heads = [0, 1, 2, 3, 4, 5, 6, 7]  # Heads for attention maps

In [None]:
# Display attention maps for encoder self-attention
get_all_attention_maps("encoder", layers, heads, encoder_input_tokens, encoder_input_tokens, min(20, sentence_len))

In [None]:
# Display attention maps for decoder self-attention
get_all_attention_maps("decoder", layers, heads, decoder_input_tokens, decoder_input_tokens, min(20, sentence_len))

In [None]:
# Display attention maps for cross-attention
get_all_attention_maps("encoder-decoder", layers, heads, encoder_input_tokens, decoder_input_tokens, min(20, sentence_len))

In [None]:
log = create_log(config)  # Create log based on configuration

# Display the first 20 rows of the log
print(log.head(20))

# Check if there are more than 20 rows in the log and print if so
remaining_rows = len(log) - 20
if remaining_rows > 0:
    print(f"\n... (+ {remaining_rows} more rows)")

print('-' * 80)

# Evaluate model metrics based on the determined log
evaluate_model_metrics(log)

### Review Example Large

In [ ]:
reset_log()  # Reset log file
config = get_config()  # Get configuration parameters
train_dataloader, val_dataloader, vocab_src, vocab_tgt = get_ds(config)  # Get training and validation data
model = get_model(config, vocab_src.get_vocab_size(), vocab_tgt.get_vocab_size()).to(device)  # Get the model

# Load pretrained weights for the model
model_filename = get_weights_file_path(config, f"19")
state = torch.load(model_filename)
model.load_state_dict(state['model_state_dict'])

In [ ]:
# Load next batch from validation set and print source and target texts
batch, encoder_input_tokens, decoder_input_tokens = load_next_batch()
print(f'Source: {batch["src_text"][0]}')
print(f'Target: {batch["tgt_text"][0]}')

sentence_len = 10  # Length of the sentence
layers = [0, 1, 2]  # Layers for attention maps
heads = [0, 1, 2, 3, 4, 5, 6, 7]  # Heads for attention maps

In [ ]:
# Display attention maps for encoder self-attention
get_all_attention_maps("encoder", layers, heads, encoder_input_tokens, encoder_input_tokens, min(20, sentence_len))

In [ ]:
# Display attention maps for decoder self-attention
get_all_attention_maps("decoder", layers, heads, decoder_input_tokens, decoder_input_tokens, min(20, sentence_len))

In [ ]:
# Display attention maps for cross-attention
get_all_attention_maps("encoder-decoder", layers, heads, encoder_input_tokens, decoder_input_tokens, min(20, sentence_len))

In [ ]:
log = create_log(config)  # Create log based on configuration

# Display the first 20 rows of the log
print(log.head(20))

# Check if there are more than 20 rows in the log and print if so
remaining_rows = len(log) - 20
if remaining_rows > 0:
    print(f"\n... (+ {remaining_rows} more rows)")

print('-' * 80)

# Evaluate model metrics based on the determined log
evaluate_model_metrics(log)