In [None]:
import json
from collections import defaultdict, Counter
from seqeval.scheme import auto_detect
from seqeval.metrics.sequence_labeling import get_entities
from seqeval.scheme import Entities

file_name = '/Users/ay227/Library/CloudStorage/GoogleDrive-ahmed.younes.sam@gmail.com/My Drive/Final Year Experiments/Thesis-Experiments/Experiments/BaseLineExperiment/ANERCorp_CamelLab_arabertv02/fine_tuning/evaluation_metrics.json'
with open(file_name, 'r') as file:
    entity_outputs = json.load(file)  # Use json.load() to read file, not json.loads()
   

In [None]:
y_true = entity_outputs['entity_outputs']['y_true']
y_pred = entity_outputs['entity_outputs']['y_pred']

In [None]:
from collections import defaultdict

class StrictConfusionMatrix:
    def __init__(self, y_true, y_pred):
        self.y_true = y_true
        self.y_pred = y_pred
    
    def compute(self):
        """
        Compute confusion matrix, false positives, and false negatives for all entities.

        Returns:
            dict: A dictionary containing:
                - 'confusion_matrix': The confusion matrix for all entity types.
                - 'false_negatives': A dictionary with false negative counts categorized by type and subcategory.
                - 'false_positives': A dictionary with false positive counts categorized by type and subcategory.
        """
        # Prepare entities (this initializes and formats entities based on the input scheme)
        self.prepare_entities()

        # Compute the confusion matrix for all entities
        confusion_matrix = self.compute_confusion_matrix()

        # Initialize dictionaries for false negatives and false positives
        false_negatives = defaultdict(Counter)
        false_positives = defaultdict(Counter)

        # Get all unique entity types from the data
        entity_types = set(
            ent[1] for ent in self.true_entities
        ).union(set(ent[1] for ent in self.pred_entities))

        # Iterate over all entity types to calculate false negatives and positives
        for entity_type in entity_types:
            # Compute false negatives for this type
            fn_counts = self.compute_false_negatives(entity_type)

            # Compute false positives for this type
            fp_counts = self.compute_false_positives(entity_type)

            # Merge the results into the global dictionaries
            for t_type, counts in fn_counts.items():
                for subtype, count in counts.items():
                    false_negatives[t_type][subtype] += count

            for t_type, counts in fp_counts.items():
                for subtype, count in counts.items():
                    false_positives[t_type][subtype] += count

        # Return the aggregated results
        return {
            'confusion_matrix': confusion_matrix,
            'false_negatives': dict(false_negatives),  # Convert to standard dict for output clarity
            'false_positives': dict(false_positives),  # Convert to standard dict for output clarity
        }

    
    
    def prepare_entities(self):
        # Initialize true and predicted entities
        self.scheme = auto_detect(self.y_true, False)
        entities_true = self.extract_entities(self.y_true)
        entities_pred = self.extract_entities(self.y_pred)
        self.true_entities = self.flatten_strict_entities(entities_true)
        self.pred_entities = self.flatten_strict_entities(entities_pred)

    def extract_entities(self, y_data):
        # Replace with the Entities() logic if provided
        return Entities(y_data, self.scheme, False)

    @staticmethod
    def flatten_strict_entities(entities):
        """Flatten entities extracted in strict mode into tuples."""
        return [e.to_tuple() for sen in entities.entities for e in sen]

    def compute_confusion_matrix(self):
        """Compute confusion matrix across all entity types."""
        types = set([ent[1] for ent in self.true_entities]).union(
            [ent[1] for ent in self.pred_entities]
        )

        confusion_matrix = {typ: {'TP': 0, 'FP': 0, 'FN': 0} for typ in types}

        for entity_type in types:
            TP, FP, FN = self.extract_strict_entity_confusion(entity_type)
            confusion_matrix[entity_type]['TP'] = TP
            confusion_matrix[entity_type]['FP'] = FP
            confusion_matrix[entity_type]['FN'] = FN

        return confusion_matrix

    def extract_strict_entity_confusion(self, entity):
        """Extract TP, FP, and FN for a given entity type."""
        fns = set([e for e in self.true_entities if e[1] == entity]) - set(
            [e for e in self.pred_entities if e[1] == entity]
        )
        fps = set([e for e in self.pred_entities if e[1] == entity]) - set(
            [e for e in self.true_entities if e[1] == entity]
        )
        tps = set([e for e in self.pred_entities if e[1] == entity]).intersection(
            set([e for e in self.true_entities if e[1] == entity])
        )
        return len(tps), len(fps), len(fns)

    def compute_false_positives(self, entity_type):
        """Analyze false positives for a specific entity type."""
        false_positives = set(
            [e for e in self.pred_entities if e[1] == entity_type]
        ) - set([e for e in self.true_entities if e[1] == entity_type])

        return self.analyze_errors(false_positives, self.true_entities, "FP")

    def compute_false_negatives(self, entity_type):
        """Analyze false negatives for a specific entity type."""
        false_negatives = set(
            [e for e in self.true_entities if e[1] == entity_type]
        ) - set([e for e in self.pred_entities if e[1] == entity_type])

        return self.analyze_errors(false_negatives, self.pred_entities, "FN")

    def analyze_errors(self, target_entities, comparison_entities, error_type):
        """Analyze entity-level errors (FP or FN)."""
        counts = defaultdict(Counter)
        non_o_errors = set()
        indexed_entities = defaultdict(list)

        # Index comparison entities by sentence
        for entity in comparison_entities:
            sen, entity_type, start, end = entity
            indexed_entities[sen].append(entity)
        
        # Track processed pairs to avoid duplicates in counting
        processed_pairs = set()
        # First pass: entity errors
        for target_entity in target_entities:
            t_sen, t_type, t_start, t_end = target_entity

            for comp_entity in indexed_entities[t_sen]:
                c_type, c_start, c_end = comp_entity[1:]

                # Check for entity type mismatch with exact boundary match
                if (
                t_start == c_start
                and t_end == c_end
                and t_type != c_type
                and target_entity not in non_o_errors
            ):
                    counts[t_type][c_type] += 1
                    non_o_errors.add(target_entity)

        # Second pass: boundary errors
        for target_entity in target_entities - non_o_errors:
            t_sen, t_type, t_start, t_end = target_entity

            for comp_entity in indexed_entities[t_sen]:
                c_type, c_start, c_end = comp_entity[1:]

                # Check for boundary issues with the same entity type
                if (
                t_type == c_type
                and (t_start <= c_start <= t_end or t_start <= c_end <= t_end)
                and target_entity not in non_o_errors
            ):
                    counts[t_type]['Boundary'] += 1
                    non_o_errors.add(target_entity)

        # Third pass: combined entity and boundary errors
        for target_entity in target_entities - non_o_errors:
            t_sen, t_type, t_start, t_end = target_entity

            for comp_entity in indexed_entities[t_sen]:
                c_type, c_start, c_end = comp_entity[1:]

                # Check for combined entity and boundary issues with different types
                if (
                c_type != t_type
                and (t_start <= c_start <= t_end or t_start <= c_end <= t_end)
                and target_entity not in non_o_errors
            ):
                    counts[t_type]['Entity and Boundary'] += 1
                    non_o_errors.add(target_entity)

        # Remaining errors are "O" errors (completely unmatched)
        for target_entity in target_entities - non_o_errors:
            t_sen, t_type, t_start, t_end = target_entity
            counts[t_type]['O'] += 1

        return dict(counts)


