<a href="https://colab.research.google.com/github/KelvinLam05/Zero-Shot-Text-Classification-with-Hugging-Face/blob/main/Zero_Shot_Text_Classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Zero-Shot Text Classification**

Zero-shot learning, or ZSL, is a machine learning process commonly used for Natural Language Processing that allows us to generate predictions on unseen data without the need to train a model. Essentially, zero-shot learning gives us immensely powerful models that have been trained on enormous datasets and work out-of-the-box.

**Goal of the project**

The Consumer Financial Protection Bureau (CFPB) is a federal U.S. agency that acts as a mediator when disputes arise between financial institutions and consumers. Via a web form, consumers can send the agency a narrative of their dispute. A zero-shot text classification model would make the classification of complaints and their routing to the appropriate teams more efficient than manually tagged complaints.


**Attribute information**

Each submission was tagged with one of five financial product classes:

* credit reporting

* debt collection

* mortgages and loans 

* credit cards

* retail banking

In [86]:
# Importing libraries
import pandas as pd
import numpy as np
import tensorflow as tf

In [87]:
# Load dataset
df = pd.read_csv('/content/complaints.csv')

In [88]:
# Examine the data
df.head()

Unnamed: 0,narrative,product
0,purchase order day shipping amount receive pro...,credit_card
1,forwarded message date tue subject please inve...,credit_card
2,forwarded message cc sent friday pdt subject f...,retail_banking
3,payment history missing credit report speciali...,credit_reporting
4,payment history missing credit report made mis...,credit_reporting


In [89]:
# Overview of all variables, their datatypes
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1000 entries, 0 to 999
Data columns (total 2 columns):
 #   Column     Non-Null Count  Dtype 
---  ------     --------------  ----- 
 0   narrative  1000 non-null   object
 1   product    1000 non-null   object
dtypes: object(2)
memory usage: 15.8+ KB


**Preprocessing**

In [90]:
# Checking for missing values
df.isnull().sum().sort_values(ascending = False)

product      0
narrative    0
dtype: int64

In [91]:
# Find the number of unique classes present in the product column 
df['product'].unique()

array(['credit_card', 'retail_banking', 'credit_reporting',
       'mortgages_and_loans', 'debt_collection'], dtype=object)

There are five different types of financial products.

In [92]:
# Find all unique characters and symbols 
all_text = str()

for sentence in df['narrative'].values:
    all_text += sentence
    
''.join(set(all_text))

'hrtze mqgsixjauvnywpckdbolf'

The kind of data we get from customer feedback is usually unstructured. It contains unusual text and symbols that need to be cleaned so that a machine learning model can grasp it.

In [93]:
import re
import nltk
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer 

In [94]:
nltk.download('stopwords')
nltk.download('punkt')
nltk.download('wordnet')

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


True

In [95]:
stop_words = set(stopwords.words('english'))
lemma = WordNetLemmatizer()

We will now set up our cleaning function.

In [96]:
def clean_text(text):

  # Removing all irrelevant characters (numbers and punctuation)                           
  text = re.sub('[^a-zA-Z]', ' ', text)                           
  # Replace one or more spaces with single space
  text = re.sub(r'\s+', ' ', text)                                
  # Convert all characters into lowercase
  text = str(text).lower()                                        
  # Tokenization
  text = word_tokenize(text)
  # Removing Stopwords                                      
  text = [item for item in text if item not in stop_words]        
  # Lemmatization
  text = [lemma.lemmatize(word = w, pos = 'v') for w in text]     
  # Remove the words having length <= 2
  text = [i for i in text if len(i) > 2]                          
  # Convert the list of tokens into back to the string
  text = ' '.join (text)                                          
  
  return text 

In [97]:
df['clean_narrative'] = df['narrative'].apply(clean_text)

In [98]:
all_text = str()

for sentence in df['clean_narrative'].values:
    all_text += sentence
    
''.join(set(all_text))

'hrtze mqxsigjauvnywpckdbolf'

When working with unstructured text data, we will inevitably find misspelled words. Luckily, SpellChecker can fix this.

In [99]:
pip install pyspellchecker



In [100]:
from spellchecker import SpellChecker

In [101]:
# Instantiate spell checker
spell = SpellChecker()

In [102]:
# Correct spelling
def correct_spellings(text):
    
    corrected_text = []
    misspelled_words = spell.unknown(text.split())
    for word in text.split():
        if word in misspelled_words:
            corrected_text.append(spell.correction(word))
        else:
            corrected_text.append(word)
    
    return ' '.join(corrected_text)

