# Content: 
## Load dataset
## Load T5 Model
## Run predictions
### Classification report

## Load Dataset

In [28]:
import transformers
import pandas as pd

from datasets import load_dataset
from app.model import IntentClassifier


dataset = load_dataset("tuetschek/atis")

In [3]:

# dataset["intent"].value_counts()
dataset["train"]

Dataset({
    features: ['id', 'intent', 'text', 'slots'],
    num_rows: 4978
})

## Examples
flight-Flight search: what flights are available from pittsburgh to baltimore on thursday morning
flight_time-Flight time questions: what is the arrival time in san francisco for the 755 am flight leaving washington
airfare-Check costs: show me the first class fares from boston to denver
ground_service-


In [4]:
dataset["train"][0]

{'id': 0,
 'intent': 'flight',
 'text': 'i want to fly from boston at 838 am and arrive in denver at 1110 in the morning',
 'slots': 'O O O O O B-fromloc.city_name O B-depart_time.time I-depart_time.time O O O B-toloc.city_name O B-arrive_time.time O O B-arrive_time.period_of_day'}

In [2]:
intents = set([row["intent"] for row in dataset["train"]])
intents

{'abbreviation',
 'aircraft',
 'aircraft+flight+flight_no',
 'airfare',
 'airfare+flight_time',
 'airline',
 'airline+flight_no',
 'airport',
 'capacity',
 'cheapest',
 'city',
 'distance',
 'flight',
 'flight+airfare',
 'flight_no',
 'flight_time',
 'ground_fare',
 'ground_service',
 'ground_service+ground_fare',
 'meal',
 'quantity',
 'restriction'}

In [3]:
from app.atis.utils import ATIS_INTENT_MAPPING as intent_mapping
intent_mapping

{'abbreviation': 'Abbreviation and Fare Code Meaning Inquiry',
 'aircraft': 'Aircraft Type Inquiry',
 'airfare': 'Airfare Information Requests',
 'airline': 'Airline Information Request',
 'airport': 'Airport Information and Queries',
 'capacity': 'Aircraft Seating Capacity Inquiry',
 'cheapest': 'Cheapest Fare Inquiry',
 'city': 'Airport Location Inquiry',
 'distance': 'Airport Distance Inquiry',
 'flight': 'Flight Booking Request',
 'flight_no': 'Flight Number Inquiry',
 'flight_time': 'Flight Schedule Inquiry',
 'ground_fare': 'Ground Transportation Cost Inquiry',
 'ground_service': 'Ground Transportation Inquiry',
 'ground_service+ground_fare': 'Airport Ground Transportation and Cost Query',
 'meal': 'Inquiry about In-flight Meals',
 'quantity': 'Flight Quantity Inquiry',
 'restriction': 'Flight Restriction Inquiry'}

In [10]:
# get 10 samples from each intent
intent_samples = {intent: [] for intent in intents}

# Iterate through the dataset
for row in dataset["train"]:
    intent = row["intent"]
    # If we haven't collected 10 samples for this intent yet, add the current row
    if len(intent_samples[intent]) < 5:
        intent_samples[intent].append(row["text"])
    # If we have collected 10 samples for this intent, move on to the next intent
    else:
        continue

In [11]:
intent_samples