In [None]:
strict_confusion = StrictConfusionMatrix(y_true, y_pred)


In [None]:
strict_confusion.compute()

In [None]:
strict_confusion.compute_false_negatives('ORG')

In [None]:
ORG': Counter({'O': 49, 'Boundary': 27, 'LOC': 19, 'PERS': 11, 'MISC': 6})

In [None]:

def compute_confusion_matrix(true_entities, pred_entities):
        # Extract all unique entity types from true and predicted entities
        types = set([ent[1] for ent in true_entities]).union([ent[1] for ent in pred_entities])

        # Initialize the confusion matrix
        confusion_matrix = {typ: {'TP': 0, 'FP': 0, 'FN': 0} for typ in types}
        
        def extract_strict_entity_confusion(entity, true_entities, pred_entities):
            fns = set([e for e in true_entities if e[1] == entity]) - set([e for e in pred_entities if e[1] == entity])
            fps = set([e for e in pred_entities if e[1] == entity]) - set([e for e in true_entities if e[1] == entity])
            tps = set([e for e in pred_entities if e[1] == entity]).intersection(set([e for e in true_entities if e[1] == entity]))
            return len(tps), len(fps), len(fns)

        # Populate the confusion matrix for each entity type
        for entity_type in types:
            TP, FP, FN = extract_strict_entity_confusion(entity_type, true_entities, pred_entities)
            confusion_matrix[entity_type]['TP'] = TP
            confusion_matrix[entity_type]['FP'] = FP
            confusion_matrix[entity_type]['FN'] = FN

        return confusion_matrix


def flatten_strict_entities(entities):
        """
        Flatten entities extracted in strict mode into tuples.

        Args:
            entities (Entities): The strict-mode entities.

        Returns:
            list: A flattened list of tuples representing the entities.
        """
        return [e.to_tuple() for sen in entities.entities for e in sen]

In [None]:

scheme = auto_detect(y_true, False)

entities_true = Entities(y_true, scheme, False)
entities_pred = Entities(y_pred, scheme, False)
true_entity_type = flatten_strict_entities(entities_true)
pred_entity_type = flatten_strict_entities(entities_pred)
# Example usage
# conf_matrix = calculate_confusion_matrix([e.to_tuple()[1:] for sen in entities_true.entities for e in sen], [e.to_tuple()[1:] for sen in entities_pred.entities for e in sen])



confusion_matrix = compute_confusion_matrix(true_entity_type, pred_entity_type)
fn_errors = compute_false_negatives_with_boundary(true_entity_type, pred_entity_type, confusion_matrix)
fp_errors = compute_false_positives_with_boundary(true_entity_type, pred_entity_type, confusion_matrix)


print("Confusion Matrix:", confusion_matrix)
print("False Negatives:", dict(fn_errors))
print("False Positives:", dict(fp_errors))
1

In [None]:
24+17+13+8

In [None]:
for e, m in confusion_matrix.items():
    print(e, m)

In [None]:
true_entity_type

In [None]:
print('ORG', sum({'O': 50, 'Boundary': 27, 'LOC': 19, 'PERS': 11, 'MISC': 6}.values()))
print('MISC', sum({'O': 58, 'Boundary': 15, 'ORG': 9, 'LOC': 4}.values()))
print('PERS', sum({'Boundary': 70, 'O': 44, 'ORG': 11, 'LOC': 6, 'MISC': 3}.values()))
print('LOC', sum({'O': 26, 'Boundary': 12, 'MISC': 3, 'ORG': 3, 'PERS': 1}.values()))


In [None]:
sen_id = 44
for i, (t, p) in enumerate(zip(y_true[sen_id], y_pred[sen_id])):
    if i == 34:
        print('from here')
    print(i, t, p)

print(entities_true.entities[sen_id])
print(entities_pred.entities[sen_id])

In [None]:
# We have just found that ANERCorp doesn't even align completely with IOB2 see example 167 there was I-MISC: 163, 167, 504, 623, 694

In [None]:
ENTITY = 'ORG'
false_positives = set([e for e in pred_entity_type if e[1] == ENTITY]) - set([e for e in true_entity_type if e[1] == ENTITY])

# Initial sets for storing different types of errors
entity_errors = set()
boundary_errors = set()
entity_and_boundary_errors = set()
non_o_errors = set()
# Using defaultdicts to handle grouped by sentence and type categorization more effectively
true_entities_indexed = defaultdict(list)

# Index true entities by sentence for efficient comparison
for entity in true_entity_type:
    true_sen, true_type, true_start, true_end = entity
    true_entities_indexed[true_sen].append(entity)

# First pass to identify entity type errors
for predicted_entity in false_positives:
    pred_sen, pred_type, pred_start, pred_end = predicted_entity

    # Only compare within the same sentence for efficiency
    for true_entity in true_entities_indexed[pred_sen]:
        true_type, true_start, true_end = true_entity[1:]

        # Check for entity type mismatch with exact boundary match
        if true_start == pred_start and true_end == pred_end and true_type != pred_type:
            entity_errors.add(predicted_entity)
            non_o_errors.add(predicted_entity)

# Second pass to identify boundary and combined errors
for predicted_entity in false_positives - non_o_errors:
    pred_sen, pred_type, pred_start, pred_end = predicted_entity

    if predicted_entity not in entity_errors:  # Skip already identified type errors
        for true_entity in true_entities_indexed[pred_sen]:
            true_type, true_start, true_end = true_entity[1:]

            # Check for boundary issues with the same entity type
            if true_type == pred_type and (pred_start <= true_start <= pred_end or pred_start <= true_end <= pred_end):
                # boundary_errors1.add((true_entity, predicted_entity))
                boundary_errors.add( predicted_entity)
                non_o_errors.add(predicted_entity)
                
for predicted_entity in false_positives - non_o_errors:
    pred_sen, pred_type, pred_start, pred_end = predicted_entity

    if predicted_entity not in entity_errors:  # Skip already identified type errors
        for true_entity in true_entities_indexed[pred_sen]:
            true_type, true_start, true_end = true_entity[1:]

            # Entity and boundary errors for different entity types
            if (true_type != pred_type) and pred_start <= true_start <= pred_end or pred_start <= true_end <= pred_end:
                entity_and_boundary_errors.add(predicted_entity)
                non_o_errors.add(predicted_entity)
