# The goal
This is an example on how to initialize the T5 based model that accepts dynamic labels. 
We will load the Atis dataset and then classify using it

## Initialize model

In [1]:
from open_intent_classifier.model import IntentClassifier

model = IntentClassifier(device="cuda", verbose=True)

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


## Initialize dataset

In [2]:
import pandas as pd
from datasets import load_dataset

# we don't include any multi-label instances in order to simplify this classification

ATIS_INTENT_MAPPING = {
    'abbreviation': "Abbreviation and Fare Code Meaning Inquiry",
    'aircraft': "Aircraft Type Inquiry",
    'airfare': "Airfare and Fares Questions",
    'airline': "Airline Information Request",
    'airport': "Airport Information and Queries",
    'capacity': "Aircraft Seating Capacity Inquiry",
    'cheapest': "Cheapest Fare Inquiry",
    'city': "Airport Location Inquiry",
    'distance': "Airport Distance Inquiry",
    'flight': "Flight Booking Request",
    'flight_no': "Flight Number Inquiry",
    'flight_time': "Time Inquiry",
    'ground_fare': "Ground Transportation Cost Inquiry",
    'ground_service': "Ground Transportation Inquiry",
    'ground_service+ground_fare': "Airport Ground Transportation and Cost Query",
    'meal': "Inquiry about In-flight Meals",
    'quantity': "Flight Quantity Inquiry",
    'restriction': "Flight Restriction Inquiry"
}

def atis_convert_old_label_to_class_atis(old_label: str):
    if old_label in ATIS_INTENT_MAPPING:
        return ATIS_INTENT_MAPPING[old_label]
    return None


dataset = load_dataset("tuetschek/atis")
dataset.set_format(type="pandas")

df_train: pd.DataFrame = dataset["train"][:]
df_test: pd.DataFrame = dataset["test"][:]

df_train["label"] = df_train["intent"].apply(lambda label: atis_convert_old_label_to_class_atis(label))
df_test["label"] = df_test["intent"].apply(lambda label: atis_convert_old_label_to_class_atis(label))
df_train.dropna(subset=["label"], inplace=True)
df_test.dropna(subset=["label"], inplace=True)

datasets - INFO - PyTorch version 2.4.1 available.


## Sort labels alphabetically - emperically improves performance for Atis dataset

In [3]:
labels = ATIS_INTENT_MAPPING.values()
sorted_labels = sorted(labels)
sorted_labels


['Abbreviation and Fare Code Meaning Inquiry',
 'Aircraft Seating Capacity Inquiry',
 'Aircraft Type Inquiry',
 'Airfare and Fares Questions',
 'Airline Information Request',
 'Airport Distance Inquiry',
 'Airport Ground Transportation and Cost Query',
 'Airport Information and Queries',
 'Airport Location Inquiry',
 'Cheapest Fare Inquiry',
 'Flight Booking Request',
 'Flight Number Inquiry',
 'Flight Quantity Inquiry',
 'Flight Restriction Inquiry',
 'Ground Transportation Cost Inquiry',
 'Ground Transportation Inquiry',
 'Inquiry about In-flight Meals',
 'Time Inquiry']

## Run predictions

In [4]:
from tqdm import tqdm
predictions = []

for index, row in tqdm(df_test.iterrows()):
    text = row["text"]
    prediction = model.predict(text, sorted_labels)
    predictions.append(prediction)

open_intent_classifier.model - DEBUG - Full prompt: Topic %% Customer: i would like to find a flight from charlotte to las vegas that makes a stop in st. louis.
END MESSAGE
Choose one topic that matches customer's issue.
 Options:
# Abbreviation and Fare Code Meaning Inquiry 
# Aircraft Seating Capacity Inquiry 
# Aircraft Type Inquiry 
# Airfare and Fares Questions 
# Airline Information Request 
# Airport Distance Inquiry 
# Airport Ground Transportation and Cost Query 
# Airport Information and Queries 
# Airport Location Inquiry 
# Cheapest Fare Inquiry 
# Flight Booking Request 
# Flight Number Inquiry 
# Flight Quantity Inquiry 
# Flight Restriction Inquiry 
# Ground Transportation Cost Inquiry 
# Ground Transportation Inquiry 
# Inquiry about In-flight Meals 
# Time Inquiry 
 
Class name: 
open_intent_classifier.model - DEBUG - Decoded output: Flight Booking Request
1it [00:00,  2.39it/s]open_intent_classifier.model - DEBUG - Full prompt: Topic %% Customer: on april first i need

In [10]:
from sklearn.metrics import classification_report
predictions = [pred.lower() for pred in predictions]
y = [label.lower() for label in df_test["label"].to_list()]

df_classification_report = pd.DataFrame(classification_report(y, predictions, output_dict=True)).T
df_classification_report

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Unnamed: 0,precision,recall,f1-score,support
abbreviation and fare code meaning inquiry,1.0,1.0,1.0,33.0
aircraft seating capacity inquiry,1.0,0.952381,0.97561,21.0
aircraft type inquiry,0.5625,1.0,0.72,9.0
airfare and fares questions,0.5,0.208333,0.294118,48.0
airline information request,0.808511,1.0,0.894118,38.0
airport distance inquiry,0.833333,1.0,0.909091,10.0
airport ground transportation and cost query,0.0,0.0,0.0,0.0
airport information and queries,0.64,0.888889,0.744186,18.0
airport location inquiry,0.307692,0.666667,0.421053,6.0
airport transportation cost inquiry,0.0,0.0,0.0,0.0
