In [1]:
import sys
import os
from pathlib import Path
# This appends the directory one level up (the root of your project) to the sys.path.
# Modify the path depending on the location of modules you want to import.
sys.path.append(os.path.abspath('../../'))

from config.config_managers import DashboardConfigManager
from dataManager import DataManager
from dash import Dash
import pandas as pd
import plotly.express as px
from abc import ABC, abstractmethod
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

2025-02-03 01:35:11 - INFO - PyTorch version 2.2.2 available.


In [2]:
class Visualization(ABC):
    def __init__(self, data, mappings):
        self.data = data
        self.tag_mapping = mappings['tag_mapping']
        self.dataset_mapping = mappings['dataset_mapping']

    @abstractmethod
    def prepare_data(self):
        pass

    @abstractmethod
    def visualize(self):
        pass

    def replace_mappings(self, data):
        data['Tag'] = data['Tag'].replace(self.tag_mapping)
        data['Model'] = data['Model'].replace(self.dataset_mapping)
        return data
    
    def process_entity_confusion(self, entity_confusion, o_error):
        """
        Processes the entity confusion matrix into high-level error categories
        and a separate DataFrame for entity and exclusion errors.
        
        Parameters:
            entity_confusion (dict): A dictionary representing entity confusion components.
            
        Returns:
            renamed_df (DataFrame): High-level error categories (Entity, Boundary, Entity and Boundary, Exclusion).
            entity_errors (DataFrame): DataFrame containing Entity Errors and Exclusion Errors only.
        """
        # Step 1: Create DataFrame
        df = pd.DataFrame(entity_confusion).fillna(0).astype(int).T

        # Step 2: Rename columns into high-level categories
        errors = df.copy()
        errors[o_error] = errors.pop('O')  # Rename 'O' to 'Exclusion'
        errors['Entity'] = errors.drop(columns=['Boundary', 'Entity and Boundary', o_error], errors='ignore').sum(axis=1)
        errors = errors[['Entity', 'Boundary', 'Entity and Boundary', o_error]]

        # Step 3: Create a separate DataFrame for Entity and Exclusion only
        entity_errors = df.drop(columns=['Boundary', 'Entity and Boundary', 'O'], errors='ignore')

        return errors, entity_errors



mappings = {
    'tag_mapping': {'PERS': 'PER'},
    'dataset_mapping': {'ANERCorp_CamelLab_arabertv02': 'AraBERTv02', 'conll2003_bert': 'BERT'}
}

class ReportBarChart(Visualization):
    def prepare_data(self):
        report_data = []
        for model_name, data_content in self.data.items():
            entity_report = data_content.entity_non_strict_report
            entity_strict_report = data_content.entity_strict_report
            entity_report['Model'] = model_name
            entity_report['Scheme'] = 'IOB'
            entity_strict_report['Model'] = model_name
            entity_strict_report['Scheme'] = 'IOB2'
            report_data.append(pd.concat([
                entity_report, 
                entity_strict_report
            ]))
        report_df = pd.concat(report_data)
        report_df = report_df[~report_df['Tag'].isin(['micro', 'macro', 'weighted'])]
        
        return self.replace_mappings(report_df)

    def visualize(self):
        entity_report_data = self.prepare_data()
        melted_df = entity_report_data.melt(id_vars=["Tag", "Support", "Model", "Scheme"], 
                        value_vars=["Precision", "Recall"], 
                        var_name="Metric", value_name="Value")
        melted_df['Value'] = melted_df['Value'].round(3)
        fig = px.bar(melted_df, x="Tag", y="Value",
                    facet_row="Scheme", facet_col="Model",
                    title="Breakdown of Precision and Recall Scores by Entity Tag, Categorized by Model and Tagging Scheme",
                    labels={"Value": "Score", 'Tag': 'Entity'},
                    color="Metric", barmode="group",
                    template="plotly_white",
                    facet_row_spacing=0.15,  # Adjust to a higher value for more space
                    facet_col_spacing=0.1,  # Adjust to a higher value for more space
                    text='Value',  # Display the Value on top of each bar
                    )

        fig.show()
    
    def visualize_f1(self):
        entity_report_data = self.prepare_data()
        entity_report_data['F1'] = entity_report_data['F1'].round(3)
        fig = px.bar(entity_report_data, x="Tag", y="F1", color="Model",
                    facet_col="Scheme",
                    title="Breakdown of F1 Score Per Model and Scheme",
                    labels={"Scale": "Scaled Counts"},
                    barmode='group',
                    template="plotly_white",
                    # facet_row_spacing=0.1,  # Adjusted spacing
                    facet_col_spacing=0.08,
                    text='F1',  # Display the actual Count on top of each bar
                    )
                
        fig.show()
        
    def visualize_support(self):
        entity_report_data = self.prepare_data()
        fig = px.bar(entity_report_data, x="Tag", y="Support", color="Model",
                    facet_col="Scheme",
                    title="Breakdown of the Number of Examples Per Model and Scheme",
                    labels={"Scale": "Scaled Counts"},
                    barmode='group',
                    template="plotly_white",
                    # facet_row_spacing=0.1,  # Adjusted spacing
                    facet_col_spacing=0.08,
                    text='Support',  # Display the actual Count on top of each bar
                    )
                
        fig.show()
        

class ConfusionBarChart(Visualization):
    def prepare_data(self):
        matrix_data = []
        for data_name, data_content in self.data.items():
            entity_matrix = pd.DataFrame(data_content.entity_non_strict_confusion_data['confusion_matrix']).T 
            entity_strict_matrix = pd.DataFrame(data_content.entity_strict_confusion_data['confusion_matrix']).T
            entity_matrix['Model'] = data_name
            entity_matrix['Scheme'] = 'IOB'
            entity_strict_matrix['Model'] = data_name
            entity_strict_matrix['Scheme'] = 'IOB2'
            matrix_data.append(pd.concat([
				entity_matrix, 
				entity_strict_matrix
			]))
        matrix_df = pd.concat(matrix_data)
        matrix_df.reset_index(inplace=True)
        matrix_df.rename(columns={'index': 'Tag'}, inplace=True)
        matrix_data = self.replace_mappings(matrix_df)
        
        grouped = matrix_data.groupby(['Tag', 'Model', 'Scheme']).sum()
        grouped['Total'] = grouped['TP'] + grouped['FP'] + grouped['FN']
        
        matrix_data = matrix_data.merge(grouped['Total'], on=['Tag', 'Model', 'Scheme'], how='left')
        
        matrix_data['TP_Count'] = matrix_data['TP']
        matrix_data['FP_Count'] = matrix_data['FP']
        matrix_data['FN_Count'] = matrix_data['FN']
        
        matrix_data['TP'] = matrix_data['TP'] / matrix_data['Total']
        matrix_data['FP'] = matrix_data['FP'] / matrix_data['Total']
        matrix_data['FN'] = matrix_data['FN'] / matrix_data['Total']
        
        confusion_scaled_df = matrix_data.melt(id_vars=["Tag", "Model", "Scheme"], value_vars=["TP", "FP", "FN"], var_name="Metric", value_name="Scale")
        confusion_count_df = matrix_data.melt(id_vars=["Tag", "Model", "Scheme"], value_vars=["TP_Count", "FP_Count", "FN_Count"], var_name="Metric", value_name="Count")
        confusion_count_df['Metric'] = confusion_count_df['Metric'].str.replace('_Count', '')
        confusion_data = confusion_scaled_df.merge(confusion_count_df, on=["Tag", "Model", "Scheme", "Metric"])
        return confusion_data

    def visualize(self):
        confusion_df = self.prepare_data()
        print(confusion_df)
        
        fig = px.bar(confusion_df, x="Tag", y="Scale", color="Metric",
            facet_row="Scheme", facet_col="Model",
            title="Breakdown of Confusion Matrix Components: by Entity Tag, Categorized by Model and Tagging Scheme",
            labels={"Scale": "Scaled Counts"},
            barmode='group',
            template="plotly_white",
            facet_row_spacing=0.1,  # Adjusted spacing
            facet_col_spacing=0.08,
            text='Count'  # Display the actual Count on top of each bar
            )
        
        fig.show()

class ConfusionHeatmap(Visualization):
    def prepare_data(self):
        matrix_data = []
        for data_name, data_content in self.data.items():
            entity_matrix = pd.DataFrame(data_content.entity_non_strict_confusion_data['confusion_matrix']).T 
            entity_strict_matrix = pd.DataFrame(data_content.entity_strict_confusion_data['confusion_matrix']).T
            entity_matrix['Model'] = data_name
            entity_matrix['Scheme'] = 'IOB'
            entity_strict_matrix['Model'] = data_name
            entity_strict_matrix['Scheme'] = 'IOB2'
            matrix_data.append(pd.concat([
				entity_matrix, 
				entity_strict_matrix
			]))
        matrix_df = pd.concat(matrix_data)
        matrix_df.reset_index(inplace=True)
        matrix_df.rename(columns={'index': 'Tag'}, inplace=True)
        return self.replace_mappings(matrix_df)

    def visualize(self):
        matrix_df = self.prepare_data()
        confusion_df = matrix_df.melt(id_vars=['Tag', 'Model', 'Scheme'], value_vars=['TP', 'FP', 'FN'], 
                            var_name='Metric', value_name='Count')
        
        unique_schemes = confusion_df['Scheme'].unique()
        unique_datasets = confusion_df['Model'].unique()
        
        fig = make_subplots(rows=len(unique_schemes), cols=len(unique_datasets),
                            subplot_titles=[f"{dataset} - {scheme}" for scheme in unique_schemes for dataset in unique_datasets],
                            shared_yaxes=True, horizontal_spacing=0.02, vertical_spacing=0.1)
        
        max_value = confusion_df['Count'].max()
        
        for idx, scheme in enumerate(unique_schemes):
            for jdx, dataset in enumerate(unique_datasets):
                filtered_data = confusion_df[(confusion_df['Scheme'] == scheme) & (confusion_df['Model'] == dataset)]
                heatmap_data = filtered_data.pivot_table(index='Metric', columns='Tag', values='Count', fill_value=0)
                text_data = filtered_data.pivot_table(index='Metric', columns='Tag', values='Count', fill_value=0).astype(int)

                
                
                fig.add_trace(
                    go.Heatmap(
                        z=heatmap_data,
                        x=heatmap_data.columns,
                        y=heatmap_data.index,
                        colorscale='RdBu_r',
                        coloraxis="coloraxis",  # Use a unified color axis
                        text=text_data,  # Add text annotations
                        texttemplate="%{text}",  # Use the text values directly
                        hovertemplate="Metric: %{y}<br>Tag: %{x}<br>Count: %{text}<extra></extra>",
                    ),
                    row=idx + 1, col=jdx + 1
                )
                
        fig.update_layout(
            coloraxis=dict(colorscale='RdBu_r', cmin=0, cmax=max_value, colorbar=dict(title="Counts")),
            title_text="Confusion Matrix Heatmap Categorized by Dataset and Tagging Scheme",
            template="plotly_white",
            height=600, width=700,
        )
        fig.show()