o_error = set(false_positives) - non_o_errors

# Debugging prints to check set contents
print(f"Entity Errors: {len(entity_errors)}")
print(f"Boundary Errors: {len(boundary_errors)}")
print(f"Entity and Boundary Errors: {len(entity_and_boundary_errors)}")
print(f"Outside Named Entity: {len(o_error)}")



In [None]:
ENTITY = 'LOC'
false_negatives = set([e for e in true_entity_type if e[1] == ENTITY]) - set([e for e in pred_entity_type if e[1] == ENTITY])

In [None]:
from collections import defaultdict

# Initial sets for storing different types of errors
entity_errors = set()
boundary_errors = set()
entity_and_boundary_errors = set()
non_o_errors = set()
# Using defaultdicts to handle grouped by sentence and type categorization more effectively
pred_entities_indexed = defaultdict(list)

# Index true entities by sentence for efficient comparison
for entity in pred_entity_type:
    pred_sen, pred_type, pred_start, pred_end = entity
    pred_entities_indexed[pred_sen].append(entity)

# First pass to identify entity type errors
for true_entity in false_negatives:
    true_sen, true_type, true_start, true_end = true_entity

    # Only compare within the same sentence for efficiency
    for pred_entity in pred_entities_indexed[true_sen]:
        pred_type, pred_start, pred_end = pred_entity[1:]

        # Check for entity type mismatch with exact boundary match
        if pred_start == true_start and pred_end == true_end and pred_type != true_type:
            entity_errors.add(true_entity)
            non_o_errors.add(true_entity)

# Second pass to identify boundary and combined errors
for true_entity in false_negatives - non_o_errors:
    true_sen, true_type, true_start, true_end = true_entity

    # if true_entity not in entity_errors:  # Skip already identified type errors
    for pred_entity in pred_entities_indexed[true_sen]:
        pred_type, pred_start, pred_end = pred_entity[1:]

        # Check for boundary issues with the same entity type
        if pred_type == true_type and (true_start <= pred_start <= true_end or true_start <= pred_end <= true_end):
            boundary_errors.add(true_entity)
            non_o_errors.add(true_entity)
            
for true_entity in false_negatives - non_o_errors:
    true_sen, true_type, true_start, true_end = true_entity

    
    for pred_entity in pred_entities_indexed[true_sen]:
        pred_type, pred_start, pred_end = pred_entity[1:]

        # Entity and boundary errors for different entity types
        if (pred_type != true_type) and true_start <= pred_start <= true_end or true_start <= pred_end <= true_end:
            entity_and_boundary_errors.add(true_entity)
            non_o_errors.add(true_entity)
o_error = set(false_negatives) - non_o_errors

# Debugging prints to check set contents
print(f"Entity Errors: {len(entity_errors)}")
print(f"Boundary Errors: {len(boundary_errors)}")
print(f"Entity and Boundary Errors: {len(entity_and_boundary_errors)}")
print(f"Outside Named Entity: {len(o_error)}")



In [None]:
sen_id = 232
for i, (t, p) in enumerate(zip(y_true[sen_id], y_pred[sen_id])):
    if i == 34:
        print('from here')
    print(i, t, p)

print(entities_true.entities[sen_id])
print(entities_pred.entities[sen_id])

In [None]:
29+14+13

In [None]:
19+29+21

In [None]:
len(entity_errors)

In [None]:
predicted_entity

In [None]:
entity_errors

In [None]:
# Example manual check
test_entity = (708, 'LOC', 15, 16)
if test_entity in entity_errors:
    print("Found in errors")
else:
    print("Not found in errors")


In [None]:
(708, 'LOC', 15, 16) in x

In [None]:
(708, 'LOC', 15, 16)   in entity_errors

In [None]:
entity_errors

In [None]:
len(false_positives)

In [None]:
len(entity_errors) + len(boundary)

In [None]:
entity = 'LOC' 
false_negatives = set(e for e in true_entity_type if e[1] == entity) -  set([e for e in pred_entity_type if e[1] == entity])

In [None]:
entity_errors = []
missing_errors = set()
boundary = set()
for entity in false_negatives:
    o = []
    sen, t, s, e = entity
    
    for p_entity in pred_entity_type:
        
        p_sen, p_t, p_s, p_e = p_entity
        if p_sen == sen and p_t == t and not (p_s == s and p_e == e):
            if (s <= p_s <= e) or (s <= p_e <= e):
                print(entity)
                print(p_entity)
                boundary.add((sen, s, e))
                print('end ####')
                
        elif p_sen == sen and p_s == s and p_e == e:
            entity_errors.append((sen, s, e))
    if (sen, s, e) not in boundary and  (sen, s, e) not in entity_errors:
            missing_errors.add((sen, s, e))
        

In [None]:
len(missing_errors)

In [None]:
print(len(boundary))
print(len(entity_errors))

In [None]:
missing_errors

In [None]:
sen_id = 94
for i, (t, p) in enumerate(zip(y_true[sen_id], y_pred[sen_id])):
    if i == 34:
        print('from here')
    print(i, t, p)

print(entities_true.entities[sen_id])
print(entities_pred.entities[sen_id])

In [None]:
missing_errors

In [None]:
for error in missing_errors:
	if error[0] in [error[0] for error in boundary]:
		print(error)

In [None]:
entity_errors

In [None]:
boundary

In [None]:
from collections import defaultdict, Counter

