# Content: 
## Load dataset
## Load T5 Model
## Run predictions
### Classification report

## Load Dataset

In [9]:
import transformers
import pandas as pd

from datasets import load_dataset
from app.model import IntentClassifier


dataset = load_dataset("tuetschek/atis")

In [10]:

# dataset["intent"].value_counts()
dataset["train"]

Dataset({
    features: ['id', 'intent', 'text', 'slots'],
    num_rows: 4978
})

## Examples
flight-Flight search: what flights are available from pittsburgh to baltimore on thursday morning
flight_time-Flight time questions: what is the arrival time in san francisco for the 755 am flight leaving washington
airfare-Check costs: show me the first class fares from boston to denver
ground_service-


In [11]:
dataset["train"][0]

{'id': 0,
 'intent': 'flight',
 'text': 'i want to fly from boston at 838 am and arrive in denver at 1110 in the morning',
 'slots': 'O O O O O B-fromloc.city_name O B-depart_time.time I-depart_time.time O O O B-toloc.city_name O B-arrive_time.time O O B-arrive_time.period_of_day'}

In [12]:
intents = set([row["intent"] for row in dataset["train"]])
intents

{'abbreviation',
 'aircraft',
 'aircraft+flight+flight_no',
 'airfare',
 'airfare+flight_time',
 'airline',
 'airline+flight_no',
 'airport',
 'capacity',
 'cheapest',
 'city',
 'distance',
 'flight',
 'flight+airfare',
 'flight_no',
 'flight_time',
 'ground_fare',
 'ground_service',
 'ground_service+ground_fare',
 'meal',
 'quantity',
 'restriction'}

In [13]:
from app.atis.utils import ATIS_INTENT_MAPPING as intent_mapping
intent_mapping

{'abbreviation': 'Abbreviation and Fare Code Meaning Inquiry',
 'aircraft': 'Aircraft Type Inquiry',
 'airfare': 'Airfare Information Requests',
 'airline': 'Airline Information Request',
 'airport': 'Airport Information and Queries',
 'capacity': 'Aircraft Seating Capacity Inquiry',
 'cheapest': 'Cheapest Fare Inquiry',
 'city': 'Airport Location Inquiry',
 'distance': 'Airport Distance Inquiry',
 'flight': 'Flight Booking Request',
 'flight_no': 'Flight Number Inquiry',
 'flight_time': 'Flight Schedule Inquiry',
 'ground_fare': 'Ground Transportation Cost Inquiry',
 'ground_service': 'Ground Transportation Inquiry',
 'ground_service+ground_fare': 'Airport Ground Transportation and Cost Query',
 'meal': 'Inquiry about In-flight Meals',
 'quantity': 'Flight Quantity Inquiry',
 'restriction': 'Flight Restriction Inquiry'}

In [14]:
# get 10 samples from each intent
intent_samples = {intent: [] for intent in intents}

# Iterate through the dataset
for row in dataset["train"]:
    intent = row["intent"]
    # If we haven't collected 10 samples for this intent yet, add the current row
    if len(intent_samples[intent]) < 5:
        intent_samples[intent].append(row["text"])
    # If we have collected 10 samples for this intent, move on to the next intent
    else:
        continue

In [15]:
intent_samples