class ErrorTypeHeatmap(Visualization):
    def prepare_data(self, component):
        matrix_data = []
        o_error = "Inclusion" if component == 'false_positives' else "Exclusion"
        # Step 1: Collect general error data
        for data_name, data_content in self.data.items():
            for scheme, entity_confusion in [('IOB', data_content.entity_non_strict_confusion_data), 
                                             ('IOB2', data_content.entity_strict_confusion_data)]:
                # Process general errors (Entity, Boundary, Entity+Boundary, Exclusion)
                error_types, _ = self.process_entity_confusion(entity_confusion[component], o_error)

                # Annotate with Model and Scheme
                error_types['Model'] = data_name
                error_types['Scheme'] = scheme
                matrix_data.append(error_types)
        
        # Step 2: Combine and process data
        matrix_df = pd.concat(matrix_data)
        matrix_df.reset_index(inplace=True)
        matrix_df.rename(columns={'index': 'Tag'}, inplace=True)
        matrix_df = self.replace_mappings(matrix_df)
        
        # Step 3: Melt raw counts for visualization
        melted_df = matrix_df.melt(
            id_vars=['Tag', 'Model', 'Scheme'],
            value_vars=['Entity', 'Boundary', 'Entity and Boundary', o_error],
            var_name="Error Type",
            value_name="Raw Count"
        )
        return melted_df

    def visualize(self, component):
        general_errors_df = self.prepare_data(component)
        
        title_component = "False Positives" if component == 'false_positives' else "False Negatives"
        
        # Step 4: Create heatmap for raw errors
        fig = make_subplots(
            rows=2, cols=2,
            subplot_titles=[f"{scheme} - {model}" for scheme in general_errors_df['Scheme'].unique() 
                            for model in general_errors_df['Model'].unique()],
            shared_xaxes=True, shared_yaxes=True, horizontal_spacing=0.05, vertical_spacing=0.1
        )
        
        unique_schemes = general_errors_df['Scheme'].unique()
        unique_models = general_errors_df['Model'].unique()
        
        for i, scheme in enumerate(unique_schemes):
            for j, model in enumerate(unique_models):
                filtered_data = general_errors_df[(general_errors_df['Scheme'] == scheme) &
                                                  (general_errors_df['Model'] == model)]
                
                pivot_data = filtered_data.pivot(index='Error Type', columns='Tag', values='Raw Count')
                print(pivot_data)
                fig.add_trace(
                    go.Heatmap(
                        z=pivot_data.values,
                        x=pivot_data.columns,
                        y=pivot_data.index,
                        coloraxis="coloraxis",
                        text=pivot_data.values,
                        texttemplate="%{text}",
                        hovertemplate="Tag: %{x}<br>Error Type: %{y}<br>Count: %{text}<extra></extra>"
                    ),
                    row=i + 1, col=j + 1
                )
       
        fig.update_layout(
            coloraxis=dict(colorscale='RdBu_r', colorbar=dict(title="Error Count")),
            title_text=f"{title_component} Error Type Heatmap: by Entity Tag, Categorized by Model and Tagging Scheme",
            template="plotly_white",
            height=600, width=1000,
        )
        fig.show()
    
    def visualize_table(self, component):
        """
        Generates tables showing raw counts and percentages for each error type, 
        categorized by Scheme and Model.
        """
        # Step 1: Prepare the data
        errors_type = self.prepare_data(component)
        
        # Aggregate totals for error types
        pivot_data = errors_type.groupby(["Error Type", "Scheme", "Model"], as_index=False).agg(
            Total_Count=("Raw Count", "sum")
        )

        # Step 2: Calculate percentages across all errors within each Scheme and Model
        pivot_data['Percentage'] = (
            pivot_data.groupby(['Scheme', 'Model'])['Total_Count']
            .transform(lambda x: (x / x.sum()) * 100)  # Use transform to maintain row alignment
        )
        pivot_data['Percentage'] = pivot_data['Percentage'].round(2)  # Round percentages for display

        # Step 3: Print tables for each Scheme and Model
        unique_schemes = pivot_data['Scheme'].unique()
        unique_models = pivot_data['Model'].unique()

        for scheme in unique_schemes:
            for model in unique_models:
                print(f"\n### Table for Scheme: {scheme}, Model: {model} ###\n")
                filtered_data = pivot_data[
                    (pivot_data['Scheme'] == scheme) & 
                    (pivot_data['Model'] == model)
                ].copy()
                display_df = filtered_data[['Error Type', 'Total_Count', 'Percentage']].copy()
                display_df.rename(
                    columns={"Error Type": "Error Type", "Total_Count": "Raw Count", "Percentage": "Percentage (%)"},
                    inplace=True
                )
                print(display_df.to_string(index=False))  # Display as a clean table
                

class EntityErrorsHeatmap(Visualization):
    def prepare_data(self, component):
        matrix_data = []
        o_error = "Inclusion" if component == 'false_positives' else "Exclusion"
        # Step 1: Collect general error data
        for data_name, data_content in self.data.items():
            for scheme, entity_confusion in [('IOB', data_content.entity_non_strict_confusion_data), 
                                             ('IOB2', data_content.entity_strict_confusion_data)]:
                # Process general errors (Entity, Boundary, Entity+Boundary, Exclusion)
                _, entity_errors = self.process_entity_confusion(entity_confusion[component], o_error)

                # Annotate with Model and Scheme
                entity_errors['Model'] = data_name
                entity_errors['Scheme'] = scheme
                entity_errors = entity_errors.rename(columns=self.tag_mapping)
                matrix_data.append(entity_errors)
        
        # Step 2: Combine and process data
        matrix_df = pd.concat(matrix_data)
        matrix_df.reset_index(inplace=True)
        matrix_df.rename(columns={'index': 'Tag'}, inplace=True)
        matrix_df = self.replace_mappings(matrix_df)
        
        
        
        # Step 3: Melt raw counts for visualization
        melted_df = melted_df = matrix_df.melt(
            id_vars=['Tag', 'Model', 'Scheme'],
            var_name="Error Type",
            value_name="Raw Count"
        )
        return melted_df

    def visualize(self, component):
        entity_errors_df = self.prepare_data(component)
        
        title_component = "False Positives" if component == 'false_positives' else "False Negatives"
        # Step 4: Create heatmap for raw errors
        fig = make_subplots(
            rows=2, cols=2,
            subplot_titles=[f"{scheme} - {model}" for scheme in entity_errors_df['Scheme'].unique() 
                            for model in entity_errors_df['Model'].unique()],
            shared_xaxes=True, shared_yaxes=True, horizontal_spacing=0.05, vertical_spacing=0.1
        )
        
        unique_schemes = entity_errors_df['Scheme'].unique()
        unique_models = entity_errors_df['Model'].unique()
        
        for i, scheme in enumerate(unique_schemes):
            for j, model in enumerate(unique_models):
                entity_errors = entity_errors_df[(entity_errors_df['Scheme'] == scheme) &
                                                  (entity_errors_df['Model'] == model)]
                
                pivot_data = entity_errors.pivot(index='Error Type', columns='Tag', values='Raw Count')
                print(pivot_data)
                
                fig.add_trace(
                    go.Heatmap(
                        z=pivot_data.values,
                        x=pivot_data.columns,
                        y=pivot_data.index,
                        coloraxis="coloraxis",
                        text=pivot_data.values,
                        texttemplate="%{text}",
                        hovertemplate="Tag: %{x}<br>Error Type: %{y}<br>Count: %{text}<extra></extra>"
                    ),
                    row=i + 1, col=j + 1
                )
                if i>0 or j>1:
                    fig.update_xaxes(title_text="True Entity", row=i + 1, col=j + 1)
                fig.update_yaxes(title_text="Predicted Entity", row=i + 1, col=j + 1)
        
        fig.update_layout(
            coloraxis=dict(colorscale='RdBu_r', colorbar=dict(title="Error Count")),
            title_text=f"{title_component} Entity Errors Heatmap: by Entity Tag, Categorized by Model and Tagging Scheme",
            template="plotly_white",
            height=600, width=1000,
        )
        fig.show()
    

from abc import ABC, abstractmethod
from collections import defaultdict
from seqeval.scheme import Entities, IOB2, IOB1
from seqeval.metrics.sequence_labeling import get_entities
pd.set_option("display.max_rows", None)  # Display all rows


class EntityErrorAnalyzer(ABC):
    """Abstract base class for entity analysis."""

    def __init__(self, df):
        self.df = df
        self.y_true, self.y_pred = self.prepare_data(df)
        self.true_entities = []
        self.pred_entities = []

    @abstractmethod
    def extract_entities(self, y_data):
        """Extract entities based on the specific mode (strict or non-strict)."""
        pass

    @abstractmethod
    def prepare_entities(self):
        """Prepare true and predicted entities for analysis."""
        pass
    
    def prepare_data(self, df):
        core_data = df[df['Labels'] !=-100]
        y_true = core_data.groupby('Sentence Ids')['True Labels'].apply(list).tolist()
        y_pred = core_data.groupby('Sentence Ids')['Pred Labels'].apply(list).tolist()
        return y_true, y_pred
    
    def compute_false_negatives(self, entity_type):
        """Compute false negatives for a specific entity type."""
        return set(
            [e for e in self.true_entities if e[1] == entity_type]
        ) - set([e for e in self.pred_entities if e[1] == entity_type])

    def compute_false_positives(self, entity_type):
        """Compute false positives for a specific entity type."""
        return set(
            [e for e in self.pred_entities if e[1] == entity_type]
        ) - set([e for e in self.true_entities if e[1] == entity_type])

    def analyze_sentence_errors(self, target_entities, comparison_entities):
        """Analyze errors and return sentence IDs by error type."""
        error_sentences = defaultdict(set)  # Dictionary to hold sentence IDs for each error type
        non_o_errors = set()
        indexed_entities = defaultdict(list)

        # Index comparison entities by sentence
        for entity in comparison_entities:
            sen, entity_type, start, end = entity
            indexed_entities[sen].append(entity)

        # First pass: entity errors
        for target_entity in target_entities:
            t_sen, t_type, t_start, t_end = target_entity

            for comp_entity in indexed_entities[t_sen]:
                c_type, c_start, c_end = comp_entity[1:]

                if (
                    t_start == c_start
                    and t_end == c_end
                    and t_type != c_type
                    and target_entity not in non_o_errors
                ):
                    non_o_errors.add(target_entity)
                    error_sentences["Entity"].add(target_entity)

        # Second pass: boundary errors
        for target_entity in target_entities - non_o_errors:
            t_sen, t_type, t_start, t_end = target_entity

            for comp_entity in indexed_entities[t_sen]:
                c_sen, c_type, c_start, c_end = comp_entity

                if (
                    t_type == c_type
                    and (t_start <= c_start <= t_end or t_start <= c_end <= t_end)
                    and target_entity not in non_o_errors
                ):
                    non_o_errors.add(target_entity)
                    error_sentences["Boundary"].add(target_entity)

        # Third pass: combined entity and boundary errors
        for target_entity in target_entities - non_o_errors:
            t_sen, t_type, t_start, t_end = target_entity

            for comp_entity in indexed_entities[t_sen]:
                c_sen, c_type, c_start, c_end = comp_entity

                if (
                    c_type != t_type
                    and (t_start <= c_start <= t_end or t_start <= c_end <= t_end)
                    and target_entity not in non_o_errors
                ):
                    non_o_errors.add(target_entity)
                    error_sentences["Entity and Boundary"].add(target_entity)
                    # print(t_sen, t_start, t_end, c_sen, c_start, c_end)
                    # print(f' ({t_start} <= {c_start} <= {t_end} or {t_start} <= {c_end} <= {t_end})')
                    

        # Remaining unmatched errors are "O errors"
        for target_entity in target_entities - non_o_errors:
            t_sen, t_type, t_start, t_end = target_entity
            error_sentences["O"].add(target_entity)

        return {error_type: list(s_ids) for error_type, s_ids in error_sentences.items()}


    def analyze_component(self, error_type, entity_type=None):
        
        """Analyze errors (FP or FN) for a specific or all entity types."""
        self.prepare_entities()
        error_analysis = {}
        entity_types = (
            [entity_type]
            if entity_type
            else set(e[1] for e in self.true_entities + self.pred_entities)
        )

        for etype in entity_types:
            if error_type == "false_negatives":
                target_entities = self.compute_false_negatives(etype)
            elif error_type == "false_positives":
                target_entities = self.compute_false_positives(etype)
            else:
                raise ValueError("Error type must be 'false_negative' or 'false_positive'.")

            error_analysis[etype] = self.analyze_sentence_errors(
                target_entities, self.pred_entities if error_type == "false_negatives" else self.true_entities
            )

        return error_analysis
    def analyze_errors(self):
        self.prepare_entities()
        """Analyze both false positives and false negatives."""
        error_components = {"false_positives": defaultdict(set), "false_negatives": defaultdict(set)}

        for error_component in error_components.keys():
            results = self.analyze_component(error_component)
            for entity_type, errors in results.items():
                for error_type, sentences in errors.items():
                    error_components[error_component][error_type].update(sentences)

        # Convert sets to lists for consistency
        return {k: {etype: set(ids) for etype, ids in v.items()} for k, v in error_components.items()}
    
    