class EntityConfusion:
    def __init__(self, y_true, y_pred):
        """
        Initialize the EntityConfusionMatrix class.

        Args:
            y_true (list): The ground truth entities.
            y_pred (list): The predicted entities.
        """
        self.y_true = y_true
        self.y_pred = y_pred
        

    def prepare_entities(self):
        """
        Prepare entities for confusion matrix calculation.
        - In non-strict mode, entities are extracted using `get_entities`.
        """    
        # Use non-strict extraction
        self.true_entities = get_entities(self.y_true)
        self.pred_entities = get_entities(self.y_pred)

    @staticmethod
    def extract_entity_confusion(entity, true_entities, pred_entities):
        fns = set([e for e in true_entities if e[0] == entity]) - set([e for e in pred_entities if e[0] == entity])
        fps = set([e for e in pred_entities if e[0] == entity]) - set([e for e in true_entities if e[0] == entity])
        tps = set([e for e in pred_entities if e[0] == entity]).intersection(set([e for e in true_entities if e[0] == entity]))
        return len(tps), len(fps), len(fns)


    def compute(self):
        """
        Compute the confusion matrix, false negatives, and false positives.

        Returns:
            dict: A dictionary containing:
                  - 'confusion_matrix': The confusion matrix for entity recognition.
                  - 'false_negatives': Detailed false negatives.
                  - 'false_positives': Detailed false positives.
        """
        self.prepare_entities()
        return {
            'confusion_matrix': self.compute_confusion_matrix(),
            'false_negatives': self.compute_false_negatives_with_boundary(),
            'false_positives': self.compute_false_positives_with_boundary()
        }
    

    def compute_confusion_matrix(self):
        """
        Compute a confusion matrix for Named Entity Recognition (NER) predictions.

        Returns:
            dict: A confusion matrix structured as:
                  {entity_type: {'TP': count, 'FP': count, 'FN': count}}
        """
        # Extract all unique entity types from true and predicted entities
        types = set([ent[0] for ent in self.true_entities]).union([ent[0] for ent in self.pred_entities])

        # Initialize the confusion matrix
        confusion_matrix = {typ: {'TP': 0, 'FP': 0, 'FN': 0} for typ in types}

        
        # Populate the confusion matrix for each entity type
        for entity_type in types:
            TP, FP, FN = self.extract_entity_confusion(entity_type, self.true_entities, self.pred_entities)
            confusion_matrix[entity_type]['TP'] = TP
            confusion_matrix[entity_type]['FP'] = FP
            confusion_matrix[entity_type]['FN'] = FN

        return confusion_matrix
    

    def compute_false_negatives_with_boundary(self):
        """
        Compute false negatives with detailed categorization:
        - 'Boundary': Incorrect boundaries for the same entity type.
        - 'Missed': Predicted as O or no match at all.

        Returns:
            dict: False negatives categorized by entity type.
        """
        fn_counts = defaultdict(Counter)
        true_indexed = {(t[1], t[2]): t[0] for t in self.true_entities}  # Index true entities by boundaries
        pred_indexed = {(p[1], p[2]): p[0] for p in self.pred_entities}  # Index predicted entities by boundaries
        # Iterate through true entities to classify false negatives
        for (t_start, t_end), t_type in true_indexed.items():
            if (t_start, t_end) in pred_indexed:
                if pred_indexed[(t_start, t_end)] != t_type:
                    # Type mismatch at the exact position
                    matched_type = pred_indexed.get((t_start, t_end))
                    fn_counts[t_type][matched_type] += 1
            else:
                # No exact match found, check for other errors
                boundary_error = False
                entity_error = False
                for (p_start, p_end), p_type in pred_indexed.items():
                    if t_type == p_type:
                        if (p_start <= t_start <= p_end) or (p_start <= t_end <= p_end):
                            # Boundary error for the same type
                            fn_counts[t_type]['Boundary'] += 1
                            boundary_error = True
                            break
                    else:
                        if (p_start <= t_start <= p_end) or (p_start <= t_end <= p_end):
                            # Boundary error with a different entity type
                            fn_counts[t_type]['Entity and Boundary'] += 1
                            entity_error = True
                            break
                if not boundary_error and not entity_error:
                    # Missed entity entirely
                    fn_counts[t_type]['O'] += 1

        return dict(fn_counts)
        # t = []
        # # Iterate through true entities to classify false negatives
        # for (t_start, t_end), t_type in true_indexed.items():
        #     entity_error = False
        #     if (t_start, t_end) not in pred_indexed or pred_indexed[(t_start, t_end)] != t_type:
        #         # No matching prediction or type mismatch at the same position
        #         matched_type = pred_indexed.get((t_start, t_end))
        #         if matched_type:
        #             fn_counts[t_type][matched_type] += 1
        #             entity_error = True
        #         # Check if there's a predicted entity of the same type with incorrect boundaries
        #         boundary_error = False
        #         entity_boundary_error = False
        #         for (p_start, p_end), p_type in pred_indexed.items():
        #             if p_type == t_type and not (p_start == t_start and p_end == t_end):
        #                 if (t_start <= p_start <= t_end) or (t_start <= p_end <= t_end):
        #                     # Detected boundary error for the same entity type
        #                     print('boundary')
        #                     print(t_start, t_end, t_type)
        #                     print(p_start, p_end, p_type)
        #                     print('######')
        #                     fn_counts[t_type]['Boundary'] += 1
        #                     boundary_error = True
        #                     break
        #             elif t_type != p_type and not (p_start == t_start and p_end == t_end):
        #                 if (p_start <= t_start <= p_end) or (p_start <= t_end <= p_end):
        #                     # Detected boundary error for the same entity type
        #                     fn_counts[p_type]['Entity and Boundary'] += 1
        #                     print('entity and boundary')
        #                     print(t_start, t_end, t_type)
        #                     print(p_start, p_end, p_type)
        #                     print('######')
        #                     entity_boundary_error = True
        #                     break
        #         if not boundary_error and not entity_error and not entity_boundary_error:
        #             if (p_start == t_start and p_end == t_end):
        #                 # Missed entity entirely (e.g., predicted as O)
        #                 fn_counts[t_type]['O'] += 1
        # return dict(fn_counts)

    def compute_false_positives_with_boundary(self):
        """
        Compute false positives with boundary categorization:
        - 'Boundary': Incorrect boundaries for the same entity type.
        - 'Missed': Predicted as O or no match at all.

        Returns:
            dict: False positives categorized by entity type.
        """
        fp_counts = defaultdict(Counter)
        true_indexed = {(t[1], t[2]): t[0] for t in self.true_entities}  # Index true entities by boundaries
        pred_indexed = {(p[1], p[2]): p[0] for p in self.pred_entities}  # Index predicted entities by boundaries
        # Iterate through predicted entities to find false positives
        for (p_start, p_end), p_type in pred_indexed.items():
            entity_error = False
            if (p_start, p_end) not in true_indexed or true_indexed[(p_start, p_end)] != p_type:
                # No matching true entity or type mismatch at the same position
                matched_type = true_indexed.get((p_start, p_end))
                if matched_type:
                    fp_counts[p_type][matched_type] += 1
                    entity_error = True
                boundary_error = False
                entity_boundary_error = False
                for (t_start, t_end), t_type in true_indexed.items():
                    if t_type == p_type and not (p_start == t_start and p_end == t_end):
                        if (p_start <= t_start <= p_end) or (p_start <= t_end <= p_end):
                            # Detected boundary error for the same entity type
                            fp_counts[p_type]['Boundary'] += 1
                            boundary_error = True
                            break
                    elif t_type != p_type and not (p_start == t_start and p_end == t_end):
                        if (p_start <= t_start <= p_end) or (p_start <= t_end <= p_end):
                            fp_counts[p_type]['Entity and Boundary'] += 1
                            entity_boundary_error = True
                            break
                if not boundary_error and not entity_error and not entity_boundary_error:
                    # Missed entity entirely (e.g., predicted as O)
                    fp_counts[p_type]['O'] += 1

        return dict(fp_counts)


