# Intent Classification (Few Shot)

The basis of this intent classification demo is a simple BERT for Sequence Classification model. This demo is similar to the main one except the BERT model that was previously fine-tuned on the Banking77 dataset is used as a feature extractor for a new model trained for few shot classifcation on the OOS dataset (banking). The new model trains a linear layer on top of the BERT feature extractor to classify the new classes. The checkpoint of the best performing model is loaded to be used for predictions on user inputs in the main loop. Note the classes that are being predicted are from the banking77 dataset but from the new dataset.

In [1]:
import torch
import torch.nn as nn

class CustomBERTModel(nn.Module):
    def __init__(self):
        super(CustomBERTModel, self).__init__()
        self.bert = BertForSequenceClassification.from_pretrained('bert-base-cased', num_labels = 77)
        for self.param in self.bert.bert.parameters():
            self.param.requires_grad = False
        self.classifier = nn.Linear(77, 15) 

    def forward(self, input_ids, attention_mask, labels=None):
        output = torch.sigmoid(self.bert(input_ids, attention_mask=attention_mask)[0])
        output = self.classifier(output)
        loss = 0
        if labels is not None:
            loss = self.criterion(output, labels)
        return loss, output

In [2]:
from transformers import AutoModelWithLMHead, AutoTokenizer
import torch
import torch.nn as nn
import pytorch_lightning as pl 
from torch.utils.data import Dataset, DataLoader

import warnings
warnings.filterwarnings("ignore")

from transformers import (
    BertForMaskedLM,
    AutoTokenizer,
    AutoConfig,
    BertModel, 
    BertConfig, 
    BertTokenizer, 
    BertForSequenceClassification
)
def print_model_info(model):
    # Get all of the model's parameters as a list of tuples.
    params = list(model.named_parameters())

    print('The BERT model has {:} different named parameters.\n'.format(len(params)))

    print('==== Embedding Layer ====\n')

    for p in params[0:5]:
        print("{:<55} {:>12}".format(p[0], str(tuple(p[1].size()))))

    print('\n==== First Transformer ====\n')

    for p in params[5:21]:
        print("{:<55} {:>12}".format(p[0], str(tuple(p[1].size()))))

    print('\n==== Output Layer ====\n')

    for p in params[-4:]:
        print("{:<55} {:>12}".format(p[0], str(tuple(p[1].size()))))

def tokenize_sentences(sentences):
    # Tokenize all of the sentences and map the tokens to thier word IDs.
    input_ids = []
    attention_masks = []

    # For every sentence...
    for sent in sentences:
        # encode_plus will:
        #   (1) Tokenize the sentence.
        #   (2) Prepend the [CLS] token to the start.
        #   (3) Append the [SEP] token to the end.
        #   (4) Map tokens to their IDs.
        #   (5) Pad or truncate the sentence to max_length
        #   (6) Create attention masks for [PAD] tokens.
        encoded_dict = tokenizer.encode_plus(
                            sent,                      # Sentence to encode.
                            add_special_tokens = True, # Add '[CLS]' and '[SEP]'
                            max_length = 320,           # Pad & truncate all sentences.
                            padding='max_length',
                            return_attention_mask = True,   # Construct attn. masks.
                            return_tensors = 'pt',     # Return pytorch tensors.
                    )
        
        # Add the encoded sentence to the list.    
        input_ids.append(encoded_dict['input_ids'])
        
        # And its attention mask (simply differentiates padding from non-padding).
        attention_masks.append(encoded_dict['attention_mask'])

    # Convert the lists into tensors.
    input_ids = torch.cat(input_ids, dim=0)
    attention_masks = torch.cat(attention_masks, dim=0)
    return input_ids,attention_masks        

# Load the BERT tokenizer.
print('Loading BERT tokenizer...')
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

chk_path = "/ssd003/projects/aieng/conversational_ai/demo/checkpoints/intent_chk/fewshot_best_checkpoint.ckpt"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

checkpoint = torch.load(chk_path, map_location=torch.device('cpu'))

model = CustomBERTModel()

print_model_info(model)

#load specified model state

model.load_state_dict(checkpoint["state_dict"])
model.eval()
model.to(device)

Loading BERT tokenizer...


Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at b

The BERT model has 203 different named parameters.

==== Embedding Layer ====

param                                                         (768,)
bert.bert.embeddings.word_embeddings.weight             (28996, 768)
bert.bert.embeddings.position_embeddings.weight           (512, 768)
bert.bert.embeddings.token_type_embeddings.weight           (2, 768)
bert.bert.embeddings.LayerNorm.weight                         (768,)

==== First Transformer ====

bert.bert.embeddings.LayerNorm.bias                           (768,)
bert.bert.encoder.layer.0.attention.self.query.weight     (768, 768)
bert.bert.encoder.layer.0.attention.self.query.bias           (768,)
bert.bert.encoder.layer.0.attention.self.key.weight       (768, 768)
bert.bert.encoder.layer.0.attention.self.key.bias             (768,)
bert.bert.encoder.layer.0.attention.self.value.weight     (768, 768)
bert.bert.encoder.layer.0.attention.self.value.bias           (768,)
bert.bert.encoder.layer.0.attention.output.dense.weight   (768,

CustomBERTModel(
  (bert): BertForSequenceClassification(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(28996, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=Tr

### New Classes
Listed in the cell below are the new classes from the OOS dataset that are being predicted. 

In [3]:
LABEL_COLUMNS = ['replacement_card_duration',
 'expiration_date',
 'damaged_card',
 'improve_credit_score',
 'report_lost_card',
 'card_declined',
 'credit_limit_change',
 'apr',
 'redeem_rewards',
 'credit_limit',
 'rewards_balance',
 'application_status',
 'credit_score',
 'new_card',
 'international_fees']

## Main Loop

Here the model takes in questions pertaining to banking and provides a class label (out of the 15 classes listed above) and a confidence level for its prediction. 

*Note the code takes questions infinitely so it has to be manually stopped/interrupted. 

In [4]:
while(True):
    string = input("How may I help you? Question: ")
    #tokenize inputted sentence to be compatible with BERT inputs
    token_ids,attention_masks = tokenize_sentences([string])
    #get a tensor containing probabilities of inputted sentence being irrelevant or relevant
    model_outputs = (model(token_ids.to(device), attention_mask=attention_masks.to(device)))
#     print(model_outputs)
    softmax_layer = torch.nn.Softmax()
    result = softmax_layer(model_outputs[1])
#     print(result)
    #identify which output node has higher probability and what that probability is
    prediction = torch.argmax(result).item()
    confidence = torch.max(result).item()
    print("I see. The problem is related to: " + LABEL_COLUMNS[prediction].replace("_", ' ') + " with {:.2f}% confident".format(confidence*100))
    print("-"*80)

How may I help you? Question: Can you tell me when my credit card is set to expire, please?
I see. The problem is related to: expiration date with 71.50% confident
--------------------------------------------------------------------------------


KeyboardInterrupt: Interrupted by user