class StrictEntityAnalyzer(EntityErrorAnalyzer):
    """Analyzer for strict entity processing."""

    def extract_entities(self, y_data):
        """Extract entities in strict mode."""
        entities = Entities(y_data, IOB2, False)
        return self.adjust_end_index(entities)

    def prepare_entities(self):
        """Prepare true and predicted entities for strict mode."""
        self.true_entities = self.flatten_entities(self.extract_entities(self.y_true))
        self.pred_entities = self.flatten_entities(self.extract_entities(self.y_pred))

    def print_sentence(self, sen_id):
        """Print entities for a specific sentence ID."""
        true_entities = self.extract_entities(self.y_true).entities
        pred_entities = self.extract_entities(self.y_pred).entities
        print(f"True: {true_entities[sen_id]}")
        print(f"Pred: {pred_entities[sen_id]}")
        error = set(pred_entities[sen_id]) - set(true_entities[sen_id])
        print(f"Error in Pred: {error}")
        core_data = self.df[self.df['Labels'] !=-100]
        sentence_data = core_data[core_data['Sentence Ids']  == sen_id].copy()
        print(sentence_data[['Words', 'Sentence Ids', 'True Labels', 'Pred Labels', 'Strict True Entities', 'Strict Pred Entities', 'True Entities', 'Pred Entities']].head(60).to_string())

    @staticmethod
    def flatten_entities(entities):
        """Flatten strict entities into tuples."""
        return [e for sen in entities.entities for e in sen]
    
    @staticmethod
    def adjust_end_index(entities):
        """Adjust the end index for IOB2 entities to make them inclusive."""
        adjusted_entities = []
        for sentence_entities in entities.entities:  # Iterate through sentences
            adjusted_sentence = []
            for entity in sentence_entities:  # Iterate through entities in each sentence
                sentence_id, entity_type, start, end = entity.to_tuple()
                # Adjust end index
                adjusted_sentence.append((sentence_id, entity_type, start, end - 1))
            adjusted_entities.append(adjusted_sentence)
        entities.entities = adjusted_entities  # Replace with adjusted entities
        return entities
    
    
    
    
class NonStrictEntityAnalyzer(EntityErrorAnalyzer):
    """Analyzer for non-strict entity processing."""

    def extract_entities(self, y_data):
        """Extract entities in non-strict mode."""
        return [
            [(sen_id,) + entity for entity in get_entities(sen)]
            for sen_id, sen in enumerate(y_data)
        ]

    def prepare_entities(self):
        """Prepare true and predicted entities for non-strict mode."""
        self.true_entities = self.flatten_entities(self.extract_entities(self.y_true))
        self.pred_entities = self.flatten_entities(self.extract_entities(self.y_pred))

    def print_sentence(self, sen_id):
        """Print entities for a specific sentence ID."""
        true_entities = self.extract_entities(self.y_true)
        pred_entities = self.extract_entities(self.y_pred)
        print(f"True: {true_entities[sen_id]}")
        print(f"Pred: {pred_entities[sen_id]}")
        error = set(pred_entities[sen_id]) - set(true_entities[sen_id])
        print(f"Error in Pred: {error}")
        core_data = self.df[self.df['Labels'] !=-100]
        sentence_data = core_data[core_data['Sentence Ids']  == sen_id].copy()
        print(sentence_data[['Words', 'Sentence Ids', 'True Labels', 'Pred Labels', 'Strict True Entities', 'Strict Pred Entities', 'True Entities', 'Pred Entities']].head(60).to_string())
        
    @staticmethod
    def flatten_entities(entities):
        """Flatten non-strict entities into tuples."""
        return [e for sen in entities for e in sen]

class ErrorAnalysisManager:
    """Manages all error analysis workflows and stores results."""

    def __init__(self, df):
        """
        Initialize the manager with the dataset.

        Args:
            df (pd.DataFrame): The dataset containing y_true and y_pred.
        """
        self.df = df
        self.strict_analyzer = StrictEntityAnalyzer(df)
        self.non_strict_analyzer = NonStrictEntityAnalyzer(df)
        self.results = {
            "IOB2": {"false_negatives": None, "false_positives": None, "errors": None},
            "IOB": {"false_negatives": None, "false_positives": None, "errors": None},
        }

    def run_workflows(self):
        """Run all error analysis workflows."""
        self.results["IOB2"]["false_negatives"] = self.strict_analyzer.analyze_component("false_negatives")
        self.results["IOB2"]["false_positives"] = self.strict_analyzer.analyze_component("false_positives")
        self.results["IOB2"]["errors"] = self.strict_analyzer.analyze_errors()

        self.results["IOB"]["false_negatives"] = self.non_strict_analyzer.analyze_component("false_negatives")
        self.results["IOB"]["false_positives"] = self.non_strict_analyzer.analyze_component("false_positives")
        self.results["IOB"]["errors"] = self.non_strict_analyzer.analyze_errors()

    def get_results(self):
        """Get the results of all workflows."""
        return self.results

class SchemeComparator:
    """Facilitator for comparing annotation schemes."""

    def __init__(self, results):
        """
        Initialize the comparator with results from error analysis.

        Args:
            results (dict): Results from the manager's workflows, structured by scheme.
        """
        self.results = results

    def compare_component(self, component, entity_type):
        """
        Compare all error types for a specific entity across schemes.

        Args:
            entity_type (str): The entity type to compare (e.g., "MISC").

        Returns:
            dict: A dictionary with set operation results for all error types.
        """
        schemes = list(self.results.keys())
        if len(schemes) != 2:
            raise ValueError("Comparator requires exactly two schemes for comparison.")

        scheme_1, scheme_2 = schemes
        component_1 = self.results[scheme_1][component]
        component_2 = self.results[scheme_2][component]

        results = {}
        entity_1 = component_1.get(entity_type, {})
        entity_2 = component_2.get(entity_type, {})

        # Compare all error types under the given entity
        all_error_types = set(entity_1.keys()).union(set(entity_2.keys()))
        for error_type in all_error_types:
            set_1 = set(entity_1.get(error_type, []))
            set_2 = set(entity_2.get(error_type, []))

            results[error_type] = {
                "overlap": set_1 & set_2,
                f"{scheme_1} Only": set_1 - set_2,
                f"{scheme_2} Only": set_2 - set_1,
            }

        return results

    def compare_errors(self, component, error_type):
        """
        Compare errors across all entities and error types for both schemes.

        Returns:
            dict: A dictionary with set operation results for all error types.
        """
        schemes = list(self.results.keys())
        if len(schemes) != 2:
            raise ValueError("Comparator requires exactly two schemes for comparison.")

        schemes_map = {'scheme_1': 'IOB', 'scheme_2': 'IOB2'}
        errors_1 = self.results[schemes_map['scheme_1']]["errors"][component]
        errors_2 = self.results[schemes_map['scheme_2']]["errors"][component]

       
       
        comparison_result = ComparisonResult.from_lists(errors_1, errors_2, error_type, schemes_map)

        return comparison_result.to_dict()


from dataclasses import dataclass, field
from typing import List, Dict, Set

@dataclass
class ComparisonResult:
    """Dataclass to store comparison results."""
    scheme_1_name: str
    scheme_2_name: str
    set_1_errors: Set[int] = field(default=set)
    set_2_errors: Set[int] = field(default=set)
    overlap: Set[int] = field(default_factory=set)
    scheme_1_only: Set[int] = field(default_factory=set)
    scheme_2_only: Set[int] = field(default_factory=set)

    @staticmethod
    def from_lists(errors_1: Dict, errors_2: Dict, error_type: str, schemes_map: Dict) -> "ComparisonResult":
        """
        Create a ComparisonResult from two lists.

        Args:
            lst_1: List of values from scheme 1.
            lst_2: List of values from scheme 2.

        Returns:
            ComparisonResult: Dataclass containing the comparison and statistics.
        """
        set_1 = set(errors_1.get(error_type, []))
        
        set_2 = set(errors_2.get(error_type, []))
        
        sentence_lst_1 = [error[0] for error in errors_1.get(error_type, [])]
        sentence_lst_2 = [error[0] for error in errors_2.get(error_type, [])]
        sentence_set_1 = set(sentence_lst_1)
        sentence_set_2 = set(sentence_lst_2)
        
        overlap = sentence_set_1 & sentence_set_2
        scheme_1_only = sentence_set_1 - sentence_set_2
        scheme_2_only = sentence_set_2 - sentence_set_1

        return ComparisonResult(
            scheme_1_name=schemes_map['scheme_1'],
            scheme_2_name=schemes_map['scheme_2'],
            set_1_errors= set_1,
            set_2_errors= set_2,
            overlap=overlap,
            scheme_1_only=scheme_1_only,
            scheme_2_only=scheme_2_only,
        )
        
    def to_dict(self) -> Dict[str, Dict[str, Set[int]]]:
        """R"Overlap": self.overlap, comparison results as a dictionary."""
        return {
            f"{self.scheme_1_name} Errors": self.set_1_errors,
            f"{self.scheme_2_name} Errors": self.set_2_errors,
            "Overlap": self.overlap,
            f"{self.scheme_1_name} Only Errors": self.scheme_1_only,
            f"{self.scheme_2_name} Only Errors": self.scheme_2_only,
        }





In [3]:
CONFIG_PATH = Path("/Users/ay227/Desktop/Final-Year/Thesis-Experiments/Online-Dashboard-Phase/dashboard-config.yaml")
config_manager = DashboardConfigManager(CONFIG_PATH)
dev_config = config_manager.development_config    

app = Dash(__name__, suppress_callback_exceptions=True)

app_config = config_manager.app_config
server = app.server  # Flask server instance for caching
variants_data = None

data_manager = DataManager(config_manager, server)
dash_data = data_manager.load_data()

In [4]:
df=dash_data['ANERCorp_CamelLab_arabertv02'].analysis_data

In [6]:
df.head()

Unnamed: 0,Sentence Ids,Token Positions,Words,Tokens,Word Pieces,Core Tokens,True Labels,Token Selector Id,Pred Labels,Agreements,...,Strict Pred Entities,True Entities,Pred Entities,True Aligned Scheme,Pred Aligned Scheme,Consistency Ratio,Inconsistency Ratio,Normalized Token Entropy,Normalized Word Entropy,Normalized Prediction Entropy
0,0,0,[CLS],[CLS],[CLS],[CLS],[CLS],[CLS]@#0@#0,[CLS],True,...,[CLS],[CLS],[CLS],True,True,0.0,0.0,-1.0,-1.0,0.000726
1,0,1,الصالحية,الصالحية,الصالحية,الصالحية,B-LOC,الصالحية@#1@#0,O,False,...,O,LOC,O,True,True,0.0,0.0,-1.0,-1.0,0.409098
2,0,2,المفرق,المفرق,المفرق,المفرق,B-LOC,المفرق@#2@#0,B-LOC,True,...,LOC,LOC,LOC,True,True,0.0,1.0,0.0,0.0,0.056689
3,0,3,-,-,-,-,O,-@#3@#0,O,True,...,O,O,O,True,True,1.0,0.0,0.0,0.0,0.002177
4,0,4,غيث,غيث,غيث,غيث,B-PERS,غيث@#4@#0,B-PERS,True,...,PERS,PERS,PERS,True,True,0.0,0.0,-1.0,-1.0,0.057604


In [4]:
report_data = []
for data_name, data_content in dash_data.items():
	entity_report = data_content.entity_non_strict_report
	entity_strict_report = data_content.entity_strict_report
	entity_report['Model'] = data_name
	entity_report['Scheme'] = 'IOB'
	entity_strict_report['Model'] = data_name
	entity_strict_report['Scheme'] = 'IOB2'
	report_data.append(pd.concat([
		entity_report, 
		entity_strict_report
	]))
report_df = pd.concat(report_data)
# report_df = report_df[~report_df['Tag'].isin(['micro', 'macro', 'weighted'])]    
tag_mapping = {
    'PERS': 'PER'
}

dataset_mapping = {
    'ANERCorp_CamelLab_arabertv02': 'AraBERTv02',
    'conll2003_bert': 'BERT'
}

report_df['Tag'] = report_df['Tag'].replace(tag_mapping)
report_df['Model'] = report_df['Model'].replace(dataset_mapping)
report_df


Unnamed: 0,Tag,Precision,Recall,F1,Support,Model,Scheme
0,LOC,0.8919,0.9275,0.9094,676,AraBERTv02,IOB
1,MISC,0.7366,0.6214,0.6741,243,AraBERTv02,IOB
2,ORG,0.763,0.7364,0.7494,459,AraBERTv02,IOB
3,PER,0.8835,0.8298,0.8558,905,AraBERTv02,IOB
4,micro,0.8483,0.8178,0.8327,2283,AraBERTv02,IOB
5,macro,0.8187,0.7788,0.7972,2283,AraBERTv02,IOB
6,weighted,0.8461,0.8178,0.831,2283,AraBERTv02,IOB
0,LOC,0.8927,0.9341,0.9129,668,AraBERTv02,IOB2
1,MISC,0.772,0.634,0.6963,235,AraBERTv02,IOB2
2,ORG,0.7842,0.7511,0.7673,450,AraBERTv02,IOB2