In [103]:
df['clean_narrative'] = df['clean_narrative'].apply(correct_spellings)

In [104]:
df[['narrative', 'clean_narrative']]

Unnamed: 0,narrative,clean_narrative
0,purchase order day shipping amount receive pro...,purchase order day ship amount receive product...
1,forwarded message date tue subject please inve...,forward message date tue subject please invest...
2,forwarded message cc sent friday pdt subject f...,forward message send friday put subject final ...
3,payment history missing credit report speciali...,payment history miss credit report specialize ...
4,payment history missing credit report made mis...,payment history miss credit report make mistak...
...,...,...
995,bank america add hard inquiry credit report pe...,bank america add hard inquiry credit report pe...
996,opened premium checking bundle advertised main...,open premium check bundle advertise main land ...
997,opened premium checking bundle advertised main...,open premium check bundle advertise main land ...
998,original account number date original account ...,original account number date original account ...


In [105]:
# Display full strings
with pd.option_context('display.max_colwidth', None):
  display(df['clean_narrative'])

0                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               purchase order day ship amount receive product week send followup email exact verbiage pay two day ship receive order company respond sorry inform due unusually high order volume order ship several week stock since early due high demand although continue take order guarantee receive order place due time mask order exact ship date right however guarantee ship soon soon deliver product get small shipment ship first come first serve basis appreciate patie

**Testing for GPU**

In [106]:
import torch

In [107]:
# Whether cuda is available
torch.cuda.is_available()

True

In [108]:
# Load
device = torch.cuda.current_device() if torch.cuda.is_available() else -1

In [109]:
print(device)

0


**Transformer Pipeline**

We will use the pipeline( ) function to load the zero-shot-classification transformer and set it to use the valhalla/distilbart-mnli-12-9 model. 

In [110]:
from transformers import pipeline

In [111]:
task = 'zero-shot-classification'
zero_shot_model = 'valhalla/distilbart-mnli-12-9'
zero_shot_classifier = pipeline(task, zero_shot_model, device = device)

We can use this pipeline by passing in a sequence and a list of candidate labels. The pipeline assumes by default that only one of the candidate labels is true, returning a list of scores for each label which add up to 1.

In [112]:
sequence = df['clean_narrative'][0]

In [113]:
candidate_labels = ['credit_reporting',       
                    'debt_collection',        
                    'mortgages_and_loans',    
                    'retail_banking',          
                    'credit_card'] 

In [114]:
outputs = zero_shot_classifier(sequences = sequence, candidate_labels = candidate_labels)

Let’s take a look at the outputs.

In [115]:
for label, score in zip(outputs['labels'], outputs['scores']):
    print(f'{label}: {score:.3f}')

credit_card: 0.563
credit_reporting: 0.143
debt_collection: 0.132
retail_banking: 0.106
mortgages_and_loans: 0.055


The model correctly identifies that the likely label is credit_card. Other irrelevant labels, such as credit_reporting, debt_collection, retail_banking and mortgages_and_loans, have a very low score.

**Classify all the submissions**

In [116]:
task = 'zero-shot-classification'
zero_shot_model = 'valhalla/distilbart-mnli-12-9'
classifier = pipeline(task, zero_shot_model, device = device) 

In [117]:
candidate_labels = ['credit_reporting',       
                    'debt_collection',        
                    'mortgages_and_loans',    
                    'retail_banking',          
                    'credit_card']  

In [118]:
# Compute the predicted label for each submission
df['label_pred_zero_shot'] = df['clean_narrative'].apply(lambda x: classifier(x, candidate_labels = candidate_labels)['labels'][0])



At this point, we have a dataset that contains labels produced by the zero-shot classifier.

In [119]:
df[['product','label_pred_zero_shot']]

Unnamed: 0,product,label_pred_zero_shot
0,credit_card,credit_card
1,credit_card,retail_banking
2,retail_banking,credit_card
3,credit_reporting,credit_reporting
4,credit_reporting,credit_reporting
...,...,...
995,credit_reporting,credit_reporting
996,retail_banking,retail_banking
997,retail_banking,retail_banking
998,debt_collection,credit_reporting


Finally, the model can be evaluated using the original datasets already labelled evaluation data.

In [120]:
from sklearn.metrics import accuracy_score

In [121]:
# Accuracy classification score
accuracy_score(df['product'], df['label_pred_zero_shot'])

0.756

The zero-shot classifier does a decent job. Glancing at a few random submissions uncorrectly labeled by the zero-shot classifier, there does not seem to be a particularly problematic class.