<a href="https://colab.research.google.com/github/adc257/AmEx-Project/blob/Aaron_branch/Playing_with_the_SoftMax_Layer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 1. Introduction
The paper *A Baseline for Detecting Misclassified and Out-of-Distribution Examples in Neural Networks* suggests that using **softmax layer** to detect the mislabeled item.

This method relies on the observation that *correctly labeled examples tends to **have higher softmax probabilities**.*

Therefore, **setting a proper threshold** for the softmax layer's output can help us detect the mislabeled items.

I would use `BERT-Banking77` on the `banking77` dataset to show how this method can be performed.

## 2. Testing
### Step1: Loading the Model, Tokenizer, and Dataset

In [5]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset
import torch
from torch.nn.functional import softmax

# Load the tokenizer, model, and dataset
tokenizer = AutoTokenizer.from_pretrained("philschmid/BERT-Banking77")
model = AutoModelForSequenceClassification.from_pretrained("philschmid/BERT-Banking77")
dataset = load_dataset("banking77")

### Step2: Process the dataset and perform prediction

In [6]:
# For simplicity, I just use a piece of the dataset
test_dataset = dataset['test'].select(range(666))  # Lucky number

# Tokenize the test-set
inputs = tokenizer(test_dataset['text'], padding=True, truncation=True, return_tensors="pt")

# Predict categories (This maybe slow 2mins: keep waiting cuz the model is performing prediction :))
with torch.no_grad():
    logits = model(**inputs).logits

### Step3: Set a threshold value and calculate the softmax probabilities

In [13]:
# Calculate and convert them to tensor
probabilities = softmax(logits, dim=1)
max_probabilities, predictions = torch.max(probabilities, dim=1)

label_ids = torch.tensor(test_dataset['label'])

# Set threshold (I personally think that 0.9 would be a good value)
threshold = 0.9

### Step4: Finding the mismatching item

In [15]:
# If the labels and the predictions don't fit well
mismatches = predictions != label_ids

for i, (mismatch, max_prob) in enumerate(zip(mismatches, max_probabilities)):
    if mismatch and max_prob < threshold:
        # Flagging items that are potential mislabels and have low confidence in prediction
        predicted_label_name = model.config.id2label[predictions[i].item()]
        actual_label_name = model.config.id2label[label_ids[i].item()]  # Using label_ids directly for mapping

        print(f"Potential mislabel detected at index {i}: '{test_dataset['text'][i]}'")
        print(f"Predicted Label: '{predicted_label_name}' with probability {max_prob.item()}")
        print(f"Original Label: '{actual_label_name}'\n")

Potential mislabel detected at index 0: 'How do I locate my card?'
Predicted: 'card_arrival' with probability 0.46647942066192627
Actual: 'card_acceptance'

Potential mislabel detected at index 21: 'Status of the card I ordered.'
Predicted: 'lost_or_stolen_card' with probability 0.37437352538108826
Actual: 'card_acceptance'

Potential mislabel detected at index 36: 'I'm starting to think my card is lost because it still hasn't arrived, can you help?'
Predicted: 'card_arrival' with probability 0.7686095237731934
Actual: 'card_acceptance'

Potential mislabel detected at index 37: 'Is there tracking info available?'
Predicted: 'card_arrival' with probability 0.6687702536582947
Actual: 'card_acceptance'

Potential mislabel detected at index 60: 'How do I link to my credit card with you?'
Predicted: 'card_linking' with probability 0.8874982595443726
Actual: 'card_delivery_estimate'

Potential mislabel detected at index 66: 'The app doesn't show the card I received.'
Predicted: 'card_linking

## 3. Conclusion
We can checking the phenomenace of **Higher Softmax Probability on an Incorrect Prediction**, which could hint **potential mislabeling** in the dataset.

For example:
```
Potential mislabel detected at index 180: 'Where did this 1 euro fee come from?'
Predicted: 'extra_charge_on_statement' with probability 0.890400767326355
Actual: 'exchange_via_app'

Potential mislabel detected at index 573: 'Why wouldn't the contactless payment work when I tried to pay at the bus today?'
Predicted: 'contactless_not_working' with probability 0.899750828742981
Actual: 'compromised_card'
```
We can then change it manually or using the code to change the mislabeled items.