In [None]:
def compute_false_negatives_with_boundary(true_entities, pred_entities):
    fn_counts = defaultdict(Counter)
    true_indexed = {(t[1], t[2]): t[0] for t in true_entities}  # Index true entities by boundaries
    pred_indexed = {(p[1], p[2]): p[0] for p in pred_entities}  # Index predicted entities by boundaries

    # Iterate through true entities to classify false negatives
    for (t_start, t_end), t_type in true_indexed.items():
        if (t_start, t_end) in pred_indexed:
            if pred_indexed[(t_start, t_end)] != t_type:
                # Type mismatch at the exact position
                matched_type = pred_indexed.get((t_start, t_end))
                fn_counts[t_type][matched_type] += 1
        else:
            # No exact match found, check for other errors
            boundary_error = False
            entity_error = False
            for (p_start, p_end), p_type in pred_indexed.items():
                if t_type == p_type:
                    if (p_start <= t_start <= p_end) or (p_start <= t_end <= p_end):
                        # Boundary error for the same type
                        fn_counts[t_type]['Boundary'] += 1
                        boundary_error = True
                        break
                else:
                    if (p_start <= t_start <= p_end) or (p_start <= t_end <= p_end):
                        # Boundary error with a different entity type
                        fn_counts[t_type]['Entity and Boundary'] += 1
                        entity_error = True
                        break
            if not boundary_error and not entity_error:
                # Missed entity entirely
                fn_counts[t_type]['O'] += 1

    return dict(fn_counts)


In [None]:
compute_false_positives_with_boundary(true_entities, pred_entities)

In [None]:
true_entities = get_entities(y_true)
pred_entities = get_entities(y_pred)
compute_false_negatives_with_boundary(true_entities, pred_entities)

In [None]:
# Step 1: Prepare entities
confusion = EntityConfusion(y_true, y_pred)

result = confusion.compute()


In [None]:
result


In [None]:
result

In [None]:
sum(list({'Boundary': 76, 'O': 54, 'ORG': 13, 'LOC': 6, 'MISC': 5}.values()))

In [None]:
sum({'Boundary': 74,
           'O': 51,
           'ORG': 13,
           'Entity and Boundary': 10,
           'LOC': 6,
           'MISC': 5}.values())

In [None]:
'PERS': Counter({'Boundary': 74,
           'O': 51,
           'ORG': 13,
           'Entity and Boundary': 10,
           'LOC': 6,
           'MISC': 5}),

In [None]:
result

In [None]:
29+10+7+3