In [5]:
# Confusion matrix heatmap example
report_bar = ReportBarChart(dash_data, mappings)
report_bar.visualize_f1()

In [6]:
# Confusion matrix heatmap example
report_bar = ReportBarChart(dash_data, mappings)
report_bar.visualize()

In [7]:
confusion_bar = ConfusionBarChart(dash_data, mappings)
confusion_bar.visualize()

     Tag       Model Scheme Metric     Scale  Count
0    ORG  AraBERTv02    IOB     TP  0.599291    338
1    LOC  AraBERTv02    IOB     TP  0.833777    627
2   MISC  AraBERTv02    IOB     TP  0.508418    151
3    PER  AraBERTv02    IOB     TP  0.748008    751
4    ORG  AraBERTv02   IOB2     TP  0.622468    338
5    LOC  AraBERTv02   IOB2     TP  0.839838    624
6   MISC  AraBERTv02   IOB2     TP  0.534050    149
7    PER  AraBERTv02   IOB2     TP  0.741803    724
8    ORG        BERT    IOB     TP  0.814595   1507
9    LOC        BERT    IOB     TP  0.862375   1554
10  MISC        BERT    IOB     TP  0.676023    578
11   PER        BERT    IOB     TP  0.923855   1553
12   ORG        BERT   IOB2     TP  0.821058   1505
13   LOC        BERT   IOB2     TP  0.864775   1554
14  MISC        BERT   IOB2     TP  0.686977    575
15   PER        BERT   IOB2     TP  0.925507   1553
16   ORG  AraBERTv02    IOB     FP  0.186170    105
17   LOC  AraBERTv02    IOB     FP  0.101064     76
18  MISC  Ar

In [8]:
# Confusion matrix heatmap example
report_bar = ReportBarChart(dash_data, mappings)
report_bar.visualize_support()

In [9]:
# Confusion matrix heatmap example
confusion_heatmap = ConfusionHeatmap(dash_data, mappings)
confusion_heatmap.visualize()

In [10]:


error_type_heatmap = ErrorTypeHeatmap(dash_data, mappings)
error_type_heatmap.visualize_table('false_positives')


error_type_heatmap = ErrorTypeHeatmap(dash_data, mappings)
error_type_heatmap.visualize('false_positives')


entity_errors_heatmap = EntityErrorsHeatmap(dash_data, mappings) 
entity_errors_heatmap.visualize('false_positives')





### Table for Scheme: IOB, Model: AraBERTv02 ###

         Error Type  Raw Count  Percentage (%)
           Boundary        116           34.73
             Entity         83           24.85
Entity and Boundary         29            8.68
          Inclusion        106           31.74

### Table for Scheme: IOB, Model: BERT ###

         Error Type  Raw Count  Percentage (%)
           Boundary         78           14.44
             Entity        233           43.15
Entity and Boundary         86           15.93
          Inclusion        143           26.48

### Table for Scheme: IOB2, Model: AraBERTv02 ###

         Error Type  Raw Count  Percentage (%)
           Boundary        101           30.61
             Entity         76           23.03
Entity and Boundary         22            6.67
          Inclusion        131           39.70

### Table for Scheme: IOB2, Model: BERT ###

         Error Type  Raw Count  Percentage (%)
           Boundary         69           13.88
       

Tag         LOC  MISC  ORG  PER
Error Type                     
LOC           0     3    3    1
MISC          6     0    9    0
ORG          20     6    0   11
PER           6     5   13    0
Tag         LOC  MISC  ORG  PER
Error Type                     
LOC           0    14   43    5
MISC         19     0   31    3
ORG          40    28    0   16
PER          12     2   20    0
Tag         LOC  MISC  ORG  PER
Error Type                     
LOC           0     3    3    1
MISC          4     0    9    0
ORG          19     6    0   11
PER           6     3   11    0
Tag         LOC  MISC  ORG  PER
Error Type                     
LOC           0    14   42    5
MISC         19     0   30    3
ORG          40    28    0   16
PER          12     2   20    0


In [76]:
error_type_heatmap = ErrorTypeHeatmap(dash_data, mappings)
error_type_heatmap.visualize_table('false_negatives')

error_type_heatmap = ErrorTypeHeatmap(dash_data, mappings)
error_type_heatmap.visualize('false_negatives')

entity_errors_heatmap = EntityErrorsHeatmap(dash_data, mappings) 
entity_errors_heatmap.visualize('false_negatives')



### Table for Scheme: IOB, Model: AraBERTv02 ###

         Error Type  Raw Count  Percentage (%)
           Boundary        128           30.77
             Entity         83           19.95
Entity and Boundary         32            7.69
          Exclusion        173           41.59

### Table for Scheme: IOB, Model: BERT ###

         Error Type  Raw Count  Percentage (%)
           Boundary         81           17.76
             Entity        233           51.10
Entity and Boundary         77           16.89
          Exclusion         65           14.25

### Table for Scheme: IOB2, Model: AraBERTv02 ###

         Error Type  Raw Count  Percentage (%)
           Boundary        113           30.05
             Entity         76           20.21
Entity and Boundary         32            8.51
          Exclusion        155           41.22

### Table for Scheme: IOB2, Model: BERT ###

         Error Type  Raw Count  Percentage (%)
           Boundary         76           16.49
       

Tag         LOC  MISC  ORG  PER
Error Type                     
LOC           0     6   20    6
MISC          3     0    6    5
ORG           3     9    0   13
PER           1     0   11    0
Tag         LOC  MISC  ORG  PER
Error Type                     
LOC           0    19   40   12
MISC         14     0   28    2
ORG          43    31    0   20
PER           5     3   16    0
Tag         LOC  MISC  ORG  PER
Error Type                     
LOC           0     4   19    6
MISC          3     0    6    3
ORG           3     9    0   11
PER           1     0   11    0
Tag         LOC  MISC  ORG  PER
Error Type                     
LOC           0    19   40   12
MISC         14     0   28    2
ORG          42    30    0   20
PER           5     3   16    0


In [59]:
df = dash_data['conll2003_bert'].analysis_data

In [77]:
dash_data['conll2003_bert'].token_confusion_matrix