{'meal': ['show me all meals on flights from atlanta to washington',
  'is there a meal on delta flight 852 from san francisco to dallas fort worth',
  'what are all the available meals',
  'what are my meal options from boston to denver',
  'do i get a meal on the atlanta to bwi flight eastern 210'],
 'airport': ["what 's the airport at orlando",
  'give me a list of airports in baltimore',
  'houston airports',
  'please list information regarding san francisco airport',
  "what 's the name of the denver airport"],
 'aircraft': ['what kind of aircraft is used on a flight from cleveland to dallas',
  'what kinds of planes are used by american airlines',
  'what types of aircraft does delta fly',
  'on the 8 am flight from san francisco to atlanta what type of aircraft is used',
  'list aircraft types that fly between boston and san francisco'],
 'ground_service': ['what kind of ground transportation is available in denver',
  'show me the ground transportation in denver',
  'atlanta g

# Load model and run one prediction

In [4]:
model = IntentClassifier()

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [5]:
prompt_options = "OPTIONS\n"
index = 1
for intent in intents:
    if intent not in intent_mapping:
        continue
        
    mapping = intent_mapping[intent]
    prompt_options += f" {index}. {mapping} "
    index += 1
prompt_options

'OPTIONS\n 1. Ground Transportation Cost Inquiry  2. Abbreviation and Fare Code Meaning Inquiry  3. Aircraft Type Inquiry  4. Flight Restriction Inquiry  5. Aircraft Seating Capacity Inquiry  6. Airport Information and Queries  7. Flight Schedule Inquiry  8. Airport Distance Inquiry  9. Airfare Information Requests  10. Flight Number Inquiry  11. Airport Location Inquiry  12. Cheapest Fare Inquiry  13. Airport Ground Transportation and Cost Query  14. Inquiry about In-flight Meals  15. Airline Information Request  16. Ground Transportation Inquiry  17. Flight Booking Request  18. Flight Quantity Inquiry '

In [11]:
(dataset["train"][0]["text"], dataset["train"][0]["intent"], intent_mapping[dataset["train"][0]["intent"]])

('i want to fly from boston at 838 am and arrive in denver at 1110 in the morning',
 'flight',
 'Flight Booking Request')

In [8]:
company_name = "Atis Airlines"
company_specific = "An Airline company"
customer_text = dataset["train"][0]["text"]
print(customer_text)
model.predict(customer_text, prompt_options, company_name, company_specific)

i want to fly from boston at 838 am and arrive in denver at 1110 in the morning




'Class name: Flight Booking Request'

In [12]:
company_name = "Atis Airlines"
company_specific = "An Airline company"
model.predict(dataset["train"][700]["text"], prompt_options, company_name, company_specific)



'Class name: Flight Booking Request'

In [24]:
model.raw_predict(f"All of the verbs: {customer_text}")

'arrive, morning, fly'

## Train set

In [26]:
from tqdm import tqdm
results = []
for row in tqdm(dataset["train"]):
    intent = row["intent"] 
    if intent not in intent_mapping:
        continue 
    
    prediction = model.predict(row["text"], prompt_options, company_name, company_specific)
    keywords = model.raw_predict(f"All of the verbs: {row['text']}")
    results.append({"prediction": prediction, "y": intent_mapping[intent], "keywords": keywords, "text": row["text"]})

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4978/4978 [06:26<00:00, 12.90it/s]


In [32]:
results

[]

In [30]:
from sklearn.metrics import classification_report
y = [r["y"] for r in results]
predictions = [r["prediction"].replace("Class name: ","") for r in results]
pd.DataFrame(classification_report(y, predictions, output_dict=True)).T

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Unnamed: 0,precision,recall,f1-score,support
accuracy,0.0,0.0,0.0,0.0
macro avg,,,,0.0
weighted avg,,,,0.0


In [49]:
import pandas as pd
df = pd.DataFrame(results)
df["prediction"].value_counts()

prediction
Class name: Flight Booking Request                          2492
Class name: Airport Ground Transportation and Cost Query     646
Class name: Cheapest Fare Inquiry                            372
Class name: Flight Schedule Inquiry                          266
Class name: Airline Information Request                      213
Class name: Airport Information and Queries                  207
Class name: Flight Number Inquiry                            141
Class name: Ground Transportation Inquiry                    139
Class name: Aircraft Type Inquiry                             82
Class name: Airfare Information Request                       68
Class name: Airport Location Inquiry                          56
Class name: Airport Ground Transportation Inquiry             53
Class name: Flight Quantity Inquiry                           46
Class name: Ground Transportation Cost Inquiry                36
Class name: Airfare Information Requests                      29
Class name: Fl

## Test set

In [33]:
from tqdm import tqdm
results = []
for row in tqdm(dataset["test"]):
    intent = row["intent"] 
    if intent not in intent_mapping:
        continue 
    
    prediction = model.predict(row["text"], prompt_options, company_name, company_specific)
    keywords = model.raw_predict(f"All of the verbs: {row['text']}")
    results.append({"prediction": prediction, "y": intent_mapping[intent]})

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 893/893 [01:09<00:00, 12.89it/s]


In [34]:
from sklearn.metrics import classification_report
y = [r["y"] for r in results]
predictions = [r["prediction"].replace("Class name: ","") for r in results]
pd.DataFrame(classification_report(y, predictions, output_dict=True)).T

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Unnamed: 0,precision,recall,f1-score,support
Abbreviation and Fare Code Meaning Inquiry,1.0,0.030303,0.058824,33.0
Aircraft Seating Capacity Inquiry,1.0,0.857143,0.923077,21.0
Aircraft Type Inquiry,0.642857,1.0,0.782609,9.0
Airfare Information Request,0.0,0.0,0.0,0.0
Airfare Information Requests,0.0,0.0,0.0,48.0
Airline Information Request,0.117647,0.052632,0.072727,38.0
Airport Distance Inquiry,1.0,0.9,0.947368,10.0
Airport Ground Transportation and Cost Query,0.0,0.0,0.0,0.0
Airport Information and Queries,0.0,0.0,0.0,18.0
Airport Location Inquiry,0.173913,0.666667,0.275862,6.0