In [None]:
'false_negatives': {'LOC': Counter({'O': 32,
           'Boundary': 10,
           'MISC': 3,
           'ORG': 3,
           'PERS': 1}),

In [None]:
from collections import defaultdict, Counter
def compute_false_negatives_with_boundary(y_true, y_pred):
    """
    Compute false negatives with detailed categorization:
    - 'Boundary': Incorrect boundaries for the same entity type.
    - 'Missed': Predicted as O or no match at all.
    """
    fn_counts = defaultdict(Counter)
    true_indexed = {(t[1], t[2]): t[0] for t in y_true}  # Index true entities by boundaries
    pred_indexed = {(p[1], p[2]): p[0] for p in y_pred}  # Index predicted entities by boundaries

    # Iterate through true entities to classify false negatives
    
    for (t_start, t_end), t_type in true_indexed.items():
        entity_error = False
        if (t_start, t_end) not in pred_indexed or pred_indexed[(t_start, t_end)] != t_type:
            # No matching prediction or type mismatch at the same position
            matched_type = pred_indexed.get((t_start, t_end))
            if matched_type:
                fn_counts[t_type][matched_type] += 1
                entity_error = True
            # Check if there's a predicted entity of the same type with incorrect boundaries
            boundary_error = False
            for (p_start, p_end), p_type in pred_indexed.items():
                if p_type == t_type and not (p_start == t_start and p_end == t_end):
                    if (t_start <= p_start <= t_end) or (t_start <= p_end <= t_end):
                        # Detected boundary error for the same entity type
                        fn_counts[t_type]['Boundary'] += 1
                        boundary_error = True
                        break
            
            if not boundary_error and not entity_error:
                # Missed entity entirely (e.g., predicted as O)
                fn_counts[t_type]['O'] += 1
    return dict(fn_counts)



def compute_false_positives(y_true, y_pred):
    fp_counts = defaultdict(Counter)
    true_indexed = {(t[1], t[2]): t[0] for t in y_true}  # Index true entities by boundaries
    pred_indexed = {(p[1], p[2]): p[0] for p in y_pred}  # Index predicted entities by boundaries

    # Iterate through predicted entities to find false positives
    for (p_start, p_end), p_type in pred_indexed.items():
        entity_error = False
        if (p_start, p_end) not in true_indexed or true_indexed[(p_start, p_end)] != p_type:
            # No matching true entity or type mismatch at the same position
            matched_type = true_indexed.get((p_start, p_end))
            if matched_type:
                fp_counts[p_type][matched_type] += 1
                entity_error = True
            boundary_error = False
            for (t_start, t_end), t_type in true_indexed.items():
                if t_type == p_type and not (p_start == t_start and p_end == t_end):
                    if (p_start <= t_start <= p_end) or (p_start <= t_end <= p_end):
                        # Detected boundary error for the same entity type
                        fp_counts[p_type]['Boundary'] += 1
                        boundary_error = True
                        break
            if not boundary_error and not entity_error:
                # Missed entity entirely (e.g., predicted as O)
                fp_counts[p_type]['O'] += 1

    return dict(fp_counts)

{'LOC': {'TP': 627, 'FP': 76, 'FN': 49}, 'ORG': {'TP': 338, 'FP': 105, 'FN': 121}, 'PERS': {'TP': 751, 'FP': 99, 'FN': 154}, 'MISC': {'TP': 151, 'FP': 54, 'FN': 92}}

def compute_confusion_matrix(true_entities, pred_entities):
    """
    Compute a confusion matrix for Named Entity Recognition (NER) predictions.

    Parameters:
        true_entities (list): List of ground truth entities, each represented as a tuple (type, start, end).
        pred_entities (list): List of predicted entities, each represented as a tuple (type, start, end).

    Returns:
        dict: A confusion matrix structured as:
              {entity_type: {'TP': count, 'FP': count, 'FN': count}}
    """

    # Extract all unique entity types from true and predicted entities
    types = set([ent[0] for ent in true_entities]).union([ent[0] for ent in pred_entities])

    # Initialize the confusion matrix
    confusion_matrix = {typ: {'TP': 0, 'FP': 0, 'FN': 0} for typ in types}

    def extract_entity_confusion(entity, true_entities, pred_entities):
        """
        Helper function to calculate TP, FP, FN for a specific entity type.
        """
        fns = set([e for e in true_entities if e[0] == entity]) - set([e for e in pred_entities if e[0] == entity])
        fps = set([e for e in pred_entities if e[0] == entity]) - set([e for e in true_entities if e[0] == entity])
        tps = set([e for e in pred_entities if e[0] == entity]).intersection(set([e for e in true_entities if e[0] == entity]))
        return len(tps), len(fps), len(fns)

    # Populate the confusion matrix for each entity type
    for entity_type in types:
        TP, FP, FN = extract_entity_confusion(entity_type, true_entities, pred_entities)
        confusion_matrix[entity_type]['TP'] = TP
        confusion_matrix[entity_type]['FP'] = FP
        confusion_matrix[entity_type]['FN'] = FN

    return confusion_matrix



# Calculate false negatives
false_negatives = compute_false_negatives_with_boundary(true_entities, pred_entities)
false_positives = compute_false_positives(true_entities, pred_entities)
print(false_negatives)
print(false_positives)

In [None]:
{'LOC': {'TP': 627, 'FP': 76, 'FN': 49}, 'ORG': {'TP': 338, 'FP': 105, 'FN': 121}, 'PERS': {'TP': 751, 'FP': 99, 'FN': 154}, 'MISC': {'TP': 151, 'FP': 54, 'FN': 92}}

ENTITY = 'LOC'
def compute_confusion_matrix(true_entities, pred_entities):
    """
    Compute a confusion matrix for Named Entity Recognition (NER) predictions.

    Parameters:
        true_entities (list): List of ground truth entities, each represented as a tuple (type, start, end).
        pred_entities (list): List of predicted entities, each represented as a tuple (type, start, end).

    Returns:
        dict: A confusion matrix structured as:
              {entity_type: {'TP': count, 'FP': count, 'FN': count}}
    """

    # Extract all unique entity types from true and predicted entities
    types = set([ent[0] for ent in true_entities]).union([ent[0] for ent in pred_entities])

    # Initialize the confusion matrix
    confusion_matrix = {typ: {'TP': 0, 'FP': 0, 'FN': 0} for typ in types}

    def extract_entity_confusion(entity, true_entities, pred_entities):
        """
        Helper function to calculate TP, FP, FN for a specific entity type.
        """
        fns = set([e for e in true_entities if e[0] == entity]) - set([e for e in pred_entities if e[0] == entity])
        fps = set([e for e in pred_entities if e[0] == entity]) - set([e for e in true_entities if e[0] == entity])
        tps = set([e for e in pred_entities if e[0] == entity]).intersection(set([e for e in true_entities if e[0] == entity]))
        return len(tps), len(fps), len(fns)

    # Populate the confusion matrix for each entity type
    for entity_type in types:
        TP, FP, FN = extract_entity_confusion(entity_type, true_entities, pred_entities)
        confusion_matrix[entity_type]['TP'] = TP
        confusion_matrix[entity_type]['FP'] = FP
        confusion_matrix[entity_type]['FN'] = FN

    return confusion_matrix



In [None]:
confusion_matrix = compute_confusion_matrix(true_entities, pred_entities)
print(confusion_matrix)

In [None]:
result

In [None]:
# Compute the confusion matrix
conf_matrix = conf_matrix_obj.compute_confusion_matrix()
print("Confusion Matrix:", conf_matrix)

false_negatives = conf_matrix_obj.compute_false_negatives_with_boundary()
print("False Negatives:", false_negatives)

# Compute false positives with boundary
false_positives = conf_matrix_obj.compute_false_positives()
print("False Positives:", false_positives)


In [None]:
import plotly.express as px

# Prepare data for bar chart
bar_data = []
for entity, counts in conf_matrix.items():
    bar_data.append({'Entity': entity, 'Metric': 'TP', 'Count': counts['TP']})
    bar_data.append({'Entity': entity, 'Metric': 'FP', 'Count': counts['FP']})
    bar_data.append({'Entity': entity, 'Metric': 'FN', 'Count': counts['FN']})

df_bar = pd.DataFrame(bar_data)

# Plot stacked bar chart
fig = px.bar(
    df_bar,
    x="Entity",
    y="Count",
    color="Metric",
    title="Distribution of TP, FP, and FN by Entity",
    text_auto=True,
    barmode="stack",
)
fig.update_layout(xaxis_title="Entity Type", yaxis_title="Count", legend_title="Metric")
fig.show()


In [None]:
import plotly.express as px
import pandas as pd
# False Negatives
false_negatives = {
    'LOC': {'O': 32, 'Boundary': 10, 'MISC': 3, 'ORG': 3, 'PERS': 1},
    'PERS': {'Boundary': 76, 'O': 54, 'ORG': 13, 'LOC': 6, 'MISC': 5},
    'ORG': {'O': 57, 'Boundary': 27, 'LOC': 20, 'PERS': 11, 'MISC': 6},
    'MISC': {'O': 59, 'Boundary': 18, 'ORG': 9, 'LOC': 6},
}

# Convert to long format for Plotly
df_false_negatives = pd.DataFrame(false_negatives).fillna(0)
df_false_negatives_long = df_false_negatives.reset_index().melt(
    id_vars='index', var_name='Entity Type', value_name='Count'
)
df_false_negatives_long.rename(columns={'index': 'Error Type'}, inplace=True)

# Stacked Bar Chart
fig = px.bar(
    df_false_negatives_long,
    x='Entity Type',
    y='Count',
    color='Error Type',
    title='False Negatives by Entity Type',
    barmode='group',
    text='Count',
)
fig.update_layout(xaxis_title="Entity Type", yaxis_title="Count", legend_title="Error Type")
fig.show()


In [None]:
df_false_negatives['LOC'].index

In [None]:
# False Positives
false_positives = {
    'LOC': {'O': 30, 'ORG': 20, 'Boundary': 14, 'MISC': 6, 'PERS': 6},
    'MISC': {'Boundary': 21, 'O': 19, 'ORG': 6, 'PERS': 5, 'LOC': 3},
    'ORG': {'O': 49, 'Boundary': 31, 'PERS': 13, 'MISC': 9, 'LOC': 3},
    'PERS': {'Boundary': 51, 'O': 36, 'ORG': 11, 'LOC': 1},
}

# Choose entity to visualize
entity = 'LOC'
data = false_positives[entity]
data = pd.DataFrame(false_positives).fillna(0)[entity]

# Pie Chart
fig = px.pie(
    names=list(data.index),
    values=list(data.values),
    title=f"False Positives Distribution for {entity}",
)
fig.update_traces(textinfo='percent+label')
fig.show()


# Confirming Boundary and O calculations

In [None]:
true_entities = get_entities(y_true)
pred_entities = get_entities(y_pred)
true_indexed = {(t[1], t[2]): t[0] for t in true_entities}
pred_indexed = {(t[1], t[2]): t[0] for t in pred_entities}
entity = 'LOC'
false_positives = set([e for e in pred_entities if e[0] == entity]) - set([e for e in true_entities if e[0] == entity])
false_negatives = set([e for e in true_entities if e[0] == entity]) - set([e for e in pred_entities if e[0] == entity])

In [None]:
entity_errors = []
missing_errors = []
for entity in false_positives:
    o = []
    t, s, e = entity
    if (s, e) not in true_indexed:
            missing_errors.append((s, e))
    for p_entity in true_entities:
        
        p_t, p_s, p_e = p_entity
        if p_t == t and not (p_s == s and p_e == e):
            if (s <= p_s <= e) or (s <= p_e <= e):
                print(entity)
                print(p_entity)
                print('end ####')
                
        if p_s == s and p_e == e:
            entity_errors.append(p_entity)
print(len(entity_errors))
print(len(missing_errors))


In [None]:
entity_errors = []
missing_errors = []
for entity in false_negatives:
    o = []
    t, s, e = entity
    if (s, e) not in pred_indexed:
            missing_errors.append((s, e))
    for p_entity in pred_entities:
        
        p_t, p_s, p_e = p_entity
        if p_t == t and not (p_s == s and p_e == e):
            if (s <= p_s <= e) or (s <= p_e <= e):
                print(entity)
                print(p_entity)
                print('end ####')
                
        if p_s == s and p_e == e:
            entity_errors.append(p_entity)
        

In [None]:
len(missing_errors)

In [None]:
len(pred_entities)

# strict test

In [None]:
 

def flatten_strict_entities(entities):
        """
        Flatten entities extracted in strict mode into tuples.

        Args:
            entities (Entities): The strict-mode entities.

        Returns:
            list: A flattened list of tuples representing the entities.
        """
        return [e.to_tuple() for sen in entities.entities for e in sen]
    
scheme = auto_detect(y_true, False)

entities_true = Entities(y_true, scheme, False)
entities_pred = Entities(y_pred, scheme, False)
true_entities = flatten_strict_entities(entities_true)
pred_entities = flatten_strict_entities(entities_pred)
entity = 'LOC'
entity_true = [e for e in true_entities if e[1] == entity]
entity_pred = [e for e in pred_entities if e[1] == entity]
print(len(set(entity_true).intersection(set(entity_pred))))
print(len(set(entity_true) - set(entity_pred)))
print(len(set(entity_pred) - set(entity_true)))

In [None]:



def flatten_strict_entities(entities):
    return [e.to_tuple()[1:] for sen in entities.entities for e in sen]

def calculate_confusion_matrix(y_true, y_pred):
    # Initialize confusion matrix data structure
    types = set([ent[0] for ent in y_true]).union([ent[0] for ent in y_pred])
    confusion_matrix = {typ: {'TP': 0, 'FP': 0, 'FN': 0} for typ in types}


    # Track matched predictions to avoid counting them more than once
    matched_pred_indices = set()

    # Check each true entity against predicted entities
    for true_ent in y_true:
        true_type, true_start, true_end = true_ent
        match_found = False

        for idx, pred_ent in enumerate(y_pred):
            pred_type, pred_start, pred_end = pred_ent

            if idx not in matched_pred_indices and true_type == pred_type and true_start == pred_start and true_end == pred_end:
                confusion_matrix[true_type]['TP'] += 1
                matched_pred_indices.add(idx)
                match_found = True
                break
        
        if not match_found:
            confusion_matrix[true_type]['FN'] += 1


    # Any unmatched prediction is a false positive
    for idx, pred_ent in enumerate(y_pred):
        if idx not in matched_pred_indices:
            pred_type = pred_ent[0]
            confusion_matrix[pred_type]['FP'] += 1

    return confusion_matrix



from collections import defaultdict, Counter

def compute_false_negatives(y_true, y_pred):
    fn_counts = defaultdict(Counter)
    true_indexed = {(t[1], t[2]): t[0] for t in y_true}  # Index true entities by boundaries
    pred_indexed = {(p[1], p[2]): p[0] for p in y_pred}  # Index predicted entities by boundaries

    # Iterate through true entities to find false negatives
    for (t_start, t_end), t_type in true_indexed.items():
        if (t_start, t_end) not in pred_indexed or pred_indexed[(t_start, t_end)] != t_type:
            # No matching prediction or type mismatch at the same position
            matched_type = pred_indexed.get((t_start, t_end), 'Boundary')
            fn_counts[t_type][matched_type] += 1

    return fn_counts




def compute_false_positives(y_true, y_pred):
    fp_counts = defaultdict(Counter)
    true_indexed = {(t[1], t[2]): t[0] for t in y_true}  # Index true entities by boundaries
    pred_indexed = {(p[1], p[2]): p[0] for p in y_pred}  # Index predicted entities by boundaries

    # Iterate through predicted entities to find false positives
    for (p_start, p_end), p_type in pred_indexed.items():
        if (p_start, p_end) not in true_indexed or true_indexed[(p_start, p_end)] != p_type:
            # No matching true entity or type mismatch at the same position
            matched_type = true_indexed.get((p_start, p_end), 'Boundary')
            fp_counts[p_type][matched_type] += 1

    return fp_counts



In [None]:
entity_y_true = get_entities(y_true)
entity_y_pred = get_entities(y_pred)

conf_matrix = calculate_confusion_matrix(entity_y_true, entity_y_pred)
print(conf_matrix)
fn_errors = compute_false_negatives(entity_y_true, entity_y_pred)
fp_errors = compute_false_positives(entity_y_true, entity_y_pred)

print("False Negatives:", dict(fn_errors))
print("False Positives:", dict(fp_errors))
1

In [None]:
scheme = auto_detect(y_true, False)

entities_true = Entities(y_true, scheme, False)
entities_pred = Entities(y_pred, scheme, False)
true_entity_type = flatten_strict_entities(entities_true)
pred_entity_type = flatten_strict_entities(entities_pred)
# Example usage
# conf_matrix = calculate_confusion_matrix([e.to_tuple()[1:] for sen in entities_true.entities for e in sen], [e.to_tuple()[1:] for sen in entities_pred.entities for e in sen])



conf_matrix = calculate_confusion_matrix(true_entity_type, pred_entity_type)
print(conf_matrix)
fn_errors = compute_false_negatives(true_entity_type, pred_entity_type)
fp_errors = compute_false_positives(true_entity_type, pred_entity_type)

print("False Negatives:", dict(fn_errors))
print("False Positives:", dict(fp_errors))
1

In [None]:
# Example usage
entity_y_true = get_entities(entity_outputs['entity_outputs']['y_true'])
entity_y_pred = get_entities(entity_outputs['entity_outputs']['y_pred'])


entities_true = Entities(entity_outputs['entity_outputs']['y_true'], scheme, False)
entities_pred = Entities(entity_outputs['entity_outputs']['y_pred'], scheme, False)



In [None]:
conf_matrix = calculate_confusion_matrix(entity_y_true, entity_y_pred)
print(conf_matrix)

In [None]:
(627+151+751+338) / ((627+92)+(151+154)+(751+49)+(338+121))

In [None]:
1867 / 2201

In [None]:
conf_matrix = calculate_confusion_matrix(entity_y_true, entity_y_pred)
print(conf_matrix)

In [None]:
total_metrics = {}
for metric in ['TP', 'FP', 'FN']:
    total_metrics[metric] = sum(details[metric] for details in conf_matrix.values())
print(total_metrics)

In [None]:
1867 / (1867+334)

In [None]:


# Example usage
fn_errors = compute_false_negatives(entity_y_true, entity_y_pred)
fp_errors = compute_false_positives(entity_y_true, entity_y_pred)

print("False Negatives:", dict(fn_errors))
print("False Positives:", dict(fp_errors))
1

In [None]:
import plotly.express as px
import pandas as pd

# Your original data
data = conf_matrix

# Prepare lists for DataFrame construction
actual = []
predicted = []
counts = []

for (act, pred), count in data.items():
    actual.append(act)
    predicted.append('None' if pred is None else pred)  # Replace None with 'None' for better visualization
    counts.append(count)

# Create DataFrame
df = pd.DataFrame({'Actual': actual, 'Predicted': predicted, 'Count': counts})

# Pivot to format suitable for heatmap
pivot_table = df.pivot(index='Actual', columns='Predicted', values='Count').fillna(0)

# Generate heatmap
fig = px.imshow(pivot_table,
                labels=dict(x="Predicted Entity Type", y="Actual Entity Type", color="Count"),
                x=pivot_table.columns,
                y=pivot_table.index,
                text_auto=True,
                aspect="auto")

fig.update_layout(
    title="Entity Recognition Confusion Matrix",
    xaxis_title="Predicted Entity Type",
    yaxis_title="Actual Entity Type"
)

fig.show()


In [None]:
errors

In [None]:
import plotly.express as px
import pandas as pd

# Your original data
data = conf_matrix1

# Prepare lists for DataFrame construction
actual = []
predicted = []
counts = []

for (act, pred), count in data.items():
    actual.append(act)
    predicted.append('None' if pred is None else pred)  # Replace None with 'None' for better visualization
    counts.append(count)

# Create DataFrame
df = pd.DataFrame({'Actual': actual, 'Predicted': predicted, 'Count': counts})

# Pivot to format suitable for heatmap
pivot_table = df.pivot(index='Actual', columns='Predicted', values='Count').fillna(0)

# Generate heatmap
fig = px.imshow(pivot_table,
                labels=dict(x="Predicted Entity Type", y="Actual Entity Type", color="Count"),
                x=pivot_table.columns,
                y=pivot_table.index,
                text_auto=True,
                aspect="auto")

fig.update_layout(
    title="Entity Recognition Confusion Matrix",
    xaxis_title="Predicted Entity Type",
    yaxis_title="Actual Entity Type"
)

fig.show()


In [None]:
entity_tag = 'LOC'

false_negatives = set([e for e in entity_y_true if e[0] == entity_tag]) - set([e for e in entity_y_pred if e[0] == entity_tag])
for fn in false_negatives:
  t, fn_s, fn_e = fn
  for entity in entity_y_true:
      t_t, t_s, t_e = entity
      if fn_s == t_s or fn_e == t_e:
        if t_s!=t_e:
          print(entity)
id = 5594
for entity in entity_y_pred:
    t, s, e = entity
    if s == id:
        print(entity)

In [None]:
entity_tag = 'LOC'

false_postive = set([e for e in entity_y_pred if e[0] == entity_tag]) - set([e for e in entity_y_true if e[0] == entity_tag])

for fn in false_postive:
  t, fn_s, fn_e = fn
  for entity in entity_y_pred:
      t_t, t_s, t_e = entity
      if fn_s == t_s or fn_e == t_e:
        if t_s!=t_e:
          print(entity)


In [None]:
# so the false positive doesn't have to be false positive. 
id = 6445
for entity in entity_y_true:
    t, s, e = entity
    if s == id:
        print(entity)

In [None]:
pr =[
        tok for sen in entity_outputs['entity_outputs']['y_pred']
        for tok in sen
        ]


tr =[
        tok for sen in entity_outputs['entity_outputs']['y_true']
        for tok in sen
        ]

In [None]:
pr[8760:8790] == tr[8760:8790]

In [None]:
tr[8760:8790]

In [None]:
('LOC', 8786, 8787)
('LOC', 8864, 8865)
('LOC', 16466, 16467)
('LOC', 5593, 5594)
('LOC', 2545, 2546)
('LOC', 25446, 25447)
('LOC', 1615, 1616)


In [None]:
entity_y_pred

In [None]:
entity_y_true

# Debugging

In [None]:
ENTITY = 'LOC'
entity_false_negatives = {ENTITY: Counter()}
false_negatives = set([e for e in entity_y_true if e[0] == 'LOC']) - set([e for e in entity_y_pred if e[0] == 'LOC'])
for e in false_negatives:
    t_type, t_start, t_end = e
    for pred_ent in entity_y_pred:
        p_type, p_start, p_end = pred_ent
        if t_start == p_start and t_start == p_end:
            if p_type == 'LOC':
                print(pred_ent)
            entity_false_negatives[t_type][p_type]+=1
            

ENTITY = 'LOC'
entity_false_positives = {ENTITY: Counter()}
false_positive = set([e for e in entity_y_pred if e[0] == ENTITY]) - set([e for e in entity_y_true if e[0] == ENTITY]) 
for e in false_positive:
    p_type, p_start, p_end = e
    for true_ent in entity_y_true:
        t_type, t_start, t_end = true_ent
        if t_start == p_start and t_end == p_end:
            # if p_type == 'ORG':
            #     # if t_type == 'ORG':
            #         print(true_ent)
            if p_type == t_type:
                entity_false_positives[p_type][t_type]+=1

In [None]:
entity_false_positives

In [None]:
id = 8786
for entity in entity_y_true:
    t, s, e = entity
    if s == id:
        print(entity)
for entity in entity_y_pred:
    t, s, e = entity
    if s == id or e == id+1:
        print(entity)

In [None]:
for entity in false_negatives:
    t, s, e = entity
    # if t == 'LOC':
    #     print(entity)
    if s == 8786:
        print(entity)

In [None]:
for entity in false_positive:
    t, s, e = entity
   
    if s == 16466 or e == 16467:
        print(entity)

In [None]:
for entity in entity_y_true:
    t, s, e = entity
    if s == 16963:
        print(entity)

In [None]:
for entity in entity_y_pred:
    t, s, e = entity
    if s == 16963:
        print(entity)

In [None]:
entity_y_true