{'meal': ['show me all meals on flights from atlanta to washington',
  'is there a meal on delta flight 852 from san francisco to dallas fort worth',
  'what are all the available meals',
  'what are my meal options from boston to denver',
  'do i get a meal on the atlanta to bwi flight eastern 210'],
 'capacity': ['how many seats in a 100',
  'how many passengers fit on a d9s',
  'how many seats in a 72s',
  'what is the total seating capacity of all aircraft of american airlines',
  "what 's the capacity of an f28"],
 'city': ['what city is the airport mco in',
  'where is mco',
  'where is general mitchell international located',
  'where is general mitchell international located',
  'show me the cities served by nationair'],
 'flight': ['i want to fly from boston at 838 am and arrive in denver at 1110 in the morning',
  'what flights are available from pittsburgh to baltimore on thursday morning',
  'i need a flight tomorrow from columbus to minneapolis',
  'show me the flights fro

# Load model and run one prediction

In [17]:
model = IntentClassifier(model_name="Serj/intent-classifier", commit_hash="9d3538c56a5f52b45bf8d7e8fa675da7b82cf9ec")

9d3538c56a5f52b45bf8d7e8fa675da7b82cf9ec


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [18]:
prompt_options = "OPTIONS\n"
index = 1
for intent in intents:
    if intent not in intent_mapping:
        continue
        
    mapping = intent_mapping[intent]
    prompt_options += f" {index}. {mapping} "
    index += 1
prompt_options

'OPTIONS\n 1. Inquiry about In-flight Meals  2. Aircraft Seating Capacity Inquiry  3. Airport Location Inquiry  4. Flight Booking Request  5. Airfare Information Requests  6. Airport Distance Inquiry  7. Flight Restriction Inquiry  8. Airline Information Request  9. Cheapest Fare Inquiry  10. Airport Information and Queries  11. Aircraft Type Inquiry  12. Flight Schedule Inquiry  13. Airport Ground Transportation and Cost Query  14. Abbreviation and Fare Code Meaning Inquiry  15. Ground Transportation Cost Inquiry  16. Flight Number Inquiry  17. Ground Transportation Inquiry  18. Flight Quantity Inquiry '

In [19]:
(dataset["train"][0]["text"], dataset["train"][0]["intent"], intent_mapping[dataset["train"][0]["intent"]])

('i want to fly from boston at 838 am and arrive in denver at 1110 in the morning',
 'flight',
 'Flight Booking Request')

In [20]:
company_name = "Atis Airlines"
company_specific = "An Airline company"
customer_text = dataset["train"][0]["text"]
print(customer_text)
model.predict(customer_text, prompt_options, company_name, company_specific)

i want to fly from boston at 838 am and arrive in denver at 1110 in the morning




'Class name: Flight Booking Request'

In [21]:
company_name = "Atis Airlines"
company_specific = "An Airline company"
model.predict(dataset["train"][700]["text"], prompt_options, company_name, company_specific)

'Class name: Flight Booking Request'

In [22]:
model.raw_predict(f"All of the verbs: {customer_text}")

'arrive, morning, fly'

## Train set

In [23]:
# from tqdm import tqdm
# results = []
# for row in tqdm(dataset["train"]):
#     intent = row["intent"] 
#     if intent not in intent_mapping:
#         continue 
    
#     prediction = model.predict(row["text"], prompt_options, company_name, company_specific)
#     keywords = model.raw_predict(f"All of the verbs: {row['text']}")
#     results.append({"prediction": prediction, "y": intent_mapping[intent], "keywords": keywords, "text": row["text"]})

In [25]:
# results

In [None]:
# from sklearn.metrics import classification_report
# y = [r["y"] for r in results]
# predictions = [r["prediction"].replace("Class name: ","") for r in results]
# pd.DataFrame(classification_report(y, predictions, output_dict=True)).T

In [29]:
# import pandas as pd
# df = pd.DataFrame(results)
# df["prediction"].value_counts()

## Test set

In [27]:
from tqdm import tqdm
results = []
for row in tqdm(dataset["test"]):
    intent = row["intent"] 
    if intent not in intent_mapping:
        continue 
    
    prediction = model.predict(row["text"], prompt_options, company_name, company_specific)
    keywords = model.raw_predict(f"All of the verbs: {row['text']}")
    results.append({"prediction": prediction, "y": intent_mapping[intent]})

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 893/893 [01:06<00:00, 13.37it/s]


In [28]:
from sklearn.metrics import classification_report
y = [r["y"] for r in results]
predictions = [r["prediction"].replace("Class name: ","") for r in results]
pd.DataFrame(classification_report(y, predictions, output_dict=True)).T

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Unnamed: 0,precision,recall,f1-score,support
Abbreviation and Fare Code Meaning Inquiry,1.0,0.181818,0.307692,33.0
Aircraft Information Request,0.0,0.0,0.0,0.0
Aircraft Seating Capacity Inquiry,1.0,0.904762,0.95,21.0
Aircraft Type Inquiry,0.692308,1.0,0.818182,9.0
Airfare Information Request,0.0,0.0,0.0,0.0
Airfare Information Requests,1.0,0.041667,0.08,48.0
Airline Information Request,0.388889,0.736842,0.509091,38.0
Airport Distance Inquiry,1.0,0.9,0.947368,10.0
Airport Ground Transportation and Cost Query,0.0,0.0,0.0,0.0
Airport Information Request,0.0,0.0,0.0,0.0
