In [25]:
import transformers
import pandas as pd
from datasets import load_dataset

from app.model import IntentClassifier
from app.atis.utils import ATIS_INTENT_MAPPING as intent_mapping

dataset = load_dataset("tuetschek/atis")

In [18]:
from transformers import pipeline
# Bart-Large-MNLI 407M parameters, almost double of Flan-T5-Base
classifier = pipeline("zero-shot-classification",
                      model="facebook/bart-large-mnli", device="cuda")


## Bart Example

In [3]:
sequence_to_classify = "one day I will see the world"
candidate_labels = ['travel', 'cooking', 'dancing']
classifier(sequence_to_classify, candidate_labels)
#{'labels': ['travel', 'dancing', 'cooking'],
# 'scores': [0.9938651323318481, 0.0032737774308770895, 0.002861034357920289],
# 'sequence': 'one day I will see the world'}


{'sequence': 'one day I will see the world',
 'labels': ['travel', 'dancing', 'cooking'],
 'scores': [0.9938650727272034, 0.003273802110925317, 0.002861041808500886]}

In [10]:
dataset["train"][0]["text"], dataset["train"][0]["intent"], intent_mapping[dataset["train"][0]["intent"]]

('i want to fly from boston at 838 am and arrive in denver at 1110 in the morning',
 'flight',
 'Flight Booking Request')

In [13]:
# use the original names
candidate_labels = list(intent_mapping.keys())
classifier(dataset["train"][0]["text"], candidate_labels)

{'sequence': 'i want to fly from boston at 838 am and arrive in denver at 1110 in the morning',
 'labels': ['flight',
  'airfare',
  'distance',
  'flight_time',
  'restriction',
  'aircraft',
  'flight_no',
  'abbreviation',
  'airport',
  'capacity',
  'quantity',
  'city',
  'airline',
  'ground_fare',
  'ground_service+ground_fare',
  'ground_service',
  'meal',
  'cheapest'],
 'scores': [0.25311359763145447,
  0.22531504929065704,
  0.1666412353515625,
  0.0520954392850399,
  0.04511941224336624,
  0.0373229943215847,
  0.030164100229740143,
  0.028714267536997795,
  0.02366679161787033,
  0.022193461656570435,
  0.02065613493323326,
  0.020584838464856148,
  0.018283870071172714,
  0.017191680148243904,
  0.015278246253728867,
  0.015065652318298817,
  0.006112908013164997,
  0.0024803695268929005]}

In [14]:
# use intent mapping names
candidate_labels = list(intent_mapping.values())
classifier(dataset["train"][0]["text"], candidate_labels)

{'sequence': 'i want to fly from boston at 838 am and arrive in denver at 1110 in the morning',
 'labels': ['Flight Booking Request',
  'Airport Distance Inquiry',
  'Flight Schedule Inquiry',
  'Airport Location Inquiry',
  'Airfare Information Requests',
  'Airport Information and Queries',
  'Flight Number Inquiry',
  'Airline Information Request',
  'Ground Transportation Inquiry',
  'Ground Transportation Cost Inquiry',
  'Flight Quantity Inquiry',
  'Aircraft Type Inquiry',
  'Cheapest Fare Inquiry',
  'Airport Ground Transportation and Cost Query',
  'Flight Restriction Inquiry',
  'Abbreviation and Fare Code Meaning Inquiry',
  'Aircraft Seating Capacity Inquiry',
  'Inquiry about In-flight Meals'],
 'scores': [0.18264347314834595,
  0.12421287596225739,
  0.09362585097551346,
  0.07969800382852554,
  0.06246088445186615,
  0.06208226457238197,
  0.05942377820611,
  0.05931268259882927,
  0.04287651926279068,
  0.042347412556409836,
  0.03694706782698631,
  0.03188294917345047,

In [19]:
from tqdm import tqdm
results = []
for row in tqdm(dataset["test"]):
    intent = row["intent"] 
    if intent not in intent_mapping:
        continue 
    
    prediction = classifier(row["text"], candidate_labels)
    results.append({"prediction": prediction, "y": intent_mapping[intent]})

  1%|█▎                                                                                                                  | 10/893 [00:02<03:03,  4.80it/s]--- Logging error ---
Traceback (most recent call last):
  File "/usr/lib/python3.10/logging/__init__.py", line 1100, in emit
    msg = self.format(record)
  File "/usr/lib/python3.10/logging/__init__.py", line 943, in format
    return fmt.format(record)
  File "/usr/lib/python3.10/logging/__init__.py", line 678, in format
    record.message = record.getMessage()
  File "/usr/lib/python3.10/logging/__init__.py", line 368, in getMessage
    msg = msg % self.args
TypeError: not all arguments converted during string formatting
Call stack:
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/serj/dev/customer_support_classifier/venv/lib/python3.10/site-packages/ipykernel_

In [21]:
from sklearn.metrics import classification_report
y = [r["y"] for r in results]
predictions = [r["prediction"] for r in results]
predictions


[{'sequence': 'i would like to find a flight from charlotte to las vegas that makes a stop in st. louis',
  'labels': ['Flight Booking Request',
   'Airport Distance Inquiry',
   'Flight Schedule Inquiry',
   'Airline Information Request',
   'Airport Location Inquiry',
   'Abbreviation and Fare Code Meaning Inquiry',
   'Airport Information and Queries',
   'Flight Number Inquiry',
   'Flight Quantity Inquiry',
   'Airfare Information Requests',
   'Ground Transportation Cost Inquiry',
   'Ground Transportation Inquiry',
   'Airport Ground Transportation and Cost Query',
   'Cheapest Fare Inquiry',
   'Aircraft Type Inquiry',
   'Aircraft Seating Capacity Inquiry',
   'Flight Restriction Inquiry',
   'Inquiry about In-flight Meals'],
  'scores': [0.11526107788085938,
   0.10899890959262848,
   0.09994936734437943,
   0.09034089744091034,
   0.08156231790781021,
   0.0749569907784462,
   0.06290633231401443,
   0.06271512806415558,
   0.04818842187523842,
   0.045033495873212814,
   0.

In [23]:
prediction_labels = []
for row in predictions:
    max_index = row['scores'].index(max(row['scores']))
    label = row['labels'][max_index]
    prediction_labels.append(label)
print(len(prediction_labels))

876


In [26]:
pd.DataFrame(classification_report(y, prediction_labels, output_dict=True)).T

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Unnamed: 0,precision,recall,f1-score,support
Abbreviation and Fare Code Meaning Inquiry,0.482143,0.818182,0.606742,33.0
Aircraft Seating Capacity Inquiry,1.0,0.952381,0.97561,21.0
Aircraft Type Inquiry,0.6,1.0,0.75,9.0
Airfare Information Requests,0.5,0.020833,0.04,48.0
Airline Information Request,0.064516,0.105263,0.08,38.0
Airport Distance Inquiry,0.024823,0.7,0.047945,10.0
Airport Information and Queries,0.0,0.0,0.0,18.0
Airport Location Inquiry,0.0,0.0,0.0,6.0
Cheapest Fare Inquiry,0.0,0.0,0.0,0.0
Flight Booking Request,0.94382,0.265823,0.414815,632.0