{'confusion_matrix': {'B-LOC': {'TP': 1563, 'FP': 120, 'FN': 105, 'TN': 44647},
  'I-ORG': {'TP': 777, 'FP': 90, 'FN': 58, 'TN': 45510},
  'B-MISC': {'TP': 589, 'FP': 121, 'FN': 113, 'TN': 45612},
  'I-MISC': {'TP': 161, 'FP': 73, 'FN': 55, 'TN': 46146},
  'I-PER': {'TP': 1145, 'FP': 15, 'FN': 11, 'TN': 45264},
  'O': {'TP': 38088, 'FP': 123, 'FN': 235, 'TN': 7989},
  'B-ORG': {'TP': 1523, 'FP': 154, 'FN': 138, 'TN': 44620},
  'B-PER': {'TP': 1560, 'FP': 54, 'FN': 57, 'TN': 44764},
  'I-LOC': {'TP': 237, 'FP': 42, 'FN': 20, 'TN': 46136}},
 'false_negatives': {'B-LOC': {'I-ORG': 2,
   'B-MISC': 25,
   'O': 24,
   'B-ORG': 57,
   'B-PER': 12},
  'I-ORG': {'B-LOC': 6,
   'B-MISC': 1,
   'I-MISC': 17,
   'I-PER': 5,
   'O': 38,
   'B-ORG': 10,
   'B-PER': 1,
   'I-LOC': 12},
  'B-MISC': {'B-LOC': 27, 'I-MISC': 5, 'O': 53, 'B-ORG': 34, 'B-PER': 2},
  'I-MISC': {'B-LOC': 2, 'I-ORG': 13, 'B-MISC': 11, 'I-PER': 1, 'O': 46},
  'I-PER': {'I-ORG': 4, 'I-MISC': 1, 'O': 3, 'B-PER': 5, 'I-LOC': 2},


In [60]:
manager = ErrorAnalysisManager(df)
manager.run_workflows()
results = manager.get_results()

In [61]:
len(results['IOB2']['false_negatives']['LOC']['Boundary'])

10

In [70]:
comparator = SchemeComparator(results)
component_comparison = comparator.compare_component("false_negatives", "LOC")
component_comparison
overall_comparison = comparator.compare_errors('false_negatives', 'Entity and Boundary')


In [71]:
overall_comparison

{'IOB Errors': {(294, 'MISC', 2, 3),
  (367, 'LOC', 0, 0),
  (385, 'LOC', 0, 0),
  (437, 'LOC', 0, 0),
  (466, 'LOC', 0, 0),
  (554, 'LOC', 3, 3),
  (554, 'ORG', 2, 2),
  (557, 'MISC', 8, 9),
  (559, 'LOC', 22, 22),
  (700, 'ORG', 3, 4),
  (873, 'MISC', 2, 3),
  (966, 'ORG', 14, 14),
  (1097, 'ORG', 2, 2),
  (1103, 'MISC', 0, 0),
  (1105, 'MISC', 16, 16),
  (1119, 'MISC', 2, 3),
  (1171, 'MISC', 12, 13),
  (1376, 'LOC', 15, 15),
  (1430, 'LOC', 2, 2),
  (1430, 'LOC', 4, 5),
  (1550, 'MISC', 5, 8),
  (1567, 'ORG', 7, 8),
  (1751, 'LOC', 1, 1),
  (1762, 'LOC', 9, 10),
  (1807, 'LOC', 5, 6),
  (1877, 'LOC', 31, 31),
  (2085, 'PER', 28, 29),
  (2132, 'LOC', 25, 25),
  (2166, 'MISC', 0, 0),
  (2184, 'PER', 1, 2),
  (2231, 'LOC', 7, 7),
  (2251, 'LOC', 1, 1),
  (2263, 'MISC', 1, 1),
  (2265, 'MISC', 0, 1),
  (2266, 'MISC', 0, 2),
  (2270, 'MISC', 0, 1),
  (2274, 'MISC', 0, 1),
  (2275, 'MISC', 0, 1),
  (2355, 'LOC', 19, 19),
  (2393, 'LOC', 14, 14),
  (2433, 'MISC', 5, 6),
  (2437, 'ORG', 15

In [74]:
manager.non_strict_analyzer.print_sentence(2068)

True: [(2068, 'PER', 0, 1), (2068, 'ORG', 8, 12), (2068, 'LOC', 14, 14)]
Pred: [(2068, 'PER', 0, 1), (2068, 'LOC', 8, 10), (2068, 'ORG', 11, 11), (2068, 'LOC', 14, 14)]
Error in Pred: {(2068, 'ORG', 11, 11), (2068, 'LOC', 8, 10)}
             Words  Sentence Ids True Labels Pred Labels Strict True Entities Strict Pred Entities True Entities Pred Entities
45899         Mike          2068       B-PER       B-PER                  PER                  PER           PER           PER
45900         Cito          2068       I-PER       I-PER                  PER                  PER           PER           PER
45902            ,          2068           O           O                    O                    O             O             O
45903           17          2068           O           O                    O                    O             O             O
45904            ,          2068           O           O                    O                    O             O             O
45905   

In [75]:
manager.strict_analyzer.print_sentence(2068)

True: [(2068, 'PER', 0, 1), (2068, 'ORG', 8, 12), (2068, 'LOC', 14, 14)]
Pred: [(2068, 'PER', 0, 1), (2068, 'LOC', 8, 10), (2068, 'LOC', 14, 14)]
Error in Pred: {(2068, 'LOC', 8, 10)}
             Words  Sentence Ids True Labels Pred Labels Strict True Entities Strict Pred Entities True Entities Pred Entities
45899         Mike          2068       B-PER       B-PER                  PER                  PER           PER           PER
45900         Cito          2068       I-PER       I-PER                  PER                  PER           PER           PER
45902            ,          2068           O           O                    O                    O             O             O
45903           17          2068           O           O                    O                    O             O             O
45904            ,          2068           O           O                    O                    O             O             O
45905          was          2068           O          

In [59]:
core_data = df[df['Labels']!=-100].copy()
y_true = core_data.groupby('Sentence Ids')['True Labels'].apply(list).tolist()
y_pred = core_data.groupby('Sentence Ids')['Pred Labels'].apply(list).tolist()

In [60]:
ids = 84
print(get_entities(y_true[ids]))
print(get_entities(y_pred[ids]))
print('######')
print(manager.strict_analyzer.adjust_end_index(Entities([y_true[ids]], IOB2, False)).entities)
print(manager.strict_analyzer.adjust_end_index(Entities([y_pred[ids]], IOB2, False)).entities)



[('PERS', 7, 8), ('PERS', 20, 20)]
[('PERS', 7, 8), ('PERS', 19, 19), ('PERS', 20, 20)]
######
[[(0, 'PERS', 7, 8), (0, 'PERS', 20, 20)]]
[[(0, 'PERS', 7, 8), (0, 'PERS', 19, 19), (0, 'PERS', 20, 20)]]


In [61]:
manager.strict_analyzer.adjust_end_index(Entities([y_true[ids]], IOB2, False)).entities

[[(0, 'PERS', 7, 8), (0, 'PERS', 20, 20)]]

True: [(84, 'PERS', 7, 8), (84, 'PERS', 20, 20)]
Pred: [(84, 'PERS', 7, 8), (84, 'PERS', 19, 19), (84, 'PERS', 20, 20)]
         Words  Sentence Ids True Labels Pred Labels Strict True Entities Strict Pred Entities True Entities Pred Entities
3063     [CLS]            84       [CLS]       [CLS]                [CLS]                [CLS]         [CLS]         [CLS]
3064    المكون            84           O           O                    O                    O             O             O
3065       عقد            84           O           O                    O                    O             O             O
3066    المنسق            84           O           O                    O                    O             O             O
3067    الأعلى            84           O           O                    O                    O             O             O
3068   للسياسة            84           O           O                    O                    O             O             O
3069  الخارجية     

True: [(84, 'PERS', 7, 8), (84, 'PERS', 20, 20)]
Pred: [(84, 'PERS', 7, 8), (84, 'PERS', 19, 19), (84, 'PERS', 20, 20)]
         Words  Sentence Ids True Labels Pred Labels Strict True Entities Strict Pred Entities True Entities Pred Entities
3063     [CLS]            84       [CLS]       [CLS]                [CLS]                [CLS]         [CLS]         [CLS]
3064    المكون            84           O           O                    O                    O             O             O
3065       عقد            84           O           O                    O                    O             O             O
3066    المنسق            84           O           O                    O                    O             O             O
3067    الأعلى            84           O           O                    O                    O             O             O
3068   للسياسة            84           O           O                    O                    O             O             O
3069  الخارجية     

# Error analysis pipeline

In [425]:
from abc import ABC, abstractmethod
from collections import defaultdict
from seqeval.scheme import Entities, IOB2
from seqeval.metrics.sequence_labeling import get_entities
pd.set_option("display.max_rows", None)  # Display all rows


class EntityErrorAnalyzer(ABC):
    """Abstract base class for entity analysis."""

    def __init__(self, df):
        self.df = df
        self.y_true, self.y_pred = self.prepare_data(df)
        self.true_entities = []
        self.pred_entities = []

    @abstractmethod
    def extract_entities(self, y_data):
        """Extract entities based on the specific mode (strict or non-strict)."""
        pass

    @abstractmethod
    def prepare_entities(self):
        """Prepare true and predicted entities for analysis."""
        pass
    
    def prepare_data(self, df):
        core_data = df[df['Labels'] !=-100]
        y_true = core_data.groupby('Sentence Ids')['True Labels'].apply(list).tolist()
        y_pred = core_data.groupby('Sentence Ids')['Pred Labels'].apply(list).tolist()
        return y_true, y_pred
    
    def compute_false_negatives(self, entity_type):
        """Compute false negatives for a specific entity type."""
        return set(
            [e for e in self.true_entities if e[1] == entity_type]
        ) - set([e for e in self.pred_entities if e[1] == entity_type])

    def compute_false_positives(self, entity_type):
        """Compute false positives for a specific entity type."""
        return set(
            [e for e in self.pred_entities if e[1] == entity_type]
        ) - set([e for e in self.true_entities if e[1] == entity_type])

    def analyze_sentence_errors(self, target_entities, comparison_entities):
        """Analyze errors and return sentence IDs by error type."""
        error_sentences = defaultdict(set)  # Dictionary to hold sentence IDs for each error type
        non_o_errors = set()
        indexed_entities = defaultdict(list)

        # Index comparison entities by sentence
        for entity in comparison_entities:
            sen, entity_type, start, end = entity
            indexed_entities[sen].append(entity)

        # First pass: entity errors
        for target_entity in target_entities:
            t_sen, t_type, t_start, t_end = target_entity

            for comp_entity in indexed_entities[t_sen]:
                c_type, c_start, c_end = comp_entity[1:]

                if (
                    t_start == c_start
                    and t_end == c_end
                    and t_type != c_type
                    and target_entity not in non_o_errors
                ):
                    non_o_errors.add(target_entity)
                    error_sentences["Entity"].add(target_entity)

        # Second pass: boundary errors
        for target_entity in target_entities - non_o_errors:
            t_sen, t_type, t_start, t_end = target_entity

            for comp_entity in indexed_entities[t_sen]:
                c_sen, c_type, c_start, c_end = comp_entity

                if (
                    t_type == c_type
                    and (t_start <= c_start <= t_end or t_start <= c_end <= t_end)
                    and target_entity not in non_o_errors
                ):
                    non_o_errors.add(target_entity)
                    error_sentences["Boundary"].add(target_entity)

        # Third pass: combined entity and boundary errors
        for target_entity in target_entities - non_o_errors:
            t_sen, t_type, t_start, t_end = target_entity

            for comp_entity in indexed_entities[t_sen]:
                c_sen, c_type, c_start, c_end = comp_entity

                if (
                    c_type != t_type
                    and (t_start <= c_start <= t_end or t_start <= c_end <= t_end)
                    and target_entity not in non_o_errors
                ):
                    non_o_errors.add(target_entity)
                    error_sentences["Entity and Boundary"].add(target_entity)
                    # print(t_sen, t_start, t_end, c_sen, c_start, c_end)
                    # print(f' ({t_start} <= {c_start} <= {t_end} or {t_start} <= {c_end} <= {t_end})')
                    

        # Remaining unmatched errors are "O errors"
        for target_entity in target_entities - non_o_errors:
            t_sen, t_type, t_start, t_end = target_entity
            error_sentences["O"].add(target_entity)

        return {error_type: list(s_ids) for error_type, s_ids in error_sentences.items()}


    def analyze_component(self, error_type, entity_type=None):
        
        """Analyze errors (FP or FN) for a specific or all entity types."""
        self.prepare_entities()
        error_analysis = {}
        entity_types = (
            [entity_type]
            if entity_type
            else set(e[1] for e in self.true_entities + self.pred_entities)
        )

        for etype in entity_types:
            if error_type == "false_negatives":
                target_entities = self.compute_false_negatives(etype)
            elif error_type == "false_positives":
                target_entities = self.compute_false_positives(etype)
            else:
                raise ValueError("Error type must be 'false_negative' or 'false_positive'.")

            error_analysis[etype] = self.analyze_sentence_errors(
                target_entities, self.pred_entities if error_type == "false_negatives" else self.true_entities
            )

        return error_analysis
    def analyze_errors(self):
        self.prepare_entities()
        """Analyze both false positives and false negatives."""
        error_components = {"false_positives": defaultdict(set), "false_negatives": defaultdict(set)}

        for error_component in error_components.keys():
            results = self.analyze_component(error_component)
            for entity_type, errors in results.items():
                for error_type, sentences in errors.items():
                    error_components[error_component][error_type].update(sentences)

        # Convert sets to lists for consistency
        return {k: {etype: set(ids) for etype, ids in v.items()} for k, v in error_components.items()}
    
    

class StrictEntityAnalyzer(EntityErrorAnalyzer):
    """Analyzer for strict entity processing."""

    def extract_entities(self, y_data):
        """Extract entities in strict mode."""
        return Entities(y_data, IOB2, False)

    def prepare_entities(self):
        """Prepare true and predicted entities for strict mode."""
        self.true_entities = self.flatten_entities(self.extract_entities(self.y_true))
        self.pred_entities = self.flatten_entities(self.extract_entities(self.y_pred))

    def print_sentence(self, sen_id):
        """Print entities for a specific sentence ID."""
        true_entities = self.extract_entities(self.y_true).entities
        pred_entities = self.extract_entities(self.y_pred).entities
        print(f"True: {true_entities[sen_id]}")
        print(f"Pred: {pred_entities[sen_id]}")
        sentence_data = self.df[self.df['Sentence Ids']  == sen_id].copy()
        print(sentence_data[['Words', 'Sentence Ids', 'True Labels', 'Pred Labels', 'Strict True Entities', 'Strict Pred Entities', 'True Entities', 'Pred Entities']].head(60).to_string())

    @staticmethod
    def flatten_entities(entities):
        """Flatten strict entities into tuples."""
        return [e.to_tuple() for sen in entities.entities for e in sen]
    
    
    
class NonStrictEntityAnalyzer(EntityErrorAnalyzer):
    """Analyzer for non-strict entity processing."""

    def extract_entities(self, y_data):
        """Extract entities in non-strict mode."""
        return [
            [(sen_id,) + entity for entity in get_entities(sen)]
            for sen_id, sen in enumerate(y_data)
        ]

    def prepare_entities(self):
        """Prepare true and predicted entities for non-strict mode."""
        self.true_entities = self.flatten_entities(self.extract_entities(self.y_true))
        self.pred_entities = self.flatten_entities(self.extract_entities(self.y_pred))

    def print_sentence(self, sen_id):
        """Print entities for a specific sentence ID."""
        true_entities = self.extract_entities(self.y_true)
        pred_entities = self.extract_entities(self.y_pred)
        print(f"True: {true_entities[sen_id]}")
        print(f"Pred: {pred_entities[sen_id]}")
        sentence_data = self.df[self.df['Sentence Ids']  == sen_id].copy()
        print(sentence_data[['Words', 'Sentence Ids', 'True Labels', 'Pred Labels', 'Strict True Entities', 'Strict Pred Entities', 'True Entities', 'Pred Entities']].head(60).to_string())
        
    @staticmethod
    def flatten_entities(entities):
        """Flatten non-strict entities into tuples."""
        return [e for sen in entities for e in sen]

class ErrorAnalysisManager:
    """Manages all error analysis workflows and stores results."""

    def __init__(self, df):
        """
        Initialize the manager with the dataset.

        Args:
            df (pd.DataFrame): The dataset containing y_true and y_pred.
        """
        self.df = df
        self.strict_analyzer = StrictEntityAnalyzer(df)
        self.non_strict_analyzer = NonStrictEntityAnalyzer(df)
        self.results = {
            "IOB2": {"false_negatives": None, "false_positives": None, "errors": None},
            "IOB": {"false_negatives": None, "false_positives": None, "errors": None},
        }

    def run_workflows(self):
        """Run all error analysis workflows."""
        self.results["IOB2"]["false_negatives"] = self.strict_analyzer.analyze_component("false_negatives")
        self.results["IOB2"]["false_positives"] = self.strict_analyzer.analyze_component("false_positives")
        self.results["IOB2"]["errors"] = self.strict_analyzer.analyze_errors()

        self.results["IOB"]["false_negatives"] = self.non_strict_analyzer.analyze_component("false_negatives")
        self.results["IOB"]["false_positives"] = self.non_strict_analyzer.analyze_component("false_positives")
        self.results["IOB"]["errors"] = self.non_strict_analyzer.analyze_errors()

    def get_results(self):
        """Get the results of all workflows."""
        return self.results

class SchemeComparator:
    """Facilitator for comparing annotation schemes."""

    def __init__(self, results):
        """
        Initialize the comparator with results from error analysis.

        Args:
            results (dict): Results from the manager's workflows, structured by scheme.
        """
        self.results = results

    def compare_component(self, component, entity_type):
        """
        Compare all error types for a specific entity across schemes.

        Args:
            entity_type (str): The entity type to compare (e.g., "MISC").

        Returns:
            dict: A dictionary with set operation results for all error types.
        """
        schemes = list(self.results.keys())
        if len(schemes) != 2:
            raise ValueError("Comparator requires exactly two schemes for comparison.")

        scheme_1, scheme_2 = schemes
        component_1 = self.results[scheme_1][component]
        component_2 = self.results[scheme_2][component]

        results = {}
        entity_1 = component_1.get(entity_type, {})
        entity_2 = component_2.get(entity_type, {})

        # Compare all error types under the given entity
        all_error_types = set(entity_1.keys()).union(set(entity_2.keys()))
        for error_type in all_error_types:
            set_1 = set(entity_1.get(error_type, []))
            set_2 = set(entity_2.get(error_type, []))

            results[error_type] = {
                "overlap": set_1 & set_2,
                f"{scheme_1} Only": set_1 - set_2,
                f"{scheme_2} Only": set_2 - set_1,
            }

        return results

    def compare_errors(self, component, error_type):
        """
        Compare errors across all entities and error types for both schemes.

        Returns:
            dict: A dictionary with set operation results for all error types.
        """
        schemes = list(self.results.keys())
        print(schemes)
        if len(schemes) != 2:
            raise ValueError("Comparator requires exactly two schemes for comparison.")

        schemes_map = {'scheme_1': 'IOB', 'scheme_2': 'IOB2'}
        errors_1 = self.results[schemes_map['scheme_1']]["errors"][component]
        errors_2 = self.results[schemes_map['scheme_2']]["errors"][component]

       
       
        comparison_result = ComparisonResult.from_lists(errors_1, errors_2, error_type, schemes_map)

        return comparison_result.to_dict()


from dataclasses import dataclass, field
from typing import List, Dict, Set

@dataclass
class ComparisonResult:
    """Dataclass to store comparison results."""
    scheme_1_name: str
    scheme_2_name: str
    set_1_errors: Set[int] = field(default=set)
    set_2_errors: Set[int] = field(default=set)
    overlap: Set[int] = field(default_factory=set)
    scheme_1_only: Set[int] = field(default_factory=set)
    scheme_2_only: Set[int] = field(default_factory=set)

    @staticmethod
    def from_lists(errors_1: Dict, errors_2: Dict, error_type: str, schemes_map: Dict) -> "ComparisonResult":
        """
        Create a ComparisonResult from two lists.

        Args:
            lst_1: List of values from scheme 1.
            lst_2: List of values from scheme 2.

        Returns:
            ComparisonResult: Dataclass containing the comparison and statistics.
        """
        set_1 = set(errors_1.get(error_type, []))
        
        set_2 = set(errors_2.get(error_type, []))
        
        sentence_lst_1 = [error[0] for error in errors_1.get(error_type, [])]
        sentence_lst_2 = [error[0] for error in errors_2.get(error_type, [])]
        sentence_set_1 = set(sentence_lst_1)
        sentence_set_2 = set(sentence_lst_2)
        
        overlap = sentence_set_1 & sentence_set_2
        scheme_1_only = sentence_set_1 - sentence_set_2
        scheme_2_only = sentence_set_2 - sentence_set_1

        return ComparisonResult(
            scheme_1_name=schemes_map['scheme_1'],
            scheme_2_name=schemes_map['scheme_2'],
            set_1_errors= set_1,
            set_2_errors= set_2,
            overlap=overlap,
            scheme_1_only=scheme_1_only,
            scheme_2_only=scheme_2_only,
        )
        
    def to_dict(self) -> Dict[str, Dict[str, Set[int]]]:
        """R"Overlap": self.overlap, comparison results as a dictionary."""
        return {
            f"{self.scheme_1_name} Errors": self.set_1_errors,
            f"{self.scheme_2_name} Errors": self.set_2_errors,
            "Overlap": self.overlap,
            f"{self.scheme_1_name} Only Errors": self.scheme_1_only,
            f"{self.scheme_2_name} Only Errors": self.scheme_2_only,
        }





In [426]:
manager = ErrorAnalysisManager(df)
manager.run_workflows()
results = manager.get_results()

In [427]:
set([id[0] for id in results['IOB2']['errors']['false_negatives']['Entity and Boundary']]) - set([id[0] for id in results['IOB']['errors']['false_negatives']['Entity and Boundary']])

{65,
 124,
 164,
 177,
 183,
 188,
 190,
 191,
 192,
 194,
 195,
 199,
 238,
 240,
 243,
 253,
 260,
 318,
 473,
 484,
 490,
 561,
 665,
 678,
 923,
 943}

In [428]:
results['IOB2']['errors']['false_negatives']['Entity and Boundary'] 

{(20, 'ORG', 41, 42),
 (65, 'MISC', 0, 1),
 (124, 'ORG', 26, 30),
 (164, 'MISC', 28, 30),
 (166, 'MISC', 7, 9),
 (175, 'MISC', 4, 5),
 (175, 'ORG', 3, 4),
 (175, 'PERS', 1, 2),
 (177, 'MISC', 5, 6),
 (179, 'MISC', 6, 7),
 (179, 'ORG', 5, 6),
 (180, 'MISC', 5, 6),
 (180, 'ORG', 4, 5),
 (181, 'MISC', 5, 6),
 (181, 'ORG', 4, 5),
 (183, 'LOC', 2, 3),
 (183, 'MISC', 4, 5),
 (184, 'MISC', 5, 8),
 (188, 'PERS', 2, 3),
 (190, 'PERS', 2, 3),
 (191, 'PERS', 2, 3),
 (192, 'MISC', 5, 8),
 (194, 'MISC', 3, 6),
 (195, 'MISC', 4, 7),
 (196, 'PERS', 1, 3),
 (199, 'MISC', 0, 1),
 (232, 'LOC', 0, 1),
 (232, 'PERS', 19, 21),
 (238, 'LOC', 0, 1),
 (240, 'LOC', 0, 1),
 (243, 'PERS', 28, 30),
 (250, 'LOC', 0, 1),
 (250, 'MISC', 12, 15),
 (253, 'MISC', 4, 7),
 (260, 'LOC', 0, 1),
 (318, 'ORG', 1, 5),
 (320, 'ORG', 17, 18),
 (473, 'LOC', 0, 1),
 (484, 'MISC', 4, 9),
 (490, 'ORG', 0, 1),
 (561, 'ORG', 20, 24),
 (584, 'LOC', 4, 5),
 (586, 'MISC', 7, 10),
 (618, 'MISC', 25, 28),
 (619, 'MISC', 12, 15),
 (625, 'M

In [429]:
overall_comparison['IOB2 Errors']

KeyError: 'IOB2 Errors'

In [417]:
comparator = SchemeComparator(results)
component_comparison = comparator.compare_component("false_negatives", "LOC")
component_comparison

{'Entity and Boundary': {'overlap': set(),
  'IOB2 Only': {(183, 'LOC', 2, 3),
   (232, 'LOC', 0, 1),
   (238, 'LOC', 0, 1),
   (240, 'LOC', 0, 1),
   (250, 'LOC', 0, 1),
   (260, 'LOC', 0, 1),
   (473, 'LOC', 0, 1),
   (584, 'LOC', 4, 5),
   (678, 'LOC', 7, 8),
   (693, 'LOC', 2, 3),
   (695, 'LOC', 30, 31)},
  'IOB Only': {(584, 'LOC', 4, 4), (693, 'LOC', 2, 2), (695, 'LOC', 30, 30)}},
 'Boundary': {'overlap': set(),
  'IOB2 Only': {(0, 'LOC', 0, 1),
   (20, 'LOC', 40, 41),
   (44, 'LOC', 8, 10),
   (75, 'LOC', 12, 14),
   (192, 'LOC', 3, 5),
   (312, 'LOC', 19, 21),
   (320, 'LOC', 16, 17),
   (614, 'LOC', 18, 20),
   (915, 'LOC', 1, 2),
   (917, 'LOC', 1, 2),
   (948, 'LOC', 20, 22)},
  'IOB Only': {(20, 'LOC', 40, 40),
   (44, 'LOC', 8, 9),
   (75, 'LOC', 12, 13),
   (192, 'LOC', 3, 4),
   (312, 'LOC', 19, 20),
   (320, 'LOC', 16, 16),
   (614, 'LOC', 18, 19),
   (915, 'LOC', 1, 1),
   (917, 'LOC', 1, 1),
   (948, 'LOC', 20, 21)}},
 'Entity': {'overlap': set(),
  'IOB2 Only': {(95

In [418]:
overall_comparison = comparator.compare_errors('false_negatives', 'Entity and Boundary')


['IOB2', 'IOB']
Counter({175: 2, 180: 2, 181: 2, 179: 2, 35: 1, 166: 1, 695: 1, 184: 1, 845: 1, 625: 1, 921: 1, 196: 1, 250: 1, 586: 1, 693: 1, 857: 1, 915: 1, 879: 1, 20: 1, 630: 1, 917: 1, 619: 1, 584: 1, 320: 1, 618: 1, 232: 1})
Counter({175: 3, 179: 2, 181: 2, 232: 2, 180: 2, 250: 2, 183: 2, 191: 1, 619: 1, 20: 1, 917: 1, 584: 1, 190: 1, 693: 1, 164: 1, 618: 1, 473: 1, 166: 1, 195: 1, 184: 1, 921: 1, 196: 1, 177: 1, 253: 1, 194: 1, 65: 1, 665: 1, 490: 1, 243: 1, 923: 1, 630: 1, 192: 1, 124: 1, 238: 1, 320: 1, 943: 1, 260: 1, 561: 1, 199: 1, 240: 1, 318: 1, 625: 1, 678: 1, 695: 1, 845: 1, 484: 1, 586: 1, 188: 1, 915: 1, 857: 1, 879: 1})


In [430]:
overall_comparison

{'MISC': {'Entity and Boundary': {'overlap': set(),
   'in_scheme_1_not_in_scheme_2': {(65, 'MISC', 0, 1),
    (164, 'MISC', 28, 30),
    (166, 'MISC', 7, 9),
    (175, 'MISC', 4, 5),
    (177, 'MISC', 5, 6),
    (179, 'MISC', 6, 7),
    (180, 'MISC', 5, 6),
    (181, 'MISC', 5, 6),
    (183, 'MISC', 4, 5),
    (184, 'MISC', 5, 8),
    (192, 'MISC', 5, 8),
    (194, 'MISC', 3, 6),
    (195, 'MISC', 4, 7),
    (199, 'MISC', 0, 1),
    (250, 'MISC', 12, 15),
    (253, 'MISC', 4, 7),
    (484, 'MISC', 4, 9),
    (586, 'MISC', 7, 10),
    (618, 'MISC', 25, 28),
    (619, 'MISC', 12, 15),
    (625, 'MISC', 26, 29),
    (630, 'MISC', 10, 17),
    (857, 'MISC', 7, 11),
    (879, 'MISC', 2, 4),
    (921, 'MISC', 18, 20),
    (923, 'MISC', 6, 7)},
   'in_scheme_2_not_in_scheme_1': {(845, 'MISC', 0, 1)}},
  'Boundary': {'overlap': {(488, 'MISC', 7, 10)},
   'in_scheme_1_not_in_scheme_2': {(163, 'MISC', 14, 16),
    (166, 'MISC', 26, 27),
    (167, 'MISC', 8, 10),
    (171, 'MISC', 48, 54),
    (

# Development Error Analysis

In [420]:
strict_analyzer = StrictEntityAnalyzer(df)
strict_errors = strict_analyzer.analyze_component("false_negatives")


non_strict_analyzer = NonStrictEntityAnalyzer(df)
non_strict_errors = non_strict_analyzer.analyze_component("false_positives")

In [421]:
entity = 'LOC'
print(dash_data['ANERCorp_CamelLab_arabertv02'].entity_non_strict_confusion_data['false_negatives'][entity])
print(dash_data['ANERCorp_CamelLab_arabertv02'].entity_strict_confusion_data['false_negatives'][entity])

{'O': 29, 'Boundary': 10, 'PERS': 1, 'MISC': 3, 'ORG': 3, 'Entity and Boundary': 3}
{'MISC': 3, 'ORG': 3, 'PERS': 1, 'Boundary': 11, 'Entity and Boundary': 11, 'O': 15}


In [422]:
print(dash_data['ANERCorp_CamelLab_arabertv02'].entity_non_strict_confusion_data['false_negatives'])

{'LOC': {'O': 29, 'Boundary': 10, 'PERS': 1, 'MISC': 3, 'ORG': 3, 'Entity and Boundary': 3}, 'PERS': {'O': 50, 'MISC': 5, 'Boundary': 75, 'ORG': 13, 'LOC': 6, 'Entity and Boundary': 5}, 'ORG': {'O': 48, 'Entity and Boundary': 10, 'Boundary': 26, 'MISC': 6, 'PERS': 11, 'LOC': 20}, 'MISC': {'O': 46, 'Entity and Boundary': 14, 'LOC': 6, 'Boundary': 17, 'ORG': 9}}


In [423]:
comparator = EntitySchemeComparator(strict_errors, non_strict_errors)

# Compare all error types for a specific entity
misc_comparison = comparator.compare_entity("MISC")
print("Comparison for MISC:", misc_comparison)

# Compare errors across all entities
overall_comparison = comparator.compare_overall()
print("Overall Comparison:", overall_comparison)

Comparison for MISC: {'Entity and Boundary': {'overlap': set(), 'in_scheme_1_not_in_scheme_2': {(923, 'MISC', 6, 7), (619, 'MISC', 12, 15), (181, 'MISC', 5, 6), (630, 'MISC', 10, 17), (192, 'MISC', 5, 8), (175, 'MISC', 4, 5), (179, 'MISC', 6, 7), (164, 'MISC', 28, 30), (618, 'MISC', 25, 28), (199, 'MISC', 0, 1), (166, 'MISC', 7, 9), (195, 'MISC', 4, 7), (625, 'MISC', 26, 29), (184, 'MISC', 5, 8), (921, 'MISC', 18, 20), (177, 'MISC', 5, 6), (180, 'MISC', 5, 6), (253, 'MISC', 4, 7), (194, 'MISC', 3, 6), (250, 'MISC', 12, 15), (65, 'MISC', 0, 1), (484, 'MISC', 4, 9), (586, 'MISC', 7, 10), (857, 'MISC', 7, 11), (879, 'MISC', 2, 4), (183, 'MISC', 4, 5)}, 'in_scheme_2_not_in_scheme_1': {(845, 'MISC', 0, 1)}}, 'Boundary': {'overlap': {(488, 'MISC', 7, 10)}, 'in_scheme_1_not_in_scheme_2': {(626, 'MISC', 39, 41), (851, 'MISC', 34, 36), (379, 'MISC', 17, 20), (167, 'MISC', 8, 10), (649, 'MISC', 25, 29), (486, 'MISC', 13, 17), (163, 'MISC', 14, 16), (171, 'MISC', 48, 54), (645, 'MISC', 15, 19), (

In [203]:
strict_analyzer = StrictEntityAnalyzer(df)
strict_errors = strict_analyzer.analyze_errors()


non_strict_analyzer = NonStrictEntityAnalyzer(df)
non_strict_errors = non_strict_analyzer.analyze_errors()

In [204]:
non_strict_errors['false_negatives']['Boundary'] - strict_errors['false_negatives']['Boundary']  


{(20, 'LOC', 40, 40),
 (33, 'ORG', 3, 6),
 (36, 'ORG', 3, 5),
 (37, 'ORG', 39, 41),
 (40, 'ORG', 29, 30),
 (42, 'ORG', 130, 131),
 (44, 'LOC', 8, 9),
 (44, 'PERS', 3, 5),
 (48, 'ORG', 29, 30),
 (75, 'LOC', 12, 13),
 (81, 'PERS', 4, 4),
 (90, 'PERS', 39, 42),
 (91, 'ORG', 26, 27),
 (98, 'PERS', 19, 23),
 (106, 'PERS', 11, 12),
 (106, 'PERS', 13, 13),
 (112, 'ORG', 10, 12),
 (120, 'ORG', 6, 7),
 (124, 'ORG', 26, 29),
 (135, 'ORG', 5, 7),
 (163, 'MISC', 14, 15),
 (166, 'MISC', 26, 26),
 (166, 'ORG', 23, 23),
 (167, 'MISC', 8, 9),
 (171, 'MISC', 48, 53),
 (171, 'PERS', 12, 12),
 (171, 'PERS', 13, 13),
 (171, 'PERS', 33, 33),
 (171, 'PERS', 34, 34),
 (192, 'LOC', 3, 4),
 (202, 'PERS', 1, 2),
 (205, 'PERS', 6, 6),
 (205, 'PERS', 7, 7),
 (206, 'PERS', 9, 9),
 (206, 'PERS', 10, 10),
 (207, 'PERS', 2, 2),
 (207, 'PERS', 3, 3),
 (207, 'PERS', 7, 7),
 (207, 'PERS', 8, 8),
 (208, 'PERS', 2, 2),
 (208, 'PERS', 3, 3),
 (208, 'PERS', 6, 6),
 (208, 'PERS', 7, 7),
 (209, 'PERS', 1, 1),
 (213, 'PERS', 3

In [207]:
strict_errors['false_negatives']['Entity and Boundary']

{(20, 'ORG', 41, 42),
 (65, 'MISC', 0, 1),
 (124, 'ORG', 26, 30),
 (164, 'MISC', 28, 30),
 (166, 'MISC', 7, 9),
 (175, 'MISC', 4, 5),
 (175, 'ORG', 3, 4),
 (175, 'PERS', 1, 2),
 (177, 'MISC', 5, 6),
 (179, 'MISC', 6, 7),
 (179, 'ORG', 5, 6),
 (180, 'MISC', 5, 6),
 (180, 'ORG', 4, 5),
 (181, 'MISC', 5, 6),
 (181, 'ORG', 4, 5),
 (183, 'LOC', 2, 3),
 (183, 'MISC', 4, 5),
 (184, 'MISC', 5, 8),
 (188, 'PERS', 2, 3),
 (190, 'PERS', 2, 3),
 (191, 'PERS', 2, 3),
 (192, 'MISC', 5, 8),
 (194, 'MISC', 3, 6),
 (195, 'MISC', 4, 7),
 (196, 'PERS', 1, 3),
 (199, 'MISC', 0, 1),
 (232, 'LOC', 0, 1),
 (232, 'PERS', 19, 21),
 (238, 'LOC', 0, 1),
 (240, 'LOC', 0, 1),
 (243, 'PERS', 28, 30),
 (250, 'LOC', 0, 1),
 (250, 'MISC', 12, 15),
 (253, 'MISC', 4, 7),
 (260, 'LOC', 0, 1),
 (318, 'ORG', 1, 5),
 (320, 'ORG', 17, 18),
 (473, 'LOC', 0, 1),
 (484, 'MISC', 4, 9),
 (490, 'ORG', 0, 1),
 (561, 'ORG', 20, 24),
 (584, 'LOC', 4, 5),
 (586, 'MISC', 7, 10),
 (618, 'MISC', 25, 28),
 (619, 'MISC', 12, 15),
 (625, 'M

In [210]:
strict_analyzer.print_sentence(20)

True: [(20, LOC, 0, 1), (20, LOC, 33, 34), (20, LOC, 40, 41), (20, ORG, 41, 42), (20, ORG, 49, 50), (20, PERS, 89, 90), (20, LOC, 95, 96), (20, PERS, 115, 117), (20, PERS, 129, 131), (20, LOC, 139, 140)]
Pred: [(20, LOC, 0, 1), (20, LOC, 33, 34), (20, LOC, 40, 42), (20, ORG, 49, 50), (20, LOC, 95, 96), (20, PERS, 115, 117), (20, PERS, 129, 131), (20, LOC, 139, 140)]
         Words  Sentence Ids True Labels Pred Labels Strict True Entities Strict Pred Entities True Entities Pred Entities
718      [CLS]            20       [CLS]       [CLS]                [CLS]                [CLS]         [CLS]         [CLS]
719    إسرائيل            20       B-LOC       B-LOC                  LOC                  LOC           LOC           LOC
720         من            20           O           O                    O                    O             O             O
721      نتائج            20           O           O                    O                    O             O             O
722        ذلك  

In [66]:

sen_id = 20
print(Entities([y_true[sen_id]], IOB2, False).entities)
print(Entities([y_pred[sen_id]], IOB2, False).entities)


[[(0, LOC, 0, 1), (0, LOC, 33, 34), (0, LOC, 40, 41), (0, ORG, 41, 42), (0, ORG, 49, 50), (0, PERS, 89, 90), (0, LOC, 95, 96), (0, PERS, 115, 117), (0, PERS, 129, 131), (0, LOC, 139, 140)]]
[[(0, LOC, 0, 1), (0, LOC, 33, 34), (0, LOC, 40, 42), (0, ORG, 49, 50), (0, LOC, 95, 96), (0, PERS, 115, 117), (0, PERS, 129, 131), (0, LOC, 139, 140)]]


In [840]:
Entities([y_true[124]], IOB2, False).entities

[[(0, LOC, 0, 1),
  (0, ORG, 5, 6),
  (0, LOC, 7, 8),
  (0, LOC, 18, 19),
  (0, LOC, 20, 22),
  (0, ORG, 26, 30),
  (0, ORG, 33, 34)]]

In [841]:
Entities([y_pred[124]], IOB2, False).entities

[[(0, ORG, 5, 6),
  (0, LOC, 7, 8),
  (0, LOC, 18, 19),
  (0, LOC, 20, 22),
  (0, PERS, 27, 29),
  (0, ORG, 33, 34)]]

[[(0, MISC, 0, 1), (0, PERS, 1, 2), (0, ORG, 8, 9), (0, PERS, 16, 17)]]
[[(0, MISC, 0, 1), (0, LOC, 4, 5), (0, ORG, 8, 9), (0, PERS, 16, 17)]]


In [None]:
strict_errors

In [None]:
from seqeval.scheme import auto_detect, Entities
from seqeval.metrics.sequence_labeling import get_entities
from collections import Counter

true_entities = get_entities(y_true)
scheme = auto_detect(y_true, False)
entities = Entities(y_true, scheme, False)
print(Counter([entity[0] for entity in true_entities]))
print(Counter([entity.to_tuple()[1] for sen in entities.entities for entity in sen]))

Counter({'PERS': 905, 'LOC': 676, 'ORG': 459, 'MISC': 243})
Counter({'PERS': 858, 'LOC': 668, 'ORG': 450, 'MISC': 235})


In [678]:
strict_error_sentences['Entity and Boundary']

[260, 678, 584, 232, 473, 238, 240, 693, 695, 183, 250]

In [677]:
non_strict_error_sentences['Entity and Boundary']

[584, 693, 695]

In [618]:
df = dash_data['ANERCorp_CamelLab_arabertv02'].analysis_data
core_df = df[df['Labels']!= -100]
y_true = core_df.groupby('Sentence Ids')['True Labels'].apply(list)
y_pred = core_df.groupby('Sentence Ids')['Pred Labels'].apply(list)

In [631]:
df = dash_data['ANERCorp_CamelLab_arabertv02'].analysis_data
pd.set_option('display.max_rows', 300)
misc = df[df['Strict True Entities'] == 'MISC']
df[~(df['True Aligned Scheme']) | ~(df['Pred Aligned Scheme'])][['Words', 'Sentence Ids', 'True Labels', 'Pred Labels', 'Strict True Entities', 'Strict Pred Entities', 'True Entities', 'Pred Entities', 'Error Type']]
# df[(df['True Entities'] != 'O') & (df['Pred Entities'] == 'O') & (df['Strict True Entities'] != df['True Entities'])][['Words', 'Sentence Ids', 'True Labels', 'Pred Labels', 'Strict True Entities', 'Strict Pred Entities', 'True Entities', 'Pred Entities', 'Error Type']]
df[df['Sentence Ids']  == 250][['Words', 'Sentence Ids', 'True Labels', 'Pred Labels', 'Strict True Entities', 'Strict Pred Entities', 'True Entities', 'Pred Entities']]
# misc[~(misc['Pred Aligned Scheme'])][['Words', 'Sentence Ids', 'True Labels', 'Pred Labels', 'Strict True Entities', 'Strict Pred Entities', 'True Entities', 'Pred Entities']]
# misc[(misc['Error Type'] != 'No Errors')][['Words', 'Sentence Ids', 'True Labels', 'Pred Labels', 'Strict True Entities', 'Strict Pred Entities', 'True Entities', 'Pred Entities']]

Unnamed: 0,Words,Sentence Ids,True Labels,Pred Labels,Strict True Entities,Strict Pred Entities,True Entities,Pred Entities
7950,[CLS],250,[CLS],[CLS],[CLS],[CLS],[CLS],[CLS]
7951,أفريقيا,250,B-LOC,O,LOC,O,LOC,O
7952,بوتين,250,B-PERS,B-PERS,PERS,PERS,PERS,PERS
7953,وميركل,250,B-PERS,B-PERS,PERS,PERS,PERS,PERS
7954,وميركل,250,IGNORED,IGNORED,IGNORED,IGNORED,IGNORED,IGNORED
7955,محادثات,250,O,O,O,O,O,O
7956,صباحا,250,O,O,O,O,O,O
7957,قبل,250,O,O,O,O,O,O
7958,المشاركة,250,O,O,O,O,O,O
7959,بعد,250,O,O,O,O,O,O


In [746]:
entity = 'LOC'
print(dash_data['ANERCorp_CamelLab_arabertv02'].entity_non_strict_confusion_data['false_positives'][entity])
print(dash_data['ANERCorp_CamelLab_arabertv02'].entity_strict_confusion_data['false_positives'][entity])

{'O': 15, 'Boundary': 14, 'Entity and Boundary': 15, 'MISC': 6, 'PERS': 6, 'ORG': 20}
{'ORG': 19, 'MISC': 4, 'PERS': 6, 'Boundary': 14, 'Entity and Boundary': 13, 'O': 19}


# Debug Confusion

In [5]:
dataset_name = 'ANERCorp_CamelLab'
model_name = 'arabertv02'
base_path = Path(f"/Users/ay227/Library/CloudStorage/GoogleDrive-ahmed.younes.sam@gmail.com/My Drive/Final Year Experiments/Thesis-Experiments/Experiments/BaseLineExperiment/{dataset_name}_{model_name}/extractions")
df = pd.read_json(
	base_path / 'results/non_strict_entity_misclassifications.json',
	orient='index'
)

In [None]:
entity = 'LOC'
print(df[entity].sum()) 
print(df.loc[entity].sum())


In [None]:
false_positives': {'LOC': {'O': 15,
   'Boundary': 14,
   'Entity and Boundary': 15,
   'MISC': 6,
   'PERS': 6,
   'ORG': 20},

In [None]:
32+2+3+5

In [25]:
ad = dash_data['ANERCorp_CamelLab_arabertv02'].analysis_data

In [None]:
fad = ad[ad['Labels'] !=-100].copy()
loc_fad = fad[fad['Pred Entities'] == 'LOC'].copy()
loc_fad[loc_fad['True Entities'] == 'ORG']

In [None]:
fad[(fad['True Labels'] == 'I-LOC') & (fad['Pred Labels'] == 'B-LOC') | (fad['True Labels'] == 'B-LOC') & (fad['Pred Labels'] == 'I-LOC')]

In [None]:
fad[(fad['True Labels'] == 'I-LOC') & (fad['Pred Labels'] == 'O') ]

In [None]:
fad[(fad['True Labels'] == 'B-LOC') & (fad['Pred Labels'] == 'O') ]

In [None]:
fad[fad['Sentence Ids']  == 315]

In [None]:
false_negatives': {'LOC': {'O': 29,
   'Boundary': 10,
   'PERS': 1,
   'MISC': 3,
   'ORG': 3,
   'Entity and Boundary': 3},

# Development

report

In [None]:
report_data = []
for data, values in dash_data.items():
    entity_report = values.entity_non_strict_report
    entity_strict_report = values.entity_strict_report
    entity_report['Dataset'] = data
    entity_report['Scheme'] = 'IOB'
    entity_strict_report['Dataset'] = data
    entity_strict_report['Scheme'] = 'IOB2'
    report_data.append(pd.concat([
		entity_report, 
		entity_strict_report
	]))
df = pd.concat(report_data)


tag_mapping = {
    'PERS': 'PER'
}

dataset_mapping = {
    'ANERCorp_CamelLab_arabertv02': 'ANERCorp',
    'conll2003_bert': 'CoNLL-2003'
}

df['Tag'] = df['Tag'].replace(tag_mapping)
df['Dataset'] = df['Dataset'].replace(dataset_mapping)
# df = df[[col for col in df.columns if col != 'F1']].copy()
entity_report_data = df[~df['Tag'].isin(['micro', 'macro', 'weighted'])]

entity_report_data



# Assuming 'df_long' is your DataFrame reshaped
df_long = entity_report_data.melt(id_vars=["Tag", "Support", "Dataset", "Scheme"], 
                  value_vars=["Precision", "Recall"], 
                  var_name="Metric", value_name="Value")
df_long['Value'] = df_long['Value'].round(2)
# Creating the faceted bar plot using the reshaped data
fig = px.bar(df_long, x="Tag", y="Value",
             facet_row="Scheme", facet_col="Dataset",
             title="Precision and Recall Scores by Tag, Dataset, and Scheme",
             labels={"Value": "Score"},
             color="Metric", barmode="group",
             template="plotly_white",
             facet_row_spacing=0.15,  # Adjust to a higher value for more space
             facet_col_spacing=0.1,  # Adjust to a higher value for more space
             text='Value',  # Display the Value on top of each bar
            #  text_auto='.2s'  # Automatically format the text with 2 significant digits
             )
# fig.update_traces(texttemplate='%{text:.2f}', textposition='outside')



# Update layout for better clarity
# fig.update_layout(
#     plot_bgcolor='rgba(255,255,255,1)',  # Ensure the background is white
#     paper_bgcolor='rgba(255,255,255,1)',  # Ensure the paper background is also white
# )

fig.show()


confusion heatmap

In [None]:
matrix_data = []
for data, values in dash_data.items():
    entity_matrix = pd.DataFrame(values.entity_non_strict_confusion_data['confusion_matrix']).T 
    entity_strict_matrix = pd.DataFrame(values.entity_strict_confusion_data['confusion_matrix']).T
    entity_matrix['Dataset'] = data
    entity_matrix['Scheme'] = 'IOB'
    entity_strict_matrix['Dataset'] = data
    entity_strict_matrix['Scheme'] = 'IOB2'
    matrix_data.append(pd.concat([
		entity_matrix, 
		entity_strict_matrix
	]))
    
matrix_df = pd.concat(matrix_data)
matrix_df.reset_index(inplace=True)
matrix_df.rename(columns={'index': 'Tag'}, inplace=True)
tag_mapping = {
    'PERS': 'PER'
}

dataset_mapping = {
    'ANERCorp_CamelLab_arabertv02': 'ANERCorp',
    'conll2003_bert': 'CoNLL-2003'
}

matrix_df['Tag'] = matrix_df['Tag'].replace(tag_mapping)
matrix_df['Dataset'] = matrix_df['Dataset'].replace(dataset_mapping)
# # df = df[[col for col in df.columns if col != 'F1']].copy()
# entity_report_data = df[~df['Tag'].isin(['micro', 'macro', 'weighted'])]
matrix_df



# Melt the DataFrame to align with Plotly requirements
df_final = matrix_df.melt(id_vars=['Tag', 'Dataset', 'Scheme'], value_vars=['TP', 'FP', 'FN'], 
                    var_name='Metric', value_name='Count')



# Assuming 'df_final' is your DataFrame prepared for plotting
unique_schemes = df_final['Scheme'].unique()
unique_datasets = df_final['Dataset'].unique()

# Create subplot configuration
fig = make_subplots(rows=len(unique_schemes), cols=len(unique_datasets),
                    subplot_titles=[f"{dataset} - {scheme}" for scheme in unique_schemes for dataset in unique_datasets],
                    shared_yaxes=True, horizontal_spacing=0.02, vertical_spacing=0.1)

# Determine the range of values to set a common scale
max_value = df_final['Count'].max()

# Add heatmaps
for idx, scheme in enumerate(unique_schemes):
    for jdx, dataset in enumerate(unique_datasets):
        filtered_data = df_final[(df_final['Scheme'] == scheme) & (df_final['Dataset'] == dataset)]
        heatmap_data = filtered_data.pivot_table(index='Metric', columns='Tag', values='Count', fill_value=0)
        text_data = filtered_data.pivot_table(index='Metric', columns='Tag', values='Count', fill_value=0).astype(int)

        
        
        fig.add_trace(
            go.Heatmap(
                z=heatmap_data,
                x=heatmap_data.columns,
                y=heatmap_data.index,
                colorscale='RdBu_r',
                coloraxis="coloraxis",  # Use a unified color axis
                text=text_data,  # Add text annotations
                texttemplate="%{text}",  # Use the text values directly
                hovertemplate="Metric: %{y}<br>Tag: %{x}<br>Count: %{text}<extra></extra>",
            ),
            row=idx + 1, col=jdx + 1
        )

# Update layout with a unified color scale and adjust the color bar
fig.update_layout(
    coloraxis=dict(colorscale='RdBu_r', cmin=0, cmax=max_value, colorbar=dict(title="Counts")),
    title_text="Confusion Matrix Metrics by Tag, Dataset, and Scheme",
    template="plotly_white",
    height=700, width=700,
)

fig.show()



bar confusion

In [None]:
matrix_data = []
for data, values in dash_data.items():
    entity_matrix = pd.DataFrame(values.entity_non_strict_confusion_data['confusion_matrix']).T 
    entity_strict_matrix = pd.DataFrame(values.entity_strict_confusion_data['confusion_matrix']).T
    entity_matrix['Dataset'] = data
    entity_matrix['Scheme'] = 'IOB'
    entity_strict_matrix['Dataset'] = data
    entity_strict_matrix['Scheme'] = 'IOB2'
    matrix_data.append(pd.concat([
		entity_matrix, 
		entity_strict_matrix
	]))
    
matrix_df = pd.concat(matrix_data)
matrix_df.reset_index(inplace=True)
matrix_df.rename(columns={'index': 'Tag'}, inplace=True)

matrix_df['Tag'] = matrix_df['Tag'].replace(tag_mapping)
matrix_df['Dataset'] = matrix_df['Dataset'].replace(dataset_mapping)
# Assuming 'matrix_df' has been set up as your DataFrame
df = matrix_df.copy()

# Calculate the sum of TP, FP, and FN per group for normalization purposes
grouped = df.groupby(['Tag', 'Dataset', 'Scheme']).sum()
grouped['Total'] = grouped['TP'] + grouped['FP'] + grouped['FN']

# Merge the total back into the original DataFrame
df = df.merge(grouped['Total'], on=['Tag', 'Dataset', 'Scheme'], how='left')

# Store actual counts for displaying as text
df['TP_Count'] = df['TP']
df['FP_Count'] = df['FP']
df['FN_Count'] = df['FN']

# Normalize the TP, FP, FN values for plotting
df['TP'] = df['TP'] / df['Total']
df['FP'] = df['FP'] / df['Total']
df['FN'] = df['FN'] / df['Total']

# Melt the DataFrame for plotting
df_long = df.melt(id_vars=["Tag", "Dataset", "Scheme"], value_vars=["TP", "FP", "FN"], var_name="Metric", value_name="Scale")
df_counts = df.melt(id_vars=["Tag", "Dataset", "Scheme"], value_vars=["TP_Count", "FP_Count", "FN_Count"], var_name="Metric", value_name="Count")

# Replace '_Count' to align with the other metric names
df_counts['Metric'] = df_counts['Metric'].str.replace('_Count', '')

# Merge percentage and actual count data
df_long = df_long.merge(df_counts, on=["Tag", "Dataset", "Scheme", "Metric"])

# Creating the faceted bar plot with actual counts displayed on normalized bars
fig = px.bar(df_long, x="Tag", y="Scale", color="Metric",
             facet_row="Scheme", facet_col="Dataset",
             title="Confusion Matrix Metrics by Tag, Dataset, and Scheme",
             labels={"Scale": "Scaled Counts"},
             barmode='group',
             template="plotly_white",
             facet_row_spacing=0.1,  # Adjusted spacing
             facet_col_spacing=0.08,
             text='Count'  # Display the actual Count on top of each bar
             )

# Set text to display above bars
# fig.update_traces(texttemplate='%{text}', textposition='outside')
fig.update_layout(
    plot_bgcolor='rgba(255,255,255,1)',  # Ensure the background is white
    paper_bgcolor='rgba(255,255,255,1)',  # Ensure the paper background is also white
)

fig.show()
