# BERTweet Misclassification Analysis

Analyise validation tweets misclassified by BERTweet.

In [None]:
import os
import sys
import re

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from IPython.display import display

sys.path.append(os.path.join(os.pardir, os.pardir, 'src'))
from data_processing.loading import load_train_valid_data

In [None]:
# Load original dataset to get correct labels.
path_to_dataset = os.path.join(os.pardir, os.pardir, 'dataset')
_, valid = load_train_valid_data(path_to_dataset)

display(valid)

In [None]:
# Load BERTweet predictions generated while experimenting with the model.
valid_predictions = pd.read_csv('bertweet_valid_predictions.csv', index_col='Id')

# Change column names to fit with original dataset for easy joining.
valid_predictions.index.name = 'id'
valid_predictions.columns = ['prediction']

display(valid_predictions)

In [None]:
# Join BERTweet classifications with original labels.
valid_joint = valid.join(valid_predictions)

# Filter correctly classified tweets.
valid_correct = valid_joint[valid_joint['label'] == valid_joint['prediction']]

# Filter misclassified tweets, i.e. where prediction is different from label.
valid_misclass = valid_joint[valid_joint['label'] != valid_joint['prediction']]

display(valid_correct)
display(valid_misclass)

In [None]:
# Compute confusion matrix to see if one class is more frequently
# misclassified.
labels = ['positive', 'negative']

num_valid_tweets = len(valid)

num_true_positives = len(valid_correct[valid_correct['label'] == 1])
num_false_positives = len(valid_misclass[valid_misclass['label'] == 1])

num_true_negatives = len(valid_correct[valid_correct['label'] == -1])
num_false_negatives = len(valid_misclass[valid_misclass['label'] == -1])

confusion_matrix = np.array([
    [num_true_positives, num_false_positives],
    [num_false_negatives, num_true_negatives]
]) / num_valid_tweets

fig, ax = plt.subplots(figsize=((6,6)))
im = ax.imshow(confusion_matrix)

# Show all ticks and label with respective list entries.
ax.set_xticks(np.arange(len(labels)))
ax.set_yticks(np.arange(len(labels)))
ax.set_xticklabels(labels)
ax.set_yticklabels(labels)

ax.set_xlabel('Predicted')
ax.set_ylabel('True')

# Loop over data dimensions and create text annotations.
for i in range(len(labels)):
    for j in range(len(labels)):
        text = ax.text(j, i, f'{100 * confusion_matrix[i, j]:.2f}%',
                       ha="center", va="center")

ax.set_title("BERTweet Confusion Matrix")
fig.tight_layout()
plt.show()

## Misclassified Common Twitter Abbreviations

Check to see if common Twitter abbreviations occur in the misclassified
validation tweets with a frequency higher than their full text versions. If so,
then this hints at BERTweet struggling with Twitter abbreviations, which would
be ground for dictionary normalization with the full text versions.

The candidates for Twitter abbreviations and their full versions are taken from:

https://www.webopedia.com/reference/twitter-dictionary-guide/

In [None]:
# Compute a dataframe with a summary of occurrences of different candidate
# abbreviations and their full text counter parts.
abbreviation_occur = []

with open("normalization-dict-candidates.csv", "r") as f:
    for index, line in enumerate(f.readlines()):
        if index == 0:
            # Skip the header line.
            continue

        abbr, full = line.strip().split(",")
        abbr_matcher = re.compile(
            f"(\s+{abbr}\s+)|(^{abbr}\s+)|(\s+{abbr}$)|(^{abbr}$)"
        )
        full_matcher = re.compile(
            f"(\s+{full}\s+)|(^{full}\s+)|(\s+{full}$)|(^{full}$)"
        )

        valid['abbr_occur'] = valid['tweet'].apply(
            lambda tweet: len(abbr_matcher.findall(tweet))
        )

        valid_misclass['abbr_occur'] = valid_misclass['tweet'].apply(
            lambda tweet: len(abbr_matcher.findall(tweet))
        )

        valid['full_occur'] = valid['tweet'].apply(
            lambda tweet: len(full_matcher.findall(tweet))
        )

        valid_misclass['full_occur'] = valid_misclass['tweet'].apply(
            lambda tweet: len(full_matcher.findall(tweet))
        )

        valid_abbr_occur = float(valid['abbr_occur'].sum())
        misclass_abbr_occur = valid_misclass['abbr_occur'].sum()

        valid_full_occur = float(valid['full_occur'].sum())
        misclass_full_occur = valid_misclass['full_occur'].sum()

        abbreviation_occur.append([
            abbr,
            misclass_abbr_occur,
            valid_abbr_occur,
            misclass_abbr_occur / valid_abbr_occur,
            full,
            misclass_full_occur,
            valid_full_occur,
            misclass_full_occur / valid_full_occur
        ])

del valid['abbr_occur']
del valid['full_occur']
del valid_misclass['abbr_occur']
del valid_misclass['full_occur']

abbreviation_occurrances = pd.DataFrame(
    abbreviation_occur,
    columns=[
        'abbr',
        'abbr_misclass_occur',
        'abbr_valid_occur',
        'abbr_error_rate',
        'full',
        'full_misclass_occur',
        'full_valid_occur',
        'full_error_rate'
    ]
)

display(abbreviation_occurrances)

In [None]:
# Filter out the actual dictionary normalizations we would want to do based on
# the misclassified data.
dict_normalizations = abbreviation_occurrances.copy()

# Remove abbreviations or full text versions that do not occur at all, as
# we cannot possibly do a reasonable replacement.
dict_normalizations = dict_normalizations[dict_normalizations['abbr_valid_occur'] != 0.0]
dict_normalizations = dict_normalizations[dict_normalizations['full_valid_occur'] != 0.0]

# Remove abbreviations or full text versions where both are below the
# overall error rate, as there is no ground for replacement at this stage.
misclass_error = len(valid_misclass) / float(len(valid))
dict_normalizations = dict_normalizations[
    (dict_normalizations['abbr_error_rate'] >= misclass_error) |
    (dict_normalizations['full_error_rate'] >= misclass_error)
]

# Remove abbreviations that occur more frequently than their full text
# counterparts. If the abbreviation occurs more frequently, then we would
# actually be obfuscating the text.
# Example: 'yolo' is in considerably higher use than 'you only live once'.
dict_normalizations = dict_normalizations[
    dict_normalizations['abbr_valid_occur'] <=
    dict_normalizations['full_valid_occur']
]

# Finally, keep abbreviations that have higher error rate than their full text
# counterparts so that they can be replaced.
dict_normalizations = dict_normalizations[
    dict_normalizations['abbr_error_rate'] >=
    dict_normalizations['full_error_rate']
]

display(dict_normalizations)

In [None]:
# Store normalization dictionary for later use.
normalization_dict = dict_normalizations[['abbr','full']]
normalization_dict = normalization_dict.rename(columns={'abbr': 'abbreviation', 'full': 'full_text'})

normalization_dict.to_csv('normalization-dict.csv', index=False)

display(normalization_dict)